示例数据为Feta-Head-Circumference
下载地址: https://zenodo.org/record/1322001#.YTHD2Y4zaUl
模型结构 U-Net
扩展阅读:https://github.com/pranjalrai-iitd/Fetal-head-segmentation-and-circumference-measurement-from-ultrasound-images
引入包
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.pylab as plab
from PIL import Image, ImageDraw
import numpy as np
import pandas as pd
import os
import copy
import collections
from sklearn.model_selection import ShuffleSplit
from scipy import ndimage as ndi
from skimage.segmentation import mark_boundaries
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split, Subset
import torchvision.transforms as transforms
from torchvision import models,utils, datasets
import torch.nn.functional as F
from torchvision.transforms.functional import to_tensor, to_pil_image
from torch import optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from albumentations import (HorizontalFlip, VerticalFlip, Compose, Resize,)
from torchsummary import summary
# CPU or GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# dataloader里的多进程用到num_workers
workers = 0 if os.name=='nt' else 4
数据初探
# 数据地址
path_train="./data/sos/training_set/"
imgs_list = [pp for pp in os.listdir(path_train) if "Annotation" not in pp and pp.endswith('.png')]
annts_list = [pp for pp in os.listdir(path_train) if "Annotation" in pp and pp.endswith('.png')]
print("number of images:", len(imgs_list))
print("number of annotations:", len(annts_list))
"""
number of images: 999
number of annotations: 999
"""
# 查看一些图片
np.random.seed(2019)
rnd_imgs = np.random.choice(imgs_list, 4)
print('The random images are: ', rnd_imgs)
# The random images are: ['166_2HC.png' '434_HC.png' '244_HC.png' '826_3HC.png']
# 可视化图片
def show_img_mask(img, mask):
if torch.is_tensor(img):
img = to_pil_image(img)
mask = to_pil_image(mask)
img_mask = mark_boundaries(
np.array(img),
np.array(mask),
outline_color=(0,1,0),
color=(0,1,0)
)
plt.imshow(img_mask)
# 画图查看图片
for fn in rnd_imgs:
img_path = os.path.join(path_train, fn)
annt_path = img_path.replace(".png", "_Annotation.png")
img = Image.open(img_path)
annt_edges = Image.open(annt_path)
mask = ndi.binary_fill_holes(annt_edges)
plt.figure(figsize=(10, 10))
plt.subplot(1, 3, 1)
plt.imshow(img, cmap="gray")
plt.subplot(1, 3, 2)
plt.imshow(mask, cmap="gray")
plt.subplot(1, 3, 3)
show_img_mask(img, mask)
构建Dataset,Transforms,DataLoader
# transforms
h, w = 128, 192
transform_train = Compose([ Resize(h, w),
HorizontalFlip(p=0.5),
VerticalFlip(p=0.5),
])
transform_val = Resize(h, w)
# 创建datasets
class FetalDataset(Dataset):
def __init__(self, path_data, transform=None):
imgs_list = [pp for pp in os.listdir(path_train) if "Annotation" not in pp and pp.endswith('.png')]
annts_list = [pp for pp in os.listdir(path_train) if "Annotation" in pp and pp.endswith('.png')]
self.path_imgs = [os.path.join(path_data, fn) for fn in imgs_list]
self.path_annts = [path_img.replace('.png', '_Annotation.png') for path_img in self.path_imgs]
self.transform = transform
def __len__(self):
return len(self.path_imgs)
def __getitem__(self, idx):
path_img = self.path_imgs[idx]
image = Image.open(path_img)
path_annt = self.path_annts[idx]
annt_edges = Image.open(path_annt)
mask = ndi.binary_fill_holes(annt_edges)
image = np.array(image)
mask = mask.astype('uint8')
if self.transform:
augmented = self.transform(image=image, mask=mask)
image = augmented['image']
mask = augmented['mask']
image = to_tensor(image)
mask = 255 * to_tensor(mask)
return image, mask
# 实例化dataset
fetal_train_ds = FetalDataset(path_train, transform=transform_train)
fetal_val_ds = FetalDataset(path_train, transform=transform_val)
# print(len(fetal_train_ds))
# print(len(fetal_val_ds))
# 数据分割为训练验证集
sss = ShuffleSplit(n_splits=1, test_size=0.2, random_state=0)
indices = range(len(fetal_train_ds))
for train_index, val_index in sss.split(indices):
train_ds = Subset(fetal_train_ds, train_index)
print(len(train_ds))
val_ds = Subset(fetal_val_ds, val_index)
print(len(val_ds))
plt.figure(figsize=(5,5))
for img,mask in train_ds:
show_img_mask(img,mask)
break
# 构建dataloader
train_dl = DataLoader(train_ds, batch_size=8, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=16, shuffle=False)
# 打印出数据查看
for img, mask in train_dl:
print(img.shape, img.dtype)
# torch.Size([8, 1, 128, 192]) torch.float32
print(mask.shape, mask.dtype)
# torch.Size([8, 1, 128, 192]) torch.float32
break
"""
799
200
torch.Size([8, 1, 128, 192]) torch.float32
torch.Size([8, 1, 128, 192]) torch.float32
"""
模型定义
# 定义模型 encoder-decoder model U-Net
class SegNet(nn.Module):
def __init__(self, params):
super(SegNet, self).__init__()
C_in, H_in, W_in = params['input_shape']
init_f = params['initial_filters']
num_outputs = params['num_outputs']
# 定义各卷积层
self.conv1 = nn.Conv2d(C_in, init_f, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(init_f, 2*init_f, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(2*init_f, 4*init_f, kernel_size=3, stride=1, padding=1)
self.conv4 = nn.Conv2d(4*init_f, 8*init_f, kernel_size=3, stride=1, padding=1)
self.conv5 = nn.Conv2d(8*init_f, 16*init_f, kernel_size=3, stride=1, padding=1)
# 定义上采样层
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv_up1 = nn.Conv2d(16*init_f, 8*init_f, kernel_size=3, stride=1, padding=1)
self.conv_up2 = nn.Conv2d(8*init_f, 4*init_f, kernel_size=3, stride=1, padding=1)
self.conv_up3 = nn.Conv2d(4*init_f, 2*init_f, kernel_size=3, stride=1, padding=1)
self.conv_up4 = nn.Conv2d(2*init_f, init_f, kernel_size=3, stride=1, padding=1)
self.conv_out = nn.Conv2d(init_f, num_outputs, kernel_size=3, padding=1)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv3(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv4(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv5(x))
x = self.upsample(x)
x = F.relu(self.conv_up1(x))
x = self.upsample(x)
x = F.relu(self.conv_up2(x))
x = self.upsample(x)
x = F.relu(self.conv_up3(x))
x = self.upsample(x)
x = F.relu(self.conv_up4(x))
x = self.conv_out(x)
return x
params_model={
"input_shape": (1, 128, 192),
"initial_filters": 16,
"num_outputs": 1,
}
model = SegNet(params_model).to(device)
# print(model)
# """
# SegNet(
# (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (conv4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (conv5): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (upsample): Upsample(scale_factor=2.0, mode=bilinear)
# (conv_up1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (conv_up2): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (conv_up3): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (conv_up4): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (conv_out): Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# )
# """
# 查看模型信息
summary(model, input_size=(1, 128, 192))
# ----------------------------------------------------------------
# Layer (type) Output Shape Param #
# ================================================================
# Conv2d-1 [-1, 16, 128, 192] 160
# Conv2d-2 [-1, 32, 64, 96] 4,640
# Conv2d-3 [-1, 64, 32, 48] 18,496
# Conv2d-4 [-1, 128, 16, 24] 73,856
# Conv2d-5 [-1, 256, 8, 12] 295,168
# Upsample-6 [-1, 256, 16, 24] 0
# Conv2d-7 [-1, 128, 16, 24] 295,040
# Upsample-8 [-1, 128, 32, 48] 0
# Conv2d-9 [-1, 64, 32, 48] 73,792
# Upsample-10 [-1, 64, 64, 96] 0
# Conv2d-11 [-1, 32, 64, 96] 18,464
# Upsample-12 [-1, 32, 128, 192] 0
# Conv2d-13 [-1, 16, 128, 192] 4,624
# Conv2d-14 [-1, 1, 128, 192] 145
# ================================================================
# Total params: 784,385
# Trainable params: 784,385
# Non-trainable params: 0
# ----------------------------------------------------------------
# Input size (MB): 0.09
# Forward/backward pass size (MB): 22.88
# Params size (MB): 2.99
# Estimated Total Size (MB): 25.96
# ----------------------------------------------------------------
定义损失函数 Dice metric
Dice系数, 根据 Lee Raymond Dice命名,是一种集合相似度度量函数,通常用于计算两个样本的相似度(值范围为 [0, 1]):
|X⋂Y| - X 和 Y 之间的交集;|X| 和 |Y| 分别表示 X 和 Y 的元素个数. 其中,分子中的系数 2,是因为分母存在重复计算 X 和 Y 之间的共同元素的原因.
Dice 系数差异函数(Dice loss):
## 定义损失函数
# Dice系数是一种集合相似度度量函数,通常用于计算两个样本的相似度,取值范围在[0,1]
# https://blog.csdn.net/JMU_Ma/article/details/97533768 , https://zhuanlan.zhihu.com/p/86704421
def dice_loss(pred, target, smooth = 1e-5):
intersection = (pred * target).sum(dim=(2,3))
union = pred.sum(dim=(2,3)) + target.sum(dim=(2,3))
dice = 2.0 * (intersection + smooth) / (union+ smooth)
loss = 1.0 - dice
return loss.sum(), dice.sum()
def loss_func(pred, target):
bce = F.binary_cross_entropy_with_logits(pred, target, reduction='sum')
pred = torch.sigmoid(pred)
dlv, _ = dice_loss(pred, target)
loss = bce + dlv
return loss
模型设计及训练
定义几个计算辅助函数
# 取得学习率
def get_lr(opt):
for param_group in opt.param_groups:
return param_group['lr']
# 定义评价函数
def metrics_batch(pred, target):
pred = torch.sigmoid(pred)
_, metric = dice_loss(pred, target)
return metric
# 各批次损失计算
def loss_batch(loss_func, output, target, opt=None):
loss = loss_func(output, target)
with torch.no_grad():
pred = torch.sigmoid(output)
_, metric_b = dice_loss(pred, target)
if opt is not None:
opt.zero_grad()
loss.backward()
opt.step()
return loss.item(), metric_b
# 各轮次计算
def loss_epoch(model,loss_func,dataset_dl,sanity_check=False,opt=None):
running_loss = 0.0
running_metric = 0.0
len_data = len(dataset_dl.dataset)
for xb, yb in dataset_dl:
xb = xb.to(device)
yb = yb.to(device)
output = model(xb)
loss_b, metric_b = loss_batch(loss_func, output, yb, opt)
running_loss += loss_b
if metric_b is not None:
running_metric += metric_b
if sanity_check is True:
break
loss = running_loss / float(len_data)
metric = running_metric / float(len_data)
return loss, metric
模型定义
# 定义模型 encoder-decoder model U-Net
class SegNet(nn.Module):
def __init__(self, params):
super(SegNet, self).__init__()
C_in, H_in, W_in = params['input_shape']
init_f = params['initial_filters']
num_outputs = params['num_outputs']
# 定义各卷积层
self.conv1 = nn.Conv2d(C_in, init_f, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(init_f, 2*init_f, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(2*init_f, 4*init_f, kernel_size=3, stride=1, padding=1)
self.conv4 = nn.Conv2d(4*init_f, 8*init_f, kernel_size=3, stride=1, padding=1)
self.conv5 = nn.Conv2d(8*init_f, 16*init_f, kernel_size=3, stride=1, padding=1)
# 定义上采样层
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv_up1 = nn.Conv2d(16*init_f, 8*init_f, kernel_size=3, stride=1, padding=1)
self.conv_up2 = nn.Conv2d(8*init_f, 4*init_f, kernel_size=3, stride=1, padding=1)
self.conv_up3 = nn.Conv2d(4*init_f, 2*init_f, kernel_size=3, stride=1, padding=1)
self.conv_up4 = nn.Conv2d(2*init_f, init_f, kernel_size=3, stride=1, padding=1)
self.conv_out = nn.Conv2d(init_f, num_outputs, kernel_size=3, padding=1)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv3(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv4(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv5(x))
x = self.upsample(x)
x = F.relu(self.conv_up1(x))
x = self.upsample(x)
x = F.relu(self.conv_up2(x))
x = self.upsample(x)
x = F.relu(self.conv_up3(x))
x = self.upsample(x)
x = F.relu(self.conv_up4(x))
x = self.conv_out(x)
return x
params_model={
"input_shape": (1, 128, 192),
"initial_filters": 16,
"num_outputs": 1,
}
model = SegNet(params_model).to(device)
# print(model)
# """
# SegNet(
# (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (conv4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (conv5): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (upsample): Upsample(scale_factor=2.0, mode=bilinear)
# (conv_up1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (conv_up2): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (conv_up3): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (conv_up4): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (conv_out): Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# )
# """
# 查看模型信息
summary(model, input_size=(1, 128, 192))
# ----------------------------------------------------------------
# Layer (type) Output Shape Param #
# ================================================================
# Conv2d-1 [-1, 16, 128, 192] 160
# Conv2d-2 [-1, 32, 64, 96] 4,640
# Conv2d-3 [-1, 64, 32, 48] 18,496
# Conv2d-4 [-1, 128, 16, 24] 73,856
# Conv2d-5 [-1, 256, 8, 12] 295,168
# Upsample-6 [-1, 256, 16, 24] 0
# Conv2d-7 [-1, 128, 16, 24] 295,040
# Upsample-8 [-1, 128, 32, 48] 0
# Conv2d-9 [-1, 64, 32, 48] 73,792
# Upsample-10 [-1, 64, 64, 96] 0
# Conv2d-11 [-1, 32, 64, 96] 18,464
# Upsample-12 [-1, 32, 128, 192] 0
# Conv2d-13 [-1, 16, 128, 192] 4,624
# Conv2d-14 [-1, 1, 128, 192] 145
# ================================================================
# Total params: 784,385
# Trainable params: 784,385
# Non-trainable params: 0
# ----------------------------------------------------------------
# Input size (MB): 0.09
# Forward/backward pass size (MB): 22.88
# Params size (MB): 2.99
# Estimated Total Size (MB): 25.96
# ----------------------------------------------------------------
模型训练与验证
模型训练主函数
# 训练验证主函数
def train_val(model, params):
num_epochs = params["num_epochs"]
loss_func = params["loss_func"]
opt = params["optimizer"]
train_dl = params["train_dl"]
val_dl = params["val_dl"]
sanity_check = params["sanity_check"]
lr_scheduler = params["lr_scheduler"]
path2weights = params["path2weights"]
loss_history = {
"train": [],
"val": []}
metric_history = {
"train": [],
"val": []}
best_model_wts = copy.deepcopy(model.state_dict())
best_loss = float('inf')
for epoch in range(num_epochs):
current_lr = get_lr(opt)
print('Epoch {}/{}, current lr={}'.format(epoch, num_epochs - 1, current_lr))
model.train()
train_loss, train_metric = loss_epoch(model,loss_func,train_dl,sanity_check,opt)
loss_history["train"].append(train_loss)
metric_history["train"].append(train_metric)
model.eval()
with torch.no_grad():
val_loss, val_metric = loss_epoch(model,loss_func,val_dl,sanity_check)
loss_history["val"].append(val_loss)
metric_history["val"].append(val_metric)
if val_loss < best_loss:
best_loss = val_loss
best_model_wts = copy.deepcopy(model.state_dict())
torch.save(model.state_dict(), path2weights)
print("Copied best model weights!")
lr_scheduler.step(val_loss)
if current_lr != get_lr(opt):
print("Loading best model weights!")
model.load_state_dict(best_model_wts)
print("train loss: %.6f, accuracy: %.2f" %(train_loss, 100*train_metric))
print("val loss: %.6f, accuracy: %.2f" %(val_loss, 100*val_metric))
print("-"*10)
model.load_state_dict(best_model_wts)
return model, loss_history, metric_history
模型训练
# 优化函数及学习率更新策略
opt = optim.Adam(model.parameters(), lr=3e-4)
lr_scheduler = ReduceLROnPlateau(opt, mode='min',factor=0.5, patience=20,verbose=1)
path_models = "./models/sos/"
if not os.path.exists(path_models):
os.mkdir(path_models)
params_train={
"num_epochs": 10,
"optimizer": opt,
"loss_func": loss_func,
"train_dl": train_dl,
"val_dl": val_dl,
"sanity_check": False,
"lr_scheduler": lr_scheduler,
"path2weights": path_models+"weights.pt",
}
model, loss_hist, metric_hist = train_val(model,params_train)
可视化结果
num_epochs=params_train["num_epochs"]
plt.title("Train-Val Loss")
plt.plot(range(1,num_epochs+1),loss_hist["train"],label="train")
plt.plot(range(1,num_epochs+1),loss_hist["val"],label="val")
plt.ylabel("Loss")
plt.xlabel("Training Epochs")
plt.legend()
plt.show()
# plot accuracy progress
plt.title("Train-Val Accuracy")
plt.plot(range(1,num_epochs+1),metric_hist["train"],label="train")
plt.plot(range(1,num_epochs+1),metric_hist["val"],label="val")
plt.ylabel("Accuracy")
plt.xlabel("Training Epochs")
plt.legend()
plt.show()
部署测试
# 部署并对测试数据进行测试验证
# 部署前需要加载model的网络结构,这里因为前面model已存在,所以未实例化
np.random.seed(2019)
path_test = './data/sos/test_set/'
imgs_list = [pp for pp in os.listdir(path_test) if "Annotation" not in pp]
rnd_imgs = np.random.choice(imgs_list, 4)
print(rnd_imgs)
model_weights_path = './models/sos/weights.pt'
model.load_state_dict(torch.load(model_weights_path))
model.eval()
for fn in rnd_imgs:
path_img = os.path.join(path_test, fn)
img = Image.open(path_img)
img = img.resize((w,h))
img_t = to_tensor(img).unsqueeze(0).to(device)
pred = model(img_t)
pred = torch.sigmoid(pred)[0]
mask_pred = (pred[0]>=0.5)
plt.figure(figsize=(10, 10))
plt.subplot(1, 3, 1)
plt.imshow(img, cmap="gray")
plt.subplot(1, 3, 2)
plt.imshow(mask_pred.cpu(), cmap="gray")
plt.subplot(1, 3, 3)
show_img_mask(img, mask_pred.cpu())