美文网首页pytorch代码解读
Non-local_pytorch代码解读

Non-local_pytorch代码解读

作者: 风之羁绊 | 来源:发表于2018-12-07 23:58 被阅读0次

    参考https://zhuanlan.zhihu.com/p/33345791以及https://github.com/AlexHex7/Non-local_pytorch

    图片.png
    代码总的模型框架
    from torch import nn
    # from lib.non_local_concatenation import NONLocalBlock2D
    # from lib.non_local_gaussian import NONLocalBlock2D
    from lib.non_local_embedded_gaussian import NONLocalBlock2D
    # from lib.non_local_dot_product import NONLocalBlock2D
    
    
    class Network(nn.Module):
       def __init__(self):
           super(Network, self).__init__()
    
           self.convs = nn.Sequential(
               nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1),
               nn.BatchNorm2d(32),
               nn.ReLU(),
               nn.MaxPool2d(2),
    
               NONLocalBlock2D(in_channels=32),
               nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
               nn.BatchNorm2d(64),
               nn.ReLU(),
               nn.MaxPool2d(2),
    
               NONLocalBlock2D(in_channels=64),
               nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
               nn.BatchNorm2d(128),
               nn.ReLU(),
               nn.MaxPool2d(2),
           )
    
           self.fc = nn.Sequential(
               nn.Linear(in_features=128*3*3, out_features=256),
               nn.ReLU(),
               nn.Dropout(0.5),
    
               nn.Linear(in_features=256, out_features=10)
           )
    
       def forward(self, x):
           batch_size = x.size(0)
           output = self.convs(x).view(batch_size, -1)
           output = self.fc(output)
           return output
    
    if __name__ == '__main__':
       import torch
    
       img = torch.randn(3, 1, 28, 28)
       net = Network()
       out = net(img)
       print(out.size())
    

    model主框架是对于minst数据集的分类,只不过中间加入了Non-local模块,剩下的进入Non-local模块代码学习
    先从上面图里的框架看起吧
    non_local_embedded_gaussian

    def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
            super(_NonLocalBlockND, self).__init__()
    
            assert dimension in [1, 2, 3]
    
            self.dimension = dimension
            self.sub_sample = sub_sample
    
            self.in_channels = in_channels
            self.inter_channels = inter_channels
    
            if self.inter_channels is None:
                self.inter_channels = in_channels // 2
                if self.inter_channels == 0:
                    self.inter_channels = 1
    
            if dimension == 3:
                conv_nd = nn.Conv3d
                max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
                bn = nn.BatchNorm3d
            elif dimension == 2:
                conv_nd = nn.Conv2d
                max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
                bn = nn.BatchNorm2d
            else:
                conv_nd = nn.Conv1d
                max_pool_layer = nn.MaxPool1d(kernel_size=(2))
                bn = nn.BatchNorm1d
    
            self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                             kernel_size=1, stride=1, padding=0)
    
            if bn_layer:
                self.W = nn.Sequential(
                    conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
                            kernel_size=1, stride=1, padding=0),
                    bn(self.in_channels)
                )
                nn.init.constant_(self.W[1].weight, 0)
                nn.init.constant_(self.W[1].bias, 0)
            else:
                self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
                                 kernel_size=1, stride=1, padding=0)
                nn.init.constant_(self.W.weight, 0)
                nn.init.constant_(self.W.bias, 0)
    
            self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                                 kernel_size=1, stride=1, padding=0)
            self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                               kernel_size=1, stride=1, padding=0)
    
            if sub_sample:
                self.g = nn.Sequential(self.g, max_pool_layer)
                self.phi = nn.Sequential(self.phi, max_pool_layer)
    

    init函数主要做了往常做的以及对bn的初始化,以及定义了theta和phi两个1*1卷积,sub_sample根据参数是否加入max_pooling
    然后下面是主要的代码

    def forward(self, x):
            '''
            :param x: (b, c, t, h, w)
            :return:
            '''
    
            batch_size = x.size(0)  
    
            g_x = self.g(x).view(batch_size, self.inter_channels, -1)
            g_x = g_x.permute(0, 2, 1)
    
            theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
            theta_x = theta_x.permute(0, 2, 1)
            phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
            f = torch.matmul(theta_x, phi_x)
            f_div_C = F.softmax(f, dim=-1)
    
            y = torch.matmul(f_div_C, g_x)
            y = y.permute(0, 2, 1).contiguous()
            y = y.view(batch_size, self.inter_channels, *x.size()[2:])
            W_y = self.W(y)
            z = W_y + x
           return z
    

    1.把hwt(3维),hw(2维) 放到一起,归为一个维度W
    2.g_x=BWC,theta_x=BWC,phi_x=BCW,f=BWW,f_div_C=BWW
    3.y=BWC-> BCW-> BChw
    4.W_y 也是1
    1卷积 ,bn可选择加,最后一个残差连接
    从代码的角度来看,就是先用三次1*1卷积,然后其中两次进行相乘,然后softmax类似进行映射操作,形成一个WW大小的权重,然后用第三个再相乘,有点类似attention的操作,也就是加了层系数加权,其实这个操作也有点类似全连接,计算参数同样很大,多了一层相似性。
    然后再来看non_local_concatenation的写法,主要罗列区别性

      self.concat_project = nn.Sequential(
                nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False),
                nn.ReLU()
     def forward(self, x):
            '''
            :param x: (b, c, t, h, w)
            :return:
            '''
    
            batch_size = x.size(0)
    
            g_x = self.g(x).view(batch_size, self.inter_channels, -1)
            g_x = g_x.permute(0, 2, 1)
    
            # (b, c, N, 1)
            theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1)
            # (b, c, 1, N)
            phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1)
    
            h = theta_x.size(2)
            w = phi_x.size(3)
            theta_x = theta_x.repeat(1, 1, 1, w)
            phi_x = phi_x.repeat(1, 1, h, 1)
    
            concat_feature = torch.cat([theta_x, phi_x], dim=1)
            f = self.concat_project(concat_feature)
            b, _, h, w = f.size()
            f = f.view(b, h, w)
    
            N = f.size(-1)
            f_div_C = f / N
    
            y = torch.matmul(f_div_C, g_x)
            y = y.permute(0, 2, 1).contiguous()
            y = y.view(batch_size, self.inter_channels, *x.size()[2:])
            W_y = self.W(y)
            z = W_y + x
            return z
    )
    

    主要不同点在于形成权重的方式上,这里采用维度上连接的方式进行的,操作,由于维度不同,所以先要弄得相同,然后再连接,连接后1*1卷积进行降维,f_div_C = f / N,感觉没什么用。。。可能数值太大,压缩一下吧,毕竟W很大。。。
    然后来看non_local_dot_product

     def forward(self, x):
            '''
            :param x: (b, c, t, h, w)
            :return:
            '''
    
            batch_size = x.size(0)
    
            g_x = self.g(x).view(batch_size, self.inter_channels, -1)
            g_x = g_x.permute(0, 2, 1)
    
            theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
            theta_x = theta_x.permute(0, 2, 1)
            phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
            f = torch.matmul(theta_x, phi_x)
            N = f.size(-1)
            f_div_C = f / N
    
            y = torch.matmul(f_div_C, g_x)
            y = y.permute(0, 2, 1).contiguous()
            y = y.view(batch_size, self.inter_channels, *x.size()[2:])
            W_y = self.W(y)
            z = W_y + x
    
            return z
    

    相比embedded_gaussian,直接采用了点乘,形成W*W的权重,没有使用softmax
    最后看下non_local_gaussian

     self.phi = max_pool_layer
     def forward(self, x):
            '''
            :param x: (b, c, t, h, w)
            :return:
            '''
    
            batch_size = x.size(0)
    
            g_x = self.g(x).view(batch_size, self.inter_channels, -1)
    
            g_x = g_x.permute(0, 2, 1)
    
            theta_x = x.view(batch_size, self.in_channels, -1)
            theta_x = theta_x.permute(0, 2, 1)
    
            if self.sub_sample:
                phi_x = self.phi(x).view(batch_size, self.in_channels, -1)
            else:
                phi_x = x.view(batch_size, self.in_channels, -1)
    
            f = torch.matmul(theta_x, phi_x)
            f_div_C = F.softmax(f, dim=-1)
    
            y = torch.matmul(f_div_C, g_x)
            y = y.permute(0, 2, 1).contiguous()
            y = y.view(batch_size, self.inter_channels, *x.size()[2:])
            W_y = self.W(y)
            z = W_y + x
    
    return z
    

    相比embedded_gaussian,把1个1*1卷积换成了max_pooling层,但这两种区别,我还不是很清楚,等看了论文再来补充。

    相关文章

      网友评论

        本文标题:Non-local_pytorch代码解读

        本文链接:https://www.haomeiwen.com/subject/sskxhqtx.html