一文搞定U2Net——图像分割(语义分割)

上文中,我们介绍了UNet,今天我们来了解一下U2Net。这个网络是 UNet的加强版。其结构如下图所示:

与UNet相比,U2Net中的每一个小立方体里面都是一个UNet。但是需要注意的是:

  1. U型结构的最下面一层白色的块为空洞卷积。
  2. 在最后一层做通道融合的时候,是用类似ResNet的思想,将特征变量进行相加。其他时候仍然用的torch.cat()操作。
  3. 与UNet不同,特征矩阵的通道数没有增加。

所以首先我们构建小立方体中的结构,我们称之为Unet_Blockx。和UNet中的一样其包括了卷积层块、下采样、上采样结构。其具体代码如下:

  • 卷积块
"""
卷积块
"""
class Conv_Block(nn.Module):
    def __init__(self, in_c, out_c, dilation=1) -> None:
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_c, out_c,3,1,padding=1*dilation,dilation=1*dilation),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.layers(x) 
  • 下采样
"""
下采样块
"""
class Down_Sample_Block(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layers = nn.Sequential(
            nn.MaxPool2d(2,2)
        ) 

    def forward(self, x):
        return self.layers(x)
  • 上采样
"""
上采样块
"""
class Up_Sample_Block(nn.Module):
    def __init__(self, scale_factor=2) -> None:
        super().__init__()
        # 上采样方法1:
        self.upsample1 = nn.UpsamplingBilinear2d(scale_factor=scale_factor) 
        # 上采样方法2:
        self.upsample2 = nn.Upsample(scale_factor=scale_factor, mode='bilinear')

    def forward(self,x,feature):
        # 方法3:x = torch.nn.functional.interpolate(input=x,scale_factor=2, mode="nearest")
        x = self.upsample1(x)
        # 下面两行代码是将
        # resize = Resize((x.shape[2], x.shape[3]))
        # feature = resize(feature)
        res = torch.cat((x,feature),dim=1)
        return res

一、UNet_Block_x 代码实现

所以Unet_Block1的代码如下:

class UNet_Block1(nn.Module):
    def __init__(self, in_c, mid_c, out_c) -> None:
        super().__init__()
        self.down = Down_Sample_Block()
        self.up = Up_Sample_Block()

        self.conv1 = Conv_Block(in_c,out_c)
        self.conv2 = Conv_Block(out_c,mid_c)
        self.conv3 = Conv_Block(mid_c,mid_c)
        self.conv4 = Conv_Block(mid_c,mid_c,dilation=2)
        self.conv5 = Conv_Block(mid_c*2,mid_c)
        self.conv6 = Conv_Block(mid_c*2,out_c)

    def forward(self,x):
        # 下采样过程
        out1 = self.conv1(x) # out_c [1, 5, 224, 224]
        out2 = self.conv2(out1) # mid_c [1, 3, 224, 224]
        out3 = self.conv3(self.down(out2)) # mid_c [1, 3, 112, 112]
        out4 = self.conv3(self.down(out3)) # mid_c [1, 3, 56, 56]
        out5 = self.conv3(self.down(out4)) # mid_c [1, 3, 28, 28]
        out6 = self.conv3(self.down(out5)) # mid_c [1, 3, 14, 14]
        out7 = self.conv3(self.down(out6)) # mid_c [1, 3, 7, 7]

        out8 = self.conv4(out7) # mid_c [1, 3, 7, 7]
        out9 = self.conv5(torch.cat((out7,out8),dim=1)) # mid_c [1, 3, 7, 7]

        # 上采样
        out10 = self.conv5(self.up(out9,out6)) # [1, 3, 14, 14]
        out11 = self.conv5(self.up(out10,out5)) # [1, 3, 28, 28]
        out12 = self.conv5(self.up(out11,out4)) # [1, 3, 56, 56]
        out13 = self.conv5(self.up(out12,out3)) # [1, 3, 112, 112]
        out14 = self.conv6(self.up(out13,out2)) # [1, 3, 224, 224]
        out = out14 + out1 # [1, 5, 224, 224]
        return out

Unet_Block2的代码:

class UNet_Block2(nn.Module):
    def __init__(self, in_c, mid_c, out_c) -> None:
        super().__init__()
        self.down = Down_Sample_Block()
        self.up = Up_Sample_Block()

        self.conv1 = Conv_Block(in_c,out_c)
        self.conv2 = Conv_Block(out_c,mid_c)
        self.conv3 = Conv_Block(mid_c,mid_c)
        self.conv4 = Conv_Block(mid_c,mid_c,dilation=2)
        self.conv5 = Conv_Block(mid_c*2,mid_c)
        self.conv6 = Conv_Block(mid_c*2,out_c)

    def forward(self,x):
        # 下采样过程
        out1 = self.conv1(x) # out_c [1, 5, 112, 112]
        out2 = self.conv2(out1) # mid_c [1, 3, 112, 112]
        out3 = self.conv3(self.down(out2)) # mid_c [1, 3, 56, 56]
        out4 = self.conv3(self.down(out3)) # mid_c [1, 3, 28, 28]
        out5 = self.conv3(self.down(out4)) # mid_c [1, 3, 14, 14]
        out6 = self.conv3(self.down(out5)) # mid_c [1, 3, 7, 7]

        out8 = self.conv4(out6) # mid_c [1, 3, 7, 7]
        out9 = self.conv5(torch.cat((out6,out8),dim=1)) # mid_c [1, 3, 7, 7]

        # 上采样
        out10 = self.conv5(self.up(out9,out5)) # [1, 3, 14, 14]
        out11 = self.conv5(self.up(out10,out4)) # [1, 3, 28, 28]
        out12 = self.conv5(self.up(out11,out3)) # [1, 3, 56, 56]
        out13 = self.conv6(self.up(out12,out2)) # [1, 3, 112, 112]
        out = out13 + out1 # [1, 5, 112, 112]
        return out

Unet_Block3的代码:

class UNet_Block3(nn.Module):
    def __init__(self, in_c, mid_c, out_c) -> None:
        super().__init__()
        self.down = Down_Sample_Block()
        self.up = Up_Sample_Block()

        self.conv1 = Conv_Block(in_c,out_c)
        self.conv2 = Conv_Block(out_c,mid_c)
        self.conv3 = Conv_Block(mid_c,mid_c)
        self.conv4 = Conv_Block(mid_c,mid_c,dilation=2)
        self.conv5 = Conv_Block(mid_c*2,mid_c)
        self.conv6 = Conv_Block(mid_c*2,out_c)

    def forward(self,x):
        # 下采样过程
        out1 = self.conv1(x) # out_c [1, 5, 56, 56]
        out2 = self.conv2(out1) # mid_c [1, 3, 56, 56]
        out3 = self.conv3(self.down(out2)) # mid_c [1, 3, 28, 28]
        out4 = self.conv3(self.down(out3)) # mid_c [1, 3, 14, 14]
        out5 = self.conv3(self.down(out4)) # mid_c [1, 3, 7, 7]

        out8 = self.conv4(out5) # mid_c [1, 3, 7, 7]
        out9 = self.conv5(torch.cat((out5,out8),dim=1)) # mid_c [1, 3, 7, 7]

        # 上采样
        out10 = self.conv5(self.up(out9,out4)) # [1, 3, 14, 14]
        out11 = self.conv5(self.up(out10,out3)) # [1, 3, 28, 28]
        out12 = self.conv6(self.up(out11,out2)) # [1, 3, 56, 56]
        out = out12 + out1 # [1, 5, 56, 56]
        return out

Unet_Block4的代码:

class UNet_Block4(nn.Module):
    def __init__(self, in_c, mid_c, out_c) -> None:
        super().__init__()
        self.down = Down_Sample_Block()
        self.up = Up_Sample_Block()

        self.conv1 = Conv_Block(in_c,out_c)
        self.conv2 = Conv_Block(out_c,mid_c)
        self.conv3 = Conv_Block(mid_c,mid_c)
        self.conv4 = Conv_Block(mid_c,mid_c,dilation=2)
        self.conv5 = Conv_Block(mid_c*2,mid_c)
        self.conv6 = Conv_Block(mid_c*2,out_c)

    def forward(self,x):
        # 下采样过程
        out1 = self.conv1(x) # out_c [1, 5, 28, 28]
        out2 = self.conv2(out1) # mid_c [1, 3, 28, 28]
        out3 = self.conv3(self.down(out2)) # mid_c [1, 3, 14, 14]
        out4 = self.conv3(self.down(out3)) # mid_c [1, 3, 7, 7]

        out8 = self.conv4(out4) # mid_c [1, 3, 7, 7]
        out9 = self.conv5(torch.cat((out4,out8),dim=1)) # mid_c [1, 3, 7, 7]

        # 上采样
        out10 = self.conv5(self.up(out9,out3)) # [1, 3, 14, 14]
        out11 = self.conv6(self.up(out10,out2)) # [1, 3, 28, 28]
        out = out11 + out1 # [1, 5, 28, 28]
        return out

Unet_Block5的代码:

class UNet_Block5(nn.Module):
    def __init__(self, in_c, mid_c, out_c) -> None:
        super().__init__()

        self.conv1 = Conv_Block(in_c,out_c)
        self.conv2 = Conv_Block(out_c,mid_c)
        self.conv3 = Conv_Block(mid_c,mid_c,dilation=2)
        self.conv4 = Conv_Block(mid_c,mid_c,dilation=4)
        self.conv5 = Conv_Block(mid_c,mid_c,dilation=8)
        self.conv6 = Conv_Block(mid_c*2,mid_c,dilation=4)
        self.conv7 = Conv_Block(mid_c*2,mid_c,dilation=2)
        self.conv8 = Conv_Block(mid_c*2,out_c)

    def forward(self,x):
        out1 = self.conv1(x) # out_c [1, 5, 14, 14]
        out2 = self.conv2(out1) 
        out3 = self.conv3(out2) 
        out4 = self.conv4(out3) 

        out5 = self.conv5(out4) 

        out6 = self.conv6(torch.cat((out4,out5),dim=1)) 
        out7 = self.conv7(torch.cat((out3,out6),dim=1)) 
        out8 = self.conv8(torch.cat((out2,out7),dim=1)) 
        out = out8 + out1
        return out

二、U2Net模型代码实现

model2.py

import torch
import torch.nn as nn 

"""
卷积块
"""
class Conv_Block(nn.Module):
    def __init__(self, in_c, out_c, dilation=1) -> None:
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_c, out_c,3,1,padding=1*dilation,dilation=1*dilation),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.layers(x) 
    
"""
下采样块
"""
class Down_Sample_Block(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layers = nn.Sequential(
            nn.MaxPool2d(2,2)
        ) 

    def forward(self, x):
        return self.layers(x)
"""
上采样块
"""
class Up_Sample_Block(nn.Module):
    def __init__(self, scale_factor=2) -> None:
        super().__init__()
        # 上采样方法1:
        self.upsample1 = nn.UpsamplingBilinear2d(scale_factor=scale_factor) 
        # 上采样方法2:
        self.upsample2 = nn.Upsample(scale_factor=scale_factor, mode='bilinear')

    def forward(self,x,feature):
        # 方法3:x = torch.nn.functional.interpolate(input=x,scale_factor=2, mode="nearest")
        x = self.upsample1(x)
        # 下面两行代码是将
        # resize = Resize((x.shape[2], x.shape[3]))
        # feature = resize(feature)
        res = torch.cat((x,feature),dim=1)
        return res
    
"""
输出模块:
"""
class Output(nn.Module):
    def __init__(self,in_c, out_c) -> None:
        super().__init__()
        self.layers = self.layers = nn.Sequential(
            nn.Conv2d(in_c, out_c, 3, 1, 1,bias=False),
            nn.BatchNorm2d(out_c),
            nn.ReLU(),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.layers(x)

class UNet_Block1(nn.Module):
    def __init__(self, in_c, mid_c, out_c) -> None:
        super().__init__()
        self.down = Down_Sample_Block()
        self.up = Up_Sample_Block()

        self.conv1 = Conv_Block(in_c,out_c)
        self.conv2 = Conv_Block(out_c,mid_c)
        self.conv3 = Conv_Block(mid_c,mid_c)
        self.conv4 = Conv_Block(mid_c,mid_c,dilation=2)
        self.conv5 = Conv_Block(mid_c*2,mid_c)
        self.conv6 = Conv_Block(mid_c*2,out_c)

    def forward(self,x):
        # 下采样过程
        out1 = self.conv1(x) # out_c [1, 5, 224, 224]
        out2 = self.conv2(out1) # mid_c [1, 3, 224, 224]
        out3 = self.conv3(self.down(out2)) # mid_c [1, 3, 112, 112]
        out4 = self.conv3(self.down(out3)) # mid_c [1, 3, 56, 56]
        out5 = self.conv3(self.down(out4)) # mid_c [1, 3, 28, 28]
        out6 = self.conv3(self.down(out5)) # mid_c [1, 3, 14, 14]
        out7 = self.conv3(self.down(out6)) # mid_c [1, 3, 7, 7]

        out8 = self.conv4(out7) # mid_c [1, 3, 7, 7]
        out9 = self.conv5(torch.cat((out7,out8),dim=1)) # mid_c [1, 3, 7, 7]

        # 上采样
        out10 = self.conv5(self.up(out9,out6)) # [1, 3, 14, 14]
        out11 = self.conv5(self.up(out10,out5)) # [1, 3, 28, 28]
        out12 = self.conv5(self.up(out11,out4)) # [1, 3, 56, 56]
        out13 = self.conv5(self.up(out12,out3)) # [1, 3, 112, 112]
        out14 = self.conv6(self.up(out13,out2)) # [1, 3, 224, 224]
        out = out14 + out1 # [1, 5, 224, 224]
        return out
    
class UNet_Block2(nn.Module):
    def __init__(self, in_c, mid_c, out_c) -> None:
        super().__init__()
        self.down = Down_Sample_Block()
        self.up = Up_Sample_Block()

        self.conv1 = Conv_Block(in_c,out_c)
        self.conv2 = Conv_Block(out_c,mid_c)
        self.conv3 = Conv_Block(mid_c,mid_c)
        self.conv4 = Conv_Block(mid_c,mid_c,dilation=2)
        self.conv5 = Conv_Block(mid_c*2,mid_c)
        self.conv6 = Conv_Block(mid_c*2,out_c)

    def forward(self,x):
        # 下采样过程
        out1 = self.conv1(x) # out_c [1, 5, 112, 112]
        out2 = self.conv2(out1) # mid_c [1, 3, 112, 112]
        out3 = self.conv3(self.down(out2)) # mid_c [1, 3, 56, 56]
        out4 = self.conv3(self.down(out3)) # mid_c [1, 3, 28, 28]
        out5 = self.conv3(self.down(out4)) # mid_c [1, 3, 14, 14]
        out6 = self.conv3(self.down(out5)) # mid_c [1, 3, 7, 7]

        out8 = self.conv4(out6) # mid_c [1, 3, 7, 7]
        out9 = self.conv5(torch.cat((out6,out8),dim=1)) # mid_c [1, 3, 7, 7]

        # 上采样
        out10 = self.conv5(self.up(out9,out5)) # [1, 3, 14, 14]
        out11 = self.conv5(self.up(out10,out4)) # [1, 3, 28, 28]
        out12 = self.conv5(self.up(out11,out3)) # [1, 3, 56, 56]
        out13 = self.conv6(self.up(out12,out2)) # [1, 3, 112, 112]
        out = out13 + out1 # [1, 5, 112, 112]
        return out
    
class UNet_Block3(nn.Module):
    def __init__(self, in_c, mid_c, out_c) -> None:
        super().__init__()
        self.down = Down_Sample_Block()
        self.up = Up_Sample_Block()

        self.conv1 = Conv_Block(in_c,out_c)
        self.conv2 = Conv_Block(out_c,mid_c)
        self.conv3 = Conv_Block(mid_c,mid_c)
        self.conv4 = Conv_Block(mid_c,mid_c,dilation=2)
        self.conv5 = Conv_Block(mid_c*2,mid_c)
        self.conv6 = Conv_Block(mid_c*2,out_c)

    def forward(self,x):
        # 下采样过程
        out1 = self.conv1(x) # out_c [1, 5, 56, 56]
        out2 = self.conv2(out1) # mid_c [1, 3, 56, 56]
        out3 = self.conv3(self.down(out2)) # mid_c [1, 3, 28, 28]
        out4 = self.conv3(self.down(out3)) # mid_c [1, 3, 14, 14]
        out5 = self.conv3(self.down(out4)) # mid_c [1, 3, 7, 7]

        out8 = self.conv4(out5) # mid_c [1, 3, 7, 7]
        out9 = self.conv5(torch.cat((out5,out8),dim=1)) # mid_c [1, 3, 7, 7]

        # 上采样
        out10 = self.conv5(self.up(out9,out4)) # [1, 3, 14, 14]
        out11 = self.conv5(self.up(out10,out3)) # [1, 3, 28, 28]
        out12 = self.conv6(self.up(out11,out2)) # [1, 3, 56, 56]
        out = out12 + out1 # [1, 5, 56, 56]
        return out
    
class UNet_Block4(nn.Module):
    def __init__(self, in_c, mid_c, out_c) -> None:
        super().__init__()
        self.down = Down_Sample_Block()
        self.up = Up_Sample_Block()

        self.conv1 = Conv_Block(in_c,out_c)
        self.conv2 = Conv_Block(out_c,mid_c)
        self.conv3 = Conv_Block(mid_c,mid_c)
        self.conv4 = Conv_Block(mid_c,mid_c,dilation=2)
        self.conv5 = Conv_Block(mid_c*2,mid_c)
        self.conv6 = Conv_Block(mid_c*2,out_c)

    def forward(self,x):
        # 下采样过程
        out1 = self.conv1(x) # out_c [1, 5, 28, 28]
        out2 = self.conv2(out1) # mid_c [1, 3, 28, 28]
        out3 = self.conv3(self.down(out2)) # mid_c [1, 3, 14, 14]
        out4 = self.conv3(self.down(out3)) # mid_c [1, 3, 7, 7]

        out8 = self.conv4(out4) # mid_c [1, 3, 7, 7]
        out9 = self.conv5(torch.cat((out4,out8),dim=1)) # mid_c [1, 3, 7, 7]

        # 上采样
        out10 = self.conv5(self.up(out9,out3)) # [1, 3, 14, 14]
        out11 = self.conv6(self.up(out10,out2)) # [1, 3, 28, 28]
        out = out11 + out1 # [1, 5, 28, 28]
        return out
    
class UNet_Block5(nn.Module):
    def __init__(self, in_c, mid_c, out_c) -> None:
        super().__init__()

        self.conv1 = Conv_Block(in_c,out_c)
        self.conv2 = Conv_Block(out_c,mid_c)
        self.conv3 = Conv_Block(mid_c,mid_c,dilation=2)
        self.conv4 = Conv_Block(mid_c,mid_c,dilation=4)
        self.conv5 = Conv_Block(mid_c,mid_c,dilation=8)
        self.conv6 = Conv_Block(mid_c*2,mid_c,dilation=4)
        self.conv7 = Conv_Block(mid_c*2,mid_c,dilation=2)
        self.conv8 = Conv_Block(mid_c*2,out_c)

    def forward(self,x):
        out1 = self.conv1(x) # out_c [1, 5, 14, 14]
        out2 = self.conv2(out1) 
        out3 = self.conv3(out2) 
        out4 = self.conv4(out3) 

        out5 = self.conv5(out4) 

        out6 = self.conv6(torch.cat((out4,out5),dim=1)) 
        out7 = self.conv7(torch.cat((out3,out6),dim=1)) 
        out8 = self.conv8(torch.cat((out2,out7),dim=1)) 
        out = out8 + out1
        return out
    
class U2NET(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.down = Down_Sample_Block()
        self.up1 = nn.UpsamplingNearest2d(scale_factor=2)
        self.up2 = nn.UpsamplingNearest2d(scale_factor=4)
        self.up3 = nn.UpsamplingNearest2d(scale_factor=8)
        self.up4 = nn.UpsamplingNearest2d(scale_factor=16)
        self.up5 = nn.UpsamplingNearest2d(scale_factor=32)

        self.unet1 = UNet_Block1(1,32,64) 
        self.unet2 = UNet_Block2(64,32,128) 
        self.unet3 = UNet_Block3(128,64,256) 
        self.unet4 = UNet_Block4(256,128,512) 
        self.unet5 = UNet_Block5(512,256,512) 
        
        self.unet6 = UNet_Block5(512,256,512) 

        self.de_unet1 = UNet_Block1(128,16,64) 
        self.de_unet2 = UNet_Block2(256,32,64) 
        self.de_unet3 = UNet_Block3(512,64,128) 
        self.de_unet4 = UNet_Block4(1024,128,256) 
        self.de_unet5 = UNet_Block5(1024,256,512) 

        self.out1 = Output(64,1)
        self.out2 = Output(64,1)
        self.out3 = Output(128,1)
        self.out4 = Output(256,1)
        self.out5 = Output(512,1)

    def forward(self, x):
        # 下采样,编码
        conv1 = self.unet1(x)
        en1 = self.down(conv1)
        conv2 = self.unet2(en1)
        en2 = self.down(conv2)
        conv3 = self.unet3(en2)
        en3 = self.down(conv3)
        conv4 = self.unet4(en3)
        en4 = self.down(conv4)
        conv5 = self.unet5(en4)
        en5 = self.down(conv5)

        conv6 = self.unet6(en5)

        # 上采样,解码
        de1 = self.up1(conv6) # [1, 512, 14, 14]
        conv7 = self.de_unet5(torch.cat((conv5,de1),dim=1)) # [1, 512, 14, 14]
        de2 = self.up1(conv7) # [1, 512, 28, 28]
        conv8 = self.de_unet4(torch.cat((conv4,de2),dim=1)) # [1, 256, 28, 28]
        de3 = self.up1(conv8) # [1, 256, 56, 56]
        conv9 = self.de_unet3(torch.cat((conv3,de3),dim=1)) # [1, 128, 56, 56]
        de4 = self.up1(conv9) # [1, 128, 112, 112]
        conv10 = self.de_unet2(torch.cat((conv2,de4),dim=1)) # [1, 64, 112, 112]
        de5 = self.up1(conv10) # [1, 64, 224, 224]

        # 输出
        out1 = self.up5(self.out5(conv6)) # [1, 1, 224, 224]
        out2 = self.up4(self.out5(conv7)) # [1, 1, 224, 224]
        out3 = self.up3(self.out4(conv8)) # [1, 1, 224, 224]
        out4 = self.up2(self.out3(conv9)) # [1, 1, 224, 224]
        out5 = self.up1(self.out2(conv10)) # [1, 1, 224, 224]
        out6 = self.out1(de5) # [1, 1, 224, 224]

        out = (out1 + out2 + out3 + out4 + out5 + out6)/6

        return out
    
if __name__ == "__main__":
    x = torch.randn((1,1,224,224))
    conv = U2NET()
    y = conv(x)
    print(y.shape)


三、数据集和实验结果

(1)数据集:
(2)训练代码:
  • train2.py
import PIL.Image as Image
import numpy as np 
import torch 
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import os 
from model2 import U2NET
from torchvision.utils import save_image

IMG_PATH = "./data/FundusVessels/JPEGImages/"
TARGET_PATH = "./data/FundusVessels/Annotations/"
DST_DIR = "./img2"

class EYE_Dataset(Dataset):
    def __init__(self) -> None:
        super().__init__()
        self.img_list =  os.listdir(IMG_PATH)
        self.target_list = os.listdir(TARGET_PATH)

    def __len__(self) -> int:
        return len(self.img_list)
    
    def __getitem__(self, index):
        name = self.img_list[index][0:-4]
        img = Image.open(IMG_PATH+f"{name}.jpg").convert("L").resize((224,224))
        lable = Image.open(TARGET_PATH+f"{name}.png").resize((224,224))
        img = np.array(img, dtype=np.float32)/255
        img = img[np.newaxis,:]
        lable = np.array(lable, dtype=np.float32)
        return img, lable
    


if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    # device = "cpu"
    train_dataset = EYE_Dataset()
    train_loader = DataLoader(train_dataset,batch_size=1,shuffle=True)

    net = U2NET()
    if os.path.exists("./U2Net.pt"):
        params = torch.load("./U2Net.pt")
        net.load_state_dict(params)

    net.to(device)
    optim = torch.optim.Adam(net.parameters())
    loss_fn = nn.BCELoss()
    epoch=1
    net.train()
    while True:
        for i,(img,target) in enumerate(train_loader):
            img, target = img.to(device), target.to(device)
            y = net(img)
            loss = loss_fn(y, target.unsqueeze(dim=0))
            optim.zero_grad()
            loss.backward()
            optim.step()

            if i%1 == 0:
                # 保存测试结果
                img2 =  (y[0]>0.6).float() *255
                res = torch.stack([img[0],img2],dim=0)
                save_image(res.cpu(), DST_DIR + f"/epoch{epoch}_{i}.jpg", nrow=2)
                print(f"epoch {epoch},loss: {loss.item()}")
                torch.save(net.state_dict(),"./U2Net.pt")
            epoch += 1
(3)训练结果

大约训练了1500轮,多训练一会儿,效果还可以更好。

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 205,236评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 87,867评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 151,715评论 0 340
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,899评论 1 278
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,895评论 5 368
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,733评论 1 283
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,085评论 3 399
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,722评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 43,025评论 1 300
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,696评论 2 323
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,816评论 1 333
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,447评论 4 322
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,057评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,009评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,254评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,204评论 2 352
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,561评论 2 343

推荐阅读更多精彩内容