美文网首页Mnist,没你想的那么简单
(2) Mnist: LeNet-5, 手写数字识别 basel

(2) Mnist: LeNet-5, 手写数字识别 basel

作者: 是风车大渣渣啊 | 来源:发表于2019-05-13 11:26 被阅读0次

LeNet-5 出自论文 Gradient-Based Learning Applied to Document Recognition,LeCun大佬非常早期的作品,用于手写字母识别,在Mnist数据集上能够达到 98% 以上的准确率。

LeNet-5

LeNet-5 Structure

Implement in Keras

基于 Keras 实现 LeNet-5,损失函数使用交叉熵,优化器选择Adam

def lenet_v5(in_shape):
    in_x = Input(in_shape)
    conv1 = Conv2D(filters=6, kernel_size=(5, 5), padding='valid', activation='tanh')(in_x)
    map1 = MaxPooling2D((2, 2))(conv1)
    conv2 = Conv2D(filters=16, kernel_size=(5, 5), padding='valid', activation='tanh')(map1)
    map2 = MaxPooling2D((2, 2))(conv2)
    mac = Flatten()(map2)
    fc1 = Dense(120)(mac)
    fc2 = Dense(84)(fc1)
    prob = Dense(10, activation='softmax')(fc2)

    lenet = Model(inputs=[in_x], outputs=prob)
    lenet.compile(loss='categorical_crossentropy', optimizer=Adam(), metrics=['accuracy'])
    # lenet.summary()
    return lenet

模型评估

数据读取的部分,定义在之前的博客 data_loader

from data_loader import get_test, get_train
from model_defination import lenet_v5
from mnist_utils import train_scheduler


def model_evaluate(model, data_list):
    train_img, train_lbl, test_img, test_lbl = data_list
    model.fit(train_img, train_lbl, epochs=16, verbose=1)
    loss, acc = model.evaluate(test_img, test_lbl)
    print('test loss-%.2f, test accuracy-%.2f%%' % (loss, 100*acc))


if __name__ == "__main__":
    train_img, train_lbl = get_train()
    test_img, test_lbl = get_test()
    datas = (train_img, train_lbl, test_img, test_lbl)
    train_num, *img_shape = train_img.shape

    model = lenet_v5(img_shape)
    model_evaluate(model, datas)

运行结果:

test loss-0.07, test accuracy-98.07

特征可视化

对于训练集的每一个样本,提取最后一层的特征,并且进行以下处理

  • 将每个样本的特征向量标准化成单位向量,将他们映射到以原点为中心的单位圆上
  • 使用 PCA方法将其降低到3维,便于画图
  • 根据类别描点,不同类别赋予不同颜色
feat cluster--2d

可以看出,对于特征的分布均匀,相同类别的特征比较好的聚集在一起。

相关文章

网友评论

    本文标题:(2) Mnist: LeNet-5, 手写数字识别 basel

    本文链接:https://www.haomeiwen.com/subject/yqnioqtx.html