美文网首页
[Tensorflow2] 数据加载

[Tensorflow2] 数据加载

作者: LZhan | 来源:发表于2019-09-27 13:26 被阅读0次

针对小型常用数据集,tensorflow2中加载数据通常有两种方法:
1、使用keras.datasets


image.png

有几种数据集调用load_data()方法可以加载。

2、使用tf.data.Dataset.from_tensor_slices()方法
相应的tf.data.Dataset还有map,shuffle,range,batch,repeat等方法可供使用

但是针对大型数据集,使用Input Pipeline的方式,进行多线程加载数据。

1 keras.datasets

    (x, y), (x_test, y_test) = keras.datasets.mnist.load_data()
    print(x.shape)
    print(y.shape)
    print(type(x))
    print(type(y))
    print(x.min(), x.max(), x.mean())
     # 获取y的前4个,y的取值为0-9,共10个值,所以one-hot的depth为10
    print(y[:4])
    y_onehot = tf.one_hot(y, depth=10)
    print(y_onehot[:4])

返回结果:
(60000, 28, 28)
(60000,)
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
0 255 33.318421449829934
[5 0 4 1]
tf.Tensor(
[[0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
[0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]], shape=(4, 10), dtype=float32)

说明keras.datasets加载的数据是numpy格式,并不是tensor格式,因此在求最小值,最大值,平均值没有用reduce_min,reduce_max和reduce_mean等。

2 tf.data.Dataset.from_tensor_slices()

tf.data.Dataset模块,使其数据读入的操作变得更为方便,而支持多线程(进程)的操作,也在效率上获得了一定程度的提高。
from_tensor_slices方法会对tensor和numpy array的处理一视同仁,所以该方法既可以使用tensor参数,也可以直接使用numpy array作为参数

    db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
    print(x_test.shape)
    # 必须先取到迭代器
    print(next(iter(db))[0].shape)

返回结果:
(10000, 28, 28)
(28, 28)

上述方法是面对小数据集的情况,面对大数据集的情况有如下方法,参考博客:
https://www.jianshu.com/p/f580f4fc2ba0

3 tf.data.Dataset的其他相关方法

<1> tf.data.Dataset.shuffle(buffer_size):对数据集进行打散,shuffle可以给定参数,代表在多大的一个范围内进行打散,该参数可以给大一些。
数据预处理功能
<2> tf.data.Dataset.map:

def preprocess(x, y):
    x = tf.cast(x, dtype=tf.float32) / 255.
    y = tf.cast(y, dtype=tf.int32)
    y = tf.one_hot(y, depth=10)
    return x, y
# 使用map操作,对单值调用preprocess方法
db2=db.map(preprocess)

<3> batch方法
通常我们是对每一批元素进行操作,可以指定批的大小

db3=db2.batch(32)
res=next(iter(db3))

<4> repeat方法
repeat的功能就是将整个序列重复多次,主要用来处理机器学习中的epoch,假设原先的数据是一个epoch,使用repeat(2)就可以将之变成两个epoch。
注意,如果直接调用repeat()的话,生成的序列就会无限重复下去,没有结束,因此也不会抛出tf.errors.OutOfRangeError异常。

4 总结

# 数据集加载的过程:
def prepare_mnist_features_and_labels(x, y):
    x = tf.cast(x, tf.float32) / 255.
    y = tf.cast(y, tf.int64)
    return x, y


def mnist_dataset():
    (x, y), (x_val, y_val) = keras.datasets.fashion_mnist.load_data()
    y = tf.one_hot(y, depth=10)
    y_val = tf.one_hot(y_val, depth=10)

    ds = tf.data.Dataset.from_tensor_slices((x, y))
    ds = ds.map(prepare_mnist_features_and_labels)
    ds = ds.shuffle(60000).batch(100)
    ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
    ds_val = ds_val.map(prepare_mnist_features_and_labels)
    ds_val = ds_val.shuffle(60000).batch(100)
    return ds, ds_val

相关文章

  • [Tensorflow2] 数据加载

    针对小型常用数据集,tensorflow2中加载数据通常有两种方法:1、使用keras.datasets 有几种数...

  • TensorFlow操作进阶

    在本章脑图中主要是介绍了TensorFlow2的进阶操作,包含: 合并和分割 数据统计,聚合函数 数据限幅实现 填...

  • SparkSql之数据的加载与保存

    加载数据 创建SaparkSession 加载数据方式 * 表示加载的方式 format指定加载数据类型 spar...

  • 数据列表涉及的基本需求点

    1. 数据来源 2.数据排序规则 3. 数据加载: 1)一页展示多少条数据 2)加载规则: --进入加载(正在加载...

  • TensorFlow2 数据管道 Dataset

    如果需要训练的数据大小不大,例如不到 1G,那么可以直接全部读入内存中进行训练,这样一般效率最高。但如果需要训练的...

  • CustomWaittingView

    加载数据时显示加载状态,加载完毕恢复 只需两句代码,搞定加载数据图片 [[LWaittingFullView sh...

  • 《机器学习Python实践 》读书笔记-数据理解

    1. 导入数据 加载数据集的方式有很多种,从数据库中加载,从文件中加载 这里涉及函数:read_csv加载数据集的...

  • TensorFlow2.0的一些常用的操作

    1、数据的加载 MNIST数据集的加载: CIFAR10数据集的加载: 2、tf.data.Dataset.fro...

  • 基因结构图

    加载R包 加载数据 绘图 选择部分数据绘图

  • Mybatis延迟加载

    延迟加载概念:需要用到数据时才进行加载,不需要用到数据时就不加载数据,延迟加载也叫做懒加载。 优点:先从单表查询,...

网友评论

      本文标题:[Tensorflow2] 数据加载

      本文链接:https://www.haomeiwen.com/subject/fzmouctx.html