肾小球的病理图像分割

项目目标

在不同的组织制备管道中分割人类肾脏组织图像中的肾小球区域。肾小球是一种功能组织单位(FTU):以毛细血管为中心的三维细胞块,因此该块中的每个细胞与同一块中的任何其他细胞都在扩散距离之内。

项目数据

提供的数据包括11张新鲜冷冻和9张福尔马林固定石蜡包埋(FFPE)PAS肾脏图像:8 张用于训练,5+7张用于测试。每个都有大约50k像素大小,并保存为高分辨率tiff图像。为了使如此大的图像适合神经网络的训练,必须将它们切成小块。根据检测到的目标大小,此数据的适当图块大小应为 1024*1024。对此使用分辨率低4倍的256*256瓦片(tiles),可以在最终设置上运行更高分辨率的瓦片。瓦片数(8211+1893)

数据处理办法

重叠裁剪

Overlap-tile策略搭配patch(图像分块)一起使用。当内存资源有限从而无法对整张大图进行预测时,可以对图像先进行镜像padding,然后按序将padding后的图像分割成固定大小的patch。这样,能够实现对任意大的图像进行无缝分割,同时每个图像块也获得了相应的上下文信息。另外,在数据量较少的情况下,每张图像都被分割成多个patch,相当于起到了扩充数据量的作用。更重要的是,这种策略不需要对原图进行缩放,每个位置的像素值与原图保持一致,不会因为缩放而带来误差。overlap-tile策略的思想是:对图像的某一块像素点(黄框内部分)进行预测时,需要该图像块周围的像素点(蓝色框内)提供上下文信息(context),以获得更准确的预测

def make_grid(shape, window=256, min_overlap=32):
    """
        Return Array of size (N,4), where N - number of tiles,
        2nd axis represente slices: x1,x2,y1,y2 
    """
    x, y = shape
    nx = x // (window - min_overlap) + 1
    x1 = np.linspace(0, x, num=nx, endpoint=False, dtype=np.int64)
    x1[-1] = x - window
    x2 = (x1 + window).clip(0, x)
    ny = y // (window - min_overlap) + 1
    y1 = np.linspace(0, y, num=ny, endpoint=False, dtype=np.int64)
    y1[-1] = y - window
    y2 = (y1 + window).clip(0, y)
    slices = np.zeros((nx,ny, 4), dtype=np.int64)
    
    for i in range(nx):
        for j in range(ny):
            slices[i,j] = x1[i], x2[i], y1[j], y2[j]    
    return slices.reshape(nx*ny,4)

数据增强策略

本项目用到的操作包括模糊图像、中心模糊、高斯噪声、色调饱和度值、对比度受限自适应直方图均衡、随机亮度对比度等,以及常用的翻转、旋转、仿射变换。在训练集上只使用旋转、翻转变换。

def get_aug(p=1.0):
    return Compose([
        HorizontalFlip(),
        VerticalFlip(),
        RandomRotate90(),
        ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=15, p=0.9,
                         border_mode=cv2.BORDER_REFLECT),
        OneOf([
            ElasticTransform(p=.3),
            GaussianBlur(p=.3),
            GaussNoise(p=.3),
            OpticalDistortion(p=0.3),
            GridDistortion(p=.1),
            # IAAPiecewiseAffine(p=0.3),
        ], p=0.3),
        OneOf([
            HueSaturationValue(15,25,0),
            CLAHE(clip_limit=2),
            RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3),
        ], p=0.3),
    ], p=p)

badcase分析显示有些更暗更小的东西与正常肾小球不相似。它们在切片的边界上分布得更密集,并且在这些结构中似乎有更少的细胞核,是纤维状新月形肾小球。通过调整数据增强策略可以有一定的改善。

Dataset类

mean = np.array([0.7720342, 0.74582646, 0.76392896])
std = np.array([0.24745085, 0.26182273, 0.25782376])

def img2tensor(img,dtype:np.dtype=np.float32):
    if img.ndim==2 : img = np.expand_dims(img,2)
    img = np.transpose(img,(2,0,1))
    return torch.from_numpy(img.astype(dtype, copy=False))

class HuBMAPDataset(Dataset):
    def __init__(self, path, fold=0, train=True, tfms=None, seed=2020, nfolds= 4, include_pl=False):
        self.path=path
       
        if include_pl:
            ids = np.concatenate([pd.read_csv(os.path.join(self.path,'train.csv')).id.values,
                     pd.read_csv(os.path.join(self.path,'sample_submission.csv')).id.values])
        else:
            ids = pd.read_csv(os.path.join(self.path,'train.csv')).id.values      
        kf = KFold(n_splits=nfolds,random_state=seed,shuffle=True)
        ids = set(ids[list(kf.split(ids))[fold][0 if train else 1]])
        print(f"number of {'train' if train else 'val'} images is {len(ids)}")
        
        self.fnames = ['train/'+fname for fname in os.listdir(os.path.join(self.path,'train')) if int(fname.split('_')[0]) in ids]
        # +['test/'+fname for fname in os.listdir(os.path.join(self.path,'test')) if fname.split('_')[0] in ids]

        self.train = train
        self.tfms = tfms

    def __len__(self):
        return len(self.fnames)

    def __getitem__(self, idx):
        fname = self.fnames[idx]
        img = cv2.cvtColor(cv2.imread(os.path.join(self.path,fname)), cv2.COLOR_BGR2RGB)

        if self.fnames[idx][:5]=='train':
            mask = cv2.imread(os.path.join(self.path,'masks',fname[6:]),cv2.IMREAD_GRAYSCALE)
        else:
            mask = cv2.imread(os.path.join(self.path,'test_masks',fname[5:]),cv2.IMREAD_GRAYSCALE)
        if self.tfms is not None:
            augmented = self.tfms(image=img,mask=mask)
            img,mask = augmented['image'],augmented['mask']

        data={'img':img2tensor((img/255.0 - mean)/std), 'mask':img2tensor(mask)}
        return data

模型设计

使用的模型基于一个 U 形网络(UneXt50,见下图)。 Unet 架构:编码器部分创建不同级别的特征表示,而解码器将特征组合并生成预测作为分割掩码。编码器和解码器之间的跳过连接允许有效地利用编码器中间卷积层的特征,而无需信息通过整个编码器和解码器。后者对于将预测掩码链接到检测对象的特定像素特别重要。后来人们意识到 ImageNet 预训练的计算机视觉模型可以显着提高分割模型的质量,因为编码器的架构经过优化,编码器容量高(与原始 Unet 中使用的编码器相比),以及具有迁移学习的强大功能。

原始Unet
本项目设计

使用半监督 Imagenet 预训练的 ResNeXt50 模型作为主干。 在 Pytorch 中,它提供了 EfficientNet B2-B3 的性能,在计算成本上具有更快的收敛速度,以及EfficientNet B0 的 GPU RAM 要求。

对 ResNet 有效性的解释主要有三种:

  • 使网络更容易在某些层学到恒等变换(identity mapping)。在某些层执行恒等变换是一种构造性解,使更深的模型的性能至少不低于较浅的模型。这也是作者原始论文指出的动机。(ResNet解决了深网络的梯度问题,自然能学习到更多抽象特征,所以效果好还是因为够深。)
    [1512.03385] Deep Residual Learning for Image Recognition
  • 残差网络是很多浅层网络的集成(ensemble),层数的指数级那么多。主要的实验证据是:把 ResNet 中的某些层直接删掉,模型的性能几乎不下降。
    [1605.06431] Residual Networks Behave Like Ensembles of Relatively Shallow Networks
  • 残差网络使信息更容易在各层之间流动,包括在前向传播时提供特征重用在反向传播时缓解梯度信号消失

ResNeXt 同时采用 VGG 堆叠的思想Inception 的 split-transform-merge 思想。ResNeXt 提出的主要原因在于:传统的要提高模型的准确率,都是加深或加宽网络,但是随着超参数数量的增加(比如channels数,filter size等等),网络设计的难度和计算开销也会增加。因此ResNeXt 结构可以在不增加参数复杂度的前提下提高准确率,同时还减少了超参数的数量
一般增强一个CNN的表达能力有三种手段:一是增加网络层次即加深网络二是增加网络模块宽度三是改善CNN网络结构设计)。ResNeXt的做法可归为上面三种方法的第三种。它引入了新的用于构建CNN网络的模块,提出了一个cardinatity的概念,用于作为模型复杂度的另外一个度量。Cardinatity指的是一个block中所具有的相同分支的数目。作者进行了一系列对比实验,有力证明在保证相似计算复杂度及模型参数大小的前提下,提升cardinatity比提升height或width可取得更好的模型表达能力。下面三种ResNeXt网络模块的变形。它们在数学计算上是完全等价的,而第三种包含有Group convolution操作的正是最终ResNeXt网络所采用的操作。

ResNeXt的分类效果为什么比Resnet好?
ResNeXt的精妙之处在于,该思路沿用到nlp里就有了multi-head attention。
第一,ResNext中引入cardinality,实际上仍然还是一个Group的概念。不同的组之间实际上是不同的subspace,而他们的确能学到更diverse的表示。
第二,这种分组的操作或许能起到网络正则化的作用。实际上,增加一个cardinality维度之后,会使得卷积核学到的关系更加稀疏。同时在整体的复杂度不变的情况下,其中Network-in-Neuron的思想,会大大降低了每个sub-network的复杂度,那么其过拟合的风险相比于ResNet也将会大大降低。

class UneXt(nn.Module):
    def __init__(self, m, stride=1, **kwargs):
        super().__init__()
        #encoder
        # m = torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models',
        #                    'resnext101_32x4d_swsl')
#         m = torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models',
#                            'resnext50_32x4d_swsl', pretrained=False)
        #m = ResNet(Bottleneck, [3, 4, 23, 3], groups=32, width_per_group=4)
        #m = torchvision.models.resnext50_32x4d(pretrained=False)
        # m = torch.hub.load(
        #     'moskomule/senet.pytorch',
        #     'se_resnet101',
        #     pretrained=True,)

        #m=torch.hub.load('zhanghang1989/ResNeSt', 'resnest50', pretrained=True)
        self.enc0 = nn.Sequential(m.conv1, m.bn1, nn.ReLU(inplace=True))
        self.enc1 = nn.Sequential(nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1),
                            m.layer1) #256
        self.enc2 = m.layer2 #512
        self.enc3 = m.layer3 #1024
        self.enc4 = m.layer4 #2048
        #aspp with customized dilatations
        self.aspp = ASPP(2048,256,out_c=512,dilations=[stride*1,stride*2,stride*3,stride*4])
        self.drop_aspp = nn.Dropout2d(0.5)
        #decoder
        self.dec4 = UnetBlock(512,1024,256)
        self.dec3 = UnetBlock(256,512,128)
        self.dec2 = UnetBlock(128,256,64)
        self.dec1 = UnetBlock(64,64,32)
        self.fpn = FPN([512,256,128,64],[16]*4)
        self.drop = nn.Dropout2d(0.1)
        self.final_conv = ConvLayer(32+16*4, 1, ks=1, norm_type=None, act_cls=None)

    def forward(self, x):
        enc0 = self.enc0(x)
        enc1 = self.enc1(enc0)
        enc2 = self.enc2(enc1)
        enc3 = self.enc3(enc2)
        enc4 = self.enc4(enc3)
        enc5 = self.aspp(enc4)
        dec3 = self.dec4(self.drop_aspp(enc5),enc3)
        dec2 = self.dec3(dec3,enc2)
        dec1 = self.dec2(dec2,enc1)
        dec0 = self.dec1(dec1,enc0)
        x = self.fpn([enc5, dec3, dec2, dec1], dec0)
        x = self.final_conv(self.drop(x))
        x = F.interpolate(x,scale_factor=2,mode='bilinear')
        return x

class UnetBlock(nn.Module):
    def __init__(self, up_in_c:int, x_in_c:int, nf:int=None, blur:bool=False,
                 self_attention:bool=False, **kwargs):
        super().__init__()
        self.shuf = PixelShuffle_ICNR(up_in_c, up_in_c//2, blur=blur, **kwargs)
        self.bn = nn.BatchNorm2d(x_in_c)
        ni = up_in_c//2 + x_in_c
        nf = nf if nf is not None else max(up_in_c//2,32)
        self.conv1 = ConvLayer(ni, nf, norm_type=None, **kwargs)
        self.conv2 = ConvLayer(nf, nf, norm_type=None,
            xtra=SelfAttention(nf) if self_attention else None, **kwargs)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, up_in:Tensor, left_in:Tensor) -> Tensor:
        s = left_in
        up_out = self.shuf(up_in)
        cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
        return self.conv2(self.conv1(cat_x))

也尝试使用了Efficientnet作为encoder构建Unet网络

pretrained_root = '/home/ruanshijian/hubmap/'
efficient_net_encoders = {
    "efficientnet-b0": {
        "out_channels": (3, 32, 24, 40, 112, 320),
        "stage_idxs": (3, 5, 9, 16),
        "weight_path": pretrained_root + "efficientnet-b0-08094119.pth"
    },
    "efficientnet-b1": {
        "out_channels": (3, 32, 24, 40, 112, 320),
        "stage_idxs": (5, 8, 16, 23),
        "weight_path": pretrained_root + "efficientnet-b1-dbc7070a.pth"
    },
    "efficientnet-b2": {
        "out_channels": (3, 32, 24, 48, 120, 352),
        "stage_idxs": (5, 8, 16, 23),
        "weight_path": pretrained_root + "efficientnet-b2-27687264.pth"
    },
    "efficientnet-b3": {
        "out_channels": (3, 40, 32, 48, 136, 384),
        "stage_idxs": (5, 8, 18, 26),
        "weight_path": pretrained_root + "efficientnet-b3-c8376fa2.pth"
    },
    "efficientnet-b4": {
        "out_channels": (3, 48, 32, 56, 160, 448),
        "stage_idxs": (6, 10, 22, 32),
        "weight_path": pretrained_root + "efficientnet-b4-e116e8b3.pth"
    },
    "efficientnet-b5": {
        "out_channels": (3, 48, 40, 64, 176, 512),
        "stage_idxs": (8, 13, 27, 39),
        "weight_path": pretrained_root + "efficientnet-b5-586e6cc6.pth"
    },
    "efficientnet-b6": {
        "out_channels": (3, 56, 40, 72, 200, 576),
        "stage_idxs": (9, 15, 31, 45),
        "weight_path": pretrained_root + "efficientnet-b6-c76e70fd.pth"
    },
    "efficientnet-b7": {
        "out_channels": (3, 64, 48, 80, 224, 640),
        "stage_idxs": (11, 18, 38, 55),
        "weight_path": pretrained_root + "efficientnet-b7-dcc49843.pth"
    }
}

import sys
sys.path.insert(0, '/home/ruanshijian/hubmap/EfficientNet-PyTorch')
from efficientnet_pytorch import EfficientNet
from efficientnet_pytorch.utils import get_model_params

class EfficientNetEncoder(EfficientNet):
    def __init__(self, stage_idxs, out_channels, model_name, depth=5):

        blocks_args, global_params = get_model_params(model_name, override_params=None)
        super().__init__(blocks_args, global_params)

        cfg = efficient_net_encoders[model_name]

        self._stage_idxs = stage_idxs
        self._out_channels = out_channels
        self._depth = depth
        self._in_channels = 3

        del self._fc
        self.load_state_dict(torch.load(cfg['weight_path']))

    def get_stages(self):
        return [
            nn.Identity(),
            nn.Sequential(self._conv_stem, self._bn0, self._swish),
            self._blocks[:self._stage_idxs[0]],
            self._blocks[self._stage_idxs[0]:self._stage_idxs[1]],
            self._blocks[self._stage_idxs[1]:self._stage_idxs[2]],
            self._blocks[self._stage_idxs[2]:],
        ]

    def forward(self, x):
        stages = self.get_stages()

        block_number = 0.
        drop_connect_rate = self._global_params.drop_connect_rate

        features = []
        for i in range(self._depth + 1):

            # Identity and Sequential stages
            if i < 2:
                x = stages[i](x)

            # Block stages need drop_connect rate
            else:
                for module in stages[i]:
                    drop_connect = drop_connect_rate * block_number / len(self._blocks)
                    block_number += 1.
                    x = module(x, drop_connect)

            features.append(x)

        return features

    def load_state_dict(self, state_dict, **kwargs):
        state_dict.pop("_fc.bias")
        state_dict.pop("_fc.weight")
        super().load_state_dict(state_dict, **kwargs)


class EffUnet(nn.Module):
    def __init__(self, model_name, stride=1):
        super().__init__()

        cfg = efficient_net_encoders[model_name]
        stage_idxs = cfg['stage_idxs']
        out_channels = cfg['out_channels']

        self.encoder = EfficientNetEncoder(stage_idxs, out_channels, model_name)

        # aspp with customized dilatations
        self.aspp = ASPP(out_channels[-1], 256, out_c=384,
                         dilations=[stride * 1, stride * 2, stride * 3, stride * 4])
        self.drop_aspp = nn.Dropout2d(0.5)
        # decoder
        self.dec4 = UnetBlock(384, out_channels[-2], 256)
        self.dec3 = UnetBlock(256, out_channels[-3], 128)
        self.dec2 = UnetBlock(128, out_channels[-4], 64)
        self.dec1 = UnetBlock(64, out_channels[-5], 32)
        self.fpn = FPN([384, 256, 128, 64], [16] * 4)
        self.drop = nn.Dropout2d(0.1)
        self.final_conv = ConvLayer(32 + 16 * 4, 1, ks=1, norm_type=None, act_cls=None)

    def forward(self, x):
        enc0, enc1, enc2, enc3, enc4 = self.encoder(x)[-5:]
        enc5 = self.aspp(enc4)
        dec3 = self.dec4(self.drop_aspp(enc5), enc3)
        dec2 = self.dec3(dec3, enc2)
        dec1 = self.dec2(dec2, enc1)
        dec0 = self.dec1(dec1, enc0)
        x = self.fpn([enc5, dec3, dec2, dec1], dec0)
        x = self.final_conv(self.drop(x))
        x = F.interpolate(x, scale_factor=2, mode='bilinear')
        return x

PixelShuffle是一种上采样方法,可以对缩小后的特征图进行有效的放大。可以替代插值或解卷积的方法实现upscale。pixelshuffle算法的实现流程如图,其实现的功能是:将一个H × W的低分辨率输入图像(Low Resolution),通过Sub-pixel操作将其变为rH*rW的高分辨率图像(High Resolution)。但是其实现过程不是直接通过插值等方式产生这个高分辨率图像,而是通过卷积先得到r^2个通道的特征图(特征图大小和输入低分辨率图像一致),然后通过周期筛选(periodic shuffing)的方法得到这个高分辨率的图像,其中r为上采样因子(upscaling factor),也就是图像的扩大倍率。

简单一句话,PixelShuffle层做的事情就是将输入feature map像素重组输出高分辨率的feature map,是一种上采样方法,具体表达为:N*(C*r*r)*W*H---->>N*C*(H*r)*(W*r)

  1. upsample是利用传统插值方法进行上采样。往往会在upsample后接一个conv,进行学习。任务:超分,目标检测。
  2. 转置卷积应该是上采样力度最大的,所以有些时候的结果看起来会不太真实。任务:GAN,分割,超分。
  3. pixel shuffle最开始也是用在超分上的,把channel通道放大r^2倍,然后再分给H,W成rH,rW,达到上采样的效果。目前超分用这个应该是主流。任务:超分。

此外,在ASPP模块中还加入了OC注意力模块

class BaseOC_Module(nn.Module):
    """
    Implementation of the BaseOC module
    Parameters:
        in_features / out_features: the channels of the input / output feature maps.
        dropout: we choose 0.05 as the default value.
        size: you can apply multiple sizes. Here we only use one size.
    Return:
        features fused with Object context information.
    """

    def __init__(self, in_channels, out_channels, key_channels, value_channels, dropout, sizes=([1])):
        super(BaseOC_Module, self).__init__()
        self.stages = []
        self.stages = nn.ModuleList(
            [self._make_stage(in_channels, out_channels, key_channels, value_channels, size) for size in sizes])
        self.conv_bn_dropout = nn.Sequential(
            nn.Conv2d(2 * in_channels, out_channels, kernel_size=1, padding=0),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Dropout2d(dropout)
        )

    def _make_stage(self, in_channels, output_channels, key_channels, value_channels, size):
        return SelfAttentionBlock2D(in_channels,
                                    key_channels,
                                    value_channels,
                                    output_channels,
                                    size)

    def forward(self, feats):
        priors = [stage(feats) for stage in self.stages]
        context = priors[0]
        for i in range(1, len(priors)):
            context += priors[i]
        output = self.conv_bn_dropout(torch.cat([context, feats], 1))
        return output

class BaseOC_Context_Module(nn.Module):
    """
    Output only the context features.
    Parameters:
        in_features / out_features: the channels of the input / output feature maps.
        dropout: specify the dropout ratio
        fusion: We provide two different fusion method, "concat" or "add"
        size: we find that directly learn the attention weights on even 1/8 feature maps is hard.
    Return:
        features after "concat" or "add"
    """

    def __init__(self, in_channels, out_channels, key_channels, value_channels, dropout, sizes=([1])):
        super(BaseOC_Context_Module, self).__init__()
        self.stages = []
        self.stages = nn.ModuleList(
            [self._make_stage(in_channels, out_channels, key_channels, value_channels, size) for size in sizes])
        self.conv_bn_dropout = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def _make_stage(self, in_channels, output_channels, key_channels, value_channels, size):
        return SelfAttentionBlock2D(in_channels,
                                    key_channels,
                                    value_channels,
                                    output_channels,
                                    size)

    def forward(self, feats):
        priors = [stage(feats) for stage in self.stages]
        context = priors[0]
        for i in range(1, len(priors)):
            context += priors[i]
        output = self.conv_bn_dropout(context)
        return output

添加了特征金字塔网络(FPN):解码器的不同上采样块和输出层之间的附加跳过连接。因此,最终预测是基于 U-net 输出与中间层调整大小的输出串联接产生的。这些跳跃连接为梯度传导提供了捷径以提高模型性能和收敛速度。由于中间层有许多通道,它们的上采样和用作最后一层的输入会在计算时间和内存方面引入大量开销。因此,在调整大小之前应用 3*3+3*3 卷积(分解)以减少通道数。浅层的网络更关注于细节信息,高层的网络更关注于语义信息,而高层的语义信息能够帮助我们准确的检测出目标,设计思想就是同时利用低层特征和高层特征,分别在不同的层同时进行预测,这是因为一幅图像中可能具有多个不同大小的目标,区分不同的目标可能需要不同的特征,对于简单的目标仅仅需要浅层的特征就可以检测到它,对于复杂的目标就需要利用复杂的特征来检测它。整个过程就是首先在原始图像上面进行深度卷积,然后分别在不同的特征层上面进行预测。它的优点是在不同的层上面输出对应的目标,不需要经过所有的层才输出对应的目标(即对于有些目标来说,不需要进行多余的前向操作),这样可以在一定程度上对网络进行加速操作,同时可以提高算法的检测性能。它的缺点是获得的特征不鲁棒,都是一些弱特征(因为很多的特征都是从较浅的层获得的)。

class FPN(nn.Module):
    def __init__(self, input_channels:list, output_channels:list):
        super().__init__()
        self.convs = nn.ModuleList(
            [nn.Sequential(nn.Conv2d(in_ch, out_ch*2, kernel_size=3, padding=1),
             nn.ReLU(inplace=True), nn.BatchNorm2d(out_ch*2),
             nn.Conv2d(out_ch*2, out_ch, kernel_size=3, padding=1))
            for in_ch, out_ch in zip(input_channels, output_channels)])

    def forward(self, xs:list, last_layer):
        hcs = [F.interpolate(c(x),scale_factor=2**(len(self.convs)-i),mode='bilinear')
               for i,(c,x) in enumerate(zip(self.convs, xs))]
        hcs.append(last_layer)
        return torch.cat(hcs, dim=1)

在编码器和解码器之间添加的 Atrous Spatial Pyramid Pooling (ASPP) 块。传统 U 形网络的缺陷是由一个小的感受野造成的。因此,如果模型需要对大对象的分割做出决定,特别是对于大图像分辨率,它可能会因为只能查看对象的一部分而感到困惑。增加感受野并实现图像不同部分之间交互的一种方法是使用具有不同扩张的卷积块组合(在 ASPP 块中具有不同速率的 Atrous 卷积)。虽然原始论文使用 6、12、18 速率,但它们可以针对特定任务和特定图像分辨率进行定制,以最大限度地提高性能。另外在 ASPP 块中使用分组卷积来减少模型参数的数量。

class _ASPPModule(nn.Module):
    def __init__(self, inplanes, planes, kernel_size, padding, dilation, groups=1):
        super().__init__()
        self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
                stride=1, padding=padding, dilation=dilation, bias=False, groups=groups)
        self.bn = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU()

        self._init_weight()

    def forward(self, x):
        x = self.atrous_conv(x)
        x = self.bn(x)

        return self.relu(x)

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

class ASPP(nn.Module):
    def __init__(self, inplanes=512, mid_c=256, dilations=[6, 12, 18, 24], out_c=None):
        super().__init__()
        self.aspps = [_ASPPModule(inplanes, mid_c, 1, padding=0, dilation=1)] + \
            [_ASPPModule(inplanes, mid_c, 3, padding=d, dilation=d,groups=4) for d in dilations]
        self.aspps = nn.ModuleList(self.aspps)
        self.global_pool = nn.Sequential(nn.AdaptiveMaxPool2d((1, 1)),
                        nn.Conv2d(inplanes, mid_c, 1, stride=1, bias=False),
                        nn.BatchNorm2d(mid_c), nn.ReLU())
        out_c = out_c if out_c is not None else mid_c
        self.out_conv = nn.Sequential(nn.Conv2d(mid_c*(2+len(dilations)), out_c, 1, bias=False),
                                    nn.BatchNorm2d(out_c), nn.ReLU(inplace=True))
        self.conv1 = nn.Conv2d(mid_c*(2+len(dilations)), out_c, 1, bias=False)
        self._init_weight()

    def forward(self, x):
        x0 = self.global_pool(x)
        xs = [aspp(x) for aspp in self.aspps]
        x0 = F.interpolate(x0, size=xs[0].size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat([x0] + xs, dim=1)
        return self.out_conv(x)

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

添加卷积注意力模块(CBAM),这是一种用于前馈卷积神经网络的简单而有效的注意力模块。 给定一个中间特征图,CBAM模块会沿着两个独立的维度(通道和空间)依次推断注意力图,然后将注意力图与输入特征图相乘以进行自适应特征优化。 由于CBAM是轻量级的通用模块,因此可以忽略的该模块的开销而将其无缝集成到任何CNN架构中,并且可以与基础CNN一起进行端到端训练。

注意力不仅要告诉我们重点关注哪里,还要提高关注点的表示。 目标是通过使用注意机制来增加表现力,关注重要特征并抑制不必要的特征。为了强调空间和通道这两个维度上的有意义特征,依次应用通道和空间注意模块,来分别在通道和空间维度上学习关注什么、在哪里关注。此外,通过了解要强调或抑制的信息也有助于网络内的信息流动

CBAM 包含2个独立的子模块, 通道注意力模块(Channel Attention Module,CAM) 和空间注意力模块(Spartial Attention Module,SAM) ,分别进行通道与空间上的 Attention 。

通道注意力模块:通道维度不变,压缩空间维度。该模块关注输入图片中有意义的信息(分类任务就关注因为什么分成了不同类别)。
图解:将输入的feature map经过两个并行的MaxPool层和AvgPool层,将特征图从C*H*W变为C*1*1的大小,然后经过Share MLP模块,在该模块中,它先将通道数压缩为原来的1/r(Reduction,减少率)倍,再扩张到原通道数,经过ReLU激活函数得到两个激活后的结果。将这两个输出结果进行逐元素相加,再通过一个sigmoid激活函数得到Channel Attention的输出结果,再将这个输出结果乘原图,变回C*H*W的大小。

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, rotio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.sharedMLP = nn.Sequential(
            nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False), nn.ReLU(),
            nn.Conv2d(in_planes // rotio, in_planes, 1, bias=False))
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avgout = self.sharedMLP(self.avg_pool(x))
        maxout = self.sharedMLP(self.max_pool(x))
        return self.sigmoid(avgout + maxout)

空间注意力模块:空间维度不变,压缩通道维度。该模块关注的是目标的位置信息。
图解:将Channel Attention的输出结果通过最大池化和平均池化得到两个1*H*W的特征图,然后经过Concat操作对两个特征图进行拼接,通过7*7卷积变为1通道的特征图(实验证明7*7效果比3*3好),再经过一个sigmoid得到Spatial Attention的特征图,最后将输出结果乘原图变回C*H*W大小。

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        assert kernel_size in (3,7), "kernel size must be 3 or 7"
        padding = 3 if kernel_size == 7 else 1

        self.conv = nn.Conv2d(2,1,kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avgout = torch.mean(x, dim=1, keepdim=True)
        maxout, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avgout, maxout], dim=1)
        x = self.conv(x)
        return self.sigmoid(x)
class CBAM(nn.Module):
    def __init__(self, planes):
        super(cbam,self).__init__()
        self.ca = ChannelAttention(planes)# planes是feature map的通道个数
        self.sa = SpatialAttention()
     def forward(self, x):
        x = self.ca(x) * x  # 广播机制
        x = self.sa(x) * x  # 广播机制

损失和度量

在图像分割任务中,经常出现类别分布不均匀的情况,例如:工业产品的瑕疵检测、道路提取及病变区域提取等。我们可使用lovasz loss解决这个问题。
Lovasz loss基于子模损失(submodular losses)的凸Lovasz扩展,对神经网络的mean IoU损失进行优化。Lovasz loss根据分割目标的类别数量可分为两种:lovasz hinge loss和lovasz softmax loss. 其中lovasz hinge loss适用于二分类问题,lovasz softmax loss适用于多分类问题。
Jaccard index :


优化的IOU loss:
其定义是离散的loss,不能直接求导,所以无法直接用来作为loss function。为了克服这个离散的问题,本文将其做了光滑的延拓(smooth extensions),从而可以使得其作为分割网络的loss function。变形为:
目前想要优化的loss function,其自变量为网络分割结果和label不匹配的集合。将其做光滑的延拓不是一件简单的事情,更一般的说,对任意的离散函数找到其光滑的延拓很难。好在变化后的公式是submodular的,submodular的函数已经有成熟数学工具可以将其做光滑延拓,而且延拓后的函数总是凸的,这样就大大方便了优化。该数学工具即为lovasz extension
即转成具有凸解形式:

代码实现

  • 为什么用这么复杂,看起来也不简单的数学工具来对Jaccard loss进行smooth extension,直接像Dice loss那样计算Jaccard loss不行吗?

基于该想法的工作已经在16年发表了出来Optimizing Intersection-Over-Union in Deep Neural Networks for Image Segmentation,虽然本文没有与其进行比较,但作者在github中说本文对Jaccard loss光滑延拓得到的loss要比Dice loss那样简单的光滑化(连续画处理)效果好。

  • Dice loss与IOU loss哪个用于网络模型的训练比较好?

都不太好。两者都存在训练过程不稳定的问题,在和很小的情况下会得到较大的梯度,会影响正常的反向传播。一般情况下,使用两者对应的损失函数的原因是分割的真实目的是最大化这两个度量指标,而交叉熵是一种代理形式,利用了其在反向传播中易于最大化优化的特点。
所以,正常情况下是使用交叉熵损失函数来训练网络模型,用Dice或IOU系数来衡量模型的性能。因为,交叉熵损失函数得到的交叉熵值关于logits的梯度计算形式类似:p-g(p是softmax的输出结果,g是ground truth),这样的关系式自然在求梯度的时候容易的多。而Dice系数的可微形式,loss值为2pg/(p^2 + g^2)或2pg/(p+g),其关于p的梯度形式显然是比较复杂的,且在极端情况下(p,g的值都非常小时)计算得到的梯度值可能会非常大,进而会导致训练不稳定。

在本项目中采用了对称的lovasz损失,不仅考虑预测的分割和提供的掩码,还要考虑逆预测和逆掩码(否定情况的预测掩膜)。

def symmetric_lovasz(outputs, targets):
    return 0.5*(lovasz_hinge(outputs, targets) + lovasz_hinge(-outputs, 1.0 - targets))

lovasz对分割的效果出类拔萃相比bce或者dice等loss可以提升一个档次,但是有时的效果一般,猜测是优化不同的metric,不同loss带来的效果不同,也可能是数据带来的问题。

模型推理

def img2tensor(img,dtype:np.dtype=np.float32):
    if img.ndim==2 : img = np.expand_dims(img,2)
    img = np.transpose(img,(2,0,1))
    return torch.from_numpy(img.astype(dtype, copy=False))

class HuBMAPDataset(Dataset):
    def __init__(self, data):
        self.data = data
        if self.data.count != 3:
            subdatasets = self.data.subdatasets
            self.layers = []
            if len(subdatasets) > 0:
                for i, subdataset in enumerate(subdatasets, 0):
                    self.layers.append(rasterio.open(subdataset))
        self.shape = self.data.shape
        self.mask_grid = make_grid(self.data.shape, window=WINDOW, min_overlap=MIN_OVERLAP)

        
    def __len__(self):
        return len(self.mask_grid)
        
    def __getitem__(self, idx):
        x1, x2, y1, y2 = self.mask_grid[idx]
        if self.data.count == 3:
            img = data.read([1,2,3], window=Window.from_slices((x1, x2), (y1, y2)))
            img = np.moveaxis(img, 0, -1)
        else:
            img = np.zeros((WINDOW, WINDOW, 3), dtype=np.uint8)
            for i, layer in enumerate(self.layers):
                img[:,:,i] = layer.read(window=Window.from_slices((x1, x2),(y1, y2)))

        img = cv2.resize(img, (NEW_SIZE, NEW_SIZE),interpolation = cv2.INTER_AREA)
        vetices = torch.tensor([x1, x2, y1, y2])
        return img2tensor((img/255.0 - mean)/std), vetices
def Make_prediction(img, tta = True):
    pred = None
    with torch.no_grad():
        for model in models:
            p_tta = None
            p = model(img)
            p = torch.sigmoid(p).detach()
            if p_tta is None:
                p_tta = p
            else:
                p_tta += p
            if tta:
                #x,y,xy flips as TTA
                flips = [[-1],[-2],[-2,-1]]
                for f in flips:
                    imgf = torch.flip(img, f)
                    p = model(imgf)
                    p = torch.flip(p, f)
                    p_tta += torch.sigmoid(p).detach()
                p_tta /= (1+len(flips))
            if pred is None:
                pred = p_tta
            else:
                pred += p_tta
        pred /= len(models)
    return pred
WINDOW=1024
MIN_OVERLAP=300
NEW_SIZE=256
NUM_CLASSES=1
identity = rasterio.Affine(1, 0, 0, 0, 1, 0)
names, predictions = [],[]

df_sample = pd.read_csv("../input/hubmap-kidney-segmentation/sample.csv")
# df_sample = df_sample.replace(np.nan, '', regex=True)
th = 0.4   
for idx, row in tqdm(df_sample.iterrows(),total=len(df_sample)):
    imageId = row['id']
    data = rasterio.open(os.path.join(DATA_PATH, imageId+'.tiff'), transform = identity, num_threads='all_cpus')
    preds = np.zeros(data.shape, dtype=np.uint8)
    dataset = HuBMAPDataset(data)
    dataloader = DataLoader(dataset, batch_size, num_workers=0, shuffle=False, pin_memory=True)
    for i, (img, vertices) in enumerate(dataloader):
        img = img.to(DEVICE)
        pred = Make_prediction(img)
        pred = pred.squeeze().cpu().numpy()
        vertices = vertices.numpy()
        for p, vert in zip(pred, vertices):
            x1, x2, y1, y2 = vert
            p = cv2.resize(p, (WINDOW, WINDOW))
            preds[x1:x2,y1:y2] +=  (p > th).astype(np.uint8)
    preds = (preds > th).astype(np.uint8)
    #convert to rle
    rle = rle_encode_less_memory(preds)
    names.append(imageId)
    predictions.append(rle)
    del preds, dataset, dataloader
    gc.collect()
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 204,732评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 87,496评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 151,264评论 0 338
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,807评论 1 277
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,806评论 5 368
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,675评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,029评论 3 399
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,683评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 41,704评论 1 299
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,666评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,773评论 1 332
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,413评论 4 321
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,016评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,978评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,204评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,083评论 2 350
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,503评论 2 343

推荐阅读更多精彩内容