美文网首页
详解 MNIST 数据集

详解 MNIST 数据集

作者: kamin | 来源:发表于2018-06-30 19:36 被阅读43次

    详解 MNIST 数据集

    代码解释见下面

    Label File

    先是一个32位的整形 表示的是Magic Number,这是用来标示文件格式的用的。一般默认不变,为2049

    第二是图片的的数量

    接下去就是一次排列图片的标示Label。

    -

    Image File

    也是Magic Number。同上。保持不变2051.

    图片的数量

    图片的高

    图片的宽

    图片的像素点[灰度 256位]。

    unpack(fmt, string)       按照给定的格式(fmt)解析字节流string,返回解析出来的tuple

    > big-endian standard       按原字节数

    见上图:图片宽高分别为28,所以有28*28=784个值

    代码:

    import os

    import struct

    import numpy as np

    def load_mnist(path, kind='train'):

    print("in load_mnist")

    """Load MNIST data from `path`"""  #注释

    labels_path = os.path.join(path,'%s-labels.idx1-ubyte'%kind) #路径+train-labels-idx1-ubyte(gz文件)

    images_path = os.path.join(path,'%s-images.idx3-ubyte'%kind) #路径+train-labels-idx1-ubyte(gz文件)

    with open(labels_path, 'rb') as lbpath: #以二进制格式打开文件train-labels-idx1-ubyte用于只读,lbpath代表此文件对象

    #从文件中读8个字节,1-4个字节为magic number,4-8个字节为图片数量,magic和n均为无符号整形     

    magic, n = struct.unpack('>II',lbpath.read(8)) #>  big-endian 高字节在高位 II两个无符号整形,每个占4个字节

    labels = np.fromfile(lbpath,dtype=np.uint8)

    print("labels length=%d"%len(labels))

    with open(images_path, 'rb') as imgpath:

    #从文件中读16个字节,1-4个字节为magic number,4-8个字节为图片数量,rows为图片的高,cols为图片的宽,magic,num,rows,cols均为无符号整形

    magic, num, rows, cols = struct.unpack('>IIII',imgpath.read(16))#> big-endian 高字节在高位IIII四个无符号整形,每个占4个字节

    #读取图片数据,并转换为 60,000行784列的矩阵,也就是说一行是一张图片

    images = np.fromfile(imgpath,dtype=np.uint8).reshape(len(labels), 784)

    print("images length=%d"%len(images))

    return images, labels

    if __name__=='__main__':

    images_train,labels_train=load_mnist('', kind='train')  #cd mnist  python load_mnist.py执行当前程序

    print("images")

    print (images_train)

    print("labels")

    print (labels_train)

    print('Rows: %d, columns: %d' % (images_train.shape[0], images_train.shape[1]))

    count = np.zeros(10)

    nTrain = len(images_train)

    for i in range(nTrain):

    label = labels_train[i]

    count[label] += 1

    filename = './train/' + str(label) + '/' + str(label) + '_' + str(int(count[label])) + '.png'

    print(filename)

    img = images_train[i].reshape(28,28)

    cv2.imwrite(filename, img) #找不到图片?

    print(str(int(count[label])))

    print("over")

    相关文章

      网友评论

          本文标题:详解 MNIST 数据集

          本文链接:https://www.haomeiwen.com/subject/cihxuftx.html