美文网首页
在PaddlePaddle中实现MNIST数据集训练:基础API

在PaddlePaddle中实现MNIST数据集训练:基础API

作者: LabVIEW_Python | 来源:发表于2021-02-23 05:25 被阅读0次

上一节《在PaddlePaddle中实现MNIST数据集训练:高层API》从MNIST数据集下载开始,详细介绍在PaddlePaddle中,基于高层API实现MNIST数据集训练。本节主要介绍在PaddlePaddle中,基于基础API实现MNIST数据集训练。高层API:Model.prepare()、Model.fit()、Model.evaluate()、Model.predict()都是由基础API封装而来,用基础API来实现模型创建与训练,就是用基础API来实现上述高层API的功能。

数据的载入与高层API实现部分一致,不同的是,需要用paddle.io.DataLoader类把paddle.io.Dataset类再封装一次,供基础API使用。

完整范例程序如下所示:

import gzip 
import struct 
import numpy as np 

# train-images-idx3-ubyte 文件格式, 参考:http://yann.lecun.com/exdb/mnist/
'''
[offset] [type]          [value]          [description] 
0000     32 bit integer  0x00000803(2051) magic number 
0004     32 bit integer  60000            number of images 
0008     32 bit integer  28               number of rows 
0012     32 bit integer  28               number of columns 
0016     unsigned byte   ??               pixel 
0017     unsigned byte   ??               pixel 
........ 
xxxx     unsigned byte   ??               pixel
Pixels are organized row-wise. Pixel values are 0 to 255. 
0 means background (white), 255 means foreground (black).
'''
def load_images(image_file):
    # 读取*.gz格式文件
    with gzip.open(image_file) as f:
        buf = f.read()

    idx = 0
    # 读取文件信息
    magic, num_images, rows, cols = struct.unpack_from('>IIII', buf, idx)
    idx += struct.calcsize('>IIII')
    length = int(num_images*rows*cols)
    # 读取图像数据
    images = struct.unpack_from('>'+str(length)+'B', buf, idx)
    images = np.array(images).astype('float32')
    images = images.reshape(num_images, rows, cols)
    # 返回np.ndarray类型, N*r*c 图像数据
    return images


# train-labels-idx1-ubyte.gz 文件格式
'''
[offset] [type]          [value]          [description]
0000     32 bit integer  0x00000801(2049) magic number (MSB first)
0004     32 bit integer  60000            number of items
0008     unsigned byte   ??               label
0009     unsigned byte   ??               label
........
xxxx     unsigned byte   ??               label
The labels values are 0 to 9.
'''
def load_labels(label_file):
    # 读取*.gz格式文件
    with gzip.open(label_file) as f:
        buf = f.read()
    # 读取文件信息
    idx = 0
    magic, num_labels = struct.unpack_from('>II', buf, idx)
    # 读取标签数据
    idx += struct.calcsize('>II')
    labels = struct.unpack_from('>'+str(num_labels)+'B',buf,idx)
 
    labels = np.array(labels).astype('int64')
    # 返回np.ndarray类型, 标签数据
    return labels

train_images = load_images('train-images-idx3-ubyte.gz')
test_images  = load_images('t10k-images-idx3-ubyte.gz')
train_labels = load_labels('train-labels-idx1-ubyte.gz').reshape(-1,1)
test_labels  = load_labels('t10k-labels-idx1-ubyte.gz').reshape(-1,1)

# 图像数据归一化
train_images = train_images / 255.0
test_images  = test_images / 255.0

num_train_samples = train_images.shape[0]
num_test_samples = test_images.shape[0]

import paddle
from paddle.io import Dataset
class TrainDataSet(Dataset):
    """
    步骤一:继承paddle.io.Dataset类
    """
    def __init__(self, num_samples):
        """
        步骤二:实现构造函数,定义数据集大小
        """
        super().__init__()
        self.num_samples = num_samples

    def __getitem__(self, index):
        """
        步骤三:实现__getitem__方法,定义指定index时如何获取数据,并返回单条数据(训练数据,对应的标签)
        """
        data = train_images[index]
        label = train_labels[index]

        return data, label

    def __len__(self):
        """
        步骤四:实现__len__方法,返回数据集总数目
        """
        return self.num_samples

class TestDataSet(Dataset):
    """
    步骤一:继承paddle.io.Dataset类
    """
    def __init__(self, num_samples):
        """
        步骤二:实现构造函数,定义数据集大小
        """
        super().__init__()
        self.num_samples = num_samples

    def __getitem__(self, index):
        """
        步骤三:实现__getitem__方法,定义指定index时如何获取数据,并返回单条数据(训练数据,对应的标签)
        """
        data = test_images[index]
        label = test_labels[index]

        return data, label

    def __len__(self):
        """
        步骤四:实现__len__方法,返回数据集总数目
        """
        return self.num_samples

# 测试定义的数据集
train_dataset = TrainDataSet(num_train_samples)
test_dataset = TestDataSet(num_test_samples)
train_loader = paddle.io.DataLoader(train_dataset, batch_size=100, shuffle=True)
test_loader = paddle.io.DataLoader(test_dataset, batch_size=100, shuffle=True)

# 定义模型
mnist = paddle.nn.Sequential(
    paddle.nn.Flatten(),
    paddle.nn.Linear(784, 512),
    paddle.nn.ReLU(),
    paddle.nn.Dropout(0.2),
    paddle.nn.Linear(512, 10)
)

# 设置模型为训练模式,这只会影响某些模块,如Dropout和BatchNorm
mnist.train()

# 模型训练相关配置,准备损失计算方法,优化器方法
loss_fn = paddle.nn.CrossEntropyLoss()
optim = paddle.optimizer.Adam(parameters=mnist.parameters())

# 设置迭代次数
epochs = 5
# 开始模型训练
for epoch in range(epochs):
    for batch_id, data in enumerate(train_loader()):

        x_data = data[0]            # 训练数据
        y_data = data[1]            # 训练数据标签
        predicts = mnist(x_data)    # 预测结果
        # print(x_data.shape, y_data.shape)
        # 计算损失 等价于 prepare 中loss的设置
        loss = loss_fn(predicts, y_data)

        # 计算准确率 等价于 prepare 中metrics的设置
        acc = paddle.metric.accuracy(predicts, y_data)

        # 下面的反向传播、打印训练信息、更新参数、梯度清零都被封装到 Model.fit() 中

        # 反向传播
        loss.backward()

        if (batch_id+1) % 100 == 0:
            print("epoch: {}, batch_id: {}, loss is: {}, acc is: {}".format(epoch, batch_id+1, loss.numpy(), acc.numpy()))

        # 更新参数
        optim.step()

        # 梯度清零
        optim.clear_grad()

# 用 evaluate 在测试集上对模型进行验证
mnist.eval()

for batch_id, data in enumerate(test_loader()):

    x_data = data[0]            # 测试数据
    y_data = data[1]            # 测试数据标签
    predicts = mnist(x_data)    # 预测结果

    # 计算损失与精度
    loss = loss_fn(predicts, y_data)
    acc = paddle.metric.accuracy(predicts, y_data)

    # 打印信息
    if (batch_id+1) % 100 == 0:
        print("batch_id: {}, loss is: {}, acc is: {}".format(batch_id+1, loss.numpy(), acc.numpy()))

# 用 predict 在测试集上对模型进行测试
mnist.eval()
for batch_id, data in enumerate(test_loader()):
    x_data = data[0]
    predicts = mnist(x_data)
    # 获取预测结果
print(f"predict[0]:{predicts[0]}")
运行结果: 运行结果

相关文章

网友评论

      本文标题:在PaddlePaddle中实现MNIST数据集训练:基础API

      本文链接:https://www.haomeiwen.com/subject/jgvyxltx.html