1、认识变量
tf.Variable(name,initial_value,trainable)
a.assign()
a.assign_add()
2、构建数据
dataset
tf.data.Dataset.from_tensor_slices(train,label) 无shape,有点类似迭代器Iterator
常用函数shuffle(1000).batch(32).repeat(10)
3、建立模型
(1)API
包括sequential api 和 functional api
tf.keras.models.Sequential([..])
functional api是为了建立更复杂的模型,比如多输入多输出
model=tf.keras.Model(inputs=inputs,outputs=outputs)
下面有三个重要的过程:
第一是训练的配置过程
compile方法,定义优化器 tf.keras.optimizers、损失函数tf.keras.losses、评估指标tf.keras.metris
例如:
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss=tf.keras.losses.sparse_categorical_crossentropy,
metrics=[tf.keras.metrics.sparse_categorical_accuracy])
fit方法,训练数据、目标(label)数据、训练轮数、Batch_size、验证数据
例如:
model.fit(data_loader.train_data,data_loader.train_label,epochs=num_epochs,batch_size=batch_size)
evaluate 方法,评估测试数据(数据/label),用tf.keras.Model.evaluate
print(model.evaluate(data_loader.test_data,data_loader.test_label))
详见网址:https://tf.wiki/zh_hans/basic/models.html#zh-hans-custom-layer
(2)自定义
网友评论