美文网首页
TensorFlow2x神经网络处理fashion mnist数

TensorFlow2x神经网络处理fashion mnist数

作者: 锅碗瓢盆油盐酱醋 | 来源:发表于2021-02-07 21:18 被阅读0次

普通的全连接神经网络

1.导入相关包
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
2.加载数据集
# 加载fashion_mnist数据集
fashion_mnist = tf.keras.datasets.fashion_mnist
(image_train, label_train), (image_test, label_test) = fashion_mnist.load_data()
3.对训练数据进行normalization(使其处于0-1之间,提高模型效率)
# 对输入数据进行normalization
image_train = image_train/255
4.构建模型
# 构造模型(分类问题的输出层activation使用softmax)
model = tf.keras.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=(28, 28))) #输入层,将图片展平成一维
model.add(tf.keras.layers.Dense(128, activation=tf.nn.relu)) #隐藏层(全连接层),神经元数自定义
model.add(tf.keras.layers.Dense(10, activation=tf.nn.softmax)) #输出层,输出数量和分类的label数一直,激活函数使用softmax
5.编译、训练、评估、测试
# 指定模型optimizer(优化算法)和loss function(损失函数)
model.compile(tf.optimizers.Adam(), tf.losses.SparseCategoricalCrossentropy(), metrics=['accuracy'])
# 训练模型
model.fit(image_train, label_train, epochs=5)
# 评估模型
model.evaluate(image_test/255, label_test)

# 验证模型
predict_result = model.predict(np.reshape([[image_test[0]/255]], newshape=(1, 28, 28, 1)))
print(np.argmax(predict_result))

# 保存模型
model.save('./resources/nn_model')

相关文章

网友评论

      本文标题:TensorFlow2x神经网络处理fashion mnist数

      本文链接:https://www.haomeiwen.com/subject/mcbmzktx.html