SVM处理mnist字体库

xiaoxiao2021-02-28  125

2017年robomaster比赛中,大神符环节使用的是 # decoding:utf-8 import os import cv2 import numpy as np import codecs from cv2.ml import VAR_ORDERED import codecs from cv2.ml import VAR_ORDERED from canny import * from find_contours import * import numpy as np import cPickle import gzip def vectorized_result(j): e = np.zeros((10, 1)) e[j] = 1.0 return e def load_data(): mnist = gzip.open(os.path.join('data', 'mnist.pkl.gz'), 'rb') training_data, classification_data, test_data = cPickle.load(mnist) mnist.close() return training_data, classification_data, test_data def wrap_data(): tr_d, va_d, te_d = load_data() # print type(tr_d), type(va_d), type(te_d) training_inputs = [np.reshape(x, (784, 1)) for x in tr_d[0]] training_results = [vectorized_result(y) for y in tr_d[1]] training_data = zip(training_inputs, training_results) validation_inputs = [np.reshape(x, (784, 1)) for x in va_d[0]] validation_data = zip(validation_inputs, va_d[1]) test_input = [np.reshape(x, (784, 1)) for x in te_d[0]] test_data = zip(test_input, te_d[1]) return training_data, validation_data, test_data def train_svm(train_file='train_data.txt', test_file= 'train_result.txt'): svm = cv2.ml.SVM_create() svm.setType(cv2.ml.SVM_C_SVC) #自己设置一下SVM参数 svm.setKernel(cv2.ml.SVM_POLY) t_d = np.loadtxt(train_file, np.float32) m_d = np.loadtxt(test_file, np.int32) train_data = cv2.ml.TrainData_create(t_d, cv2.ml.ROW_SAMPLE, m_d) svm.train(train_data) return svm def svm_test(svm, test_data): le = len(test_data) sum_tem = 0 for i in range(le): sample = np.array([test_data[i][0].ravel()], dtype=np.float32).reshape(28, 28) a, b =svm.predict(np.array([test_data[i][0].ravel()], dtype=np.float32)) if b[0][0] == test_data[i][1] or test_data[i][1] == 0: sum_tem += 1 print '正确率 ', float(sum_tem * 1.0/ le) def svm_predict(svm, sample): resized = sample.copy() rows, cols = resized.shape if (rows != 28 or cols != 28) and rows * cols > 0: resized = cv2.resize(resized, (28, 28), interpolation=cv2.INTER_CUBIC) return svm.predict(np.array([resized.ravel()], dtype=np.float32)) if __name__ == '__main__': tr, val, test = wrap_data() save_path = os.path.join('data', '自己想个文件名') if os.path.exists(save_path): print 'find it' svm = cv2.ml.SVM_load(save_path) else: svm = train_svm() svm.save(save_path) svm_test(svm, test)
转载请注明原文地址: https://www.6miu.com/read-49525.html

最新回复(0)