目的:尝试搭建一个CNN网络,该CNN网络含有两个卷积层、两个下采样层(池化层)、两个全连接层
#
# Copyright @2017 R&D, CINS Inc. (cins.com)
#
# Author: PengjunZhu <1512568691@qq.com>
#
# Function: 使用keras搭建一个CNN网络,该网络包括两个卷积层、两个池化层、两个全连接层
#
# time: 2018/04/24
#
# 参考:https://blog.csdn.net/Miracle_guy/article/details/73124077
print("starting...")
# 调用此次训练的数据集
from keras.datasets import mnist
import numpy as np
np.random.seed(1337)
# 加载数据
(X_train,Y_train),(X_test,Y_test) = mnist.load_data()#X.shape(60000,28,28),Y.shape(10000)
# 数据预处理
from keras.utils import np_utils
X_train = X_train.reshape(-1,1,28,28)
X_test = X_test.reshape(-1,1,28,28)
Y_train = np_utils.to_categorical(Y_train,num_classes = 10)
Y_test = np_utils.to_categorical(Y_test,num_classes = 10)
# 调用模型
from keras.models import Sequential
# 用于模型初始化,Conv2D模型初始化、Activation激活函数,MaxPooling2D是池化层
# Flatten作用是将多位输入进行一维化
# Dense是全连接层
from keras.layers import Conv2D, Activation, MaxPool2D, Flatten, Dense
# 优化方法选用Adam(其实可选项有很多,如SGD)
from keras.optimizers import Adam
# 初始化一个模型
model = Sequential()
# 模型卷积层设计
model.add(Conv2D(
nb_filter=32, # 第一层设置32个滤波器
nb_row=5,
nb_col=5, # 设置滤波器的大小为5*5
padding='same', # 选择滤波器的扫描方式,即是否考虑边缘
input_shape=(1,28,28), # 设置输入的形状
))
# 选择激活函数
model.add(Activation('relu'))
# 设置下采样(池化层)
model.add(MaxPool2D(
pool_size=(2,2), # 下采样格为2*2
strides=(2,2), # 向右向下的步长
padding='same', # padding mode is 'same'
))
model.add(Conv2D(64, (5, 5), padding='same'))
model.add(Activation('relu'))
model.add(MaxPool2D(strides=(2, 2), padding='same'))
# 使用Flatten函数,将输入数据扁平化(因为输入数据是一个多维的形式,需要将其扁平化)
model.add(Flatten()) # 将多维的输入一维化
model.add(Dense(1024)) # 全连接层1024个点
model.add(Activation('relu'))
# 在建设一层
model.add(Dense(10)) # 输入是个类别
model.add(Activation('softmax')) # 用于分类的softmax函数
adam = Adam() # 学习速率lr=0.0001
model.compile(optimizer=adam,
loss='categorical_crossentropy',
metrics=['accuracy'])
print("training ==========~~~~~~~~=======")
model.fit(X_train, Y_train, epochs=1, batch_size=64) # 全部训练次数epochs=1次,每次训练批次大小batch_size=64
print("Testing ==========~~~~~~~~~~~~======")
loss, accuracy = model.evaluate(X_test, Y_test)
print("\nloss:", loss)
print("\nTest:", accuracy)
# 参考:https://blog.csdn.net/Miracle_guy/article/details/73124077
网友评论