美文网首页
Learning Tensorflow part 2

Learning Tensorflow part 2

作者: 轻骑兵1390 | 来源:发表于2020-08-21 09:27 被阅读0次

说明一些tensorflow课程中的用法.

1. 加载数据

def normalize(images, labels):
  images = tf.cast(images, tf.float32)
  images /= 255
  return images, labels

train_dataset =  train_dataset.map(normalize)
train_dataset =  train_dataset.cache()

normalize将图片数据由[0, 255]归一化到[0, 1].

  • cast将tensorflow的数据改为另一种数据类型, 这里只是改为浮点型, 避免除操作使其全为0.
  • map将数据进行重新映射tf.data.Dataset.map, 可以参考官方文件
  • cache 要求数据迭代之后(包括map)需要进行cache, 否则下一次迭代不会使用已经cache的数据.

2. 构造神经网络

l0 = tf.keras.layers.Flatten(input_shape = (28, 28, 1))
  • Flatten为展平输入.
  • 这里是灰度图, 因此是[28, 28, 1].

3. 构造误差函数

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(),
              metrics=['accuracy'])
  • crossentropy loss 交叉熵损失.

4. 训练

BATCH_SIZE = 32
train_dataset = train_dataset.cache().repeat().shuffle(num_train_examples).batch(BATCH_SIZE)
test_dataset = test_dataset.cache().batch(BATCH_SIZE)

model.fit(train_dataset, epochs=5, steps_per_epoch=math.ceil(num_train_examples/BATCH_SIZE))

这里要打乱原有的数据

  • repeat()表示重复使用
  • shuffle为随机打乱次序
  • batch(32)表示每次迭代32个数训练.

相关文章

网友评论

      本文标题:Learning Tensorflow part 2

      本文链接:https://www.haomeiwen.com/subject/erbsjktx.html