美文网首页深度学习TF深度学习
TensorFlow 使用预训练模型 ResNet-50

TensorFlow 使用预训练模型 ResNet-50

作者: 公输睚信 | 来源:发表于2018-04-29 23:46 被阅读2325次

            前面的文章已经说明了怎么使用 TensorFlow 来构建、训练、保存、导出模型等,现在来说明怎么使用 TensorFlow 调用预训练模型来精调神经网络。为了简单起见,以调用预训练的 ResNet-50 用于图像分类为例,使用的模块仍然是 tf.contrib.slim

            TensorFlow 的所有用于图像分类的预训练模型的下载地址为 models/research/slim,包含常用的 VGG,Inception,ResNet,MobileNet 以及最新的 NasNet 模型等。要使用这些预训练模型的关键是将这些预训练的参数正确的导入到定义好的神经网络,这可以通过函数 slim.assign_from_checkpoint_fn 来方便的实现。下面,用代码来说明。

    一、Fine tuning 模型定义

            前已提及,TensorFlow 所有预训练模型均在 GitHub 项目 models/research/slim,而其对应的神经网络实现则在其子文件夹 nets。我们以调用 ResNet-50 为例(其它模型类似),首先来定义网络结构:

    import tensorflow as tf
    
    from tensorflow.contrib.slim import nets
    
    slim = tf.contrib.slim
    
    
    def predict(self, preprocessed_inputs):
        """Predict prediction tensors from inputs tensor.
    
        Outputs of this function can be passed to loss or postprocess functions.
    
        Args:
            preprocessed_inputs: A float32 tensor with shape [batch_size,
                height, width, num_channels] representing a batch of images.
                
        Returns:
            prediction_dict: A dictionary holding prediction tensors to be
                passed to the Loss or Postprocess functions.
        """
        net, endpoints = nets.resnet_v1.resnet_v1_50(
            preprocessed_inputs, num_classes=None,
            is_training=self._is_training)
        net = tf.squeeze(net, axis=[1, 2])
        net = slim.fully_connected(net, num_outputs=self.num_classes,
                                   activation_fn=None, scope='Predict')
        prediction_dict = {'logits': net}
        return prediction_dict
    

            我们假设要分类的图像有 self.num_classes 个类,随机选择一个批量的图像,对这些图像进行预处理后,把它们作为参数传入 predict 函数,此时直接调用 TensorFlow-slim 封装好的 nets.resnet_v1.resnet_v1_50 神经网络得到图像特征,因为 ResNet-50 是用于 1000 个类的分类的,所以需要设置参数 num_classes=None 禁用它的最后一个输出层。我们假设输入的图像批量形状为 [None, 224, 224, 3],则 resnet_v1_50 函数返回的形状为 [None, 1, 1, 2048],为了输入到全连接层,需要用函数 tf.squeeze 去掉形状为 1 的第 1,2 个索引维度。最后,连接再一个全连接层得到 self.num_classes 个类的预测输出。

            可以看到,使用 tf.contrib.slim 模块,调用 ResNet-50 等神经网络变得异常简单。而接下来的关键问题是怎么导入预训练的参数,进而使用我们自己的数据来对预训练模型进行精调。在阐述怎么解决这个问题之前,先将整个模型定义的文件 model.py 列出以方便阅读:

    #!/usr/bin/env python3
    # -*- coding: utf-8 -*-
    """
    Created on Fri Mar 30 16:54:02 2018
    
    @author: shirhe-lyh
    """
    
    import tensorflow as tf
    
    from abc import ABCMeta
    from abc import abstractmethod
    from tensorflow.contrib.slim import nets
    
    slim = tf.contrib.slim
    
    
    class BaseModel(object):
        """Abstract base class for any model."""
        __metaclass__ = ABCMeta
        
        def __init__(self, num_classes):
            """Constructor.
            
            Args:
                num_classes: Number of classes.
            """
            self._num_classes = num_classes
            
        @property
        def num_classes(self):
            return self._num_classes
        
        @abstractmethod
        def preprocess(self, inputs):
            """Input preprocessing. To be override by implementations.
            
            Args:
                inputs: A float32 tensor with shape [batch_size, height, width,
                    num_channels] representing a batch of images.
                
            Returns:
                preprocessed_inputs: A float32 tensor with shape [batch_size, 
                    height, widht, num_channels] representing a batch of images.
            """
            pass
        
        @abstractmethod
        def predict(self, preprocessed_inputs):
            """Predict prediction tensors from inputs tensor.
            
            Outputs of this function can be passed to loss or postprocess functions.
            
            Args:
                preprocessed_inputs: A float32 tensor with shape [batch_size,
                    height, width, num_channels] representing a batch of images.
                
            Returns:
                prediction_dict: A dictionary holding prediction tensors to be
                    passed to the Loss or Postprocess functions.
            """
            pass
        
        @abstractmethod
        def postprocess(self, prediction_dict, **params):
            """Convert predicted output tensors to final forms.
            
            Args:
                prediction_dict: A dictionary holding prediction tensors.
                **params: Additional keyword arguments for specific implementations
                    of specified models.
                    
            Returns:
                A dictionary containing the postprocessed results.
            """
            pass
        
        @abstractmethod
        def loss(self, prediction_dict, groundtruth_lists):
            """Compute scalar loss tensors with respect to provided groundtruth.
            
            Args:
                prediction_dict: A dictionary holding prediction tensors.
                groundtruth_lists: A list of tensors holding groundtruth
                    information, with one entry for each image in the batch.
                    
            Returns:
                A dictionary mapping strings (loss names) to scalar tensors
                    representing loss values.
            """
            pass
        
            
    class Model(BaseModel):
        """xxx definition."""
        
        def __init__(self, is_training, num_classes):
            """Constructor.
            
            Args:
                is_training: A boolean indicating whether the training version of
                    computation graph should be constructed.
                num_classes: Number of classes.
            """
            super(Model, self).__init__(num_classes=num_classes)
            
            self._is_training = is_training
            
        def preprocess(self, inputs):
            """Predict prediction tensors from inputs tensor.
            
            Outputs of this function can be passed to loss or postprocess functions.
            
            Args:
                preprocessed_inputs: A float32 tensor with shape [batch_size,
                    height, width, num_channels] representing a batch of images.
                
            Returns:
                prediction_dict: A dictionary holding prediction tensors to be
                    passed to the Loss or Postprocess functions.
            """
            channel_means = [123.68, 116.779, 103.939]
            preprocessed_inputs = tf.to_float(inputs)
            preprocessed_inputs = preprocessed_inputs - [[channel_means]]
            return preprocessed_inputs
        
        def predict(self, preprocessed_inputs):
            """Predict prediction tensors from inputs tensor.
            
            Outputs of this function can be passed to loss or postprocess functions.
            
            Args:
                preprocessed_inputs: A float32 tensor with shape [batch_size,
                    height, width, num_channels] representing a batch of images.
                
            Returns:
                prediction_dict: A dictionary holding prediction tensors to be
                    passed to the Loss or Postprocess functions.
            """
            net, endpoints = nets.resnet_v1.resnet_v1_50(
                preprocessed_inputs, num_classes=None,
                is_training=self._is_training)
            net = tf.squeeze(net, axis=[1, 2])
            logits = slim.fully_connected(net, num_outputs=self.num_classes,
                                          activation_fn=None, scope='Predict')
            prediction_dict = {'logits': logits}
            return prediction_dict
        
        def postprocess(self, prediction_dict):
            """Convert predicted output tensors to final forms.
            
            Args:
                prediction_dict: A dictionary holding prediction tensors.
                **params: Additional keyword arguments for specific implementations
                    of specified models.
                    
            Returns:
                A dictionary containing the postprocessed results.
            """
            logits = prediction_dict['logtis']
            logits = tf.nn.softmax(logits)
            classes = tf.argmax(logits, axis=1)
            postprocessed_dict = {'classes': classes}
            return postprocessed_dict
        
        def loss(self, prediction_dict, groundtruth_lists):
            """Compute scalar loss tensors with respect to provided groundtruth.
            
            Args:
                prediction_dict: A dictionary holding prediction tensors.
                groundtruth_lists_dict: A dict of tensors holding groundtruth
                    information, with one entry for each image in the batch.
                    
            Returns:
                A dictionary mapping strings (loss names) to scalar tensors
                    representing loss values.
            """
            logits = prediction_dict['logtis']
            slim.losses.sparse_softmax_cross_entropy(
                logits=logits, 
                labels=groundtruth_lists,
                scope='Loss')
            loss = slim.losses.get_total_loss()
            loss_dict = {'loss': loss}
            return loss_dict
            
        def accuracy(self, postprocessed_dict, groundtruth_lists):
            """Calculate accuracy.
            
            Args:
                postprocessed_dict: A dictionary containing the postprocessed 
                    results
                groundtruth_lists: A dict of tensors holding groundtruth
                    information, with one entry for each image in the batch.
                    
            Returns:
                accuracy: The scalar accuracy.
            """
            classes = postprocessed_dict['classes']
            accuracy = tf.reduce_mean(
                tf.cast(tf.equal(classes, groundtruth_lists), dtype=tf.float32))
            return accuracy
    

    二、预训练模型导入

            要将预训练模型 ResNet-50 的参数导入到前面定义好的模型,需要继续借助 tf.contrib.slim 模块,而且方法很简单,只需要在训练函数 slim.learning.train 中指定初始化参数来源函数 init_fn 即可,而这可以通过函数

    slim.assign_from_checkpoint_fn(model_path, var_list,
                                   ignore_missing_vars=False,
                                   reshape_variables=False)
    

    很方便的实现。其中,第一个参数 model_path 指定预训练模型 xxx.ckpt 文件的路径,第二个参数 var_list 指定需要导入对应预训练参数的所有变量,通过函数

    slim.get_variables_to_restore(include=None,
                                  exclude=None)
    

    可以快速指定,如果需要排除一些变量,也就是如果想让某些变量随机初始化而不是直接使用预训练模型来初始化,则直接在参数 exclude 中指定即可。第三个参数 ignore_missing_vars 非常重要,一定要将其设置为 True,也就是说,一定要忽略那些在定义的模型结构中可能存在的而在预训练模型中没有的变量,因为如果自己定义的模型结构中存在一个参数,而这些参数在预训练模型文件 xxx.ckpt 中没有,那么如果不忽略的话,就会导入失败(这样的变量很多,比如卷积层的偏置项 bias,一般预训练模型中没有,所以需要忽略,即使用默认的零初始化)。最后一个参数 reshape_variabels 指定对某些变量进行变形,这个一般用不到,使用默认的 False 即可。

            有了以上的基础,而且你还阅读过上一篇文章 TensorFlow-slim 训练 CNN 分类模型(续) 的话,那么整个使用预训练模型的训练文件 train.py 就很容易写出了,如下(重点在最后几行):

    #!/usr/bin/env python3
    # -*- coding: utf-8 -*-
    """
    Created on Fri Mar 30 19:27:44 2018
    
    @author: shirhe-lyh
    """
    
    """Train a CNN classification model via pretrained ResNet-50 model.
    
    Example Usage:
    ---------------
    python3 train.py \
        --resnet50_model_path: Path to pretrained ResNet-50 model.
        --record_path: Path to training tfrecord file.
        --logdir: Path to log directory.
    """
    
    import tensorflow as tf
    
    import model
    
    slim = tf.contrib.slim
    flags = tf.app.flags
    
    flags.DEFINE_string('record_path', None, 'Path to training tfrecord file.')
    flags.DEFINE_string('resnet50_model_path', None, 
                        'Path to pretrained ResNet-50 model.')
    flags.DEFINE_string('logdir', None, 'Path to log directory.')
    FLAGS = flags.FLAGS
    
    
    def get_record_dataset(record_path,
                           reader=None, image_shape=[224, 224, 3], 
                           num_samples=50000, num_classes=10):
        """Get a tensorflow record file.
        
        Args:
            
        """
        if not reader:
            reader = tf.TFRecordReader
            
        keys_to_features = {
            'image/encoded': 
                tf.FixedLenFeature((), tf.string, default_value=''),
            'image/format': 
                tf.FixedLenFeature((), tf.string, default_value='jpeg'),
            'image/class/label': 
                tf.FixedLenFeature([1], tf.int64, default_value=tf.zeros([1], 
                                   dtype=tf.int64))}
            
        items_to_handlers = {
            'image': slim.tfexample_decoder.Image(shape=image_shape, 
                                                  #image_key='image/encoded',
                                                  #format_key='image/format',
                                                  channels=3),
            'label': slim.tfexample_decoder.Tensor('image/class/label', shape=[])}
        
        decoder = slim.tfexample_decoder.TFExampleDecoder(
            keys_to_features, items_to_handlers)
        
        labels_to_names = None
        items_to_descriptions = {
            'image': 'An image with shape image_shape.',
            'label': 'A single integer between 0 and 9.'}
        return slim.dataset.Dataset(
            data_sources=record_path,
            reader=reader,
            decoder=decoder,
            num_samples=num_samples,
            num_classes=num_classes,
            items_to_descriptions=items_to_descriptions,
            labels_to_names=labels_to_names)
    
    
    def main(_):
        dataset = get_record_dataset(FLAGS.record_path, num_samples=79573, 
                                     num_classes=54)
        data_provider = slim.dataset_data_provider.DatasetDataProvider(dataset)
        image, label = data_provider.get(['image', 'label'])
        
        # Data augumentation
        image = tf.image.random_flip_left_right(image)
            
        inputs, labels = tf.train.batch([image, label],
                                        batch_size=64,
                                        allow_smaller_final_batch=True)
        
        cls_model = model.Model(is_training=True)
        preprocessed_inputs = cls_model.preprocess(inputs)
        prediction_dict = cls_model.predict(preprocessed_inputs)
        loss_dict = cls_model.loss(prediction_dict, labels)
        loss = loss_dict['loss']
        postprocessed_dict = cls_model.postprocess(prediction_dict)
        acc = cls_model.accuracy(postprocessed_dict, labels)
        tf.summary.scalar('loss', loss)
        tf.summary.scalar('accuracy', acc)
        
    
        #optimizer = tf.train.MomentumOptimizer(learning_rate=0.001, momentum=0.99)
        optimizer = tf.train.AdamOptimizer(learning_rate=0.0001)
        train_op = slim.learning.create_train_op(loss, optimizer,
                                                 summarize_gradients=True)
        
        variables_to_restore = slim.get_variables_to_restore()
        init_fn = slim.assign_from_checkpoint_fn(FLAGS.resnet50_model_path,
                                                 variables_to_restore,
                                                 ignore_missing_vars=True)
        
        slim.learning.train(train_op=train_op, logdir=FLAGS.logdir, 
                            init_fn=init_fn, 
                            save_summaries_secs=20, 
                            save_interval_secs=600)
        
    if __name__ == '__main__':
        tf.app.run()
    

    预告:下一篇文章将要介绍如何用 TensorFlow 来训练多任务多标签模型,敬请期待!

    相关文章

      网友评论

      • infilimi:你好,有个疑惑。定义网络结构时,添加新的全连接层后为什么不加上softmax的呢?
        公输睚信:可以加点,不过我把它写到 postprocess 函数里了
      • 2f88b36d6d79:谢谢作者分享,有一个小问题,请问预训练模型用的是哪种优化器呀,载入预训练模型的时候总是提示很多WARNING:tensorflow:Variable resnet_v2_50/block1/unit_1/bottleneck_v2/conv1/weights/Adam missing in checkpoint
        这样的警告是为什么
        2f88b36d6d79:@公输睚信 OK,谢谢
        公输睚信:@ItsPossible 这个警告无所谓。另外,如果你用 v2 模型的话,predict 函数里的这句 net, endpoints = nets.resnet_v1.resnet_v1_50(
        preprocessed_inputs, num_classes=None,
        is_training=self._is_training)

        改为 v2 的版本
        2f88b36d6d79:对了,我用的是v2的模型:grin:
      • 马儿柯孚:谢谢楼主的分享。想问下,导入预训练模型参数时,输出层(1000类)前面的全连接的参数是怎么处理的? 预训练模型中存在,新模型(如10类)中不需要的参数
        公输睚信:@mi_tf 对,多余的层参数会忽略掉,你自己的全连接层则会重新初始化
        马儿柯孚:@公输睚信 好的,这里我知道。我的意思是,在ImageNet预训练时会得到全连接层的权重参数,但是这一层权重参数不会用到我的新模型的全连接层(如10类)上,直接导入不会出错么?虽然跑楼主的程序没出错smile:

        或者,是否可以这样理解: 修改后的新模型会到预训练模型.ckpt文件去查找需要的权重参数,忽略那些多余的全连接层的参数
        公输睚信:net, endpoints = nets.resnet_v1.resnet_v1_50(
        preprocessed_inputs, num_classes=None,
        is_training=self._is_training)
        这里设置了 num_classes=None 表示禁用了最后的全连接层,输出的 net 是个卷积的 feature map,没有全连接层了
      • cpine:作者你好,pre-train的resnet50模型的ckpt文件中是不是所有层的biases都没有,我下载的ckpt文件finetune都提示没找到biases。作者你有包含weights和biases的ckpt文件吗?
        cpine:问题是提示在ckpt文件中找不到bias,说明调用的网络里面确实定义了bias,只是这个参数文件中没有。如果是网络中没有bias,ckpt文件中有,他就不会报错,会直接忽略了。怎么能说卷积层没有bias呢?
        公输睚信:卷积层一般都没有 biases
      • P0ny:写的太棒了 就喜欢这种最新用法 !之前的用法太老旧反锁 加油!
        公输睚信:@imagecainiao 所有代码都这本篇和这篇《 TensorFlow-slim 训练 CNN 分类模型(续)》里
        a3a3b6e1bc37:@公输睚信 你好,有完整的代码吗,新人想拿这个来学习学习,谢谢了
        公输睚信:谢谢,同加油!

      本文标题:TensorFlow 使用预训练模型 ResNet-50

      本文链接:https://www.haomeiwen.com/subject/kbyalftx.html