美文网首页
一文搞定U2Net——图像分割(语义分割)

一文搞定U2Net——图像分割(语义分割)

作者: 小黄不头秃 | 来源:发表于2023-06-14 02:30 被阅读0次

上文中,我们介绍了UNet,今天我们来了解一下U2Net。这个网络是 UNet的加强版。其结构如下图所示:

与UNet相比,U2Net中的每一个小立方体里面都是一个UNet。但是需要注意的是:

  1. U型结构的最下面一层白色的块为空洞卷积。
  2. 在最后一层做通道融合的时候,是用类似ResNet的思想,将特征变量进行相加。其他时候仍然用的torch.cat()操作。
  3. 与UNet不同,特征矩阵的通道数没有增加。

所以首先我们构建小立方体中的结构,我们称之为Unet_Blockx。和UNet中的一样其包括了卷积层块、下采样、上采样结构。其具体代码如下:

  • 卷积块
"""
卷积块
"""
class Conv_Block(nn.Module):
    def __init__(self, in_c, out_c, dilation=1) -> None:
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_c, out_c,3,1,padding=1*dilation,dilation=1*dilation),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.layers(x) 
  • 下采样
"""
下采样块
"""
class Down_Sample_Block(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layers = nn.Sequential(
            nn.MaxPool2d(2,2)
        ) 

    def forward(self, x):
        return self.layers(x)
  • 上采样
"""
上采样块
"""
class Up_Sample_Block(nn.Module):
    def __init__(self, scale_factor=2) -> None:
        super().__init__()
        # 上采样方法1:
        self.upsample1 = nn.UpsamplingBilinear2d(scale_factor=scale_factor) 
        # 上采样方法2:
        self.upsample2 = nn.Upsample(scale_factor=scale_factor, mode='bilinear')

    def forward(self,x,feature):
        # 方法3:x = torch.nn.functional.interpolate(input=x,scale_factor=2, mode="nearest")
        x = self.upsample1(x)
        # 下面两行代码是将
        # resize = Resize((x.shape[2], x.shape[3]))
        # feature = resize(feature)
        res = torch.cat((x,feature),dim=1)
        return res

一、UNet_Block_x 代码实现

所以Unet_Block1的代码如下:

class UNet_Block1(nn.Module):
    def __init__(self, in_c, mid_c, out_c) -> None:
        super().__init__()
        self.down = Down_Sample_Block()
        self.up = Up_Sample_Block()

        self.conv1 = Conv_Block(in_c,out_c)
        self.conv2 = Conv_Block(out_c,mid_c)
        self.conv3 = Conv_Block(mid_c,mid_c)
        self.conv4 = Conv_Block(mid_c,mid_c,dilation=2)
        self.conv5 = Conv_Block(mid_c*2,mid_c)
        self.conv6 = Conv_Block(mid_c*2,out_c)

    def forward(self,x):
        # 下采样过程
        out1 = self.conv1(x) # out_c [1, 5, 224, 224]
        out2 = self.conv2(out1) # mid_c [1, 3, 224, 224]
        out3 = self.conv3(self.down(out2)) # mid_c [1, 3, 112, 112]
        out4 = self.conv3(self.down(out3)) # mid_c [1, 3, 56, 56]
        out5 = self.conv3(self.down(out4)) # mid_c [1, 3, 28, 28]
        out6 = self.conv3(self.down(out5)) # mid_c [1, 3, 14, 14]
        out7 = self.conv3(self.down(out6)) # mid_c [1, 3, 7, 7]

        out8 = self.conv4(out7) # mid_c [1, 3, 7, 7]
        out9 = self.conv5(torch.cat((out7,out8),dim=1)) # mid_c [1, 3, 7, 7]

        # 上采样
        out10 = self.conv5(self.up(out9,out6)) # [1, 3, 14, 14]
        out11 = self.conv5(self.up(out10,out5)) # [1, 3, 28, 28]
        out12 = self.conv5(self.up(out11,out4)) # [1, 3, 56, 56]
        out13 = self.conv5(self.up(out12,out3)) # [1, 3, 112, 112]
        out14 = self.conv6(self.up(out13,out2)) # [1, 3, 224, 224]
        out = out14 + out1 # [1, 5, 224, 224]
        return out

Unet_Block2的代码:

class UNet_Block2(nn.Module):
    def __init__(self, in_c, mid_c, out_c) -> None:
        super().__init__()
        self.down = Down_Sample_Block()
        self.up = Up_Sample_Block()

        self.conv1 = Conv_Block(in_c,out_c)
        self.conv2 = Conv_Block(out_c,mid_c)
        self.conv3 = Conv_Block(mid_c,mid_c)
        self.conv4 = Conv_Block(mid_c,mid_c,dilation=2)
        self.conv5 = Conv_Block(mid_c*2,mid_c)
        self.conv6 = Conv_Block(mid_c*2,out_c)

    def forward(self,x):
        # 下采样过程
        out1 = self.conv1(x) # out_c [1, 5, 112, 112]
        out2 = self.conv2(out1) # mid_c [1, 3, 112, 112]
        out3 = self.conv3(self.down(out2)) # mid_c [1, 3, 56, 56]
        out4 = self.conv3(self.down(out3)) # mid_c [1, 3, 28, 28]
        out5 = self.conv3(self.down(out4)) # mid_c [1, 3, 14, 14]
        out6 = self.conv3(self.down(out5)) # mid_c [1, 3, 7, 7]

        out8 = self.conv4(out6) # mid_c [1, 3, 7, 7]
        out9 = self.conv5(torch.cat((out6,out8),dim=1)) # mid_c [1, 3, 7, 7]

        # 上采样
        out10 = self.conv5(self.up(out9,out5)) # [1, 3, 14, 14]
        out11 = self.conv5(self.up(out10,out4)) # [1, 3, 28, 28]
        out12 = self.conv5(self.up(out11,out3)) # [1, 3, 56, 56]
        out13 = self.conv6(self.up(out12,out2)) # [1, 3, 112, 112]
        out = out13 + out1 # [1, 5, 112, 112]
        return out

Unet_Block3的代码:

class UNet_Block3(nn.Module):
    def __init__(self, in_c, mid_c, out_c) -> None:
        super().__init__()
        self.down = Down_Sample_Block()
        self.up = Up_Sample_Block()

        self.conv1 = Conv_Block(in_c,out_c)
        self.conv2 = Conv_Block(out_c,mid_c)
        self.conv3 = Conv_Block(mid_c,mid_c)
        self.conv4 = Conv_Block(mid_c,mid_c,dilation=2)
        self.conv5 = Conv_Block(mid_c*2,mid_c)
        self.conv6 = Conv_Block(mid_c*2,out_c)

    def forward(self,x):
        # 下采样过程
        out1 = self.conv1(x) # out_c [1, 5, 56, 56]
        out2 = self.conv2(out1) # mid_c [1, 3, 56, 56]
        out3 = self.conv3(self.down(out2)) # mid_c [1, 3, 28, 28]
        out4 = self.conv3(self.down(out3)) # mid_c [1, 3, 14, 14]
        out5 = self.conv3(self.down(out4)) # mid_c [1, 3, 7, 7]

        out8 = self.conv4(out5) # mid_c [1, 3, 7, 7]
        out9 = self.conv5(torch.cat((out5,out8),dim=1)) # mid_c [1, 3, 7, 7]

        # 上采样
        out10 = self.conv5(self.up(out9,out4)) # [1, 3, 14, 14]
        out11 = self.conv5(self.up(out10,out3)) # [1, 3, 28, 28]
        out12 = self.conv6(self.up(out11,out2)) # [1, 3, 56, 56]
        out = out12 + out1 # [1, 5, 56, 56]
        return out

Unet_Block4的代码:

class UNet_Block4(nn.Module):
    def __init__(self, in_c, mid_c, out_c) -> None:
        super().__init__()
        self.down = Down_Sample_Block()
        self.up = Up_Sample_Block()

        self.conv1 = Conv_Block(in_c,out_c)
        self.conv2 = Conv_Block(out_c,mid_c)
        self.conv3 = Conv_Block(mid_c,mid_c)
        self.conv4 = Conv_Block(mid_c,mid_c,dilation=2)
        self.conv5 = Conv_Block(mid_c*2,mid_c)
        self.conv6 = Conv_Block(mid_c*2,out_c)

    def forward(self,x):
        # 下采样过程
        out1 = self.conv1(x) # out_c [1, 5, 28, 28]
        out2 = self.conv2(out1) # mid_c [1, 3, 28, 28]
        out3 = self.conv3(self.down(out2)) # mid_c [1, 3, 14, 14]
        out4 = self.conv3(self.down(out3)) # mid_c [1, 3, 7, 7]

        out8 = self.conv4(out4) # mid_c [1, 3, 7, 7]
        out9 = self.conv5(torch.cat((out4,out8),dim=1)) # mid_c [1, 3, 7, 7]

        # 上采样
        out10 = self.conv5(self.up(out9,out3)) # [1, 3, 14, 14]
        out11 = self.conv6(self.up(out10,out2)) # [1, 3, 28, 28]
        out = out11 + out1 # [1, 5, 28, 28]
        return out

Unet_Block5的代码:

class UNet_Block5(nn.Module):
    def __init__(self, in_c, mid_c, out_c) -> None:
        super().__init__()

        self.conv1 = Conv_Block(in_c,out_c)
        self.conv2 = Conv_Block(out_c,mid_c)
        self.conv3 = Conv_Block(mid_c,mid_c,dilation=2)
        self.conv4 = Conv_Block(mid_c,mid_c,dilation=4)
        self.conv5 = Conv_Block(mid_c,mid_c,dilation=8)
        self.conv6 = Conv_Block(mid_c*2,mid_c,dilation=4)
        self.conv7 = Conv_Block(mid_c*2,mid_c,dilation=2)
        self.conv8 = Conv_Block(mid_c*2,out_c)

    def forward(self,x):
        out1 = self.conv1(x) # out_c [1, 5, 14, 14]
        out2 = self.conv2(out1) 
        out3 = self.conv3(out2) 
        out4 = self.conv4(out3) 

        out5 = self.conv5(out4) 

        out6 = self.conv6(torch.cat((out4,out5),dim=1)) 
        out7 = self.conv7(torch.cat((out3,out6),dim=1)) 
        out8 = self.conv8(torch.cat((out2,out7),dim=1)) 
        out = out8 + out1
        return out

二、U2Net模型代码实现

model2.py

import torch
import torch.nn as nn 

"""
卷积块
"""
class Conv_Block(nn.Module):
    def __init__(self, in_c, out_c, dilation=1) -> None:
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_c, out_c,3,1,padding=1*dilation,dilation=1*dilation),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.layers(x) 
    
"""
下采样块
"""
class Down_Sample_Block(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layers = nn.Sequential(
            nn.MaxPool2d(2,2)
        ) 

    def forward(self, x):
        return self.layers(x)
"""
上采样块
"""
class Up_Sample_Block(nn.Module):
    def __init__(self, scale_factor=2) -> None:
        super().__init__()
        # 上采样方法1:
        self.upsample1 = nn.UpsamplingBilinear2d(scale_factor=scale_factor) 
        # 上采样方法2:
        self.upsample2 = nn.Upsample(scale_factor=scale_factor, mode='bilinear')

    def forward(self,x,feature):
        # 方法3:x = torch.nn.functional.interpolate(input=x,scale_factor=2, mode="nearest")
        x = self.upsample1(x)
        # 下面两行代码是将
        # resize = Resize((x.shape[2], x.shape[3]))
        # feature = resize(feature)
        res = torch.cat((x,feature),dim=1)
        return res
    
"""
输出模块:
"""
class Output(nn.Module):
    def __init__(self,in_c, out_c) -> None:
        super().__init__()
        self.layers = self.layers = nn.Sequential(
            nn.Conv2d(in_c, out_c, 3, 1, 1,bias=False),
            nn.BatchNorm2d(out_c),
            nn.ReLU(),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.layers(x)

class UNet_Block1(nn.Module):
    def __init__(self, in_c, mid_c, out_c) -> None:
        super().__init__()
        self.down = Down_Sample_Block()
        self.up = Up_Sample_Block()

        self.conv1 = Conv_Block(in_c,out_c)
        self.conv2 = Conv_Block(out_c,mid_c)
        self.conv3 = Conv_Block(mid_c,mid_c)
        self.conv4 = Conv_Block(mid_c,mid_c,dilation=2)
        self.conv5 = Conv_Block(mid_c*2,mid_c)
        self.conv6 = Conv_Block(mid_c*2,out_c)

    def forward(self,x):
        # 下采样过程
        out1 = self.conv1(x) # out_c [1, 5, 224, 224]
        out2 = self.conv2(out1) # mid_c [1, 3, 224, 224]
        out3 = self.conv3(self.down(out2)) # mid_c [1, 3, 112, 112]
        out4 = self.conv3(self.down(out3)) # mid_c [1, 3, 56, 56]
        out5 = self.conv3(self.down(out4)) # mid_c [1, 3, 28, 28]
        out6 = self.conv3(self.down(out5)) # mid_c [1, 3, 14, 14]
        out7 = self.conv3(self.down(out6)) # mid_c [1, 3, 7, 7]

        out8 = self.conv4(out7) # mid_c [1, 3, 7, 7]
        out9 = self.conv5(torch.cat((out7,out8),dim=1)) # mid_c [1, 3, 7, 7]

        # 上采样
        out10 = self.conv5(self.up(out9,out6)) # [1, 3, 14, 14]
        out11 = self.conv5(self.up(out10,out5)) # [1, 3, 28, 28]
        out12 = self.conv5(self.up(out11,out4)) # [1, 3, 56, 56]
        out13 = self.conv5(self.up(out12,out3)) # [1, 3, 112, 112]
        out14 = self.conv6(self.up(out13,out2)) # [1, 3, 224, 224]
        out = out14 + out1 # [1, 5, 224, 224]
        return out
    
class UNet_Block2(nn.Module):
    def __init__(self, in_c, mid_c, out_c) -> None:
        super().__init__()
        self.down = Down_Sample_Block()
        self.up = Up_Sample_Block()

        self.conv1 = Conv_Block(in_c,out_c)
        self.conv2 = Conv_Block(out_c,mid_c)
        self.conv3 = Conv_Block(mid_c,mid_c)
        self.conv4 = Conv_Block(mid_c,mid_c,dilation=2)
        self.conv5 = Conv_Block(mid_c*2,mid_c)
        self.conv6 = Conv_Block(mid_c*2,out_c)

    def forward(self,x):
        # 下采样过程
        out1 = self.conv1(x) # out_c [1, 5, 112, 112]
        out2 = self.conv2(out1) # mid_c [1, 3, 112, 112]
        out3 = self.conv3(self.down(out2)) # mid_c [1, 3, 56, 56]
        out4 = self.conv3(self.down(out3)) # mid_c [1, 3, 28, 28]
        out5 = self.conv3(self.down(out4)) # mid_c [1, 3, 14, 14]
        out6 = self.conv3(self.down(out5)) # mid_c [1, 3, 7, 7]

        out8 = self.conv4(out6) # mid_c [1, 3, 7, 7]
        out9 = self.conv5(torch.cat((out6,out8),dim=1)) # mid_c [1, 3, 7, 7]

        # 上采样
        out10 = self.conv5(self.up(out9,out5)) # [1, 3, 14, 14]
        out11 = self.conv5(self.up(out10,out4)) # [1, 3, 28, 28]
        out12 = self.conv5(self.up(out11,out3)) # [1, 3, 56, 56]
        out13 = self.conv6(self.up(out12,out2)) # [1, 3, 112, 112]
        out = out13 + out1 # [1, 5, 112, 112]
        return out
    
class UNet_Block3(nn.Module):
    def __init__(self, in_c, mid_c, out_c) -> None:
        super().__init__()
        self.down = Down_Sample_Block()
        self.up = Up_Sample_Block()

        self.conv1 = Conv_Block(in_c,out_c)
        self.conv2 = Conv_Block(out_c,mid_c)
        self.conv3 = Conv_Block(mid_c,mid_c)
        self.conv4 = Conv_Block(mid_c,mid_c,dilation=2)
        self.conv5 = Conv_Block(mid_c*2,mid_c)
        self.conv6 = Conv_Block(mid_c*2,out_c)

    def forward(self,x):
        # 下采样过程
        out1 = self.conv1(x) # out_c [1, 5, 56, 56]
        out2 = self.conv2(out1) # mid_c [1, 3, 56, 56]
        out3 = self.conv3(self.down(out2)) # mid_c [1, 3, 28, 28]
        out4 = self.conv3(self.down(out3)) # mid_c [1, 3, 14, 14]
        out5 = self.conv3(self.down(out4)) # mid_c [1, 3, 7, 7]

        out8 = self.conv4(out5) # mid_c [1, 3, 7, 7]
        out9 = self.conv5(torch.cat((out5,out8),dim=1)) # mid_c [1, 3, 7, 7]

        # 上采样
        out10 = self.conv5(self.up(out9,out4)) # [1, 3, 14, 14]
        out11 = self.conv5(self.up(out10,out3)) # [1, 3, 28, 28]
        out12 = self.conv6(self.up(out11,out2)) # [1, 3, 56, 56]
        out = out12 + out1 # [1, 5, 56, 56]
        return out
    
class UNet_Block4(nn.Module):
    def __init__(self, in_c, mid_c, out_c) -> None:
        super().__init__()
        self.down = Down_Sample_Block()
        self.up = Up_Sample_Block()

        self.conv1 = Conv_Block(in_c,out_c)
        self.conv2 = Conv_Block(out_c,mid_c)
        self.conv3 = Conv_Block(mid_c,mid_c)
        self.conv4 = Conv_Block(mid_c,mid_c,dilation=2)
        self.conv5 = Conv_Block(mid_c*2,mid_c)
        self.conv6 = Conv_Block(mid_c*2,out_c)

    def forward(self,x):
        # 下采样过程
        out1 = self.conv1(x) # out_c [1, 5, 28, 28]
        out2 = self.conv2(out1) # mid_c [1, 3, 28, 28]
        out3 = self.conv3(self.down(out2)) # mid_c [1, 3, 14, 14]
        out4 = self.conv3(self.down(out3)) # mid_c [1, 3, 7, 7]

        out8 = self.conv4(out4) # mid_c [1, 3, 7, 7]
        out9 = self.conv5(torch.cat((out4,out8),dim=1)) # mid_c [1, 3, 7, 7]

        # 上采样
        out10 = self.conv5(self.up(out9,out3)) # [1, 3, 14, 14]
        out11 = self.conv6(self.up(out10,out2)) # [1, 3, 28, 28]
        out = out11 + out1 # [1, 5, 28, 28]
        return out
    
class UNet_Block5(nn.Module):
    def __init__(self, in_c, mid_c, out_c) -> None:
        super().__init__()

        self.conv1 = Conv_Block(in_c,out_c)
        self.conv2 = Conv_Block(out_c,mid_c)
        self.conv3 = Conv_Block(mid_c,mid_c,dilation=2)
        self.conv4 = Conv_Block(mid_c,mid_c,dilation=4)
        self.conv5 = Conv_Block(mid_c,mid_c,dilation=8)
        self.conv6 = Conv_Block(mid_c*2,mid_c,dilation=4)
        self.conv7 = Conv_Block(mid_c*2,mid_c,dilation=2)
        self.conv8 = Conv_Block(mid_c*2,out_c)

    def forward(self,x):
        out1 = self.conv1(x) # out_c [1, 5, 14, 14]
        out2 = self.conv2(out1) 
        out3 = self.conv3(out2) 
        out4 = self.conv4(out3) 

        out5 = self.conv5(out4) 

        out6 = self.conv6(torch.cat((out4,out5),dim=1)) 
        out7 = self.conv7(torch.cat((out3,out6),dim=1)) 
        out8 = self.conv8(torch.cat((out2,out7),dim=1)) 
        out = out8 + out1
        return out
    
class U2NET(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.down = Down_Sample_Block()
        self.up1 = nn.UpsamplingNearest2d(scale_factor=2)
        self.up2 = nn.UpsamplingNearest2d(scale_factor=4)
        self.up3 = nn.UpsamplingNearest2d(scale_factor=8)
        self.up4 = nn.UpsamplingNearest2d(scale_factor=16)
        self.up5 = nn.UpsamplingNearest2d(scale_factor=32)

        self.unet1 = UNet_Block1(1,32,64) 
        self.unet2 = UNet_Block2(64,32,128) 
        self.unet3 = UNet_Block3(128,64,256) 
        self.unet4 = UNet_Block4(256,128,512) 
        self.unet5 = UNet_Block5(512,256,512) 
        
        self.unet6 = UNet_Block5(512,256,512) 

        self.de_unet1 = UNet_Block1(128,16,64) 
        self.de_unet2 = UNet_Block2(256,32,64) 
        self.de_unet3 = UNet_Block3(512,64,128) 
        self.de_unet4 = UNet_Block4(1024,128,256) 
        self.de_unet5 = UNet_Block5(1024,256,512) 

        self.out1 = Output(64,1)
        self.out2 = Output(64,1)
        self.out3 = Output(128,1)
        self.out4 = Output(256,1)
        self.out5 = Output(512,1)

    def forward(self, x):
        # 下采样,编码
        conv1 = self.unet1(x)
        en1 = self.down(conv1)
        conv2 = self.unet2(en1)
        en2 = self.down(conv2)
        conv3 = self.unet3(en2)
        en3 = self.down(conv3)
        conv4 = self.unet4(en3)
        en4 = self.down(conv4)
        conv5 = self.unet5(en4)
        en5 = self.down(conv5)

        conv6 = self.unet6(en5)

        # 上采样,解码
        de1 = self.up1(conv6) # [1, 512, 14, 14]
        conv7 = self.de_unet5(torch.cat((conv5,de1),dim=1)) # [1, 512, 14, 14]
        de2 = self.up1(conv7) # [1, 512, 28, 28]
        conv8 = self.de_unet4(torch.cat((conv4,de2),dim=1)) # [1, 256, 28, 28]
        de3 = self.up1(conv8) # [1, 256, 56, 56]
        conv9 = self.de_unet3(torch.cat((conv3,de3),dim=1)) # [1, 128, 56, 56]
        de4 = self.up1(conv9) # [1, 128, 112, 112]
        conv10 = self.de_unet2(torch.cat((conv2,de4),dim=1)) # [1, 64, 112, 112]
        de5 = self.up1(conv10) # [1, 64, 224, 224]

        # 输出
        out1 = self.up5(self.out5(conv6)) # [1, 1, 224, 224]
        out2 = self.up4(self.out5(conv7)) # [1, 1, 224, 224]
        out3 = self.up3(self.out4(conv8)) # [1, 1, 224, 224]
        out4 = self.up2(self.out3(conv9)) # [1, 1, 224, 224]
        out5 = self.up1(self.out2(conv10)) # [1, 1, 224, 224]
        out6 = self.out1(de5) # [1, 1, 224, 224]

        out = (out1 + out2 + out3 + out4 + out5 + out6)/6

        return out
    
if __name__ == "__main__":
    x = torch.randn((1,1,224,224))
    conv = U2NET()
    y = conv(x)
    print(y.shape)


三、数据集和实验结果

(1)数据集:
(2)训练代码:
  • train2.py
import PIL.Image as Image
import numpy as np 
import torch 
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import os 
from model2 import U2NET
from torchvision.utils import save_image

IMG_PATH = "./data/FundusVessels/JPEGImages/"
TARGET_PATH = "./data/FundusVessels/Annotations/"
DST_DIR = "./img2"

class EYE_Dataset(Dataset):
    def __init__(self) -> None:
        super().__init__()
        self.img_list =  os.listdir(IMG_PATH)
        self.target_list = os.listdir(TARGET_PATH)

    def __len__(self) -> int:
        return len(self.img_list)
    
    def __getitem__(self, index):
        name = self.img_list[index][0:-4]
        img = Image.open(IMG_PATH+f"{name}.jpg").convert("L").resize((224,224))
        lable = Image.open(TARGET_PATH+f"{name}.png").resize((224,224))
        img = np.array(img, dtype=np.float32)/255
        img = img[np.newaxis,:]
        lable = np.array(lable, dtype=np.float32)
        return img, lable
    


if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    # device = "cpu"
    train_dataset = EYE_Dataset()
    train_loader = DataLoader(train_dataset,batch_size=1,shuffle=True)

    net = U2NET()
    if os.path.exists("./U2Net.pt"):
        params = torch.load("./U2Net.pt")
        net.load_state_dict(params)

    net.to(device)
    optim = torch.optim.Adam(net.parameters())
    loss_fn = nn.BCELoss()
    epoch=1
    net.train()
    while True:
        for i,(img,target) in enumerate(train_loader):
            img, target = img.to(device), target.to(device)
            y = net(img)
            loss = loss_fn(y, target.unsqueeze(dim=0))
            optim.zero_grad()
            loss.backward()
            optim.step()

            if i%1 == 0:
                # 保存测试结果
                img2 =  (y[0]>0.6).float() *255
                res = torch.stack([img[0],img2],dim=0)
                save_image(res.cpu(), DST_DIR + f"/epoch{epoch}_{i}.jpg", nrow=2)
                print(f"epoch {epoch},loss: {loss.item()}")
                torch.save(net.state_dict(),"./U2Net.pt")
            epoch += 1
(3)训练结果

大约训练了1500轮,多训练一会儿,效果还可以更好。

相关文章

  • Unet图像分割

    Unet网络是一种图像语义分割网络,图像语义分割网络让计算机根据图像的语义来进行分割,例如让计算机在输入下面下图,...

  • 图像分割常用指标及MIoU计算

    一、语义图像分割结果定义 语义图像分割结果,可分为True Positive、False Positive、Tru...

  • 【转载】语义分割中的深度学习方法全解:从FCN、SegNet到各

    http://www.sohu.com/a/155907339_610300 图像语义分割就是机器自动从图像中分割...

  • 计算机视觉-语义分割

    什么是语义分割 我们都知道,图像是由许多像素(Pixel)组成,而「语义分割」顾名思义就是将像素按照图像中表达语义...

  • 我所了解的图像分割

    图像分割是我大二2019年做的东西,这篇文章用来总结。 一、什么是图像分割 分语义【像素级别图像】,实例【分割物体...

  • 语义图像分割概览

    摘要:本文讨论如何利用卷积神经网络进行语义图像分割的任务。 语义图像分割,目标是将图像的每个像素标记为所表示的相关...

  • 图像分割

    一、语义分割  图像语义分割的目标是标记图像每个像素的类别。因为我们需要预测图像中的每个像素,所以此任务通常被称为...

  • 自动驾驶入门日记-4-图像语义分割

    接着介绍图像语义分割中解决分割精细度的另一种解决方案--多尺度特征融合。在基于FCNs的语义分割架构中,如何有效的...

  • 基于深度学习的语义分割技术讲解

    在计算机视觉领域,有一个方向是语义分割,一般是针对图像进行像素级分类。具体而言,语义图像分割就是将每个像素都标注上...

  • 基于深度学习的语义分割

    基于深度学习的语义分割 1 什么是语义分割(Semantic Segmentation) 语义分割是在像素层面...

网友评论

      本文标题:一文搞定U2Net——图像分割(语义分割)

      本文链接:https://www.haomeiwen.com/subject/lnrpedtx.html