美文网首页
Pytorch-UNet实现

Pytorch-UNet实现

作者: WAI_f | 来源:发表于2020-06-07 21:24 被阅读0次

简介

UNet结构很简单,模型代码实现也不复杂,我选择使用Pytorch来搭建模型,本着能偷懒就偷懒的原则,我在github上找了一个现成的程序,在这个程序的基础上做了一点调整。

代码分析

1. DoubleConv
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            # nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            # nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

卷积之后可以尝试添加一个BatchNorm,然后再用Relu激活函数处理。Relu函数里面的inplace参数是用来设置输入tensor是否复用,就是relu直接在input tensor上计算,设置为true可以降低显存的占用。
Conv2d参数设置还是有一点想要说明的,这里因为卷积核大小都为3x3,所以为了保证卷积之后feature map和输入的tensor shape相等,需要padding一下输入tensor,在原图的外围添加一圈像素。Pytorch和tensorflow在卷积函数参数设计上有很大的区别:如果是tensorflow,一般直接设置padding='same'就可以了;但是pytorch需要自己计算padding的大小,简化的计算公式:

N= floor((M-ksize+2*padding)/stride)+1

M是输入tensor的长或者宽,N是对应的输出tensor的长或者宽,ksize是卷积核的大小,stride就是步长。根据上面的公式和我们的参数设置,可以验证:N=M。

2. Up
class Up(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Up, self).__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x = torch.cat([x1, x2], dim=1)
        return self.conv(x)

这部分结构就是反卷积、padding、cat和DoubleConv四个小模块。
我用的是反卷积实现上采样,但是也有用插值进行上采样的实现。反卷积的参数设置也让我困惑了一段时间,搞不懂怎么设置参数,根据反卷积输入输出尺寸计算公式:

N=(M-1)*stride-2*padding+ksize+output_padding

根据上面的公式,如果想要实现两倍上采样,首先需要设置stride=2,然后就是让ksize+output_padding-2*padding-stride=0就可以了。把我们的参数代入,可以验证N=2*M。
为了能够将两个tensor沿channel方向合并,需要先把反卷积得到的tensor做一个padding,为什么要padding呢,因为在下采样过程中,tensor的长宽可能出现奇数,再次下采样的时候因为取整边缘就损失了一个像素,但是上采样没有这个问题,所以两个tensor的尺寸可能就不一样了,需要专门处理一下。

3. UNet model

最后就是根据UNet的模型实现一下代码:

class UNetModel(BaseModel):
    def __init__(self, in_channels, n_labels):
        super(UNetModel, self).__init__()
        self.conv1 = DoubleConv(in_channels, 32)
        self.conv2 = DoubleConv(32, 64)
        self.conv3 = DoubleConv(64, 128)
        self.conv4 = DoubleConv(128, 256)
        self.conv5 = DoubleConv(256, 512)

        self.up6 = Up(512, 256)
        self.up7 = Up(256, 128)
        self.up8 = Up(128, 64)
        self.up9 = Up(64, 32)

        self.conv10 = nn.Conv2d(32, n_labels, kernel_size=1)

    def forward(self, x):
        x1 = self.conv1(x)
        pool1 = F.max_pool2d(x1, 2)
        x2 = self.conv2(pool1)
        pool2 = F.max_pool2d(x2, 2)
        x3 = self.conv3(pool2)
        pool3 = F.max_pool2d(x3, 2)
        x4 = self.conv4(pool3)
        pool4 = F.max_pool2d(x4, 2)

        x = self.conv5(pool4)
        x = self.up6(x, x4)
        x = self.up7(x, x3)
        x = self.up8(x, x2)
        x = self.up9(x, x1)

        output = self.conv10(x)
        return output

本节就这么多东西,有些问题可能写的也不是很清楚,或者有错误的地方,望大家能够指教。
下一节我会专门讲一下我用UNet训练的过程,同时也会整理一下代码,放到github。

相关文章

  • Pytorch-UNet实现

    简介 UNet结构很简单,模型代码实现也不复杂,我选择使用Pytorch来搭建模型,本着能偷懒就偷懒的原则,我在g...

  • Pytorch-UNet训练

    简介 本节主要是介绍我怎么用上一节实现的UNet进行训练,一共分成3部分进行说明。需要强调的是,本节中的数据集以及...

  • Pytorch-UNet介绍

    简介 UNet网络主要用在医学图像分割任务上,网络的结构特点就是: 全卷积网络,没有全连接层,训练参数少,模型体积...

  • call.apply.bind实现

    call实现 apply实现 bind实现 new实现

  • 【call apply bind】源码实现

    call方法的实现 apply方法实现 bind方法实现 new方法实现 reduce实现

  • 实现一个简单的String类

    一、确定需求 需要实现如下功能 实现构造函数 实现析构函数 实现复制构造函数 实现赋值操作符 实现字符串连接 实现...

  • Android ImageView+Glide实现加载长图片(适

    实现效果: 布局实现: activity实现

  • splice实现应用

    实现push方法 实现pop方法 实现shift方法 实现unshift方法

  • 链表

    单链表 C实现 Java实现 双链表 C实现 Java实现

  • MD5的几种实现

    PHP的实现 Nodejs的实现 Python的实现 Golang的实现

网友评论

      本文标题:Pytorch-UNet实现

      本文链接:https://www.haomeiwen.com/subject/zaomzhtx.html