美文网首页
Contextual Attention论文与源码解析与结果

Contextual Attention论文与源码解析与结果

作者: 抬头挺胸才算活着 | 来源:发表于2020-04-13 16:17 被阅读0次

参考资料:
[1] Generative Image Inpainting with Contextual Attention
[2] question about the inpaint_ops.py
[3] about the contextual attention
[4] inpaint_ops / contextual attention

  • Contextual Attention
    在GAN生成图片中,很多都不能从自己图片中复制到需要的内容,传统方法PatchMatch这点做的很好,但是却不能产生图中没有的内容,所以[1]构造的模型提出了Contextual Attention,可以很好地复制图片中的内容。

下面是我看contextual_attention函数做的注释,大致弄懂了。
跟上图不一样的是f和b实际上大小是一样的,mask指定b中被污染的地方,f是前一阶段产生出来的大致的预测,经过图中的卷积,可以得出f中大概跟背景那块比较相似,后面再接一个deconv层可以把b中相似的地方给"借"过来,做实验可以看到输出的y跟输入的b是完全一样的,说明经过这个操作,f可以借到b中任何想要的内容。

def contextual_attention(f, b, mask=None, ksize=3, stride=1, rate=1,
                         fuse_k=3, softmax_scale=10., training=True, fuse=True):
    """ Contextual attention layer implementation.

    Contextual attention is first introduced in publication:
        Generative Image Inpainting with Contextual Attention, Yu et al.

    Args:
        x: Input feature to match (foreground).
        t: Input feature for match (background).
        mask: Input mask for t, indicating patches not available.
              跟b一样大小全零代表一张完整的好图,有缺陷的地方为1
        ksize: Kernel size for contextual attention.
        stride: Stride for extracting patches from t.
        rate: Dilation for matching.
        softmax_scale: Scaled softmax for attention.
        training: Indicating if current graph is training or inference.

    Returns:
        tf.Tensor: output

    """
    # sess = tf.InteractiveSession()

    # get shapes
    raw_fs = tf.shape(f)
    raw_int_fs = f.get_shape().as_list()
    raw_int_bs = b.get_shape().as_list()
    # extract patches from background with stride and rate
    kernel = 2*rate
    # 这里跟下面的不同是,卷积核是像空洞卷积一样还是连在一起,这个模型选择了连在一起
    # raw_w = tf.extract_image_patches(
    #     b, [1,kernel,kernel,1], [1,stride,stride,1], [1,rate,rate,1], padding='SAME')
    raw_w = tf.extract_image_patches(
        b, [1,kernel,kernel,1], [1,rate*stride,rate*stride,1], [1,1,1,1], padding='SAME')

    # 这里两步不能直接用reshape搞定,看test1
    raw_w = tf.reshape(raw_w, [raw_int_bs[0], -1, kernel, kernel, raw_int_bs[3]])
    raw_w = tf.transpose(raw_w, [0, 2, 3, 4, 1])  # transpose to b*k*k*c*hw

    # downscaling foreground option: downscaling both foreground and
    # background for matching and use original background for reconstruction.
    f = resize(f, scale=1./rate, func=tf.image.resize_nearest_neighbor)
    # 为什么这里改为to_shape了??
    b = resize(b, to_shape=[int(raw_int_bs[1]/rate), int(raw_int_bs[2]/rate)], func=tf.image.resize_nearest_neighbor)  # https://github.com/tensorflow/tensorflow/issues/11651
    if mask is not None:
        mask = resize(mask, scale=1./rate, func=tf.image.resize_nearest_neighbor)

    # 缩放后重新获取shape
    fs = tf.shape(f)
    int_fs = f.get_shape().as_list()

    # 每张图片切成一份tensor
    f_groups = tf.split(f, int_fs[0], axis=0)

    # from t(H*W*C) to w(b*k*k*c*h*w)
    bs = tf.shape(b)
    int_bs = b.get_shape().as_list()
    w = tf.extract_image_patches(
        b, [1,ksize,ksize,1], [1,stride,stride,1], [1,1,1,1], padding='SAME')
    w = tf.reshape(w, [int_fs[0], -1, ksize, ksize, int_fs[3]])
    w = tf.transpose(w, [0, 2, 3, 4, 1])  # transpose to b*k*k*c*hw

    # process mask
    # mask跟b的区别:只有一个batch,只有一个通道,图片大小一样
    if mask is None:
        mask = tf.zeros([1, bs[1], bs[2], 1])

    m = tf.extract_image_patches(
        mask, [1,ksize,ksize,1], [1,stride,stride,1], [1,1,1,1], padding='SAME')
    m = tf.reshape(m, [1, -1, ksize, ksize, 1])
    m = tf.transpose(m, [0, 2, 3, 4, 1])  # transpose to b*k*k*c*hw
    m = m[0]
    # 每个采样出来的patch对应一个mean,mm的大小跟b的像素点个数一样多
    mm = tf.cast(tf.equal(tf.reduce_mean(m, axis=[0,1,2], keep_dims=True), 0.), tf.float32)

    # 每张图片对应一组w,k*k*c*hw
    w_groups = tf.split(w, int_bs[0], axis=0)
    raw_w_groups = tf.split(raw_w, int_bs[0], axis=0)

    y = []
    offsets = []
    k = fuse_k
    scale = softmax_scale
    fuse_weight = tf.reshape(tf.eye(k), [k, k, 1, 1])
    # 这里跟一般的卷积还不一样,一般的卷积都是每个batch都是一样的系数
    # 这里每幅图片都是各自的卷积核系数
    for xi, wi, raw_wi in zip(f_groups, w_groups, raw_w_groups):
        # conv for compare
        wi = wi[0]

        # 对权重进行归一化
        wi_normed = wi / tf.maximum(tf.sqrt(tf.reduce_sum(tf.square(wi), axis=[0,1,2])), 1e-4)

        # 最关键的一步卷积
        yi = tf.nn.conv2d(xi, wi_normed, strides=[1,1,1,1], padding="SAME")

        # conv implementation for fuse scores to encourage large patches
        if fuse:
            yi = tf.reshape(yi, [1, fs[1]*fs[2], bs[1]*bs[2], 1])
            yi = tf.nn.conv2d(yi, fuse_weight, strides=[1,1,1,1], padding='SAME')
            yi = tf.reshape(yi, [1, fs[1], fs[2], bs[1], bs[2]])
            yi = tf.transpose(yi, [0, 2, 1, 4, 3])
            yi = tf.reshape(yi, [1, fs[1]*fs[2], bs[1]*bs[2], 1])
            yi = tf.nn.conv2d(yi, fuse_weight, strides=[1,1,1,1], padding='SAME')
            yi = tf.reshape(yi, [1, fs[2], fs[1], bs[2], bs[1]])
            yi = tf.transpose(yi, [0, 2, 1, 4, 3])

        # 保留f的长宽
        yi = tf.reshape(yi, [1, fs[1], fs[2], bs[1]*bs[2]])

        # lyc1 = yi.eval()[0,0,0,:]

        # 如果mask所在领域有任意一个为1,代表有缺陷的值,那么这个邻域代表的filter卷积出来的跟f一样大小的feature map
        # 会乘完之后变成0
        # softmax to match
        yi *= mm  # mask
        # 对应f的每个像素点所在邻域,所有filter一起算,看那个出来的值比较大。
        yi = tf.nn.softmax(yi*scale, 3)
        yi *= mm  # mask

        # lyc1 = yi.eval()[0,:,:,0]
        # lyc2 = yi.eval()[0,:,:,1]

        # 在b中的偏移,//和%运算符是因为offset这个轴的长度是fs[1]*fs[2]
        offset = tf.argmax(yi, axis=3, output_type=tf.int32)
        offset = tf.stack([offset // fs[2], offset % fs[2]], axis=-1)

        # deconv for patch pasting
        # 3.1 paste center
        wi_center = raw_wi[0]
        # lyc3 = wi_center.eval()[:,:,0,:]
        # lyc4 = wi_center.eval()[:,:,0,0]
        # print(repr(lyc3))
        # print(repr(lyc4))
        # 为什么除以4??
        yi = tf.nn.conv2d_transpose(yi, wi_center, tf.concat([[1], raw_fs[1:]], axis=0), strides=[1,rate,rate,1]) / 4.
        # lyc2 = yi.eval()[0,:,:,0]

        y.append(yi)
        offsets.append(offset)

    y = tf.concat(y, axis=0)
    y.set_shape(raw_int_fs)
    offsets = tf.concat(offsets, axis=0)
    offsets.set_shape(int_bs[:3] + [2])
    # case1: visualize optical flow: minus current position
    h_add = tf.tile(tf.reshape(tf.range(bs[1]), [1, bs[1], 1, 1]), [bs[0], 1, bs[2], 1])
    w_add = tf.tile(tf.reshape(tf.range(bs[2]), [1, 1, bs[2], 1]), [bs[0], bs[1], 1, 1])
    # 由绝对偏移转向相对偏移
    offsets = offsets - tf.concat([h_add, w_add], axis=3)
    # to flow image
    flow = flow_to_image_tf(offsets)
    # # case2: visualize which pixels are attended
    # flow = highlight_flow_tf(offsets * tf.cast(mask, tf.int32))
    if rate != 1:
        flow = resize(flow, scale=rate, func=tf.image.resize_nearest_neighbor)
    return y, flow
  • wgan loss
    wgan loss只有两行,是根据下面公式计算出来的


def gan_wgan_loss(pos, neg, name='gan_wgan_loss'):
    """
    wgan loss function for GANs.

    - Wasserstein GAN: https://arxiv.org/abs/1701.07875
    """
    with tf.variable_scope(name):
        d_loss = tf.reduce_mean(neg-pos)
        g_loss = -tf.reduce_mean(neg)
        scalar_summary('d_loss', d_loss)
        scalar_summary('g_loss', g_loss)
        scalar_summary('pos_value_avg', tf.reduce_mean(pos))
        scalar_summary('neg_value_avg', tf.reduce_mean(neg))
    return g_loss, d_loss

相关文章

网友评论

      本文标题:Contextual Attention论文与源码解析与结果

      本文链接:https://www.haomeiwen.com/subject/ydddkctx.html