美文网首页
TensorFlow数据读取(batch)

TensorFlow数据读取(batch)

作者: MapleLuv | 来源:发表于2018-08-01 15:58 被阅读0次

    大神的文章中get新技能
    来不及解释先上代码(本人在Notebook运行):

    单列表

    import tensorflow as tf
    import numpy as np
    
    input_x = np.array([1,2,3,4,5,6,7,8,9])
    
    # 初始化对象
    dataset = tf.data.Dataset.from_tensor_slices(input_x)
    print("dataset:",dataset)
    
    ############# 一个一个读取 ###################
    
    # dataset = dataset.shuffle(buffer_size=1000)
    
    # 实例化一个Iterator(one_shot_iterator只能从头到尾读取一次)
    iterator = dataset.make_one_shot_iterator()
    print("iterator:",iterator)
    
    # 对Iterator进行迭代(从iterator里取出一个元素)
    # 由于这是非Eager模式,所以one_element只是一个Tensor,并不是一个实际的值。调用sess.run(one_element)后,才能真正地取出一个值。
    one_element = iterator.get_next()
    print("one_element:",one_element)
    
    # run
    with tf.Session() as sess:
        try:
            while True:
                print(sess.run(one_element))
        except tf.errors.OutOfRangeError:
            print("end!")
    
    ############# 组成batch读取 ###################
    batch_dataset = dataset.batch(batch_size=4)
    print("batch_dataset:",batch_dataset)
    
    # 实例化一个Iterator(one_shot_iterator只能从头到尾读取一次)
    batch_iterator = batch_dataset.make_one_shot_iterator()
    print("iterator:",batch_iterator)
    
    # 对Iterator进行迭代(从iterator里取出一个元素)
    # 由于这是非Eager模式,所以one_element只是一个Tensor,并不是一个实际的值。调用sess.run(one_element)后,才能真正地取出一个值。
    batch_one_element = batch_iterator.get_next()
    print("one_element:",batch_one_element)
    
    # run
    with tf.Session() as sess:
        try:
            while True:
                print(sess.run(batch_one_element))
        except tf.errors.OutOfRangeError:
            print("end!")
    

    多列表

    import tensorflow as tf
    import numpy as np
    
    #### 本地读取图片和label
    # 函数的功能时将filename对应的图片文件读进来,并缩放到统一的大小
    def _parse_function(input_x, input_y):
    #     image_string = tf.read_file(filename)
    #     image_decoded = tf.image.decode_image(image_string)
    #     image_resized = tf.image.resize_images(image_decoded, [28, 28])
        return input_x, input_y
    
    # # 图片文件的列表
    # filenames = tf.constant(["/var/data/image1.jpg", "/var/data/image2.jpg", ...])
    # # label[i]就是图片filenames[i]的label
    # labels = tf.constant([0, 37, ...])
    input_x = np.array([1,2,3,4,5,6,7,8,9])
    input_y = np.array(['a','b','c','d','e','f','g','h','i'])
     
    # 此时dataset中的一个元素是(filename, label)
    dataset = tf.data.Dataset.from_tensor_slices((input_x, input_y))
    print(dataset)
    
    # map接收一个函数,Dataset中的每个元素都会被当作这个函数的输入,并将函数返回值作为新的Dataset
    # 此时dataset中的一个元素是(image_resized, label)
    dataset = dataset.map(_parse_function)   # 这个针对_parse_function里面的操作有效,我这里因为是操作数字直接跳过
    print(dataset)
    
    # 此时dataset中的一个元素是(image_resized_batch, label_batch)
    # dataset = dataset.shuffle(buffer_size=1000).batch(4).repeat(3)
    dataset = dataset.batch(4).repeat(1)
    print(dataset)
    
    # 实例化一个Iterator(one_shot_iterator只能从头到尾读取一次)
    batch_iterator = dataset.make_one_shot_iterator()
    print("iterator:",batch_iterator)
    
    # 对Iterator进行迭代(从iterator里取出一个元素)
    # 由于这是非Eager模式,所以one_element只是一个Tensor,并不是一个实际的值。调用sess.run(one_element)后,才能真正地取出一个值。
    batch_one_element = batch_iterator.get_next()
    print("one_element:",batch_one_element)
    
    # run
    with tf.Session() as sess:
        try:
            while True:
                print("result:",sess.run(batch_one_element))
        except tf.errors.OutOfRangeError:
            print("end!")
    
    • 代码可以直接运行,两个例子单独运行,随后有空再来解释

    相关文章

      网友评论

          本文标题:TensorFlow数据读取(batch)

          本文链接:https://www.haomeiwen.com/subject/rqzlvftx.html