美文网首页
keras 之 迁移学习,改变VGG16输出层,用imagene

keras 之 迁移学习,改变VGG16输出层,用imagene

作者: vola_lei | 来源:发表于2018-04-18 00:49 被阅读0次

    迁移学习, 用现成网络,跑自己数据: 保留已有网络除输出层以外其它层的权重, 改变已有网络的输出层的输出class 个数. 以已有网络权值为基础, 训练自己的网络,
    以keras 2.1.5 / VGG16Net为例.

    • 导入必要的库
    from keras.preprocessing.image import ImageDataGenerator
    from keras import optimizers
    from keras.models import Sequential
    from keras.layers import Dropout, Flatten, Dense
    from keras import Model
    from keras import initializers
    from keras.callbacks import ModelCheckpoint, EarlyStopping
    from keras.applications.vgg16 import VGG16
    
    • 设置输入图片augmentation方法
      这里对于train/test data只设置rescale, 即图片矩阵整体除以某个数值.
    # prepare data augmentation configuration
    train_datagen = ImageDataGenerator(
        rescale=1. / 255,
    #     shear_range=0.2,
    #     zoom_range=0.2,
    #     horizontal_flip=True
        )
    
    test_datagen = ImageDataGenerator(rescale=1. / 255)
    
    • 设置图片尺寸
      因为是保留VGG16所有网络结构(除输出层), 因此输入必须和VGGnet要求的一致.
    # the input is the same as original network
    input_shape = (224,224,3)
    
    • 设置图片路径
      设置图片路径, directory参数对应类别根目录, directory下的各子目录对应各个类别的图片. 子目录文件夹名即为class name. 读入的类别顺序按照类别目录的字母排序而定.
    train_generator = train_datagen.flow_from_directory(
        directory = './data/train/',
        target_size = input_shape[:-1],
        color_mode = 'rgb',
        classes = None,
        class_mode = 'categorical',
        batch_size = 10,
        shuffle = True)
    
    test_generator = test_datagen.flow_from_directory(
        directory = './data/test/',
        target_size = input_shape[:-1],
        batch_size = 10,
        class_mode = 'categorical')
    
    • 加载VGG16 网络
      input_shape : 图片输入尺寸
      include_top = True, 保留全连接层.
      classes = 10: 类别数目(我们的类别数目)
      weights = None: 不加载任何网络
    # build the VGG16 network, 加载VGG16网络, 改变输出层的类别数目.
    # include_top = True, load the whole network
    # set the new output classes to 10
    # weights = None, load no weights 
    base_model = VGG16(input_shape = input_shape, 
                         include_top = True, 
                         classes = 10, 
                         weights = None
                         )
    print('Model loaded.')
    
    • 改变最后一层的名称
    base_model.layers[-1].name = 'pred'
    
    • 查看最后一层的初始化方法
    base_model.layers[-1].kernel_initializer.get_config()
    

    将会得到:

    {'distribution': 'uniform', 'mode': 'fan_avg', 'scale': 1.0, 'seed': None}
    
    • 改变最后一层的权值初始化方法
    base_model.layers[-1].kernel_initializer = initializers.glorot_normal()
    
    • 加载VGG16 在imagenet训练得到的权重, 一定要按照 by_name = True的方式
      按照layer的名称加载权重(名称不对应的层级将不会加载权重), 这就是为什么我们一定要改变最后一层的名称了. 因为唯有如此, 这一步加载权重,将会加载除了最后一层的所有层的权重.
    base_model.load_weights('./vgg16_weights_tf_dim_ordering_tf_kernels.h5', by_name = True)
    
    • compile 网络
    # compile the model with a SGD/momentum optimizer
    # and a very slow learning rate.
    sgd = optimizers.SGD(lr=0.01, decay=1e-4, momentum=0.9, nesterov=True)
    
    base_model.compile(loss = 'categorical_crossentropy',
                  optimizer = sgd,
                  metrics=['accuracy'])
    
    • 开始训练
      训练过程中自动保存权重,并按要求停止训练.
    
    # fine-tune the model
    
    check = ModelCheckpoint('./', 
                    monitor='val_loss', 
                    verbose=0, 
                    save_best_only=False, 
                    save_weights_only=False, 
                    mode='auto', 
                    period=1)
    
    stop = EarlyStopping(monitor='val_loss',
                  min_delta=0, 
                  patience=0, 
                  verbose=0, 
                  mode='auto')
    
    base_model.fit_generator(
        generator = train_generator,
        epochs = 5,
        verbose = 1,
        validation_data = test_generator,
        shuffle = True,
        callbacks = [check, stop]
        
        )
    
    • 保存网络
    model.save_weights('fine_tuned_net.h5')
    

    相关文章

      网友评论

          本文标题:keras 之 迁移学习,改变VGG16输出层,用imagene

          本文链接:https://www.haomeiwen.com/subject/anymkftx.html