美文网首页
Mnist数据集的处理

Mnist数据集的处理

作者: 带带吴腾跃 | 来源:发表于2019-11-14 17:03 被阅读0次

    不考虑从tensorflow或者keras等平台上直接下载转换好的mnist数据集的方法,直接手动处理mnist官方的数据集。分别有四个文件,对应训练集图像,训练集标签,测试集图像,测试集标签。

    官网给的数据集并不是图像数据格式,而是编码后的二进制格式。这是官网的数据说明:


    image.png

    前16个字节分为4个整型数据,每个4字节,分别代表数据信息、图像数量、行数、列数,之后的数据全部为像素,色素值为0-255。
    代码如下:

    import numpy as np
    import struct
    mnist_dir = r'./digit/'
    def fetch_mnist(mnist_dir,data_type):
        train_data_path = mnist_dir + 'train-images.idx3-ubyte'
        train_label_path = mnist_dir + 'train-labels.idx1-ubyte'
        test_data_path = mnist_dir + 't10k-images.idx3-ubyte'
        test_label_path = mnist_dir + 't10k-labels.idx1-ubyte'
    
    # train_img
        with open(train_data_path, 'rb') as f:
            data = f.read(16)
            des,img_nums,row,col = struct.unpack_from('>IIII', data, 0) // >IIII中每个I代表integral 或者 long类型数据
            train_x = np.zeros((img_nums, row*col))
            for index in range(img_nums):
                data = f.read(784)
                if len(data) == 784:
                    train_x[index,:] = np.array(struct.unpack_from('>' + 'B' * (row * col), data, 0)).reshape(1,784)
            f.close()
        # train label
        with open(train_label_path, 'rb') as f:
            data = f.read(8)
            des,label_nums = struct.unpack_from('>II', data, 0)
            train_y = np.zeros((label_nums, 1))
            for index in range(label_nums):
                data = f.read(1)
                train_y[index,:] = np.array(struct.unpack_from('>B', data, 0)).reshape(1,1)
            f.close()
    # test_img
            with open(test_data_path, 'rb') as f:
                data = f.read(16)
                des, img_nums, row, col = struct.unpack_from('>IIII', data, 0)
                test_x = np.zeros((img_nums, row * col))
                for index in range(img_nums):
                    data = f.read(784)
                    if len(data) == 784:
                        test_x[index, :] = np.array(struct.unpack_from('>' + 'B' * (row * col), data, 0)).reshape(1, 784)
                f.close()
            # test label
            with open(test_label_path, 'rb') as f:
                data = f.read(8)
                des, label_nums = struct.unpack_from('>II', data, 0)
                test_y = np.zeros((label_nums, 1))
                for index in range(label_nums):
                    data = f.read(1)
                    test_y[index, :] = np.array(struct.unpack_from('>B', data, 0)).reshape(1, 1)
                f.close()
            if data_type == 'train':
                return train_x, train_y
            elif data_type == 'test':
                return test_x, test_y
            elif data_type == 'all':
                return train_x, train_y,test_x, test_y
            else:
                print('type error')
    
    if __name__ == '__main__':
        tr_x, tr_y, te_x, te_y = fetch_mnist(mnist_dir,'all')
        import matplotlib.pyplot as plt # plt 用于显示图片
        img_0 = tr_x[59999,:].reshape(28,28)
        plt.imshow(img_0)
        print(tr_y[59999,:])
        img_1 = te_x[500,:].reshape(28,28)
        plt.imshow(img_1)
        print(te_y[500,:])
        plt.show()
    ————————————————
    https://blog.csdn.net/jinxiaonian11/article/details/78172613
    

    相关文章

      网友评论

          本文标题:Mnist数据集的处理

          本文链接:https://www.haomeiwen.com/subject/pybkictx.html