ANN处理mnist字体库

xiaoxiao2021-02-28  48

作者很懒 还是先知贴个代码,open3 + python 自己体会 # decoding:utf-8 import os import cv2 from cv2.cv2 import * import codecs from cv2.ml import VAR_ORDERED from canny import * from find_contours import * import numpy as np import cPickle import gzip # decoding:utf-8 def revel(a):     list = []     for i in a:         list.append(i[0])     return list def load_data():     mnist = gzip.open(os.path.join('data', 'mnist.pkl.gz'), 'rb')     training_data, classification_data, test_data = cPickle.load(mnist)     mnist.close()     return training_data, classification_data, test_data def wrap_data():     tr_d, va_d, te_d = load_data()     # print type(tr_d), type(va_d), type(te_d)     training_inputs = [np.reshape(x, (784, 1)) for x in tr_d[0]]     training_results = [vectorized_result(y) for y in tr_d[1]]     training_data = zip(training_inputs, training_results)     validation_inputs = [np.reshape(x, (784, 1)) for x in va_d[0]]     validation_data = zip(validation_inputs, va_d[1])     test_input = [np.reshape(x, (784, 1)) for x in te_d[0]]     test_data = zip(test_input, te_d[1])     return training_data, validation_data, test_data def vectorized_result(j):     e = np.zeros((10, 1))     e[j] = 1.0     return e def create_ANN(hidden=20):     ann = cv2.ml.ANN_MLP_create()     ann.setLayerSizes(np.array([64, hidden, 10]))     ann.setTrainMethod(cv2.ml.ANN_MLP_RPROP)     ann.setActivationFunction(cv2.ml.ANN_MLP_IDENTITY)     ann.setTermCriteria((cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 20, 1))     return ann def train(ann, samples=10000, epochs=1):     tr, val, test = wrap_data()     t_d = np.loadtxt('train_data.txt', np.float32)     m_d = np.loadtxt('train_result.txt', np.int32)     for x in xrange(epochs):         ann.train(t_d, cv2.ml.ROW_SAMPLE, m_d)     # for x in xrange(epochs):     #     counter = 0     #     for img in tr:     #         if counter > samples:     #             break     #         if counter % 1000 == 0:     #             print "Epoch %d : Trained %d/%d" % (x, counter, samples)     #         counter += 1     #         data, digit = img     #         t_d = np.loadtxt(train_file, np.float32)     #         m_d = np.loadtxt(test_file, np.int32)     #         # print 'data', np.array([data], dtype=np.float32).reshape(28, 28),\     #         #                digit     #         ann.train(np.array([data.ravel()], dtype=np.float32),\     #         cv2.ml.ROW_SAMPLE, np.array([digit.ravel()], dtype=np.float32))     #         # print '看一下训练数据的young', np.array([data.ravel()], dtype=np.float32)     #         # cv2.imshow('img', np.array([data.ravel()], dtype=np.float32).reshape(28, 28))     #         # while cv2.waitKey() is not 27:     #         #     pass     #         # cv2.destroyWindow('img')     #     print 'Epoch %d complete' % x     return ann, test def test(ann, test_data):     # for i in range(10):     #     name = ['sample']     #     name.append(str(i))     #     sample = np.array([test_data[i][0].ravel()], dtype=np.float32).reshape(28, 28)     #     cv2.imshow(str(''.join(name)), sample)     #     while (cv2.waitKey()!=27):     #         pass     #     print ann.predict(np.array([test_data[i][0].ravel()], dtype=np.float32))     sample = np.array([test_data[4][0].ravel()], dtype=np.float32).reshape(28, 28)     print 'sample', sample     cv2.imshow('sample', sample)     while (cv2.waitKey()!=27):         pass     sample_tem = cv2.resize(sample, (8, 8), interpolation=cv2.INTER_CUBIC)     print ann.predict(np.array([sample_tem.ravel()], dtype=np.float32)) def predict(ann, sample):     resized = sample.copy()     rows, cols = resized.shape     if (rows != 28 or cols != 28) and rows * cols > 0:         resized = cv2.resize(resized, (28, 28), interpolation=cv2.INTER_CUBIC)     return ann.predict(np.array([resized.ravel()], dtype=np.float32)) def main():     save_path = os.path.join('data', 'best_ann_test_rp+ident+ddddd')     if os.path.exists(save_path):         print 'find it'         ann = cv2.ml.ANN_MLP_load(save_path)         a, b, test_data = wrap_data()     else:         ann, test_data = train(create_ANN(58), 50000, 10)         ann.save(save_path)     test(ann, test_data) if __name__ == '__main__':     main()
转载请注明原文地址: https://www.6miu.com/read-56306.html

最新回复(0)