美文网首页
Contextual Attention论文与源码解析与结果

Contextual Attention论文与源码解析与结果

作者: 抬头挺胸才算活着 | 来源:发表于2020-04-13 16:17 被阅读0次

    参考资料:
    [1] Generative Image Inpainting with Contextual Attention
    [2] question about the inpaint_ops.py
    [3] about the contextual attention
    [4] inpaint_ops / contextual attention

    • Contextual Attention
      在GAN生成图片中,很多都不能从自己图片中复制到需要的内容,传统方法PatchMatch这点做的很好,但是却不能产生图中没有的内容,所以[1]构造的模型提出了Contextual Attention,可以很好地复制图片中的内容。

    下面是我看contextual_attention函数做的注释,大致弄懂了。
    跟上图不一样的是f和b实际上大小是一样的,mask指定b中被污染的地方,f是前一阶段产生出来的大致的预测,经过图中的卷积,可以得出f中大概跟背景那块比较相似,后面再接一个deconv层可以把b中相似的地方给"借"过来,做实验可以看到输出的y跟输入的b是完全一样的,说明经过这个操作,f可以借到b中任何想要的内容。

    def contextual_attention(f, b, mask=None, ksize=3, stride=1, rate=1,
                             fuse_k=3, softmax_scale=10., training=True, fuse=True):
        """ Contextual attention layer implementation.
    
        Contextual attention is first introduced in publication:
            Generative Image Inpainting with Contextual Attention, Yu et al.
    
        Args:
            x: Input feature to match (foreground).
            t: Input feature for match (background).
            mask: Input mask for t, indicating patches not available.
                  跟b一样大小全零代表一张完整的好图,有缺陷的地方为1
            ksize: Kernel size for contextual attention.
            stride: Stride for extracting patches from t.
            rate: Dilation for matching.
            softmax_scale: Scaled softmax for attention.
            training: Indicating if current graph is training or inference.
    
        Returns:
            tf.Tensor: output
    
        """
        # sess = tf.InteractiveSession()
    
        # get shapes
        raw_fs = tf.shape(f)
        raw_int_fs = f.get_shape().as_list()
        raw_int_bs = b.get_shape().as_list()
        # extract patches from background with stride and rate
        kernel = 2*rate
        # 这里跟下面的不同是,卷积核是像空洞卷积一样还是连在一起,这个模型选择了连在一起
        # raw_w = tf.extract_image_patches(
        #     b, [1,kernel,kernel,1], [1,stride,stride,1], [1,rate,rate,1], padding='SAME')
        raw_w = tf.extract_image_patches(
            b, [1,kernel,kernel,1], [1,rate*stride,rate*stride,1], [1,1,1,1], padding='SAME')
    
        # 这里两步不能直接用reshape搞定,看test1
        raw_w = tf.reshape(raw_w, [raw_int_bs[0], -1, kernel, kernel, raw_int_bs[3]])
        raw_w = tf.transpose(raw_w, [0, 2, 3, 4, 1])  # transpose to b*k*k*c*hw
    
        # downscaling foreground option: downscaling both foreground and
        # background for matching and use original background for reconstruction.
        f = resize(f, scale=1./rate, func=tf.image.resize_nearest_neighbor)
        # 为什么这里改为to_shape了??
        b = resize(b, to_shape=[int(raw_int_bs[1]/rate), int(raw_int_bs[2]/rate)], func=tf.image.resize_nearest_neighbor)  # https://github.com/tensorflow/tensorflow/issues/11651
        if mask is not None:
            mask = resize(mask, scale=1./rate, func=tf.image.resize_nearest_neighbor)
    
        # 缩放后重新获取shape
        fs = tf.shape(f)
        int_fs = f.get_shape().as_list()
    
        # 每张图片切成一份tensor
        f_groups = tf.split(f, int_fs[0], axis=0)
    
        # from t(H*W*C) to w(b*k*k*c*h*w)
        bs = tf.shape(b)
        int_bs = b.get_shape().as_list()
        w = tf.extract_image_patches(
            b, [1,ksize,ksize,1], [1,stride,stride,1], [1,1,1,1], padding='SAME')
        w = tf.reshape(w, [int_fs[0], -1, ksize, ksize, int_fs[3]])
        w = tf.transpose(w, [0, 2, 3, 4, 1])  # transpose to b*k*k*c*hw
    
        # process mask
        # mask跟b的区别:只有一个batch,只有一个通道,图片大小一样
        if mask is None:
            mask = tf.zeros([1, bs[1], bs[2], 1])
    
        m = tf.extract_image_patches(
            mask, [1,ksize,ksize,1], [1,stride,stride,1], [1,1,1,1], padding='SAME')
        m = tf.reshape(m, [1, -1, ksize, ksize, 1])
        m = tf.transpose(m, [0, 2, 3, 4, 1])  # transpose to b*k*k*c*hw
        m = m[0]
        # 每个采样出来的patch对应一个mean,mm的大小跟b的像素点个数一样多
        mm = tf.cast(tf.equal(tf.reduce_mean(m, axis=[0,1,2], keep_dims=True), 0.), tf.float32)
    
        # 每张图片对应一组w,k*k*c*hw
        w_groups = tf.split(w, int_bs[0], axis=0)
        raw_w_groups = tf.split(raw_w, int_bs[0], axis=0)
    
        y = []
        offsets = []
        k = fuse_k
        scale = softmax_scale
        fuse_weight = tf.reshape(tf.eye(k), [k, k, 1, 1])
        # 这里跟一般的卷积还不一样,一般的卷积都是每个batch都是一样的系数
        # 这里每幅图片都是各自的卷积核系数
        for xi, wi, raw_wi in zip(f_groups, w_groups, raw_w_groups):
            # conv for compare
            wi = wi[0]
    
            # 对权重进行归一化
            wi_normed = wi / tf.maximum(tf.sqrt(tf.reduce_sum(tf.square(wi), axis=[0,1,2])), 1e-4)
    
            # 最关键的一步卷积
            yi = tf.nn.conv2d(xi, wi_normed, strides=[1,1,1,1], padding="SAME")
    
            # conv implementation for fuse scores to encourage large patches
            if fuse:
                yi = tf.reshape(yi, [1, fs[1]*fs[2], bs[1]*bs[2], 1])
                yi = tf.nn.conv2d(yi, fuse_weight, strides=[1,1,1,1], padding='SAME')
                yi = tf.reshape(yi, [1, fs[1], fs[2], bs[1], bs[2]])
                yi = tf.transpose(yi, [0, 2, 1, 4, 3])
                yi = tf.reshape(yi, [1, fs[1]*fs[2], bs[1]*bs[2], 1])
                yi = tf.nn.conv2d(yi, fuse_weight, strides=[1,1,1,1], padding='SAME')
                yi = tf.reshape(yi, [1, fs[2], fs[1], bs[2], bs[1]])
                yi = tf.transpose(yi, [0, 2, 1, 4, 3])
    
            # 保留f的长宽
            yi = tf.reshape(yi, [1, fs[1], fs[2], bs[1]*bs[2]])
    
            # lyc1 = yi.eval()[0,0,0,:]
    
            # 如果mask所在领域有任意一个为1,代表有缺陷的值,那么这个邻域代表的filter卷积出来的跟f一样大小的feature map
            # 会乘完之后变成0
            # softmax to match
            yi *= mm  # mask
            # 对应f的每个像素点所在邻域,所有filter一起算,看那个出来的值比较大。
            yi = tf.nn.softmax(yi*scale, 3)
            yi *= mm  # mask
    
            # lyc1 = yi.eval()[0,:,:,0]
            # lyc2 = yi.eval()[0,:,:,1]
    
            # 在b中的偏移,//和%运算符是因为offset这个轴的长度是fs[1]*fs[2]
            offset = tf.argmax(yi, axis=3, output_type=tf.int32)
            offset = tf.stack([offset // fs[2], offset % fs[2]], axis=-1)
    
            # deconv for patch pasting
            # 3.1 paste center
            wi_center = raw_wi[0]
            # lyc3 = wi_center.eval()[:,:,0,:]
            # lyc4 = wi_center.eval()[:,:,0,0]
            # print(repr(lyc3))
            # print(repr(lyc4))
            # 为什么除以4??
            yi = tf.nn.conv2d_transpose(yi, wi_center, tf.concat([[1], raw_fs[1:]], axis=0), strides=[1,rate,rate,1]) / 4.
            # lyc2 = yi.eval()[0,:,:,0]
    
            y.append(yi)
            offsets.append(offset)
    
        y = tf.concat(y, axis=0)
        y.set_shape(raw_int_fs)
        offsets = tf.concat(offsets, axis=0)
        offsets.set_shape(int_bs[:3] + [2])
        # case1: visualize optical flow: minus current position
        h_add = tf.tile(tf.reshape(tf.range(bs[1]), [1, bs[1], 1, 1]), [bs[0], 1, bs[2], 1])
        w_add = tf.tile(tf.reshape(tf.range(bs[2]), [1, 1, bs[2], 1]), [bs[0], bs[1], 1, 1])
        # 由绝对偏移转向相对偏移
        offsets = offsets - tf.concat([h_add, w_add], axis=3)
        # to flow image
        flow = flow_to_image_tf(offsets)
        # # case2: visualize which pixels are attended
        # flow = highlight_flow_tf(offsets * tf.cast(mask, tf.int32))
        if rate != 1:
            flow = resize(flow, scale=rate, func=tf.image.resize_nearest_neighbor)
        return y, flow
    
    • wgan loss
      wgan loss只有两行,是根据下面公式计算出来的


    def gan_wgan_loss(pos, neg, name='gan_wgan_loss'):
        """
        wgan loss function for GANs.
    
        - Wasserstein GAN: https://arxiv.org/abs/1701.07875
        """
        with tf.variable_scope(name):
            d_loss = tf.reduce_mean(neg-pos)
            g_loss = -tf.reduce_mean(neg)
            scalar_summary('d_loss', d_loss)
            scalar_summary('g_loss', g_loss)
            scalar_summary('pos_value_avg', tf.reduce_mean(pos))
            scalar_summary('neg_value_avg', tf.reduce_mean(neg))
        return g_loss, d_loss
    

    相关文章

      网友评论

          本文标题:Contextual Attention论文与源码解析与结果

          本文链接:https://www.haomeiwen.com/subject/ydddkctx.html