说明一些tensorflow课程中的用法.
1. 加载数据
def normalize(images, labels):
images = tf.cast(images, tf.float32)
images /= 255
return images, labels
train_dataset = train_dataset.map(normalize)
train_dataset = train_dataset.cache()
normalize
将图片数据由[0, 255]归一化到[0, 1].
-
cast
将tensorflow的数据改为另一种数据类型, 这里只是改为浮点型, 避免除操作使其全为0. -
map
将数据进行重新映射tf.data.Dataset.map
, 可以参考官方文件 -
cache
要求数据迭代之后(包括map
)需要进行cache
, 否则下一次迭代不会使用已经cache
的数据.
2. 构造神经网络
l0 = tf.keras.layers.Flatten(input_shape = (28, 28, 1))
-
Flatten
为展平输入. - 这里是灰度图, 因此是[28, 28, 1].
3. 构造误差函数
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy'])
- crossentropy loss 交叉熵损失.
4. 训练
BATCH_SIZE = 32
train_dataset = train_dataset.cache().repeat().shuffle(num_train_examples).batch(BATCH_SIZE)
test_dataset = test_dataset.cache().batch(BATCH_SIZE)
model.fit(train_dataset, epochs=5, steps_per_epoch=math.ceil(num_train_examples/BATCH_SIZE))
这里要打乱原有的数据
-
repeat()
表示重复使用 -
shuffle
为随机打乱次序 -
batch(32)
表示每次迭代32个数训练.
网友评论