作者很懒 还是先知贴个代码,open3 + python 自己体会
# decoding:utf-8
import os
import cv2
from cv2.cv2 import *
import codecs
from cv2.ml import VAR_ORDERED
from canny import *
from find_contours import *
import numpy as np
import cPickle
import gzip
# decoding:utf-8
def revel(a):
list = []
for i in a:
list.append(i[0])
return list
def load_data():
mnist = gzip.open(os.path.join('data', 'mnist.pkl.gz'), 'rb')
training_data, classification_data, test_data = cPickle.load(mnist)
mnist.close()
return training_data, classification_data, test_data
def wrap_data():
tr_d, va_d, te_d = load_data()
# print type(tr_d), type(va_d), type(te_d)
training_inputs = [np.reshape(x, (784, 1)) for x in tr_d[0]]
training_results = [vectorized_result(y) for y in tr_d[1]]
training_data = zip(training_inputs, training_results)
validation_inputs = [np.reshape(x, (784, 1)) for x in va_d[0]]
validation_data = zip(validation_inputs, va_d[1])
test_input = [np.reshape(x, (784, 1)) for x in te_d[0]]
test_data = zip(test_input, te_d[1])
return training_data, validation_data, test_data
def vectorized_result(j):
e = np.zeros((10, 1))
e[j] = 1.0
return e
def create_ANN(hidden=20):
ann = cv2.ml.ANN_MLP_create()
ann.setLayerSizes(np.array([64, hidden, 10]))
ann.setTrainMethod(cv2.ml.ANN_MLP_RPROP)
ann.setActivationFunction(cv2.ml.ANN_MLP_IDENTITY)
ann.setTermCriteria((cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 20, 1))
return ann
def train(ann, samples=10000, epochs=1):
tr, val, test = wrap_data()
t_d = np.loadtxt('train_data.txt', np.float32)
m_d = np.loadtxt('train_result.txt', np.int32)
for x in xrange(epochs):
ann.train(t_d, cv2.ml.ROW_SAMPLE, m_d)
# for x in xrange(epochs):
# counter = 0
# for img in tr:
# if counter > samples:
# break
# if counter % 1000 == 0:
# print "Epoch %d : Trained %d/%d" % (x, counter, samples)
# counter += 1
# data, digit = img
# t_d = np.loadtxt(train_file, np.float32)
# m_d = np.loadtxt(test_file, np.int32)
# # print 'data', np.array([data], dtype=np.float32).reshape(28, 28),\
# # digit
# ann.train(np.array([data.ravel()], dtype=np.float32),\
# cv2.ml.ROW_SAMPLE, np.array([digit.ravel()], dtype=np.float32))
# # print '看一下训练数据的young', np.array([data.ravel()], dtype=np.float32)
# # cv2.imshow('img', np.array([data.ravel()], dtype=np.float32).reshape(28, 28))
# # while cv2.waitKey() is not 27:
# # pass
# # cv2.destroyWindow('img')
# print 'Epoch %d complete' % x
return ann, test
def test(ann, test_data):
# for i in range(10):
# name = ['sample']
# name.append(str(i))
# sample = np.array([test_data[i][0].ravel()], dtype=np.float32).reshape(28, 28)
# cv2.imshow(str(''.join(name)), sample)
# while (cv2.waitKey()!=27):
# pass
# print ann.predict(np.array([test_data[i][0].ravel()], dtype=np.float32))
sample = np.array([test_data[4][0].ravel()], dtype=np.float32).reshape(28, 28)
print 'sample', sample
cv2.imshow('sample', sample)
while (cv2.waitKey()!=27):
pass
sample_tem = cv2.resize(sample, (8, 8), interpolation=cv2.INTER_CUBIC)
print ann.predict(np.array([sample_tem.ravel()], dtype=np.float32))
def predict(ann, sample):
resized = sample.copy()
rows, cols = resized.shape
if (rows != 28 or cols != 28) and rows * cols > 0:
resized = cv2.resize(resized, (28, 28), interpolation=cv2.INTER_CUBIC)
return ann.predict(np.array([resized.ravel()], dtype=np.float32))
def main():
save_path = os.path.join('data', 'best_ann_test_rp+ident+ddddd')
if os.path.exists(save_path):
print 'find it'
ann = cv2.ml.ANN_MLP_load(save_path)
a, b, test_data = wrap_data()
else:
ann, test_data = train(create_ANN(58), 50000, 10)
ann.save(save_path)
test(ann, test_data)
if __name__ == '__main__':
main()