最近在训练网络中会用到非图像类型的数据,我这里是将这种数据转换成LMDB类型作为一个数据层,加载进网络。主要用到caffe的Python接口。 1、在网络的中间层中,其接受一个1x6维的bottom数据作为输入; 2、每个训练样本对应的1x6维的数据存储到data.txt,同时记录其类别标签; 3、写入LMDB 。
#-*- coding: UTF-8 -*- import numpy as np import caffe import lmdb from caffe.proto import caffe_pb2 import sys,os # 读入数据和对应的类别标签 theta_file=open('./data.txt','r') label=open('./label.txt','r') theta_list=[] theta_label=[] for line in theta_file: content=line.strip().split(',') theta=[] for i in range(len(content)): theta.append(float(content[i])) theta_list.append(theta) del content,theta theta_file.close() for line in label: content=line.strip().split('\n') theta_label.append(int(content[0])) # 写入lmdb,需要将list转换为array db = lmdb.open('data_lmdb', map_size=int(1e12)) with db.begin(write=True) as in_txn: for i in range(len(theta_list)): datum = caffe.proto.caffe_pb2.Datum() datum.channels = 1 datum.height = 1 datum.width = 6 tmp_=theta_list[i] tmp=np.array(range(6), dtype=np.float) for j in range(6): tmp[j]=tmp_[j] label=int(theta_label[i]) datum.data = tmp.tobytes() # datum.data = tmp.tostring() datum.label=label in_txn.put('{:0>10d}'.format(i), datum.SerializeToString()) db.close()