美文网首页
[PyTorch] non-local的patch版本实现

[PyTorch] non-local的patch版本实现

作者: csdongxian | 来源:发表于2018-08-19 22:19 被阅读0次

    MM 18有篇文章《Non-locally Enhanced Encoder-Decoder Network for Single Image De-raining》,这篇文章把两大吃显存利器——non-local 和 densely connection一起给用了(绝望脸),如果要控制显存使用的话,就对实现要求比较高了。里面提到了一个控制 non-local 运算的方法,将feature map划分成patch,在patch内进行non-local操作,而不是原来的全局non-local。

    思路

    这里实现一下这个patch版本的non-local。一般的non-local实现,请参考:Github 传送门
    思路很简单,就是把patch的索引作为一个特殊的batch索引,原来的non-local运算会逐batch中的样本进行,现在就是逐batch中的样本、逐每个样本中的patch进行了。

    图1. 新的索引

    上图中,索引的第1个位置表示batch,第2、3个位置表示patch,将1~3个位置的索引看成整体,作为这个特殊的“batch”的索引。此时,patch版本的 non-local 就和一般的 non-local 没有太大区别了。

    实现

    那么,如何实现这样的一个新索引呢?
    假如有输入图像 (B, C, H, W),首先需要将最后两个表示位置的索引分解成四个索引,两个表示块的位置,两个表示块中元素的位置,例如要将行分解成m块、列分解成n块,就得到(B, C, m, H/m, n, W/n),使用 view 方法就能实现。如果对这个实现有疑惑,可以参考附录中的例子。
    然后进行转置(或者说是交换索引的位置),得到(B, m, n, C, H/m, W/n)。这里使用前三个索引,表示具体某个patch(batch中某个feature map的某个patch)。最后,使用这个新的Tensor来进行non-local的操作即可,方法类似,仅仅是在前面多了两个索引。

    # implementation in PyTorch
    # x=>(b, c, m, h/m, n, w/n)
    # e.g. nb_patches = [2, 2]
    b, c, h, w = x.size()
    x = x.view(b, c,
               nb_patches[0], h / nb_patches[0],
               nb_patches[1], w / nb_patches[1])
    # x=>(b, m, n, h/m, w/n, c)->(b, m, n, h/m*w/n, c)
    x = x.permute(0, 2, 4, 3, 5, 1).contiguous()
    x = x.view(b, nb_patches[0], nb_patches[1], -1, c)
    

    附录

    图2. view 用法举例

    如果不确定view的使用,可以举个简单的例子,如上图的4 \times 4的一个方块 X,显然大小就是(1, 1, 4, 4)。上图里中间一列,左边表示元素,右边表示原来的元素索引。如果使用 view 就是在新的 Tensor 中按顺序排列旧的 Tensor 中的元素。按如下代码重新排列元素。图中右边一列,就是使用了 view 后的新索引。

    X = X.view(1, 1, 2, 4/2, 2, 4/2)
    

    所以,对于左上角的patch,其索引是(0,0),右上角则是(0, 1),也就是使用view得到新的 Tensor的第3、5个位置的索引。

    待填坑: Tensor中的contiguous方法

    相关文章

      网友评论

          本文标题:[PyTorch] non-local的patch版本实现

          本文链接:https://www.haomeiwen.com/subject/ouehiftx.html