深度学习入门数据集--1.Cifar10数据集

作者: ac619467fef3 | 来源:发表于2019-02-20 21:30 被阅读15次

    前一段时间写了系列的机器学习入门,本期打算写深度学习入门数据集,第一个入手的是Cifar-10。Cifar-10数据集主要用来做图像识别。这个数据集包含图像和标签,图像信息由32*32像素大小组成,标签包含10个类别(飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船、卡车)。

    这个数据集的目的是,用这些标注好的数据训练深度学习模型,使模型能够识别图片中的目标。比如,我们可以通过这个神经网络识别猫vs狗。

    一、数据集

    官网地址
    官网上提供多种格式数据集,我们选bin。首先观察前25条记录。由于图像像素32*32,很多图像人眼也是难以进行辨别。

    CIFA-10 前25条数据
    相关代码:
    import numpy as np  
    from scipy.misc import imsave  
    import matplotlib.pyplot as plt
    import pylab
    
    filename = '/Users/wangsen/ai/13/models-master/data/cifar-10-batches-bin/test_batch.bin' 
    label_mate = '/Users/wangsen/ai/13/models-master/data/cifar-10-batches-bin/batches.meta.txt'
    labels_txt = open(label_mate,"r").read().strip().split("\n")
    bytestream = open(filename, "rb")  
    buf = bytestream.read(25 * (1 + 32 * 32 * 3))  
    bytestream.close()  
    
    data = np.frombuffer(buf, dtype=np.uint8)  
    data = data.reshape(25, 1 + 32*32*3)  
    labels_images = np.hsplit(data, [1])  
    labels = labels_images[0].reshape(25)  
    images = labels_images[1].reshape(25, 32, 32, 3)  
    fig, axes1 = plt.subplots(5,5,figsize=(4,5))
    # for itr,label in enumerate(labels):
    #         print(itr,":",labels_txt[label])
    i = 0
    for j in range(5):
        for k in range(5):
            img = np.reshape(images[i],(3,32,32))
            img = img.transpose(1,2,0)
            axes1[j][k].set_axis_off()
            axes1[j][k].imshow(img)
            axes1[j][k].set_title(labels_txt[labels[i]])
            i=i+1
    pylab.show()
    

    二、训练

    在Tensorflow 官网教程里,有一个CIFAR-10训练程序的例子。官网
    代码下载地址:https://github.com/tensorflow/models
    代码位置models/tutorials/image/cifar10/

    2.1运行训练代码

    >python cifar10_train.py,如果数据集没有下载,那么要重新下载数据集,运行结果如下:

    Filling queue with 20000 CIFAR images before starting to train. This will take a few minutes.
    2019-02-20 13:42:05.167927: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
    2019-02-20 13:42:09.260566: step 0, loss = 4.67 (304.9 examples/sec; 0.420 sec/batch)
    2019-02-20 13:42:13.762996: step 10, loss = 4.63 (284.3 examples/sec; 0.450 sec/batch)
    2019-02-20 13:42:18.095651: step 20, loss = 4.49 (295.4 examples/sec; 0.433 sec/batch)
    2019-02-20 13:42:22.444906: step 30, loss = 4.50 (294.3 examples/sec; 0.435 sec/batch)
    2019-02-20 13:42:27.136578: step 40, loss = 4.40 (272.8 examples/sec; 0.469 sec/batch)
    2019-02-20 13:42:31.833072: step 50, loss = 4.32 (272.5 examples/sec; 0.470 sec/batch)
    

    官方给出的训练数据如下,我的主机Mac air2018 i7 2核,快赶上Tesla K20m的训练速度了。那么需要许梿

    A binary to train CIFAR-10 using a single GPU.
    
    Accuracy:
    cifar10_train.py achieves ~86% accuracy after 100K steps (256 epochs of
    data) as judged by cifar10_eval.py.
    
    Speed: With batch_size 128.
    
    System        | Step Time (sec/batch)  |     Accuracy
    ------------------------------------------------------------------
    1 Tesla K20m  | 0.35-0.60              | ~86% at 60K steps  (5 hours)
    1 Tesla K40m  | 0.25-0.35              | ~86% at 100K steps (4 hours)
    

    2.2 运行测试代码

    当训练结束,可以运行评估代码,评估代码在10000张图片上进行预测,判断预测准确率。
    python cifar10_eval.py
    设置训练step1000步,准确率在60%。

    2019-02-20 15:59:41.109588: precision @ 1 = 0.606
    

    通过实验,训练在100k时,准确率为86%。

    三、模型预测图片

    测试代码

    • checkpoint_dir 训练过程保存的模型参数。
    • test_file 预测图片保存的位置。

    四、实际预测结果

    对大图片的预测效果较差,需要将图片用较好算法压缩到50px以下,实测预测准确率不到50%。

    # -*- coding:utf-8 -*-
    import tensorflow as tf
    from tensorflow.python.ops.image_ops_impl import ResizeMethod
    from prettytable import PrettyTable  
     
    import cifar10
    import numpy as np
    import matplotlib.image as mpimg
    import matplotlib.pyplot as plt
     
    FLAGS = tf.app.flags.FLAGS
    # 设置存储模型训练结果的路径
    tf.app.flags.DEFINE_string('checkpoint_dir', '/Users/wangsen/ai/13/models-master/tutorials/image/cifar10/cifar10_train',
                 """Directory where to read model checkpoints.""")
    tf.app.flags.DEFINE_string('class_dir', '//Users/wangsen/ai/13/models-master/data/cifar-10-batches-bin/',
                               """存储文件batches.meta.txt的目录""")
    tf.app.flags.DEFINE_string('test_file', '/Users/wangsen/Desktop/1.jpeg', """测试用的图片""")
    
    IMAGE_SIZE = 24
     
     
    def evaluate_images(images):  # 执行验证
        logits = cifar10.inference(images)
        load_trained_model(logits=logits)
     
     
    def load_trained_model(logits):
        with tf.Session() as sess:
            ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
            if ckpt and ckpt.model_checkpoint_path:
                # 从训练模型恢复数据
                saver = tf.train.Saver()
                saver.restore(sess, ckpt.model_checkpoint_path)
            else:
                print('No checkpoint file found')
                return
    
     
            # 从文件以字符串方式获取10个类标签,使用制表格分割
            cifar10_class = np.loadtxt(FLAGS.class_dir + "batches.meta.txt", str, delimiter='\t')
            # 预测最大的三个分类
            top_k_pred = tf.nn.top_k(logits, k=3)
            output = sess.run(top_k_pred)
            probability = np.array(output[0]).flatten()  # 取出概率值,将其展成一维数组
            index = np.array(output[1]).flatten()
            # 使用表格的方式显示
            tabel = PrettyTable(["index", "class", "probability"])
            tabel.align["index"] = "l"  
            tabel.padding_width = 1 
            for i in np.arange(index.size):
                tabel.add_row([index[i], cifar10_class[index[i]], probability[i]])
            print(tabel)
        lena = mpimg.imread(FLAGS.test_file) # 读取和代码处于同一目录下的 lena.png
        plt.imshow(lena) # 显示图片
        plt.axis('off') # 不显示坐标轴
        plt.show() 
     
    def img_read(filename):
        if not tf.gfile.Exists(filename):
            tf.logging.fatal('File does not exists %s', filename)
        image_data = tf.image.convert_image_dtype(tf.image.decode_jpeg(tf.read_file(filename),
                                                                       channels=3), dtype=tf.float32)
        height = IMAGE_SIZE
        width = IMAGE_SIZE
        image = tf.image.resize_images(image_data, (height, width), method=ResizeMethod.BILINEAR)
        image = tf.expand_dims(image, -1)
        image = tf.reshape(image, (1, 24, 24, 3))
        return image
     
    def main(argv=None):  # pylint: disable=unused-argument
        filename = FLAGS.test_file
        images = img_read(filename)
        evaluate_images(images)
     
    if __name__ == '__main__':
        tf.app.run()
    

    相关文章

      网友评论

        本文标题:深度学习入门数据集--1.Cifar10数据集

        本文链接:https://www.haomeiwen.com/subject/zseseqtx.html