美文网首页深度学习
TensorFlow 实现语义分割模型:DeepLab V3+(

TensorFlow 实现语义分割模型:DeepLab V3+(

作者: 公输睚信 | 来源:发表于2019-03-11 17:57 被阅读20次

    本文将实现 deeplab v3 + 模型(参考:DeepLab 官方开源代码

    # -*- coding: utf-8 -*-
    """
    Created on Mon Dec  3 17:57:46 2018
    
    @author: shirhe-lyh
    
    
    Implementation of DeepLab V3+:
        Encoder-Decoder with atrous seperable convolutioon for semantic image
        segmentation, Liang-Chieh Chen, et. al., arXiv:1802.02611v3.
    """
    
    import numpy as np
    import tensorflow as tf
    
    from tensorflow.contrib.slim import nets
    
    import preprocessing
    import resnet_v1_beta
    
    slim = tf.contrib.slim
    
    
    class DeepLab(object):
        """Implementation of DeepLab V3+."""
        
        def __init__(self,
                     is_training,
                     num_classes=3,
                     output_stride=16,
                     atrous_rates=[6, 12, 18],  # [12, 24, 36] for output_stride=8
                     decoder_output_stride=4,
                     default_image_size=513,
                     fine_tune_batch_norm=False):
            """Constructor.
            
            Args:
                is_training: A boolean indicating whether the training version of
                    computation graph should be constructed.
                num_classes: The number of classes.
                defualt_image_size: The input size of the model.
            """
            self._is_training = is_training
            self._num_classes = num_classes
            self._output_stride = output_stride
            self._atrous_rates = atrous_rates
            self._decoder_output_stride = decoder_output_stride
            self._default_image_size = default_image_size
            
            # When fine_tune_batch_norm=True, use at least batch size larger than 
            # 12 (batch size more than 16 is better). Otherwise, one could use 
            # smaller batch size and set fine_tune_batch_norm=False.
            _is_training = is_training and fine_tune_batch_norm
            self._batch_norm_params = {'is_training': _is_training,
                                       'epsilon': 1e-5,
                                       'decay': 0.9997,
                                       'scale': True}
            
        @property
        def default_image_size(self):
            return self._default_image_size
            
        def preprocess(self, images=None, masks=None):
            """Preprocessing.
            
            Args:
                images: A float32 tensor with shape [batch_size, height, width,
                    3] representing a batch of images. Only passed values in case
                    of test (i.e., in training case images=None).
                masks: A float32 tensor with shape [batch_size, height, width, 1] 
                    representing a batch of groundtruth masks.
                
            Returns:
                The preprocessed inputs.
            """
            
            preprocessed_dict = {'images': images_preprocessed,
                                 'masks': trimaps_preprocessed}
            return preprocessed_dict
        
        def _preprocess_zero_mean_unit_range(self, inputs):
            """Map image values from [0, 255] to [-1, 1].
            
            Only for beta version.
            """
            return (2.0 / 255.0) * tf.to_float(inputs) - 1.0
        
        def predict(self, preprocessed_inputs):
            """Predict prediction tensors from inputs tensor.
            
            Outputs of this function can be passed to loss or postprocess functions.
            
            Args:
                preprocessed_inputs: A 4-D float32 tensor with shape [batch_size, 
                    height, width, channels].
                
            Returns:
                The prediction tensors to be passed to the Loss or Postprocess 
                functions.
            """
            # ResNet-50
            with slim.arg_scope(nets.resnet_v1.resnet_arg_scope()):
                net, end_points = resnet_v1_beta.resnet_v1_50_beta(
                    preprocessed_inputs, num_classes=None,
                    is_training=self._is_training,
                    multi_grid=[1, 2, 4],
                    global_pool=False,
                    output_stride=self._output_stride)
                
            # Use the same scope with ResNet-50
            scope='resnet_v1_50'
            
            # Atrous spatial pyramid pooling
            net = self._atrous_spatial_pyramid_pooling(
                net, atrous_rates=self._atrous_rates, scope=scope)
            
            # Refine by decoder
            decoder_height = self.default_image_size // self._decoder_output_stride
            decoder_width = self.default_image_size // self._decoder_output_stride
            net = self._refine_by_decoder(
                net,
                end_points,
                decoder_height=decoder_height,
                decoder_width=decoder_width,
                decoder_use_seperable_conv=True,
                is_training=self._is_training)
            
            # Convolution
            net = self._get_branch_logits(net, self._num_classes,
                                          self._atrous_rates, kernel_size=1)
            net = tf.image.resize_bilinear(net, size=[self._default_image_size,
                                                      self._default_image_size],
                                           align_corners=True,
                                           name='upsampling_logits')
            return net
        
        def split_seperable_conv2d(self,
                                   inputs,
                                   filters,
                                   kernel_size=3,
                                   rate=1,
                                   weight_decay=0.00004,
                                   depthwise_weights_initializer_stddev=0.33,
                                   pointwise_weights_initializer_stddev=0.06,
                                   scope=None):
            """Splits a seperable conv2d into depthwise and pointwise conv2d.
            
            This operation differs from `tf.layers.separable_conv2d` as this 
            operation applies activation function between depthwise and pointwise 
            conv2d.
            
            Copy from:
                https://github.com/tensorflow/models/blob/master/research/deeplab/
                core/utils.py
                
            Args:
                inputs: Input tensor with shape [batch, height, width, channels].
                filters: Number of filters in the 1x1 pointwise convolution.
                kernel_size: A list of length 2: [kernel_height, kernel_width] of
                    of the filters. Can be an int if both values are the same.
                rate: Atrous convolution rate for the depthwise convolution.
                weight_decay: The weight decay to use for regularizing the model.
                depthwise_weights_initializer_stddev: The standard deviation of the
                    truncated normal weight initializer for depthwise convolution.
                pointwise_weights_initializer_stddev: The standard deviation of the
                    truncated normal weight initializer for pointwise convolution.
                scope: Optional scope for the operation.
                
            Returns:
                Computed features after split separable conv2d.
            """
            outputs = slim.separable_conv2d(
                inputs,
                None,
                kernel_size=kernel_size,
                depth_multiplier=1,
                rate=rate,
                weights_initializer=tf.truncated_normal_initializer(
                    stddev=depthwise_weights_initializer_stddev),
                weights_regularizer=None,
                scope=scope + '_depthwise')
            return slim.conv2d(
                outputs,
                filters,
                1,
                weights_initializer=tf.truncated_normal_initializer(
                    stddev=pointwise_weights_initializer_stddev),
                weights_regularizer=slim.l2_regularizer(weight_decay),
                scope=scope + '_pointwise')
                  
        def _atrous_spatial_pyramid_pooling(self, feature_map, weight_decay=0.0001,
                                            atrous_rates=[12, 24, 36],
                                            scope='resnet_v1_50'):
            """Atrous spatial pyramid pooling for DeepLab v3."""
            branch_nets = []
            # Convolution
            with tf.variable_scope(scope):
                with slim.arg_scope([slim.conv2d, slim.separable_conv2d], 
                                    weights_regularizer=slim.l2_regularizer(
                                        weight_decay),
                                    normalizer_fn=slim.batch_norm,
                                    normalizer_params=self._batch_norm_params):
                    depth=256
                    
                    # Image pooling feature
                    shape = tf.shape(feature_map)[1:3]
                    image_feature = tf.reduce_mean(feature_map, axis=[1, 2],
                                                   keep_dims=True)
                    image_feature = slim.conv2d(image_feature, kernel_size=1,
                                                num_outputs=depth,
                                                scope='global_pool')
                    image_feature = tf.image.resize_bilinear(image_feature, 
                                                             size=shape,
                                                             align_corners=True)
                    branch_nets.append(image_feature)
                    
                    # Employ a 1x1 convolution
                    branch_nets.append(slim.conv2d(feature_map, kernel_size=1,
                                                   num_outputs=depth,
                                                   scope='aspp' + str(0)))          
                    
                    # Employ 3x3 convolutions with different atrous rates.
                    for i, rate in enumerate(atrous_rates, 1):
                        scope =scope + 'aspp' + str(i)
                        aspp_net = self.split_seperable_conv2d(
                            feature_map,
                            filters=depth,
                            rate=rate,
                            weight_decay=weight_decay,
                            scope=scope)
                        branch_nets.append(aspp_net)
            
            # Concatenation
            net = tf.concat(branch_nets, axis=3, name='aspp_concate')
            net = slim.conv2d(net, depth, kernel_size=1, 
                              scope=scope + '/concat_projection')
            net = slim.dropout(net, keep_prob=0.9, is_training=self._is_training,
                               scope= scope + '/concat_projection_dropout')
            return net
        
        def _refine_by_decoder(self,
                               feature_map,
                               end_points,
                               decoder_height,
                               decoder_width,
                               decoder_use_seperable_conv=False,
                               weight_decay=0.0001,
                               reuse=None,
                               is_training=False,
                               scope='resnet_v1_50'):
            """Adds the decoder to obtain sharper segmentation results.
            
            Args:
                feature_map: A tensor with shape [batch_size, height, width, depth].
                end_points: A dictionary from components of the network to the 
                    corresponding activation.
                decoder_height: The height of decoder feature maps.
                decoder_width: The width of decoder feature maps.
                decoder_use_seperable_conv: Employ seperable convolution for 
                    decoder or not.
                weight_decay: The weight decay for model variables.
                reuse: Reuse the model variables or not.
                is_training: Is training or not.
                #fine_tune_batch_norm: Fine-tune the batch norm parameters or not.
                
            Returns:
                Decoder output size [batch_size, decoder_height, decoder_width,
                decoder_depth].
            """
            with tf.variable_scope(scope):
                with slim.arg_scope([slim.conv2d, slim.separable_conv2d],
                                    weights_regularizer=slim.l2_regularizer(
                                        weight_decay),
                                    normalizer_fn=slim.batch_norm,
                                    normalizer_params=self._batch_norm_params,
                                    reuse=reuse):
                    feature_list = ['block1/unit_2/bottleneck_v1/conv3']
                    decoder_features = feature_map
                    for i, name in enumerate(feature_list):
                        decoder_features_list = [decoder_features]
                        feature_name = '{}/{}'.format('resnet_v1_50', name)
                        decoder_features_list.append(
                            slim.conv2d(end_points[feature_name], 48, 1,
                                       scope='feature_project' + str(i)))
                        for j, feature in enumerate(decoder_features_list):
                            decoder_features_list[j] = tf.image.resize_bilinear(
                                feature, [decoder_height, decoder_width], 
                                align_corners=True)
                            h = (None if isinstance(decoder_height, tf.Tensor)
                                 else decoder_height)
                            w = (None if isinstance(decoder_width, tf.Tensor)
                                 else decoder_width)
                            decoder_features_list[j].set_shape([None, h, w, None])
                        decoder_depth = 256
                        if decoder_use_seperable_conv:
                            decoder_features = self.split_seperable_conv2d(
                                tf.concat(decoder_features_list, axis=3),
                                filters=decoder_depth,
                                rate=1,
                                weight_decay=weight_decay,
                                scope='decoder_conv0')
                            decoder_features = self.split_seperable_conv2d(
                                decoder_features,
                                filters=decoder_depth,
                                rate=1,
                                weight_decay=weight_decay,
                                scope='decoder_conv1')
                        else:
                            num_convs = 2
                            decoder_features = slim.repeat(
                                tf.concat(decoder_features_list, axis=3),
                                num_convs,
                                slim.conv2d,
                                decoder_depth,
                                3,
                                scope='decoder_conv' + str(i))
                    return decoder_features
                
        def _get_branch_logits(self,
                               feature_map,
                               num_classes,
                               atrous_rates=[12, 24, 36],
                               kernel_size=1,
                               weight_decay=0.0001,
                               reuse=None,
                               scope_suffix='seg_logits',
                               scope='resnet_v1_50'):
            """Gets the logits from each model's branch.
            
            The underlying model is branched out in the last layer when atrous
            spatial pyramid pooling is employed, and all branches are sum-merged
            to form the final logits.
            
            Args:
                feature_map: A float32 tensor with shape [batch_size, height,
                    width, channels].
                num_classes: Number of classes to predict.
                atrous_rates: A list of atrous convolution rates for last layer.
                kernel_size: Kernel size for convolution.
                weight_decay: Weight decay for the model variables.
                reuse: Reuse model variables or not.
                scope_suffix: Scope suffix for the model variables.
                
            Returns:
                Merged logits with shape [batch_size, height, width, num_classes].
            """
            with tf.variable_scope(scope):
                with slim.arg_scope(
                    [slim.conv2d],
                    weights_regularizer=slim.l2_regularizer(
                        weight_decay),
                    weights_initializer=tf.truncated_normal_initializer(stddev=0.01),
                    reuse=reuse):
                    branch_logits = []
                    for i, rate in enumerate(atrous_rates):
                        scope = scope_suffix
                        if i:
                            scope += '_%d' % i
                            
                        branch_logits.append(
                            slim.conv2d(feature_map,
                                        num_classes,
                                        kernel_size=kernel_size,
                                        rate=rate,
                                        activation_fn=None,
                                        normalizer_fn=None,
                                        scope=scope))
                return tf.add_n(branch_logits)
        
        def postprocess(self, prediction_tensors):
            """Convert predicted output tensors to final forms.
            
            Args:
                prediction_tensors: The prediction tensors.
                    
            Returns:
                The postprocessed results.
            """
            logits = tf.nn.softmax(prediction_tensors, axis=3)
            return logits
        
        def loss(self, prediction_tensors, groundtruth_tensors):
            """Compute scalar loss tensors with respect to provided groundtruth."""
            logits = tf.reshape(prediction_tensors, shape=[-1, self._num_classes])
            labels = tf.reshape(groundtruth_tensors, shape=[-1,])
            labels = tf.where(tf.greater(labels, 0.8),
                              tf.ones_like(labels),
                              labels)
            labels = tf.where(tf.logical_and(tf.less_equal(labels, 0.8),
                                             tf.greater(labels, 0.0)),
                              2 * tf.ones_like(labels),
                              labels)
            labels = tf.cast(labels, dtype=tf.int32)
            slim.losses.sparse_softmax_cross_entropy(logits, labels)
            loss = slim.losses.get_total_loss()
            return loss
    

    相关文章

      网友评论

        本文标题:TensorFlow 实现语义分割模型:DeepLab V3+(

        本文链接:https://www.haomeiwen.com/subject/vhabpqtx.html