美文网首页
TensorFlow_简单数据加载

TensorFlow_简单数据加载

作者: 刘璐_95d7 | 来源:发表于2019-11-27 10:31 被阅读0次

    建议预先导入这些模块

    # import os
    # os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    import matplotlib as mpl
    import matplotlib.pyplot as plt
    %matplotlib inline
    import numpy as np 
    import sklearn
    import pandas as pd 
    import os 
    import sys
    import time
    import tensorflow as tf
    
    from tensorflow import keras
    
    print("*"*30+"下面这一行是tf_gpu测试代码"+"*"*30)
    print('GPU是否可用:', tf.test.is_gpu_available())
    print("*"*83)
    
    print(tf.__version__)
    print(sys.version_info)
    print("*"*83)
    for module in mpl, np, pd, sklearn, tf, keras:
        print(module.__name__,module.__version__)
    print("*"*83)    
    
    ******************************下面这一行是tf_gpu测试代码******************************
    GPU是否可用: True
    ***********************************************************************************
    2.0.0
    sys.version_info(major=3, minor=7, micro=4, releaselevel='final', serial=0)
    ***********************************************************************************
    matplotlib 3.1.1
    numpy 1.16.5
    pandas 0.25.1
    sklearn 0.21.3
    tensorflow 2.0.0
    tensorflow_core.keras 2.2.4-tf
    ***********************************************************************************
    

    导入Keras数据

    一般来说,获得数据集需要去网站上下载对应数据集的文件,下载以后把他放在某个目录,再写个python的一个解析格式,转化成numpy的这样一个格式。对于自带的数据集,使用下面这一行代码就可以自动的下载,管理,解析,读取,转换等工作,最终是一个numpy的格式

    (x, y), (x_test, y_test) = keras.datasets.mnist.load_data()
    

    探索数据结构

    x.shape, y.shape, x.min(), x.max(), x.mean()
    
    ((60000, 28, 28), (60000,), 0, 255, 33.318421449829934)
    
    x_test.shape, y_test.shape
    y[:4]
    
    ((10000, 28, 28), (10000,))
    array([5, 0, 4, 1], dtype=uint8)
    

    独热编码

    y_onehot = tf.one_hot(y, depth=10)
    
    y_onehot[:4]
    

    加载cifar10、100

    (x, y), (x_test, y_test) = keras.datasets.cifar10.load_data()
    
    x.shape, y.shape, x_test.shape, y_test.shape, x.min(), x.max(), x.mean()
    

    现在我们的数据是一个numpy的格式,我们应该转换成tensor的一个格式,再将其做一个迭代,这一部分直接用tf.data.Dataset.from_tensor_slices,因此数据加载到内存后,需要转化成dataset对象,以利用TensorFlow提供的各种便捷功能

    我们就可以通过dataset做一个迭代,首先取得一个迭代器iter,做迭代 可以看到里面的内容。
    问题:为什么不直接将x转成x_tensor,然后利用for x in x_tensor来做读取操作?

    首先这样拿到的image只能一张一张的拿,然后你还没有做一个图片的预处理,这一部分可以通过利用dataset提供的一个接口非常方便的完成,所以dataset比直接使用一个tensor,然后做一个硬的循环操作简单方便的多,而且它还支持多线程。至少来说,你一个个的读取是没有办法将他转化成一个batch的。

    db = tf.data.Dataset.from_tensor_slices(x_test)
    
    iter(db)
    

    返回的是一个迭代对象

    >> <tensorflow.python.data.ops.iterator_ops.IteratorV2 at 0x242bfa32a08>
    
    next(iter(db)).shape
    
    >>TensorShape([32, 32, 3])
    
    db = tf.data.Dataset.from_tensor_slices((x_test, y_test))  # 元组
    next(iter(db))[0].shape, next(iter(db))[1].shape, next(iter(db))[1]
    
    >>(TensorShape([32, 32, 3]),
     TensorShape([1]),
     <tf.Tensor: id=86, shape=(1,), dtype=uint8, numpy=array([3], dtype=uint8)>)
    

    shuffle功能

    打散功能:注意:dataset还有一个非常重要的shuffle功能,这个功能是非常必须的,因为我们做模型时数据量大,模型的记忆能力非常非常的强,如果数据总是按照原本的顺序出来,那其实它不需要看你这张图片就可以预测出来,直接预测序号

    所以我们做training时,一定的要做一个打散的功能

    db = db.shuffle(100000)
    

    数据预处理

    def preprocess(x, y):
        x = tf.cast(x, dtype=tf.float32)/255  # 在numpy默认是float64,deeplearning一般用64
        y = tf.cast(y, dtype=tf.int32)
        y = tf.one_hot(y, depth=10)
        return x, y
    
    db2 = db.map(preprocess) 
    
    res = next(iter(db2))
    res[0].shape, res[1].shape
    
    >>(TensorShape([32, 32, 3]), TensorShape([1, 10]))
    

    最最有用的batch功能

    db3 = db2.batch(32)
    
    res = next(iter(db3))
    
    res[0].shape, res[1].shape
    
    >>(TensorShape([32, 32, 32, 3]), TensorShape([32, 1, 10]))
    

    额外注意:TensorShape([32, 1, 10]),我们不希望看见这个中间的1,所以在导入数据集时,y.shape =(60000, 1),与增加维度一样,删除维度只能删除长度为1的维度,也不会改变张量的存储。例如考虑维度为[1,28,28,1]的x的例子,可以通过
    函数tf.squeeze(x,axis),axis参数为待删除的维度的索引号,x = tf.squeeze(x,axis=0)=>[28,28,1]

    那么这样子就可以对整个数据集进行迭代,比如说: for x, y in db,那么x的shape就是[batch,32,32,3],y就是[batch,10]

    但是如果迭代完以后这个地方会发生什么问题呢?你用for的话迭代完成就停止了,如果用while true:x,y=next(db_iter),当迭代完成就会报出异常StopIteration,那是因为已经没有元素了

    介绍repeat函数:db4=db3.repeat(),(db4=db3.repeat(2)), for x, y in db可以迭代2次,如果永远不想退出,就不要填2

    比较完整的预处理流程

    def prepare_mnist_features_and_lables(x, y):
        x = tf.cast(x, tf.float32) / 255.0
        y = tf.cast(y, tf,int64)
        return x, y
    
    def mnist_dataset():
        (x, y), (x_val, y_val) = datasets.fashion_mnist.load_data()
        y = tf.one_hot(y, depth=10)
        y_val = tf.one_hot(y_val, depth=10)
        
        ds = tf.data.Dataset.from_tensor_slices((x, y))
        ds = ds.map(prepare_mnist_features_and_lables)
        ds = ds.shuffle(60000).batch(100)
        
        ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
        ds_val = ds_val.map(prepare_mnist_features_and_lables)
        ds_val = ds_val.shuffle(60000).batch(100)
        
        return ds, ds_val
    

    相关文章

      网友评论

          本文标题:TensorFlow_简单数据加载

          本文链接:https://www.haomeiwen.com/subject/sonuwctx.html