python读取MNIST image数据

xiaoxiao2021-02-28  100

标签: pythonMNISTimage 3670人阅读 评论(0) 收藏 举报 分类: machine learning(15) computer language(10)

Lecun Mnist数据集下载

import numpy as np import struct def loadImageSet(which=0): print "load image set" binfile=None if which==0: binfile = open("..//dataset//train-images-idx3-ubyte", 'rb') else: binfile= open("..//dataset//t10k-images-idx3-ubyte", 'rb') buffers = binfile.read() head = struct.unpack_from('>IIII' , buffers ,0) print "head,",head offset=struct.calcsize('>IIII') imgNum=head[1] width=head[2] height=head[3] #[60000]*28*28 bits=imgNum*width*height bitsString='>'+str(bits)+'B' #like '>47040000B' imgs=struct.unpack_from(bitsString,buffers,offset) binfile.close() imgs=np.reshape(imgs,[imgNum,width,height]) print "load imgs finished" return imgs def loadLabelSet(which=0): print "load label set" binfile=None if which==0: binfile = open("..//dataset//train-labels-idx1-ubyte", 'rb') else: binfile= open("..//dataset//t10k-labels-idx1-ubyte", 'rb') buffers = binfile.read() head = struct.unpack_from('>II' , buffers ,0) print "head,",head imgNum=head[1] offset = struct.calcsize('>II') numString='>'+str(imgNum)+"B" labels= struct.unpack_from(numString , buffers , offset) binfile.close() labels=np.reshape(labels,[imgNum,1]) #print labels print 'load label finished' return labels if __name__=="__main__": imgs=loadImageSet() #import PlotUtil as pu #pu.showImgMatrix(imgs[0]) loadLabelSet() 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758

或者通过另一种方式加载mnist数据(来自Theano Tutorial的手写数据集),该数据集被处理过,所以加载速度更快些,相比于Lecun的原始数据集。 mnist.pkl.gz下载地址

import os import cPickle import gzip def loadIMGData(dataset): f = gzip.open(dataset, 'rb') train_set, valid_set, test_set = cPickle.load(f) f.close() return [train_set,valid_set,test_set] def getIMGSets(): print "load dataset" sets=loadIMGData("your_path/mnist.pkl.gz") train_x,train_y=sets[0] valid_x,valid_y=sets[1] test_x , test_y=sets[2] print "train image,label shape:",train_x.shape,train_y.shape print "valid image,label shape:",valid_x.shape,valid_y.shape print "test image,label shape:",test_x.shape,test_y.shape print "load dataset end" return [train_x,train_y,valid_x,valid_y,test_x,test_y] 12345678910111213141516171819202122 12345678910111213141516171819202122

及方便训练的reader

import numpy as np import struct import gzip import cPickle class MnistReader(): def __init__(self,mnist_path,data_dim=1,one_hot=True): ''' mnist_path: the path of mnist.pkl.gz data_dim=1 [N,784] data_dim=3 [N,28,28,1] one_hot: one hot encoding(like: [0,1,0,0,0,0,0,0,0,0]) if true ''' self.mnist_path=mnist_path self.data_dim=data_dim self.one_hot=one_hot self.load_minist(mnist_path) self.train_datalabel=zip(self.train_x,self.train_y) self.valid_datalabel=zip(self.valid_x,self.valid_y) self.batch_offset_train=0 def next_batch_train(self,batch_size): ''' return list of images with shape [N,784] or [N,28,28,1] dependents on self.data_dim and list of labels with shape [N] or [N,10] dependents on self.one_hot ''' if self.batch_offset_train<len(self.train_datalabel)//batch_size: imgs=list();labels=list() for d,l in self.train_datalabel[self.batch_offset_train:self.batch_offset_train+batch_size]: if self.data_dim==3: d=np.reshape(d, [28,28,1]) imgs.append(d) if self.one_hot: a=np.zeros(10) a[l]=1 labels.append(l) else: labels.append(l) self.batch_offset_train+=1 return imgs,labels else: self.batch_offset_train=0 np.random.shuffle(self.train_datalabel) return self.next_batch_train(batch_size) def next_batch_val(self,batch_size): ''' return list of images with shape [N,784] or [N,28,28,1] dependents on self.data_dim and list of labels with shape [N,1] or [N,10] dependents on self.one_hot ''' np.random.shuffle(self.valid_datalabel) imgs=list();labels=list() for d,l in self.train_datalabel[0:batch_size]: if self.data_dim==3: d=np.reshape(d, [28,28,1]) imgs.append(d) if self.one_hot: a=np.zeros(10) a[l]=1 labels.append(l) else: labels.append(l) return imgs,labels def load_minist(self,dataset): print "load dataset" f = gzip.open(dataset, 'rb') train_set, valid_set, test_set = cPickle.load(f) f.close() self.train_x,self.train_y=train_set self.valid_x,self.valid_y=valid_set self.test_x , self.test_y=test_set print "train image,label shape:",self.train_x.shape,self.train_y.shape print "valid image,label shape:",self.valid_x.shape,self.valid_y.shape print "test image,label shape:",self.test_x.shape,self.test_y.shape print "load dataset end" if __name__=="__main__": mnist=MnistReader('../dataset/mnist.pkl.gz',data_dim=3) data,label=mnist.next_batch_train(batch_size=1) print data print label 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586

第三种加载方式需要 gzip和struct

import gzip, struct def _read(image,label): minist_dir = 'your_dir/' with gzip.open(minist_dir+label) as flbl: magic, num = struct.unpack(">II", flbl.read(8)) label = np.fromstring(flbl.read(), dtype=np.int8) with gzip.open(minist_dir+image, 'rb') as fimg: magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16)) image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label), rows, cols) return image,label def get_data(): train_img,train_label = _read( 'train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz') test_img,test_label = _read( 't10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz') return [train_img,train_label,test_img,test_label] 1234567891011121314151617181920 1234567891011121314151617181920
转载请注明原文地址: https://www.6miu.com/read-57573.html

最新回复(0)