美文网首页
百度飞桨-手写数字识别实验

百度飞桨-手写数字识别实验

作者: 竞媒体 | 来源:发表于2023-04-11 16:37 被阅读0次

1.打开上篇安装的 jupterhub


1681288404695.jpg

2.新建 python3 notebook

输入如下代码

import paddle
import numpy as np
from paddle.vision.transforms import Normalize

transform = Normalize(mean=[127.5], std=[127.5], data_format='CHW')

下载数据集并初始化 DataSet

train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)

模型组网并初始化网络

lenet = paddle.vision.models.LeNet(num_classes=10)
model = paddle.Model(lenet)

模型训练的配置准备,准备损失函数,优化器和评价指标

model.prepare(paddle.optimizer.Adam(parameters=model.parameters()),
paddle.nn.CrossEntropyLoss(),
paddle.metric.Accuracy())

模型训练

model.fit(train_dataset, epochs=5, batch_size=64, verbose=1)

模型评估

model.evaluate(test_dataset, batch_size=64, verbose=1)

保存模型

model.save('./output/mnist')

加载模型

model.load('output/mnist')

从测试集中取出一张图片

img, label = test_dataset[0]

将图片shape从12828变为1128*28,增加一个batch维度,以匹配模型输入格式要求

img_batch = np.expand_dims(img.astype('float32'), axis=0)

执行推理并打印结果,此处predict_batch返回的是一个list,取出其中数据获得预测结果

out = model.predict_batch(img_batch)[0]
pred_label = out.argmax()
print('true label: {}, pred label: {}'.format(label[0], pred_label))

可视化图片

from matplotlib import pyplot as plt
plt.imshow(img[0])

3.点击运行


1681287139258.jpg

4.查看结果,可以正确的识别数字,整个流程和用到的关键 API 如下图所示


model_develop_flow.png

相关文章

网友评论

      本文标题:百度飞桨-手写数字识别实验

      本文链接:https://www.haomeiwen.com/subject/dcwgddtx.html