美文网首页
Image Inpainting for Irregular H

Image Inpainting for Irregular H

作者: LuDon | 来源:发表于2019-07-24 15:58 被阅读0次

    引言

    本文是由NVIDIA提出的一种基于局部卷积的图像修复算法。
    图像修复,即修复图像中缺失的块,可用于图像编辑,替换掉图像中不想要的内容。本文使用自动mask更新的局部卷积网络进行图像修复。

    方法

    partial convolutional layer

    假设W是卷积核的权重,b是相应的偏差。X是当前卷积窗口的特征值,M是相应的二进制mask。则卷积计算为
    x'=\begin{cases} W^T (X*M) \frac{sum(1)}{sum(M)} + b,\quad sum(M)\leq 0 \\\\ 0,\quad otherwise \end{cases}
    由上式可知,输出是由没有mask的输入决定。
    在局部卷积操作之后,需要更新mask:
    x'=\begin{cases} 1,\quad sum(M)\leq 0 \\\\ 0,\quad otherwise \end{cases}

    网络结构

    整体网络使用UNet结构,将所有的卷积换成局部卷积层,在decoder阶段使用最近邻上采样。skip连接连接两个特征图和两个mask作为下一层的输入,最后一个卷积层的输入为有洞的原始输入和原始mask的组成。

    损失函数

    给定有洞的输入I_{in},初始化二进制maskM(有洞的地方为0),网络的输出为I_{out},原始的图像为I_{gt}
    1、像素损失:
    L_{hole} = ||(1-M)*(I_{out}-I_{gt})||_1
    L_{valid} = ||M*(I_{out}-I_{gt})||_1
    2、感知损失
    L_{perceptual} = \sum_{n=0}^{N-1}||\psi(I_{out}) - \psi(I_{gt}) ||_1 + \sum_{n=0}^{N-1}||\psi(I_{comp}) - \psi(I_{gt}) ||_1
    其中,I_{comp}是未加工的输出图像I_{output}\psi使用vgg16的pool1,pool2,pool3
    3、风格损失
    L_{style_{out}} = \sum_{n=0}^{N-1}||K_n ((\psi(I_{out}))^T(\psi(I_{out}) - (\psi(I_{gt}))^T(\psi(I_{gt}) ||_1 +
    L_{style_{comp}} = sum_{n=0}^{N-1}||K_n ((\psi(I_{comp}))^T(\psi(I_{comp}) - (\psi(I_{gt}))^T(\psi(I_{gt}) ||_1

    4、全变差损失
    L_{tv} = \sum ||I_{comp}^{i,j+1} - I_{comp}^{i,j}||_1 + \sum ||I_{comp}^{i+1,j} - I_{comp}^{i,j}||_1
    总的损失函数为:
    L_{total} = L_{valid} + 6L_{hole} + 0.05 L_{perceptual} + 120 (L_{style_{out}} + L_{style_{comp}}) + 0.1L_{tv}

    代码分析

    1、局部卷积层

    from keras.utils import conv_utils
    from keras import backend as K
    from keras.engine import InputSpec
    from keras.layers import Conv2D
    
    class PConv2D(Conv2D):
        def __init__(self, *args, n_channels=3, mono=False, **kwargs):
            super().__init__(*args, **kwargs)
            self.input_spec = [InputSpec(ndim=4), InputSpec(ndim=4)]
        def build(self, input_shape):        
            if self.data_format == 'channels_first':
                channel_axis = 1
            else:
                channel_axis = -1
                
            if input_shape[0][channel_axis] is None:
                raise ValueError('The channel dimension of the inputs should be defined. Found `None`.')
                
            self.input_dim = input_shape[0][channel_axis]
      
            kernel_shape = self.kernel_size + (self.input_dim, self.filters)
            self.kernel = self.add_weight(shape=kernel_shape,
                                          initializer=self.kernel_initializer,
                                          name='img_kernel',
                                          regularizer=self.kernel_regularizer,
                                          constraint=self.kernel_constraint)
    
            self.kernel_mask = K.ones(shape=self.kernel_size + (self.input_dim, self.filters))
    
            # Calculate padding size to achieve zero-padding
            self.pconv_padding = (
                (int((self.kernel_size[0]-1)/2), int((self.kernel_size[0]-1)/2)), 
                (int((self.kernel_size[0]-1)/2), int((self.kernel_size[0]-1)/2)), 
            )
    
            # Window size - used for normalization
            self.window_size = self.kernel_size[0] * self.kernel_size[1]
            
            if self.use_bias:
                self.bias = self.add_weight(shape=(self.filters,),
                                            initializer=self.bias_initializer,
                                            name='bias',
                                            regularizer=self.bias_regularizer,
                                            constraint=self.bias_constraint)
            else:
                self.bias = None
            self.built = True
    
        def call(self, inputs, mask=None):
    
            if type(inputs) is not list or len(inputs) != 2:
                raise Exception('PartialConvolution2D must be called on a list of two tensors [img, mask]. Instead got: ' + str(inputs))
    
            # Padding done explicitly so that padding becomes part of the masked partial convolution
            images = K.spatial_2d_padding(inputs[0], self.pconv_padding, self.data_format)
            masks = K.spatial_2d_padding(inputs[1], self.pconv_padding, self.data_format)
            # Apply convolutions to mask
            mask_output = K.conv2d(
                masks, self.kernel_mask, 
                strides=self.strides,
                padding='valid',
                data_format=self.data_format,
                dilation_rate=self.dilation_rate
            )
            # Apply convolutions to image
            img_output = K.conv2d(
                (images*masks), self.kernel, 
                strides=self.strides,
                padding='valid',
                data_format=self.data_format,
                dilation_rate=self.dilation_rate
            )        
            # Calculate the mask ratio on each pixel in the output mask
            mask_ratio = self.window_size / (mask_output + 1e-8)
            # Clip output to be between 0 and 1
            mask_output = K.clip(mask_output, 0, 1)
            # Remove ratio values where there are holes
            mask_ratio = mask_ratio * mask_output
            # Normalize iamge output
            img_output = img_output * mask_ratio
            # Apply bias only to the image (if chosen to do so)
            if self.use_bias:
                img_output = K.bias_add(
                    img_output,
                    self.bias,
                    data_format=self.data_format)
            
            # Apply activations on the image
            if self.activation is not None:
                img_output = self.activation(img_output)
             
            return [img_output, mask_output]
        
        def compute_output_shape(self, input_shape):
            if self.data_format == 'channels_last':
                space = input_shape[0][1:-1]
                new_space = []
                for i in range(len(space)):
                    new_dim = conv_utils.conv_output_length(
                        space[i],
                        self.kernel_size[i],
                        padding='same',
                        stride=self.strides[i],
                        dilation=self.dilation_rate[i])
                    new_space.append(new_dim)
                new_shape = (input_shape[0][0],) + tuple(new_space) + (self.filters,)
                return [new_shape, new_shape]
            if self.data_format == 'channels_first':
                space = input_shape[2:]
                new_space = []
                for i in range(len(space)):
                    new_dim = conv_utils.conv_output_length(
                        space[i],
                        self.kernel_size[i],
                        padding='same',
                        stride=self.strides[i],
                        dilation=self.dilation_rate[i])
                    new_space.append(new_dim)
                new_shape = (input_shape[0], self.filters) + tuple(new_space)
                return [new_shape, new_shape]
    

    2、损失函数

        def loss_hole(self, mask, y_true, y_pred):
            """Pixel L1 loss within the hole / mask"""
            return self.l1((1-mask) * y_true, (1-mask) * y_pred)
        
        def loss_valid(self, mask, y_true, y_pred):
            """Pixel L1 loss outside the hole / mask"""
            return self.l1(mask * y_true, mask * y_pred)
        
        def loss_perceptual(self, vgg_out, vgg_gt, vgg_comp): 
            """Perceptual loss based on VGG16, see. eq. 3 in paper"""       
            loss = 0
            for o, c, g in zip(vgg_out, vgg_comp, vgg_gt):
                loss += self.l1(o, g) + self.l1(c, g)
            return loss
            
        def loss_style(self, output, vgg_gt):
            """Style loss based on output/computation, used for both eq. 4 & 5 in paper"""
            loss = 0
            for o, g in zip(output, vgg_gt):
                loss += self.l1(self.gram_matrix(o), self.gram_matrix(g))
            return loss
        
        def loss_tv(self, mask, y_comp):
            """Total variation loss, used for smoothing the hole region, see. eq. 6"""
    
            # Create dilated hole region using a 3x3 kernel of all 1s.
            kernel = K.ones(shape=(3, 3, mask.shape[3], mask.shape[3]))
            dilated_mask = K.conv2d(1-mask, kernel, data_format='channels_last', padding='same')
    
            # Cast values to be [0., 1.], and compute dilated hole region of y_comp
            dilated_mask = K.cast(K.greater(dilated_mask, 0), 'float32')
            P = dilated_mask * y_comp
    
            # Calculate total variation loss
            a = self.l1(P[:,1:,:,:], P[:,:-1,:,:])
            b = self.l1(P[:,:,1:,:], P[:,:,:-1,:])        
            return a+b
    

    参考文献

    [1]Image Inpainting for Irregular Holes Using
    Partial Convolutions

    相关文章

      网友评论

          本文标题:Image Inpainting for Irregular H

          本文链接:https://www.haomeiwen.com/subject/brexrctx.html