Mxnet图片分类(4)利用训练好的模型进行测试

xiaoxiao2021-02-28  8

利用训练好的模型测试只需要把模型和数据准备好。

系统: ubuntu14.04 Mxnet: 0.904

1.模型和数据准备

2.模型加载测试

import mxnet as mx sym,arg_params,aux_params = mx.model.load_checkpoint('vggnew',40) mod = mx.mod.Module(symbol=sym,context=mx.gpu(),data_names=['data'],label_names=['softmax_label']) mod.bind(for_training=False,data_shapes=[('data',(1,3,224,224))]) mod.set_params(arg_params,aux_params)

需要输出标签还要准备一个synset.txt文件,格式如图:

with open('synset.txt','r') as f: labels = [l.rstrip() for l in f]

对图片进行处理

%matplotlib inline import matplotlib.pyplot as plt import cv2 import numpy as np # define a simple data batch from collections import namedtuple Batch = namedtuple('Batch', ['data']) def get_image(url, show=False): #url:图片路径 #show:是否显示图片 img = cv2.cvtColor(cv2.imread(url), cv2.COLOR_BGR2RGB) if img is None: return None if show: plt.imshow(img) plt.axis('off') # convert into format (batch, RGB, width, height) img = cv2.resize(img, (224, 224)) img = np.swapaxes(img, 0, 2) img = np.swapaxes(img, 1, 2) img = img[np.newaxis, :] return img def predict(url): img = get_image(url, show=True) # compute the predict probabilities mod.forward(Batch([mx.nd.array(img)])) prob = mod.get_outputs()[0].asnumpy() # print the top-5 prob = np.squeeze(prob) prob = np.argsort(prob)[::-1] top1=prob[0]#取概率最高的一类 print top1 #输入类别 print labels[top1] #输出标签 #批量测试 path = '/mxnet/tools/train-cat/2' import os for lists in os.listdir(path): image = os.path.join(path,lists) predict(image)

参考文献:

[1]http://mxnet.io/api/python/model.html?highlight=predict#mxnet.model.FeedForward.predict [2]http://mxnet.io/tutorials/python/predict_image.html?highlight=predict

模型的训练可以参考

Mxnet图片分类(2)训练模型 Mxnet图片分类(3)fine-tune
转载请注明原文地址: https://www.6miu.com/read-750231.html

最新回复(0)