版权声明:本文为博主原创文章,未经博主允许不得转载。
有很多机器学习的公开数据都需要手工编码读取,当然自己写代码读取是机器学习应用的基本能力,这里为了大家方便开发代码,避免重复发明轮子。
关于cifar数据集,点击这里,因为其下载比较慢,所以可以用csdn的下载地址下载cifar-10,cifar-10 csdn地址
下载后将其解压,如路径为: /xxx/cifar-10-batches-py/
代码很简单没有写注释,读取代码如下:
[python] view plain copy import cPickle import numpy as np import os class Cifar10DataReader(): def __init__(self,cifar_folder,onehot=True): self.cifar_folder=cifar_folder self.onehot=onehot self.data_index=1 self.read_next=True self.data_label_train=None self.data_label_test=None self.batch_index=0 def unpickle(self,f): fo = open(f, 'rb') d = cPickle.load(fo) fo.close() return d def next_train_data(self,batch_size=100): assert 10000