Tensorflow2.0使用sklearn内置的数据集进行人脸识别
首先准备数据集的下载,由于数据集是从国外的网站上下载,可能会报错
此时可以用此方法解决
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
from sklearn import datasets
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
获取数据
faces = datasets.fetch_olivetti_faces()
测试获取的数据
plt.figure(figsize=(20, 25))
for index, img in enumerate(faces.images):
plt.subplot(20, 20, index + 1)
plt.imshow(img, cmap='gray')
# 关闭x轴
plt.xticks([])
# 关闭y轴
plt.yticks([])
plt.xlabel(faces.target[index])
plt.show()
获取训练数据以及测试数据
X = faces.images
y = faces.target
X = X.reshape(400, 64, 64, 1)
抽取训练、测试数据
train_x, test_x, train_y, test_y = train_test_split(X, y, test_size=0.2)
建模
model = tf.keras.Sequential()
model.add(tf.keras.layers.Conv2D(128, kernel_size=3, activation='relu', input_shape=X.shape[1:]))
model.add(tf.keras.layers.Conv2D(64, kernel_size=3, activation='relu'))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(40, activation='softmax'))
编译
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
开始训练
model.fit(train_x, train_y, epochs=8, validation_data=(test_x, test_y))
320/320 [==============================] - 7s 21ms/sample - loss: 5.0002 - accuracy: 0.0250 - val_loss: 3.6919 - val_accuracy: 0.0125
Epoch 2/9
320/320 [==============================] - 6s 19ms/sample - loss: 3.6688 - accuracy: 0.0656 - val_loss: 3.6290 - val_accuracy: 0.1250
Epoch 3/9
320/320 [==============================] - 6s 20ms/sample - loss: 3.4579 - accuracy: 0.1813 - val_loss: 3.5491 - val_accuracy: 0.0750
Epoch 4/9
320/320 [==============================] - 6s 20ms/sample - loss: 2.8424 - accuracy: 0.4563 - val_loss: 2.4117 - val_accuracy: 0.6000
Epoch 5/9
320/320 [==============================] - 7s 21ms/sample - loss: 1.5031 - accuracy: 0.7969 - val_loss: 1.5258 - val_accuracy: 0.7000
Epoch 6/9
320/320 [==============================] - 7s 22ms/sample - loss: 0.5492 - accuracy: 0.9187 - val_loss: 0.6792 - val_accuracy: 0.8500
Epoch 7/9
320/320 [==============================] - 7s 22ms/sample - loss: 0.1736 - accuracy: 0.9781 - val_loss: 0.7218 - val_accuracy: 0.7625
Epoch 8/9
320/320 [==============================] - 7s 21ms/sample - loss: 0.0640 - accuracy: 0.9969 - val_loss: 0.6178 - val_accuracy: 0.7875
当在第9个时候回过拟合
网友评论