Unet
一、原理:
Unet网络分为两个部分:
第一部分:特征提取。上图中的左侧,有点类似VGG网络。由简单的卷积、池化下采样。图中采用的是3*3和1*1的卷积核进行卷积操作,3*3用于提取特征,1*1用于改变纬度。另外每经过一次池化,就变成另一个尺度,包括input的图像总计5个尺度。
第二部分:上采样及特征融合。上图中的右侧。此处的上采样即通过转置卷积进行。然后进行特征融合,但是此处的特征融合和FCN的方法不一样(见下方)。但是融合之前要将其crop。这里的融合也是拼接。
特征融合:
1.Unet:拼接。采用将特征在channel维度拼接在一起,形成更厚的特征。对应于 TensorFlow的tf.concat()函数,比较占显存。
2. FCN:对应点相加,并不形成更厚的特征,对应于TensorFlow中的tf.add()函数。
Unet网络的输入与输出部分:
Unet最开始是用来设计在医学图像中的细胞分割的,但是分割时候不可能将原图输入网络,所以必须切成一张一张的小patch,在切成小patch的时候,Unet由于网络结构原因适合有overlap的切图,可以看图,红框是要分割区域,但是在切图时要包含周围区域,overlap另一个重要原因是周围overlap部分可以为分割区域边缘部分提供文理等信息。可以看黄框的边缘,分割结果并没有受到切成小patch而造成分割情况不好。在后续使用的时候,由于本人使用的是512*512大小的图片,所以这步就不需要进行。
优点:
(1):多次下采样,提供多个尺度,实现了网络对图像特征的多尺度特征识别;
(2):在上采样部分,进行了特征融合,并且是将多个不同的尺度特征融合。这一层的转置卷积后与上一层同一个尺度的特征提取卷积的输出进行融合。想对比FCN仅在最后一层进行融合。
代码分享:
#导入相应模块
from __future__import print_function
import os
import datetime
import numpyas np
from keras.modelsimport Model
from keras.layersimport Input, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose, AveragePooling2D, Dropout, BatchNormalization
from keras.optimizersimport Adam
from keras.layers.convolutionalimport UpSampling2D, Conv2D
from keras.callbacksimport ModelCheckpoint
from kerasimport backendas K
from keras.layers.advanced_activationsimport LeakyReLU, ReLU
import cv2
PIXEL =图片大小
BATCH_SIZE = batch_size
lr =学习率
EPOCH =训练epoch
train_img_CHANNEL =训练图片的纬度
train_mask_CHANNEL =训练图片mask的纬度
train_NUM =训练图片数量
train_img ='训练集image路径'
train_mask ='训练集mask路径'
test_img ='测试集image路径'
test_mask ='测试集mask路径'
#训练generator,返回的X、Y的纬度是4维的
def train_generator(train_img, train_mask,BATCH_SIZE):
while 1:
X_train_files = os.listdir(train_img)
Y_train_files = os.listdir(train_mask)
a = (np.arange(1, train_NUM))
X = []
Y = []
for iin range(BATCH_SIZE):
index = np.random.choice(a)
img = cv2.imread(train_img + X_train_files[index], 1)
img = np.array(img).reshape(PIXEL, PIXEL, train_img_CHANNEL)
X.append(img)
img1 = cv2.imread(train_mask + Y_train_files[index], 1)
img1 = np.array(img1).reshape(PIXEL, PIXEL, train_mask_CHANNEL);
Y.append(img1)
X = np.array(X)
Y = np.array(Y)
yield X, Y
#测试generator,返回的X、Y的纬度是4维的
def test_generator(test_img, test_mask,BATCH_SIZE):
while 1:
X_test_files = os.listdir(test_img)
Y_test_files = os.listdir(test_mask)
a = (np.arange(1, train_NUM))
X = []
Y = []
for iin range(BATCH_SIZE):
index = np.random.choice(a)
img = cv2.imread(test_img + X_test_files[index], 1)
img = np.array(img).reshape(PIXEL, PIXEL, train_img_CHANNEL)
X.append(img)
img1 = cv2.imread(test_mask + Y_test_files[index], 1)
img1 = np.array(img1).reshape(PIXEL, PIXEL, train_mask_CHANNEL);
Y.append(img1)
X = np.array(X)
Y = np.array(Y)
yield X, Y
#搭建模型
inputs = Input((PIXEL, PIXEL, 3))
conv1 = Conv2D(8, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs)
pool1 = AveragePooling2D(pool_size=(2, 2))(conv1)# 16
conv2 = BatchNormalization(momentum=0.99)(pool1)
conv2 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv2)
conv2 = BatchNormalization(momentum=0.99)(conv2)
conv2 = Conv2D(64, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv2)
conv2 = Dropout(0.02)(conv2)
pool2 = AveragePooling2D(pool_size=(2, 2))(conv2)# 8
conv3 = BatchNormalization(momentum=0.99)(pool2)
conv3 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)
conv3 = BatchNormalization(momentum=0.99)(conv3)
conv3 = Conv2D(128, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)
conv3 = Dropout(0.02)(conv3)
pool3 = AveragePooling2D(pool_size=(2, 2))(conv3)# 4
conv4 = BatchNormalization(momentum=0.99)(pool3)
conv4 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
conv4 = BatchNormalization(momentum=0.99)(conv4)
conv4 = Conv2D(256, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
conv4 = Dropout(0.02)(conv4)
pool4 = AveragePooling2D(pool_size=(2, 2))(conv4)
conv5 = BatchNormalization(momentum=0.99)(pool4)
conv5 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv5)
conv5 = BatchNormalization(momentum=0.99)(conv5)
conv5 = Conv2D(512, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv5)
conv5 = Dropout(0.02)(conv5)
pool4 = AveragePooling2D(pool_size=(2, 2))(conv4)
pool4 = AveragePooling2D(pool_size=(2, 2))(pool3)# 2
pool5 = AveragePooling2D(pool_size=(2, 2))(pool4)# 1
conv6 = BatchNormalization(momentum=0.99)(pool5)
conv6 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)
conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)
up7 = (UpSampling2D(size=(2, 2))(conv7))# 2
conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up7)
merge7 = concatenate([pool4, conv7], axis=3)
conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge7)
up8 = (UpSampling2D(size=(2, 2))(conv8))# 4
conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up8)
merge8 = concatenate([pool3, conv8], axis=3)
conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge8)
up9 = (UpSampling2D(size=(2, 2))(conv9))# 8
conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up9)
merge9 = concatenate([pool2, conv9], axis=3)
conv10 = Conv2D(32, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge9)
up10 = (UpSampling2D(size=(2, 2))(conv10))# 16
conv10 = Conv2D(32, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up10)
conv11 = Conv2D(16, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv10)
up11 = (UpSampling2D(size=(2, 2))(conv11))# 32
conv11 = Conv2D(8, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up11)
conv12 = Conv2D(3, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv11)
model = Model(input=inputs, output=conv12)
print(model.summary())
model.compile(optimizer=Adam(lr=1e-3), loss='mse', metrics=['accuracy'])
history = model.fit_generator(train_generator(train_img, train_mask,BATCH_SIZE),
steps_per_epoch=600, nb_epoch=EPOCH,validation_data=test_generator(test_img, test_mask,BATCH_SIZE),nb_val_samples=20)
end_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
model.save('模型保存路径,h5格式')
mse = np.array((history.history['loss']))
np.save('历史loss保存路径,npy格式', mse)