美文网首页
卷积神经网络识别手写数字mnist

卷积神经网络识别手写数字mnist

作者: small瓜瓜 | 来源:发表于2020-04-09 06:57 被阅读0次
import tensorflow as tf
from tensorflow.keras import datasets, Sequential, layers


def preprocess(pre_x, pre_y):
    pre_x = tf.cast(pre_x, dtype=tf.float32) / 255.
    pre_y = tf.cast(pre_y, dtype=tf.int32)
    pre_y = tf.one_hot(pre_y, depth=10)
    return pre_x, pre_y


(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
x_train, y_train = preprocess(x_train, y_train)
x_test, y_test = preprocess(x_test, y_test)

# 观察mnist数据类型
print(x_train, y_train, x_test, y_test)

network = Sequential([
    layers.Reshape((28, 28, 1)),
    layers.Conv2D(3, 4, 2, activation=tf.nn.leaky_relu),
    layers.BatchNormalization(),
    layers.Conv2D(12, 3, 1, activation=tf.nn.leaky_relu),
    layers.BatchNormalization(),
    layers.Conv2D(28, 5, 2, activation=tf.nn.leaky_relu),
    layers.BatchNormalization(),
    layers.Flatten(),
    layers.Dense(10, activation=tf.nn.sigmoid)
])


# 超参数
epochs = 20
batch_size = 128
learning_rate = 2e-2

network.build((None, 28, 28, 1))
network.summary()

d_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

network.compile(loss=tf.keras.losses.categorical_crossentropy, optimizer=d_optimizer,
                metrics=['accuracy'])

network.fit(x_train, y_train, epochs=epochs, batch_size=batch_size,
            validation_data=(x_test, y_test))

network.evaluate(x_test, y_test)

运行打印

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
reshape (Reshape)            multiple                  0         
_________________________________________________________________
conv2d (Conv2D)              multiple                  51        
_________________________________________________________________
batch_normalization (BatchNo multiple                  12        
_________________________________________________________________
conv2d_1 (Conv2D)            multiple                  336       
_________________________________________________________________
batch_normalization_1 (Batch multiple                  48        
_________________________________________________________________
conv2d_2 (Conv2D)            multiple                  8428      
_________________________________________________________________
batch_normalization_2 (Batch multiple                  112       
_________________________________________________________________
flatten (Flatten)            multiple                  0         
_________________________________________________________________
dense (Dense)                multiple                  4490      
=================================================================
Total params: 13,477
Trainable params: 13,391
Non-trainable params: 86
_________________________________________________________________
Train on 60000 samples, validate on 10000 samples
Epoch 1/20

   32/10000 [..............................] - ETA: 39s - loss: 0.0063 - accuracy: 1.0000
  320/10000 [..............................] - ETA: 5s - loss: 0.0615 - accuracy: 0.9812 
  640/10000 [>.............................] - ETA: 3s - loss: 0.1109 - accuracy: 0.9750
  928/10000 [=>............................] - ETA: 2s - loss: 0.1096 - accuracy: 0.9752
 1248/10000 [==>...........................] - ETA: 2s - loss: 0.1410 - accuracy: 0.9712
 1536/10000 [===>..........................] - ETA: 2s - loss: 0.1245 - accuracy: 0.9733
 1856/10000 [====>.........................] - ETA: 1s - loss: 0.1315 - accuracy: 0.9731
 2144/10000 [=====>........................] - ETA: 1s - loss: 0.1316 - accuracy: 0.9729
 2464/10000 [======>.......................] - ETA: 1s - loss: 0.1339 - accuracy: 0.9724
 2752/10000 [=======>......................] - ETA: 1s - loss: 0.1397 - accuracy: 0.9720
 3040/10000 [========>.....................] - ETA: 1s - loss: 0.1512 - accuracy: 0.9701
 3328/10000 [========>.....................] - ETA: 1s - loss: 0.1429 - accuracy: 0.9712
 3648/10000 [=========>....................] - ETA: 1s - loss: 0.1433 - accuracy: 0.9715
 3936/10000 [==========>...................] - ETA: 1s - loss: 0.1376 - accuracy: 0.9718
 4256/10000 [===========>..................] - ETA: 1s - loss: 0.1429 - accuracy: 0.9718
 4576/10000 [============>.................] - ETA: 1s - loss: 0.1376 - accuracy: 0.9727
 4864/10000 [=============>................] - ETA: 1s - loss: 0.1347 - accuracy: 0.9731
 5120/10000 [==============>...............] - ETA: 0s - loss: 0.1328 - accuracy: 0.9734
 5408/10000 [===============>..............] - ETA: 0s - loss: 0.1260 - accuracy: 0.9749
 5696/10000 [================>.............] - ETA: 0s - loss: 0.1196 - accuracy: 0.9761
 6016/10000 [=================>............] - ETA: 0s - loss: 0.1144 - accuracy: 0.9771
 6304/10000 [=================>............] - ETA: 0s - loss: 0.1115 - accuracy: 0.9776
 6592/10000 [==================>...........] - ETA: 0s - loss: 0.1097 - accuracy: 0.9779
 6912/10000 [===================>..........] - ETA: 0s - loss: 0.1072 - accuracy: 0.9782
 7232/10000 [====================>.........] - ETA: 0s - loss: 0.1030 - accuracy: 0.9790
 7520/10000 [=====================>........] - ETA: 0s - loss: 0.0994 - accuracy: 0.9795
 7808/10000 [======================>.......] - ETA: 0s - loss: 0.0958 - accuracy: 0.9803
 8128/10000 [=======================>......] - ETA: 0s - loss: 0.0948 - accuracy: 0.9804
 8448/10000 [========================>.....] - ETA: 0s - loss: 0.0922 - accuracy: 0.9809
 8704/10000 [=========================>....] - ETA: 0s - loss: 0.0895 - accuracy: 0.9815
 8992/10000 [=========================>....] - ETA: 0s - loss: 0.0867 - accuracy: 0.9821
 9312/10000 [==========================>...] - ETA: 0s - loss: 0.0839 - accuracy: 0.9827
 9632/10000 [===========================>..] - ETA: 0s - loss: 0.0820 - accuracy: 0.9830
 9920/10000 [============================>.] - ETA: 0s - loss: 0.0831 - accuracy: 0.9830
10000/10000 [==============================] - 2s 191us/sample - loss: 0.0824 - accuracy: 0.9831
6s 105us/sample - loss: 0.0282 - accuracy: 0.9919 - val_loss: 0.0824 - val_accuracy: 0.9831

使用卷积神经网络进行训练识别正确率还非常高的

相关文章

网友评论

      本文标题:卷积神经网络识别手写数字mnist

      本文链接:https://www.haomeiwen.com/subject/eizfmhtx.html