摘要:简介读取数据共有三种方法当运行每步计算的时候,从获取数据。数据直接预加载到的中,再把传入运行。在中定义好文件读取的运算节点,把传入运行时,执行读取文件的运算,这样可以避免在和执行环境之间反复传递数据。本文讲解的代码。
简介
TensorFlow读取数据共有三种方法:
Feeding:当TensorFlow运行每步计算的时候,从Python获取数据。在Graph的设计阶段,用placeholder占住Graph的位置,完成Graph的表达;当Graph传给Session后,在运算时再把需要的数据从Python传过来。
Preloaded data:数据直接预加载到TensorFlow的Graph中,再把Graph传入Session运行。只适用于小数据。
Reading from file:在Graph中定义好文件读取的运算节点,把Graph传入Session运行时,执行读取文件的运算,这样可以避免在Python和TensorFlow C++执行环境之间反复传递数据。
本文讲解Reading from file的代码。
其他关于TensorFlow的学习笔记,请点击入门教程
实现#!/usr/bin/env python # -*- coding=utf-8 -*- # @author: 陈水平 # @date: 2017-02-19 # @description: modified program to illustrate reading from file based on TF offitial tutorial # @ref: https://www.tensorflow.org/programmers_guide/reading_data def read_my_file_format(filename_queue): """从文件名队列读取一行数据 输入: ----- filename_queue:文件名队列,举个例子,可以使用`tf.train.string_input_producer(["file0.csv", "file1.csv"])`方法创建一个包含两个CSV文件的队列 输出: ----- 一个样本:`[features, label]` """ reader = tf.SomeReader() # 创建Reader key, record_string = reader.read(filename_queue) # 读取一行记录 example, label = tf.some_decoder(record_string) # 解析该行记录 processed_example = some_processing(example) # 对特征进行预处理 return processed_example, label def input_pipeline(filenames, batch_size, num_epochs=None): """ 从一组文件中读取一个批次数据 输入: ----- filenames:文件名列表,如`["file0.csv", "file1.csv"]` batch_size:每次读取的样本数 num_epochs:每个文件的读取次数 输出: ----- 一批样本,`[[example1, label1], [example2, label2], ...]` """ filename_queue = tf.train.string_input_producer(filenames, num_epochs=num_epochs, shuffle=True) # 创建文件名队列 example, label = read_my_file_format(filename_queue) # 读取一个样本 # 将样本放进样本队列,每次输出一个批次样本 # - min_after_dequeue:定义输出样本后的队列最小样本数,越大随机性越强,但start up时间和内存占用越多 # - capacity:队列大小,必须比min_after_dequeue大 min_after_dequeue = 10000 capacity = min_after_dqueue + 3 * batch_size example_batch, label_batch = tf.train.shuffle_batch( [example, label], batch_size=batch_size, capacity=capacity, min_after_dequeue=min_after_dequeue) return example_batch, label_batch def main(_): x, y = input_pipeline(["file0.csv", "file1.csv"], 1000, 5) train_op = some_func(x, y) init_op = tf.global_variables_initializer() local_init_op = tf.local_variables_initializer() # local variables like epoch_num, batch_size sess = tf.Session() sess.run(init_op) sess.run(local_init_op) # `QueueRunner`用于创建一系列线程,反复地执行`enqueue` op # `Coordinator`用于让这些线程一起结束 # 典型应用场景: # - 多线程准备样本数据,执行enqueue将样本放进一个队列 # - 一个训练线程从队列执行dequeu获取一批样本,执行training op # `tf.train`的许多函数会在graph中添加`QueueRunner`对象,如`tf.train.string_input_producer` # 在执行training op之前,需要保证Queue里有数据,因此需要先执行`start_queue_runners` coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: while not coord.should_stop(): sess.run(train_op) except tf.errors.OutOfRangeError: print "Done training -- epoch limit reached" finally: coord.request_stop() # Wait for threads to finish coord.join(threads) sess.close() if __name__ == "__main__": tf.app.run()
文章版权归作者所有,未经允许请勿转载,若此文章存在违规行为,您可以联系管理员删除。
转载请注明本文地址:https://www.ucloud.cn/yun/38460.html
摘要:贡献者飞龙版本最近总是有人问我,把这些资料看完一遍要用多长时间,如果你一本书一本书看的话,的确要用很长时间。为了方便大家,我就把每本书的章节拆开,再按照知识点合并,手动整理了这个知识树。 Special Sponsors showImg(https://segmentfault.com/img/remote/1460000018907426?w=1760&h=200); 贡献者:飞龙版...
摘要:前言本文基于官网的写成。输入数据是,全称是,是一组由这个机构搜集的手写数字扫描文件和每个文件对应标签的数据集,经过一定的修改使其适合机器学习算法读取。这个数据集可以从牛的不行的教授的网站获取。 前言 本文基于TensorFlow官网的Tutorial写成。输入数据是MNIST,全称是Modified National Institute of Standards and Technol...
摘要:本文的目的是聚焦于数据操作能力,讲述中比较重要的一些,帮助大家实现各自的业务逻辑。传入输入值,指定输出的基本数据类型。 引言 用TensorFlow做好一个机器学习项目,需要具备多种代码能力: 工程开发能力:怎么读取数据、怎么设计与运行Computation Graph、怎么保存与恢复变量、怎么保存统计结果、怎么共享变量、怎么分布式部署 数据操作能力:怎么将原始数据一步步转化为模型需...
阅读 3474·2021-10-13 09:39
阅读 1457·2021-10-08 10:05
阅读 2259·2021-09-26 09:56
阅读 2274·2021-09-03 10:28
阅读 2672·2019-08-29 18:37
阅读 2032·2019-08-29 17:07
阅读 599·2019-08-29 16:23
阅读 2190·2019-08-29 11:24