注:在很长一段时间,MNIST数据集都是机器学习界很多分类算法的benchmark。初学深度学习,在这个数据集上训练一个有效的卷积神经网络就相当于学习编程的时候打印出一行“Hello World!”。下面基于与MNIST数据集非常类似的另一个数据集Fashion-MNIST数据集来构建一个卷积神经网络。
0. Fashion-MNIST数据集
<bi style="box-sizing: border-box; display: block;">In fact, MNIST is often the first dataset researchers try. "If it doesn't work on MNIST, it won't work at all", they said. "Well, if it does work on MNIST, it may still fail on others."</bi>
- MNIST太容易了,卷积神经网络可以达到99.7%的正确率,传统的分类算法也能很轻易的达到97%的正确率;
- 被过度使用了;
- 不能很好的代表现代计算机视觉任务.
LabelDescription0T-shirt/top1Trouser2Pullover3Dress4Coat5Sandal6Shirt7Sneaker8Bag9Ankle boot
<tt-image data-tteditor-tag="tteditorTag" contenteditable="false" class="syl1555920872991" data-render-status="finished" data-syl-blot="image" style="box-sizing: border-box; cursor: text; color: rgb(34, 34, 34); font-family: "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei", "WenQuanYi Micro Hei", "Helvetica Neue", Arial, sans-serif; font-size: 16px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-weight: 400; letter-spacing: normal; orphans: 2; text-align: left; text-indent: 0px; text-transform: none; white-space: pre-wrap; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; background-color: rgb(255, 255, 255); text-decoration-style: initial; text-decoration-color: initial; display: block;"> image<input class="pgc-img-caption-ipt" placeholder="图片描述(最多50字)" value="" style="box-sizing: border-box; outline: 0px; color: rgb(102, 102, 102); position: absolute; left: 187.5px; transform: translateX(-50%); padding: 6px 7px; max-width: 100%; width: 375px; text-align: center; cursor: text; font-size: 12px; line-height: 1.5; background-color: rgb(255, 255, 255); background-image: none; border: 0px solid rgb(217, 217, 217); border-radius: 4px; transition: all 0.2s cubic-bezier(0.645, 0.045, 0.355, 1) 0s;"></tt-image>
图0-1:Fashion-MNIST 中的图片示例
为了便于使用,TF 收集了常用的数据集,制作成了一个独立的 Python package。可以通过以下方式安装:
- 更多关于该数据集的信息可参考:https://github.com/tensorflow/datasets
<pre spellcheck="false" style="box-sizing: border-box; margin: 5px 0px; padding: 5px 10px; border: 0px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-variant-numeric: inherit; font-variant-east-asian: inherit; font-weight: 400; font-stretch: inherit; font-size: 16px; line-height: inherit; font-family: inherit; vertical-align: baseline; cursor: text; counter-reset: list-1 0 list-2 0 list-3 0 list-4 0 list-5 0 list-6 0 list-7 0 list-8 0 list-9 0; background-color: rgb(240, 240, 240); border-radius: 3px; white-space: pre-wrap; color: rgb(34, 34, 34); letter-spacing: normal; orphans: 2; text-align: left; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-style: initial; text-decoration-color: initial;">pip install -U tensorflow_datasets
1. 普通神经网络
1.1 导入依赖的包
下面导入了一些必要的 package(包括前面安装的 tensorflow_datasets),并且输出了当前使用的 TensorFlow(TF) 的版本号。如果不是最新的 TF,可以使用下面的命令安装最新的TF。
<pre spellcheck="false" style="box-sizing: border-box; margin: 5px 0px; padding: 5px 10px; border: 0px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-variant-numeric: inherit; font-variant-east-asian: inherit; font-weight: 400; font-stretch: inherit; font-size: 16px; line-height: inherit; font-family: inherit; vertical-align: baseline; cursor: text; counter-reset: list-1 0 list-2 0 list-3 0 list-4 0 list-5 0 list-6 0 list-7 0 list-8 0 list-9 0; background-color: rgb(240, 240, 240); border-radius: 3px; white-space: pre-wrap; color: rgb(34, 34, 34); letter-spacing: normal; orphans: 2; text-align: left; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-style: initial; text-decoration-color: initial;">pip install tensorflow==2.0.0-alpha0 # 安装最新版的TF
<pre spellcheck="false" style="box-sizing: border-box; margin: 5px 0px; padding: 5px 10px; border: 0px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-variant-numeric: inherit; font-variant-east-asian: inherit; font-weight: 400; font-stretch: inherit; font-size: 16px; line-height: inherit; font-family: inherit; vertical-align: baseline; cursor: text; counter-reset: list-1 0 list-2 0 list-3 0 list-4 0 list-5 0 list-6 0 list-7 0 list-8 0 list-9 0; background-color: rgb(240, 240, 240); border-radius: 3px; white-space: pre-wrap; color: rgb(34, 34, 34); letter-spacing: normal; orphans: 2; text-align: left; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-style: initial; text-decoration-color: initial;"> 1 from future import absolute_import, division, print_function
4 # Import TensorFlow and TensorFlow Datasets
5 import tensorflow as tf
6 import tensorflow_datasets as tfds
8 # Helper libraries
9 import math
10 import numpy as np
11 import matplotlib.pyplot as plt
13 # Improve progress bar display
14 import tqdm
15 import tqdm.auto
16 tqdm.tqdm = tqdm.auto.tqdm
19 print(tf.version) # 2.0.0-alpha0
21 # This will go away in the future.
22 # If this gives an error, you might be running TensorFlow 2 or above
23 # If so, the just comment out this line and run this cell again
24 # tf.enable_eager_execution()
1.2 导入数据集
准备就绪,就可以从 tensorflow_datasets 中导入Fashion-MNIST数据集了:
加载的过程中,会自动 shuffle 数据;
该数据集与MNIST数据集相同,train_dataset 中包含60000张图片用来做训练集,test_dataset 中包含10000张图片用来做测试集.
<pre spellcheck="false" style="box-sizing: border-box; margin: 5px 0px; padding: 5px 10px; border: 0px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-variant-numeric: inherit; font-variant-east-asian: inherit; font-weight: 400; font-stretch: inherit; font-size: 16px; line-height: inherit; font-family: inherit; vertical-align: baseline; cursor: text; counter-reset: list-1 0 list-2 0 list-3 0 list-4 0 list-5 0 list-6 0 list-7 0 list-8 0 list-9 0; background-color: rgb(240, 240, 240); border-radius: 3px; white-space: pre-wrap; color: rgb(34, 34, 34); letter-spacing: normal; orphans: 2; text-align: left; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-style: initial; text-decoration-color: initial;">dataset, metadata = tfds.load('fashion_mnist', as_supervised=True, with_info=True)
train_dataset, test_dataset = dataset['train'], dataset['test']
<pre spellcheck="false" style="box-sizing: border-box; margin: 5px 0px; padding: 5px 10px; border: 0px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-variant-numeric: inherit; font-variant-east-asian: inherit; font-weight: 400; font-stretch: inherit; font-size: 16px; line-height: inherit; font-family: inherit; vertical-align: baseline; cursor: text; counter-reset: list-1 0 list-2 0 list-3 0 list-4 0 list-5 0 list-6 0 list-7 0 list-8 0 list-9 0; background-color: rgb(240, 240, 240); border-radius: 3px; white-space: pre-wrap; color: rgb(34, 34, 34); letter-spacing: normal; orphans: 2; text-align: left; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-style: initial; text-decoration-color: initial;">class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
可以利用 metadata 来查看数据集的信息:
- 下面会输出训练集和测试集中样本的个数
<pre spellcheck="false" style="box-sizing: border-box; margin: 5px 0px; padding: 5px 10px; border: 0px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-variant-numeric: inherit; font-variant-east-asian: inherit; font-weight: 400; font-stretch: inherit; font-size: 16px; line-height: inherit; font-family: inherit; vertical-align: baseline; cursor: text; counter-reset: list-1 0 list-2 0 list-3 0 list-4 0 list-5 0 list-6 0 list-7 0 list-8 0 list-9 0; background-color: rgb(240, 240, 240); border-radius: 3px; white-space: pre-wrap; color: rgb(34, 34, 34); letter-spacing: normal; orphans: 2; text-align: left; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-style: initial; text-decoration-color: initial;"># metadata包含一些关于该数据集的元信息,包括数据集的description, url, version等信息
num_train_examples = metadata.splits['train'].num_examples
num_test_examples = metadata.splits['test'].num_examples
print("Number of training examples: {}".format(num_train_examples))
print("Number of test examples: {}".format(num_test_examples))
1.3 数据的预处理
原始数据中图片的每个像素由[0, 255]区间上的整数表示。为了更好的训练模型,需要将所有的值都标准化到区间[0, 1]。
- 经过测试,如果不做这一步,最终在测试集的准确率会下降大概8%。
<pre spellcheck="false" style="box-sizing: border-box; margin: 5px 0px; padding: 5px 10px; border: 0px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-variant-numeric: inherit; font-variant-east-asian: inherit; font-weight: 400; font-stretch: inherit; font-size: 16px; line-height: inherit; font-family: inherit; vertical-align: baseline; cursor: text; counter-reset: list-1 0 list-2 0 list-3 0 list-4 0 list-5 0 list-6 0 list-7 0 list-8 0 list-9 0; background-color: rgb(240, 240, 240); border-radius: 3px; white-space: pre-wrap; color: rgb(34, 34, 34); letter-spacing: normal; orphans: 2; text-align: left; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-style: initial; text-decoration-color: initial;">1 def normalize(images, labels):
2 images = tf.cast(images, tf.float32) # Casts a tensor to a new type
3 images /= 255
4 return images, labels
6 # The map function applies the normalize function to each element in the train
7 # and test datasets
8 train_dataset = train_dataset.map(normalize)
9 test_dataset = test_dataset.map(normalize)
<pre spellcheck="false" style="box-sizing: border-box; margin: 5px 0px; padding: 5px 10px; border: 0px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-variant-numeric: inherit; font-variant-east-asian: inherit; font-weight: 400; font-stretch: inherit; font-size: 16px; line-height: inherit; font-family: inherit; vertical-align: baseline; cursor: text; counter-reset: list-1 0 list-2 0 list-3 0 list-4 0 list-5 0 list-6 0 list-7 0 list-8 0 list-9 0; background-color: rgb(240, 240, 240); border-radius: 3px; white-space: pre-wrap; color: rgb(34, 34, 34); letter-spacing: normal; orphans: 2; text-align: left; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-style: initial; text-decoration-color: initial;"># Take a single image, and remove the color dimension by reshaping
for image, label in test_dataset.take(1):
print(image.shape, label.shape)
image = image.numpy().reshape((28,28))
Plot the image - voila a piece of fashion clothing
plt.imshow(image, cmap=plt.cm.binary)
<input class="pgc-img-caption-ipt" placeholder="图片描述(最多50字)" value="" style="box-sizing: border-box; outline: 0px; color: rgb(102, 102, 102); position: absolute; left: 187.5px; transform: translateX(-50%); padding: 6px 7px; max-width: 100%; width: 375px; text-align: center; cursor: text; font-size: 12px; line-height: 1.5; background-color: rgb(255, 255, 255); background-image: none; border: 0px solid rgb(217, 217, 217); border-radius: 4px; transition: all 0.2s cubic-bezier(0.645, 0.045, 0.355, 1) 0s;"></tt-image>
<pre spellcheck="false" style="box-sizing: border-box; margin: 5px 0px; padding: 5px 10px; border: 0px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-variant-numeric: inherit; font-variant-east-asian: inherit; font-weight: 400; font-stretch: inherit; font-size: 16px; line-height: inherit; font-family: inherit; vertical-align: baseline; cursor: text; counter-reset: list-1 0 list-2 0 list-3 0 list-4 0 list-5 0 list-6 0 list-7 0 list-8 0 list-9 0; background-color: rgb(240, 240, 240); border-radius: 3px; white-space: pre-wrap; color: rgb(34, 34, 34); letter-spacing: normal; orphans: 2; text-align: left; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-style: initial; text-decoration-color: initial;"> 1 plt.figure(figsize=(10,10))
2 i = 0
3 for (image, label) in train_dataset.take(25):
4 image = image.numpy().reshape((28,28))
5 plt.subplot(5,5,i+1)
6 plt.xticks([])
7 plt.yticks([])
8 plt.grid(False)
9 plt.imshow(image, cmap=plt.cm.binary)
10 plt.xlabel(class_names[label])
11 i += 1
12 plt.show()
<input class="pgc-img-caption-ipt" placeholder="图片描述(最多50字)" value="" style="box-sizing: border-box; outline: 0px; color: rgb(102, 102, 102); position: absolute; left: 187.5px; transform: translateX(-50%); padding: 6px 7px; max-width: 100%; width: 375px; text-align: center; cursor: text; font-size: 12px; line-height: 1.5; background-color: rgb(255, 255, 255); background-image: none; border: 0px solid rgb(217, 217, 217); border-radius: 4px; transition: all 0.2s cubic-bezier(0.645, 0.045, 0.355, 1) 0s;"></tt-image>
1.4 建立模型
1.4.1 构建网络
- 网络中包含的总层数;
- 每一层的类型:例如Flattten,Dense等;
- 每一层中包含的神经单元的个数;
- 每一层使用的激活函数:例如Relu,Softmax等,不设置该参数表示不对该层进行任何非线性变换.
<pre spellcheck="false" style="box-sizing: border-box; margin: 5px 0px; padding: 5px 10px; border: 0px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-variant-numeric: inherit; font-variant-east-asian: inherit; font-weight: 400; font-stretch: inherit; font-size: 16px; line-height: inherit; font-family: inherit; vertical-align: baseline; cursor: text; counter-reset: list-1 0 list-2 0 list-3 0 list-4 0 list-5 0 list-6 0 list-7 0 list-8 0 list-9 0; background-color: rgb(240, 240, 240); border-radius: 3px; white-space: pre-wrap; color: rgb(34, 34, 34); letter-spacing: normal; orphans: 2; text-align: left; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-style: initial; text-decoration-color: initial;">1 model = tf.keras.Sequential([
2 tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
3 tf.keras.layers.Dense(128, activation=tf.nn.relu),
4 tf.keras.layers.Dense(10, activation=tf.nn.softmax)
5 ])
- 第一层是Flatten层(下图中的l0),输入的单个样本是一个28*28的矩阵(矩阵每一个元素的值表示图片中对应的一个像素点的值),输出一个长度为784的向量;
- 第二层是Dense层(下图中的l1),输入是上一层的输出,即长度为784的向量;该层具有128个神经单元,激活函数为Relu;输出为一个长度为128的向量;
- 第三层是Dense层(下图中的l2),输入是上一层的输出;该层具有10个神经单元,激活函数为Softmax;输出为一个长度为10的向量,也是该网络的输出层.
<input class="pgc-img-caption-ipt" placeholder="图片描述(最多50字)" value="" style="box-sizing: border-box; outline: 0px; color: rgb(102, 102, 102); position: absolute; left: 187.5px; transform: translateX(-50%); padding: 6px 7px; max-width: 100%; width: 375px; text-align: center; cursor: text; font-size: 12px; line-height: 1.5; background-color: rgb(255, 255, 255); background-image: none; border: 0px solid rgb(217, 217, 217); border-radius: 4px; transition: all 0.2s cubic-bezier(0.645, 0.045, 0.355, 1) 0s;"></tt-image>
1.4.2 编译
- 损失函数(Loss function):评价模型的好坏;
- 优化器(Optimizer):根据误差和梯度更新参数,从而最小化误差;
- 评估标准(Metrics):同样用于评价模型的好坏.
- 都是评价模型好坏的方式,且具有高度的相关性;
- 损失函数必须可导,是待训练参数的函数,模型的训练过程就是基于损失函数的优化过程;
- 评估标准不一定可导,具有更好的可解释性,例如分类问题中分类的准确率.
<pre spellcheck="false" style="box-sizing: border-box; margin: 5px 0px; padding: 5px 10px; border: 0px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-variant-numeric: inherit; font-variant-east-asian: inherit; font-weight: 400; font-stretch: inherit; font-size: 16px; line-height: inherit; font-family: inherit; vertical-align: baseline; cursor: text; counter-reset: list-1 0 list-2 0 list-3 0 list-4 0 list-5 0 list-6 0 list-7 0 list-8 0 list-9 0; background-color: rgb(240, 240, 240); border-radius: 3px; white-space: pre-wrap; color: rgb(34, 34, 34); letter-spacing: normal; orphans: 2; text-align: left; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-style: initial; text-decoration-color: initial;">model.compile(optimizer='adam',
**1.5 训练模型 **
- 批次大小(batch size):单次训练模型使用的样本数(下面设置该值为32,也就是每次训练只使用全部训练集中的32个样本,使用完所有训练集样本需要训练60000/32=1875次);
- 训练迭代次数(epochs):在整个训练集上训练的次数,如果该值为5且批次大小为32,那么参数总共会更新5*1875次(也就是说训练集中的每张图片会被用到5次);
<pre spellcheck="false" style="box-sizing: border-box; margin: 5px 0px; padding: 5px 10px; border: 0px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-variant-numeric: inherit; font-variant-east-asian: inherit; font-weight: 400; font-stretch: inherit; font-size: 16px; line-height: inherit; font-family: inherit; vertical-align: baseline; cursor: text; counter-reset: list-1 0 list-2 0 list-3 0 list-4 0 list-5 0 list-6 0 list-7 0 list-8 0 list-9 0; background-color: rgb(240, 240, 240); border-radius: 3px; white-space: pre-wrap; color: rgb(34, 34, 34); letter-spacing: normal; orphans: 2; text-align: left; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-style: initial; text-decoration-color: initial;">BATCH_SIZE = 32
train_dataset = train_dataset.repeat().shuffle(num_train_examples).batch(BATCH_SIZE)
test_dataset = test_dataset.batch(BATCH_SIZE)
model.fit(train_dataset, epochs=5, steps_per_epoch=math.ceil(num_train_examples/BATCH_SIZE))
<pre spellcheck="false" style="box-sizing: border-box; margin: 5px 0px; padding: 5px 10px; border: 0px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-variant-numeric: inherit; font-variant-east-asian: inherit; font-weight: 400; font-stretch: inherit; font-size: 16px; line-height: inherit; font-family: inherit; vertical-align: baseline; cursor: text; counter-reset: list-1 0 list-2 0 list-3 0 list-4 0 list-5 0 list-6 0 list-7 0 list-8 0 list-9 0; background-color: rgb(240, 240, 240); border-radius: 3px; white-space: pre-wrap; color: rgb(34, 34, 34); letter-spacing: normal; orphans: 2; text-align: left; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-style: initial; text-decoration-color: initial;">Epoch 1/5
1875/1875 [==============================] - 24s 13ms/step - loss: 0.2735 - accuracy: 0.8981
Epoch 2/5
1875/1875 [==============================] - 14s 8ms/step - loss: 0.2719 - accuracy: 0.8995
Epoch 3/5
1875/1875 [==============================] - 14s 8ms/step - loss: 0.2613 - accuracy: 0.9018
Epoch 4/5
1875/1875 [==============================] - 13s 7ms/step - loss: 0.2457 - accuracy: 0.9087
Epoch 5/5
1875/1875 [==============================] - 13s 7ms/step - loss: 0.2407 - accuracy: 0.9091
<tensorflow.python.keras.callbacks.History at 0x7fe5305bca58>
1.6 模型的最终评价
<pre spellcheck="false" style="box-sizing: border-box; margin: 5px 0px; padding: 5px 10px; border: 0px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-variant-numeric: inherit; font-variant-east-asian: inherit; font-weight: 400; font-stretch: inherit; font-size: 16px; line-height: inherit; font-family: inherit; vertical-align: baseline; cursor: text; counter-reset: list-1 0 list-2 0 list-3 0 list-4 0 list-5 0 list-6 0 list-7 0 list-8 0 list-9 0; background-color: rgb(240, 240, 240); border-radius: 3px; white-space: pre-wrap; color: rgb(34, 34, 34); letter-spacing: normal; orphans: 2; text-align: left; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-style: initial; text-decoration-color: initial;">test_loss, test_accuracy = model.evaluate(test_dataset, steps=math.ceil(num_test_examples/32))
print('Accuracy on test dataset:', test_accuracy)
<pre spellcheck="false" style="box-sizing: border-box; margin: 5px 0px; padding: 5px 10px; border: 0px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-variant-numeric: inherit; font-variant-east-asian: inherit; font-weight: 400; font-stretch: inherit; font-size: 16px; line-height: inherit; font-family: inherit; vertical-align: baseline; cursor: text; counter-reset: list-1 0 list-2 0 list-3 0 list-4 0 list-5 0 list-6 0 list-7 0 list-8 0 list-9 0; background-color: rgb(240, 240, 240); border-radius: 3px; white-space: pre-wrap; color: rgb(34, 34, 34); letter-spacing: normal; orphans: 2; text-align: left; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-style: initial; text-decoration-color: initial;">313/313 [==============================] - 2s 6ms/step - loss: 0.3582 - accuracy: 0.8772
Accuracy on test dataset: 0.8772
1.7 使用模型进行预测以及结果的可视化
下面从测试集取一个 batch 的样本(32个样本)进行预测,并将真实的label保存在test_labels中,最终得到第一个样本的预测分类与真实分类都是6.
<pre spellcheck="false" style="box-sizing: border-box; margin: 5px 0px; padding: 5px 10px; border: 0px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-variant-numeric: inherit; font-variant-east-asian: inherit; font-weight: 400; font-stretch: inherit; font-size: 16px; line-height: inherit; font-family: inherit; vertical-align: baseline; cursor: text; counter-reset: list-1 0 list-2 0 list-3 0 list-4 0 list-5 0 list-6 0 list-7 0 list-8 0 list-9 0; background-color: rgb(240, 240, 240); border-radius: 3px; white-space: pre-wrap; color: rgb(34, 34, 34); letter-spacing: normal; orphans: 2; text-align: left; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-style: initial; text-decoration-color: initial;">for test_images, test_labels in test_dataset.take(1):
test_images = test_images.numpy()
test_labels = test_labels.numpy()
predictions = model.predict(test_images)
np.argmax(predictions[0]), test_labels[0] # (6, 6)
<pre spellcheck="false" style="box-sizing: border-box; margin: 5px 0px; padding: 5px 10px; border: 0px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-variant-numeric: inherit; font-variant-east-asian: inherit; font-weight: 400; font-stretch: inherit; font-size: 16px; line-height: inherit; font-family: inherit; vertical-align: baseline; cursor: text; counter-reset: list-1 0 list-2 0 list-3 0 list-4 0 list-5 0 list-6 0 list-7 0 list-8 0 list-9 0; background-color: rgb(240, 240, 240); border-radius: 3px; white-space: pre-wrap; color: rgb(34, 34, 34); letter-spacing: normal; orphans: 2; text-align: left; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-style: initial; text-decoration-color: initial;"> 1 def plot_image(i, predictions_array, true_labels, images):
2 predictions_array, true_label, img = predictions_array[i], true_labels[i], images[i]
3 plt.grid(False)
4 plt.xticks([])
5 plt.yticks([])
7 plt.imshow(img[...,0], cmap=plt.cm.binary)
9 predicted_label = np.argmax(predictions_array)
10 if predicted_label == true_label:
11 color = 'blue'
12 else:
13 color = 'red'
15 plt.xlabel("{} {:2.0f}% ({})".format(class_names[predicted_label],
16 100np.max(predictions_array),
17 class_names[true_label]),
18 color=color)
20 def plot_value_array(i, predictions_array, true_label):
21 predictions_array, true_label = predictions_array[i], true_label[i]
22 plt.grid(False)
23 plt.xticks([])
24 plt.yticks([])
25 thisplot = plt.bar(range(10), predictions_array, color="#777777")
26 plt.ylim([0, 1])
27 predicted_label = np.argmax(predictions_array)
29 thisplot[predicted_label].set_color('red')
30 thisplot[true_label].set_color('blue')
32 # Plot the first X test images, their predicted label, and the true label
33 # Color correct predictions in blue, incorrect predictions in red
34 num_rows = 5
35 num_cols = 3
36 num_images = num_rowsnum_cols
37 plt.figure(figsize=(22num_cols, 2num_rows))
38 for i in range(num_images):
39 plt.subplot(num_rows, 2num_cols, 2i+1)
40 plot_image(i, predictions, test_labels, test_images)
41 plt.subplot(num_rows, 2num_cols, 2*i+2)
42 plot_value_array(i, predictions, test_labels)
<tt-image data-tteditor-tag="tteditorTag" contenteditable="false" class="syl1555920709159" data-render-status="finished" data-syl-blot="image" style="box-sizing: border-box; cursor: text; color: rgb(34, 34, 34); font-family: "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei", "WenQuanYi Micro Hei", "Helvetica Neue", Arial, sans-serif; font-size: 16px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-weight: 400; letter-spacing: normal; orphans: 2; text-align: left; text-indent: 0px; text-transform: none; white-space: pre-wrap; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; background-color: rgb(255, 255, 255); text-decoration-style: initial; text-decoration-color: initial; display: block;"> image<input class="pgc-img-caption-ipt" placeholder="图片描述(最多50字)" value="" style="box-sizing: border-box; outline: 0px; color: rgb(102, 102, 102); position: absolute; left: 187.5px; transform: translateX(-50%); padding: 6px 7px; max-width: 100%; width: 375px; text-align: center; cursor: text; font-size: 12px; line-height: 1.5; background-color: rgb(255, 255, 255); background-image: none; border: 0px solid rgb(217, 217, 217); border-radius: 4px; transition: all 0.2s cubic-bezier(0.645, 0.045, 0.355, 1) 0s;"></tt-image>
2. 卷积神经网络
前面直接使用全连接层加上激活函数,已经取得了非常好分类效果:测试集的准确率为88%。实现卷积神经网络只需要改动网络的结构(1.4.1 构建网络)这一部分就可以了:
<pre spellcheck="false" style="box-sizing: border-box; margin: 5px 0px; padding: 5px 10px; border: 0px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-variant-numeric: inherit; font-variant-east-asian: inherit; font-weight: 400; font-stretch: inherit; font-size: 16px; line-height: inherit; font-family: inherit; vertical-align: baseline; cursor: text; counter-reset: list-1 0 list-2 0 list-3 0 list-4 0 list-5 0 list-6 0 list-7 0 list-8 0 list-9 0; background-color: rgb(240, 240, 240); border-radius: 3px; white-space: pre-wrap; color: rgb(34, 34, 34); letter-spacing: normal; orphans: 2; text-align: left; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-style: initial; text-decoration-color: initial;"> 1 model = tf.keras.Sequential([
2 tf.keras.layers.Conv2D(32, (3,3), padding='same', activation=tf.nn.relu,
3 input_shape=(28, 28, 1)),
4 tf.keras.layers.MaxPooling2D((2, 2), strides=2),
5 tf.keras.layers.Conv2D(64, (3,3), padding='same', activation=tf.nn.relu),
6 tf.keras.layers.MaxPooling2D((2, 2), strides=2),
7 tf.keras.layers.Flatten(),
8 tf.keras.layers.Dense(128, activation=tf.nn.relu),
9 tf.keras.layers.Dense(10, activation=tf.nn.softmax)
10 ])
2.1 卷积层
Conv2D表示二维卷积层(2D convolution layer),主要参数如下:
- filters:过滤器(filter或kernal)的个数n,每一个过滤器都可以对上一层的整个图片进行卷积操作,得到n个激活图(activation map)。例如上面的网络结构中第一个卷积层中n=32,表示该层有32个过滤器,因此该层处理后得到的结果的维度是(28, 28, 32);
- kernel_size:过滤器的大小,因为这里使用的图片是灰度图片只有1个channel(彩色图片有3个channel),因此kernal的深度也为1,只需要设定kernal的长和宽。上面两个卷积层都是用了(3, 3)大小的过滤器;
- padding:padding的处理方式,如果不padding,过滤后原图片边缘的信息会丢失。本例中该参数都设置为"same",会在原图像周围补0,从而保持过滤后图像的长宽保持不变;
- 激活函数:同其他层,用于对神经单元的值做非线性变换.
<tt-image data-tteditor-tag="tteditorTag" contenteditable="false" class="syl1555920709187" data-render-status="finished" data-syl-blot="image" style="box-sizing: border-box; cursor: text; color: rgb(34, 34, 34); font-family: "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei", "WenQuanYi Micro Hei", "Helvetica Neue", Arial, sans-serif; font-size: 16px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-weight: 400; letter-spacing: normal; orphans: 2; text-align: left; text-indent: 0px; text-transform: none; white-space: pre-wrap; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; background-color: rgb(255, 255, 255); text-decoration-style: initial; text-decoration-color: initial; display: block;"> image<input class="pgc-img-caption-ipt" placeholder="图片描述(最多50字)" value="" style="box-sizing: border-box; outline: 0px; color: rgb(102, 102, 102); position: absolute; left: 187.5px; transform: translateX(-50%); padding: 6px 7px; max-width: 100%; width: 375px; text-align: center; cursor: text; font-size: 12px; line-height: 1.5; background-color: rgb(255, 255, 255); background-image: none; border: 0px solid rgb(217, 217, 217); border-radius: 4px; transition: all 0.2s cubic-bezier(0.645, 0.045, 0.355, 1) 0s;"></tt-image>
图2-1 卷积层过滤
本文更多的是介绍利用 TF 2.0 实现神经网络的方式,关于卷积层的更多知识点可以参考下面的链接:
http://cs231n.stanford.edu/syllabus.html,Convolutional Neural Networks相关部分
2.2 最大池化层
MaxPooling2D表示2维最大池化层,用于对原图像进行下采用(down sampling),从而减小图片大小,降低训练难度。最大池化操作一般与卷积操作连在一起使用。主要参数如下:
- pool_size:池化窗口的大小。例如上面两个最大池化操作的窗口大小都为(2, 2);
- strides:步幅,窗口平移时间隔的距离。例如上面的设置都为2,表示窗口平移时,下一个窗口与上一个窗口间隔两个像素.
<input class="pgc-img-caption-ipt" placeholder="图片描述(最多50字)" value="" style="box-sizing: border-box; outline: 0px; color: rgb(102, 102, 102); position: absolute; left: 187.5px; transform: translateX(-50%); padding: 6px 7px; max-width: 100%; width: 375px; text-align: center; cursor: text; font-size: 12px; line-height: 1.5; background-color: rgb(255, 255, 255); background-image: none; border: 0px solid rgb(217, 217, 217); border-radius: 4px; transition: all 0.2s cubic-bezier(0.645, 0.045, 0.355, 1) 0s;"></tt-image>
图2-2 使用(2, 2),步幅为2的窗口进行最大池化操作
最大池化就是只保留每个窗口中的最大值。如上图所示,按照(2, 2)的窗口大小和2的步幅,在左边(4, 4)的图像中只有4个窗口,每个窗口取最大值就可以得到右边的结果。
2.3 CNN的位置不变性
由卷积层和最大池化层构成的卷积神经网络将 Fashion-MNIST 测试集图片分类的正确率提高到了92%.
3. 小结
- 准备数据集:明确数据的特征、标签和样本总数,将数据集拆分成训练集和测试集(有时候还会包括验证集),数据的预处理(例如标准化等操作);
- 定义网络结构:在 Keras 和 TF 2.0 中,层(layer)是网络的基本结构,所有的网络类型都可以使用基本类型的层搭建起来。这里需要确定网络的层数,每一层的类型、激活函数、神经单元的个数等超参数;
- 编译模型:编译构建好的网络,需要明确三个参数,损失函数(loss function)、优化器(optimizer)和评估标准(metrics);
- 训练模型:需要指定批次大小(batch size)和迭代次数(epochs);
- 评价模型:在测试集上评价模型的效果.
- 两分类:binary crossentropy
- 对分类问题:categorical crossentropy
- 回归问题:mean-squared error
Deep Learning with Python, by François Chollet, 2017.11
文章来至:Belter 微博:昕-2008