注意:不同的文件有不同的读取api
tensorflow中文件读取的过程
假设现在ABCD四个文件,每个文件100个样本
读取csv的步骤:
- 构造一个文件队列,将文件的路径+名字放入队列中
- 读取队列内容
- 常见的文件格式:csv、二进制文件、图片文件。
- 默认只读取一个样本。假如是csv文件,就只读取一行(一般是一行一个样本数据);如果是二进制文件,指定一个样本的bytes读取;如果是图片文件,按一张张的读取
- 解码:只有一个样本
- 批处理:相当于也是构造一个队列,里面可以放多个样本
主线程要做的就是取样本训练
文件读取流程 构造文件队列 构造文件阅读器 构造文件内容解码器 开启线程操作,只需要在会话中开启一个就行 批处理
读取一个数据
import tensorflow as tf
def csvread(filelist):
"""
读取csv文件
:return: None
"""
file_queue = tf.train.string_input_producer(filelist)
#构造cvs的阅读器读取队列
reader = tf.TextLineReader() #读取一行数据的阅读器
key, value = reader.read(file_queue) #从文件队列中随机读取一行数据,Returns the next record (key, value) pair produced by a reader.
"""reader.read
Returns:
A tuple of Tensors (key, value).
key: A string scalar Tensor.
value: A string scalar Tensor.
"""
print(value) #value是一个reader的op
#对每行内容解码,recorde_default指定每一个样本的每一列的类型和指定默认值
records = [["None"],["None"]] #必须是一个二维的列表。这里是分别指定两列数据的默认值和类型,因为这里是读字符串,所以列表里面为None
example, lable = tf.decode_csv(value, record_defaults=records) #返回每个样本的每一列的值,用example接收第一列,label接收第二列
# print(example, lable)
return example, lable
import os
if __name__ == "__main__":
#找到文件,放入列表 路径+名字——》列表
file_name = os.listdir("./csvdata/")
filelist = [os.path.join("./csvdata/", file) for file in file_name]
# print(filelist)
example, lable = csvread(filelist)
#开启会话运行结果
with tf.Session() as sess:
"""以下步骤都是固定写法了"""
#定义一个线程协调器
coord = tf.train.Coordinator()
#开启读取文件的线程
threads = tf.train.start_queue_runners(sess, coord=coord)
#打印读取内容
print(sess.run([example, lable]))
#回收线程
coord.request_stop()
coord.join(threads=threads)
"""end"""
批处理数据
import tensorflow as tf
"""
批处理大小,跟队列,数据的数量没有影响,只决定这批次取多少数据
batch_size才决定最终取多少数据训练,如果容量小于batch_size,就多取几次,直到数据量满足batch_size,就进行一批数据的训练
"""
def csvread(filelist):
"""
读取csv文件
:return: None
"""
file_queue = tf.train.string_input_producer(filelist)
#构造cvs的阅读器读取队列
reader = tf.TextLineReader()
key, value = reader.read(file_queue)
print(value) #value是一个reader的op
#对每行内容解码,recorde_default指定每一个样本的每一列的类型和指定默认值
records = [["None"],["None"]] #必须是一个二维的列表。这里是分别指定两列数据的默认值和类型,因为这里是读字符串,所以列表里面为None
example, lable = tf.decode_csv(value, record_defaults=records) #返回每个样本的每一列的值,用example接收第一列,label接收第二列
# print(example, lable)
return example, lable
import os
if __name__ == "__main__":
#找到文件,放入列表 路径+名字——》列表
file_name = os.listdir("./csvdata/")
filelist = [os.path.join("./csvdata/", file) for file in file_name]
# print(filelist)
example, lable = csvread(filelist)
# 想要读取多个数据,就需要批处理
example_batch, lable_batch = tf.train.batch([example, lable], batch_size=9, num_threads=1, capacity=9) #要批处理的数据,每批数据的大小,线程数,队列的容量
print(example_batch, lable_batch)
#开启会话运行结果
with tf.Session() as sess:
"""以下步骤都是固定写法了"""
#定义一个线程协调器
coord = tf.train.Coordinator()
#开启读取文件的线程
threads = tf.train.start_queue_runners(sess, coord=coord)
#打印读取内容
print(sess.run([example_batch, lable_batch]))
#回收线程
coord.request_stop()
coord.join(threads=threads)
"""end"""
网友评论