在TensorFlow中使用pipeline加载数据

xiaoxiao2021-02-28  26

正文共2028个字,6张图,预计阅读时间6分钟。

前面对TensorFlow的多线程做了测试,接下来就利用多线程和Queue pipeline地加载数据。数据流如下图所示:

首先,A、B、C三个文件通过RandomShuffle进程被随机加载到FilenameQueue里,然后Reader1和Reader2进程同FilenameQueue里取文件名读取文件,读取的内容再被放到ExampleQueue里。最后,计算进程会从ExampleQueue里取数据。各个进程独立操作,互不影响,这样可以加快程序速度。

我们简单地生成3个样本文件。

#生成三个样本文件,每个文件包含5列,假设前4列为特征,最后1列为标签

data = np.zeros([20,5]) np.savetxt('file0.csv', data, fmt='%d', delimiter=',') data += 1np.savetxt('file1.csv', data, fmt='%d', delimiter=',') data += 1np.savetxt('file2.csv', data, fmt='%d', delimiter=',')

然后,创建pipeline数据流。

#定义FilenameQueuefilename_queue = tf.train.string_input_producer(["file%d.csv"%i for i in range(3)])

#定义ExampleQueue

example_queue = tf.RandomShuffleQueue(    capacity=1000,    min_after_dequeue=0,    dtypes=[tf.int32,tf.int32],    shapes=[[4],[1]] )

#读取CSV文件,每次读一行

reader = tf.TextLineReader() key, value = reader.read(filename_queue)

#对一行数据进行解码

record_defaults = [[1], [1], [1], [1], [1]] col1, col2, col3, col4, col5 = tf.decode_csv(    value, record_defaults=record_defaults) features = tf.stack([col1, col2, col3, col4])

#将特征和标签push进ExampleQueue

enq_op = example_queue.enqueue([features, [col5]])

#使用QueueRunner创建两个进程加载数据到ExampleQueue

qr = tf.train.QueueRunner(example_queue, [enq_op]*2)

#使用此方法方便后面tf.train.start_queue_runner统一开始进程

tf.train.add_queue_runner(qr) xs = example_queue.dequeue()

with tf.Session() as sess:    coord = tf.train.Coordinator()

#开始所有进程    threads = tf.train.start_queue_runners(coord=coord)    

for i in range(200):        x = sess.run(xs)        print(x)    coord.request_stop()    coord.join(threads)

以上我们采用for循环step_num次来控制训练迭代次数。我们也可以通过tf.train.string_input_producer的num_epochs参数来设置FilenameQueue循环次数来控制训练,当达到num_epochs时,TensorFlow会抛出OutOfRangeError异常,通过捕获该异常,停止训练。

filename_queue = tf.train.string_input_producer(["file%d.csv"%i for i in range(3)], num_epochs=6) ...

with tf.Session() as sess:    sess.run(tf.initialize_local_variables()) #必须加上这句话,否则报错!    coord = tf.train.Coordinator()

#开始所有进程

   threads = tf.train.start_queue_runners(coord=coord)    

try:        

while not coord.should_stop():            x = sess.run(xs)            print(x)    

except tf.errors.OutOfRangeError:        print('Done training -- epch limit reached')    

finally:        coord.request_stop()

捕获到异常时,请求结束所有进程。

原文: 在TensorFlow中使用pipeline加载数据(https://goo.gl/jbVPjM)

原文链接:https://www.jianshu.com/p/12b52e54a63c

查阅更为简洁方便的分类文章以及最新的课程、产品信息,请移步至全新呈现的“LeadAI学院官网”:

www.leadai.org

请关注人工智能LeadAI公众号,查看更多专业文章

大家都在看

LSTM模型在问答系统中的应用

基于TensorFlow的神经网络解决用户流失概览问题

最全常见算法工程师面试题目整理(一)

最全常见算法工程师面试题目整理(二)

TensorFlow从1到2 | 第三章 深度学习革命的开端:卷积神经网络

装饰器 | Python高级编程

今天不如来复习下Python基础

转载请注明原文地址: https://www.6miu.com/read-1900075.html

最新回复(0)