美文网首页
Tensorflow(五)用VGG实现迁移学习(代码实现)

Tensorflow(五)用VGG实现迁移学习(代码实现)

作者: 续袁 | 来源:发表于2019-05-14 11:51 被阅读0次

    0.文件目录

    image.png

    1. 所有层参数都更新

    # -*- coding: utf-8 -*-
    '''
    利用已经训练好的vgg16网络对flowers数据集进行微调
    把最后一层分类由2000->5 然后重新训练,我们也可以冻结其它所有层,只训练最后一层
    '''
    from tensorflow.contrib.slim.python.slim.nets import vgg
    #from nets import vgg
    import matplotlib.pyplot as plt
    import tensorflow as tf
    import numpy as np
    import input_data
    import os
    
    slim = tf.contrib.slim
    
    DATA_DIR = './datasets/data/flowers'
    # 输出类别
    NUM_CLASSES = 5
    
    # 获取图片大小
    IMAGE_SIZE = vgg.vgg_16.default_image_size
    
    
    def flowers_fine_tuning():
        '''
        演示一个VGG16的例子 
        微调 这里只调整VGG16最后一层全连接层,把1000类改为5类 
        对网络进行训练
        '''
    
        '''
        1.设置参数,并加载数据
        '''
        # 用于保存微调后的检查点文件和日志文件路径
        train_log_dir = './log/vgg16/fine_tune'
        train_log_file = 'flowers_fine_tune.ckpt'
    
        # 官方下载的检查点文件路径
        checkpoint_file = './log/vgg16/vgg_16.ckpt'
    
        # 设置batch_size
        batch_size = 16  #32  #128  #256
    
        learning_rate = 1e-4
    
        # 训练集数据长度
        n_train =  32 #3320
        # 测试集数据长度
        # n_test = 350
        # 迭代轮数
        training_epochs = 3
    
        display_epoch = 1
    
        if not tf.gfile.Exists(train_log_dir):
            tf.gfile.MakeDirs(train_log_dir)
    
        # 加载数据
        train_images, train_labels = input_data.get_batch_images_and_label(DATA_DIR, batch_size, NUM_CLASSES, True,
                                                                           IMAGE_SIZE, IMAGE_SIZE)
        test_images, test_labels = input_data.get_batch_images_and_label(DATA_DIR, batch_size, NUM_CLASSES, False,
                                                                         IMAGE_SIZE, IMAGE_SIZE)
    
        # 获取模型参数的命名空间
        arg_scope = vgg.vgg_arg_scope()
    
        # 创建网络
        with  slim.arg_scope(arg_scope):
    
            '''
            2.定义占位符和网络结构
            '''
            # 输入图片
            input_images = tf.placeholder(dtype=tf.float32, shape=[None, IMAGE_SIZE, IMAGE_SIZE, 3])
            # 图片标签
            input_labels = tf.placeholder(dtype=tf.float32, shape=[None, NUM_CLASSES])
            # 训练还是测试?测试的时候弃权参数会设置为1.0
            is_training = tf.placeholder(dtype=tf.bool)
    
            # 创建vgg16网络  如果想冻结所有层,可以指定slim.conv2d中的 trainable=False
            logits, end_points = vgg.vgg_16(input_images, is_training=is_training, num_classes=NUM_CLASSES)
            # print(end_points)  每个元素都是以vgg_16/xx命名
    
    
            '''
            #从当前图中搜索指定scope的变量,然后从检查点文件中恢复这些变量(即vgg_16网络中定义的部分变量)  
            #如果指定了恢复检查点文件中不存在的变量,则会报错 如果不知道检查点文件有哪些变量,我们可以打印检查点文件查看变量名
            params = []
            conv1 = slim.get_variables(scope="vgg_16/conv1")
            params.extend(conv1)            
            conv2 = slim.get_variables(scope="vgg_16/conv2")
            params.extend(conv2)
            conv3 = slim.get_variables(scope="vgg_16/conv3")
            params.extend(conv3)
            conv4 = slim.get_variables(scope="vgg_16/conv4")
            params.extend(conv4)
            conv5 = slim.get_variables(scope="vgg_16/conv5")
            params.extend(conv5)
            fc6 = slim.get_variables(scope="vgg_16/fc6")
            params.extend(fc6)
            fc7 = slim.get_variables(scope="vgg_16/fc7")
            params.extend(fc7)          
            '''
    
            # Restore only the convolutional layers: 从检查点载入当前图除了fc8层之外所有变量的参数
            params = slim.get_variables_to_restore(exclude=['vgg_16/fc8'])
            # 用于恢复模型  如果使用这个保存或者恢复的话,只会保存或者恢复指定的变量
            restorer = tf.train.Saver(params)
    
            # 预测标签
            pred = tf.argmax(logits, axis=1)
    
            '''
    定义代价函数和优化器
            '''
            # 代价函数
            cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=input_labels, logits=logits))
    
            # 设置优化器
            optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)  #全部参数重新训练
    
            # 预测结果评估
            correct = tf.equal(pred, tf.argmax(input_labels, 1))  # 返回一个数组 表示统计预测正确或者错误
            accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))  # 求准确率
    
            num_batch = int(np.ceil(n_train / batch_size))
    
            # 用于保存检查点文件
            save = tf.train.Saver(max_to_keep=1)
    
            # 恢复模型
            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
    
                # 检查最近的检查点文件
                ckpt = tf.train.latest_checkpoint(train_log_dir)
                if ckpt != None:
                    save.restore(sess, ckpt)
                    print('从上次训练保存后的模型继续训练!')
                else:
                    restorer.restore(sess, checkpoint_file)
                    print('从官方模型加载训练!')
    
                # 创建一个协调器,管理线程
                coord = tf.train.Coordinator()
    
                # 启动QueueRunner, 此时文件名才开始进队。
                threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    
                '''
                查看预处理之后的图片
                '''
                imgs, labs = sess.run([train_images, train_labels])
                print('原始训练图片信息:', imgs.shape, labs.shape)
                show_img = np.array(imgs[0], dtype=np.uint8)
                plt.imshow(show_img)
                plt.title('Original train image')
                plt.show()
    
                imgs, labs = sess.run([test_images, test_labels])
                print('原始测试图片信息:', imgs.shape, labs.shape)
                show_img = np.array(imgs[0], dtype=np.uint8)
                plt.imshow(show_img)
                plt.title('Original test image')
                plt.show()
    
                print('开始训练!')
                for epoch in range(training_epochs):
                    total_cost = 0.0
                    print("训练111111!!!")
                    for i in range(num_batch):
                        print("批次:"+str(i))
                        print("训练222222!!!")
                        imgs, labs = sess.run([train_images, train_labels])
                        _, loss = sess.run([optimizer, cost],
                                           feed_dict={input_images: imgs, input_labels: labs, is_training: True})
                        total_cost += loss
                    print("训练ing!!!")
                    # 打印信息
                    if epoch % display_epoch == 0:
                        print('Epoch {}/{}  average cost {:.9f}'.format(epoch + 1, training_epochs, total_cost / num_batch))
    
                    # 进行预测处理
                    imgs, labs = sess.run([test_images, test_labels])
                    cost_values, accuracy_value = sess.run([cost, accuracy],
                                                           feed_dict={input_images: imgs, input_labels: labs,
                                                                      is_training: False})
                    print('Epoch {}/{}  Test cost {:.9f}'.format(epoch + 1, training_epochs, cost_values))
                    print('准确率:', accuracy_value)
    
                    # 保存模型
                    save.save(sess, os.path.join(train_log_dir, train_log_file), global_step=epoch)
                    print('Epoch {}/{}  模型保存成功'.format(epoch + 1, training_epochs))
    
                print('训练完成')
    
                # 终止线程
                coord.request_stop()
                coord.join(threads)
    
    
    def flowers_test():
        '''
        使用微调好的网络进行测试
        '''
        '''
        1.设置参数,并加载数据
        '''
        # 微调后的检查点文件和日志文件路径
        save_dir = './log/vgg16/fine_tune'
    
        # 设置batch_size
        batch_size = 16 #128
    
        # 加载数据
        train_images, train_labels = input_data.get_batch_images_and_label(DATA_DIR, batch_size, NUM_CLASSES, True,
                                                                           IMAGE_SIZE, IMAGE_SIZE)
        test_images, test_labels = input_data.get_batch_images_and_label(DATA_DIR, batch_size, NUM_CLASSES, False,
                                                                         IMAGE_SIZE, IMAGE_SIZE)
    
        # 获取模型参数的命名空间
        arg_scope = vgg.vgg_arg_scope()
    
        # 创建网络
        with  slim.arg_scope(arg_scope):
            '''
            2.定义占位符和网络结构
            '''
            # 输入图片
            input_images = tf.placeholder(dtype=tf.float32, shape=[None, IMAGE_SIZE, IMAGE_SIZE, 3])
            # 训练还是测试?测试的时候弃权参数会设置为1.0
            is_training = tf.placeholder(dtype=tf.bool)
    
            # 创建vgg16网络
            logits, end_points = vgg.vgg_16(input_images, is_training=is_training, num_classes=NUM_CLASSES)
    
            # 预测标签
            pred = tf.argmax(logits, axis=1)
    
            restorer = tf.train.Saver()
    
            # 恢复模型
            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                ckpt = tf.train.latest_checkpoint(save_dir)
                if ckpt != None:
                    # 恢复模型
                    restorer.restore(sess, ckpt)
                    print("Model restored.")
    
                # 创建一个协调器,管理线程
                coord = tf.train.Coordinator()
    
                # 启动QueueRunner, 此时文件名才开始进队。
                threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    
                '''
                查看预处理之后的图片
                '''
                imgs, labs = sess.run([test_images, test_labels])
                print('原始测试图片信息:', imgs.shape, labs.shape)
                show_img = np.array(imgs[0], dtype=np.uint8)
                plt.imshow(show_img)
                plt.title('Original test image')
                plt.show()
    
                pred_value = sess.run(pred, feed_dict={input_images: imgs, is_training: False})
                print('预测结果为:', pred_value)
                print('实际结果为:', np.argmax(labs, 1))
                correct = np.equal(pred_value, np.argmax(labs, 1))
                print('准确率为:', np.mean(correct))
    
                # 终止线程
                coord.request_stop()
                coord.join(threads)
    
    
    if __name__ == '__main__':
        tf.reset_default_graph()
        #flowers_fine_tuning()
        flowers_test()
    

    2.只更新最后一层的参数

    # -*- coding: utf-8 -*-
    '''
    利用已经训练好的vgg16网络对flowers数据集进行微调
    把最后一层分类由2000->5 然后重新训练,我们也可以冻结其它所有层,只训练最后一层
    '''
    from tensorflow.contrib.slim.python.slim.nets import vgg
    #from nets import vgg
    import matplotlib.pyplot as plt
    import tensorflow as tf
    import numpy as np
    import input_data
    import os
    
    slim = tf.contrib.slim
    
    DATA_DIR = './datasets/data/flowers'
    # 输出类别
    NUM_CLASSES = 5
    
    # 获取图片大小
    IMAGE_SIZE = vgg.vgg_16.default_image_size
    
    
    def flowers_fine_tuning():
        '''
        演示一个VGG16的例子 
        微调 这里只调整VGG16最后一层全连接层,把1000类改为5类 
        对网络进行训练
        '''
    
        '''
        1.设置参数,并加载数据
        '''
        # 用于保存微调后的检查点文件和日志文件路径
        train_log_dir = './log/vgg16/fine_tune'
        train_log_file = 'flowers_fine_tune.ckpt'
    
        # 官方下载的检查点文件路径
        checkpoint_file = './log/vgg16/vgg_16.ckpt'
    
        # 设置batch_size
        batch_size = 16  #32  #128  #256
    
        learning_rate = 1e-4
    
        # 训练集数据长度
        n_train =  32 #3320
        # 测试集数据长度
        # n_test = 350
        # 迭代轮数
        training_epochs = 3
    
        display_epoch = 1
    
        if not tf.gfile.Exists(train_log_dir):
            tf.gfile.MakeDirs(train_log_dir)
    
        # 加载数据
        train_images, train_labels = input_data.get_batch_images_and_label(DATA_DIR, batch_size, NUM_CLASSES, True,
                                                                           IMAGE_SIZE, IMAGE_SIZE)
        test_images, test_labels = input_data.get_batch_images_and_label(DATA_DIR, batch_size, NUM_CLASSES, False,
                                                                         IMAGE_SIZE, IMAGE_SIZE)
    
        # 获取模型参数的命名空间
        arg_scope = vgg.vgg_arg_scope()
    
        # 创建网络
        with  slim.arg_scope(arg_scope):
    
            '''
            2.定义占位符和网络结构
            '''
            # 输入图片
            input_images = tf.placeholder(dtype=tf.float32, shape=[None, IMAGE_SIZE, IMAGE_SIZE, 3])
            # 图片标签
            input_labels = tf.placeholder(dtype=tf.float32, shape=[None, NUM_CLASSES])
            # 训练还是测试?测试的时候弃权参数会设置为1.0
            is_training = tf.placeholder(dtype=tf.bool)
    
            # 创建vgg16网络  如果想冻结所有层,可以指定slim.conv2d中的 trainable=False
            logits, end_points = vgg.vgg_16(input_images, is_training=is_training, num_classes=NUM_CLASSES)
            # print(end_points)  每个元素都是以vgg_16/xx命名
    
    
            '''
            #从当前图中搜索指定scope的变量,然后从检查点文件中恢复这些变量(即vgg_16网络中定义的部分变量)  
            #如果指定了恢复检查点文件中不存在的变量,则会报错 如果不知道检查点文件有哪些变量,我们可以打印检查点文件查看变量名
            params = []
            conv1 = slim.get_variables(scope="vgg_16/conv1")
            params.extend(conv1)            
            conv2 = slim.get_variables(scope="vgg_16/conv2")
            params.extend(conv2)
            conv3 = slim.get_variables(scope="vgg_16/conv3")
            params.extend(conv3)
            conv4 = slim.get_variables(scope="vgg_16/conv4")
            params.extend(conv4)
            conv5 = slim.get_variables(scope="vgg_16/conv5")
            params.extend(conv5)
            fc6 = slim.get_variables(scope="vgg_16/fc6")
            params.extend(fc6)
            fc7 = slim.get_variables(scope="vgg_16/fc7")
            params.extend(fc7)          
            '''
    
            # Restore only the convolutional layers: 从检查点载入当前图除了fc8层之外所有变量的参数
            params = slim.get_variables_to_restore(exclude=['vgg_16/fc8'])
            # 用于恢复模型  如果使用这个保存或者恢复的话,只会保存或者恢复指定的变量
            restorer = tf.train.Saver(params)
    
            # 预测标签
            pred = tf.argmax(logits, axis=1)
    
            '''
    定义代价函数和优化器
            '''
            # 代价函数
            cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=input_labels, logits=logits))
            # 选择待优化的参数  xpb添加
            output_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='vgg_16/fc8')  #'vgg_16/fc8'
            print("output_vars:",output_vars)
            # 设置优化器
            optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost,var_list=output_vars)  #'vgg_16/fc8'层参数重新训练
    
            # 预测结果评估
            correct = tf.equal(pred, tf.argmax(input_labels, 1))  # 返回一个数组 表示统计预测正确或者错误
            accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))  # 求准确率
    
            num_batch = int(np.ceil(n_train / batch_size))
    
            # 用于保存检查点文件
            save = tf.train.Saver(max_to_keep=1)
    
            # 恢复模型
            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
    
                # 检查最近的检查点文件
                ckpt = tf.train.latest_checkpoint(train_log_dir)
                if ckpt != None:
                    save.restore(sess, ckpt)
                    print('从上次训练保存后的模型继续训练!')
                else:
                    restorer.restore(sess, checkpoint_file)
                    print('从官方模型加载训练!')
    
                # 创建一个协调器,管理线程
                coord = tf.train.Coordinator()
    
                # 启动QueueRunner, 此时文件名才开始进队。
                threads = tf.train.start_queue_runners(sess=sess, coord=coord)
                output_vars1 = sess.run([output_vars])
                print("output_vars:", output_vars1)
    
    
                '''
                查看预处理之后的图片
                '''
                imgs, labs = sess.run([train_images, train_labels])
                print('原始训练图片信息:', imgs.shape, labs.shape)
                show_img = np.array(imgs[0], dtype=np.uint8)
                plt.imshow(show_img)
                plt.title('Original train image')
                plt.show()
    
                imgs, labs = sess.run([test_images, test_labels])
                print('原始测试图片信息:', imgs.shape, labs.shape)
                show_img = np.array(imgs[0], dtype=np.uint8)
                plt.imshow(show_img)
                plt.title('Original test image')
                plt.show()
    
                print('开始训练!')
                for epoch in range(training_epochs):
                    total_cost = 0.0
                    print("训练111111!!!")
                    for i in range(num_batch):
                        print("批次:"+str(i))
                        print("训练222222!!!")
                        imgs, labs = sess.run([train_images, train_labels])
                        _, loss = sess.run([optimizer, cost],
                                           feed_dict={input_images: imgs, input_labels: labs, is_training: True})
                        total_cost += loss
                    print("训练ing!!!")
                    # 打印信息
                    if epoch % display_epoch == 0:
                        print('Epoch {}/{}  average cost {:.9f}'.format(epoch + 1, training_epochs, total_cost / num_batch))
    
                    # 进行预测处理
                    imgs, labs = sess.run([test_images, test_labels])
                    cost_values, accuracy_value = sess.run([cost, accuracy],
                                                           feed_dict={input_images: imgs, input_labels: labs,
                                                                      is_training: False})
                    print('Epoch {}/{}  Test cost {:.9f}'.format(epoch + 1, training_epochs, cost_values))
                    print('准确率:', accuracy_value)
    
                    # 保存模型
                    save.save(sess, os.path.join(train_log_dir, train_log_file), global_step=epoch)
                    print('Epoch {}/{}  模型保存成功'.format(epoch + 1, training_epochs))
    
                print('训练完成')
    
                # 终止线程
                coord.request_stop()
                coord.join(threads)
    
    
    def flowers_test():
        '''
        使用微调好的网络进行测试
        '''
        '''
        1.设置参数,并加载数据
        '''
        # 微调后的检查点文件和日志文件路径
        save_dir = './log/vgg16/fine_tune'
    
        # 设置batch_size
        batch_size = 16 #128
    
        # 加载数据
        train_images, train_labels = input_data.get_batch_images_and_label(DATA_DIR, batch_size, NUM_CLASSES, True,
                                                                           IMAGE_SIZE, IMAGE_SIZE)
        test_images, test_labels = input_data.get_batch_images_and_label(DATA_DIR, batch_size, NUM_CLASSES, False,
                                                                         IMAGE_SIZE, IMAGE_SIZE)
    
        # 获取模型参数的命名空间
        arg_scope = vgg.vgg_arg_scope()
    
        # 创建网络
        with  slim.arg_scope(arg_scope):
            '''
            2.定义占位符和网络结构
            '''
            # 输入图片
            input_images = tf.placeholder(dtype=tf.float32, shape=[None, IMAGE_SIZE, IMAGE_SIZE, 3])
            # 训练还是测试?测试的时候弃权参数会设置为1.0
            is_training = tf.placeholder(dtype=tf.bool)
    
            # 创建vgg16网络
            logits, end_points = vgg.vgg_16(input_images, is_training=is_training, num_classes=NUM_CLASSES)
    
            # 预测标签
            pred = tf.argmax(logits, axis=1)
    
            restorer = tf.train.Saver()
    
            # 恢复模型
            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                ckpt = tf.train.latest_checkpoint(save_dir)
                if ckpt != None:
                    # 恢复模型
                    restorer.restore(sess, ckpt)
                    print("Model restored.")
    
                # 创建一个协调器,管理线程
                coord = tf.train.Coordinator()
    
                # 启动QueueRunner, 此时文件名才开始进队。
                threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    
                '''
                查看预处理之后的图片
                '''
                imgs, labs = sess.run([test_images, test_labels])
                print('原始测试图片信息:', imgs.shape, labs.shape)
                show_img = np.array(imgs[0], dtype=np.uint8)
                plt.imshow(show_img)
                plt.title('Original test image')
                plt.show()
    
                pred_value = sess.run(pred, feed_dict={input_images: imgs, is_training: False})
                print('预测结果为:', pred_value)
                print('实际结果为:', np.argmax(labs, 1))
                correct = np.equal(pred_value, np.argmax(labs, 1))
                print('准确率为:', np.mean(correct))
    
                # 终止线程
                coord.request_stop()
                coord.join(threads)
    
    
    if __name__ == '__main__':
        tf.reset_default_graph()
        flowers_fine_tuning()
        flowers_test()
    

    参考资料

    TensorFlow 的所有用于图像分类的预训练模型的下载地址

    [1] models/research/slim

    相关文章

      网友评论

          本文标题:Tensorflow(五)用VGG实现迁移学习(代码实现)

          本文链接:https://www.haomeiwen.com/subject/iazjaqtx.html