Tensorflow可以使用feed_dict的方式输入数据,但是效率比较低。Tensorflow提供了一个内置函数可以利用输入管道的方式输入数据。
tf.data.Dataset()
接收numpy和tensor类型的数据
Dataset
Dataset()可以接收多个输入,当数据由特征和标签组成时,使用起来及其方便。
image_paths = ['特征路径']
label_paths = ['标签路径']
dataset = tf.data.Dataset.from_tensor_slices((image_paths, label_paths))
结果:
>>b'('特征路径', '标签路径')'
当输入为string
时,使用form_tensor_slices()
得到的结果是bytes
类型,可能需要decode('utf-8')
。
除了加载数据方便外,dataset
还可以做数据转换。
dataset.map()
接收一个函数,Dataset
中的每个元素都会被当作这个函数的输入,并将函数返回值作为新的Dataset
。
dataset = tf.data.Dataset.from_tensor_slices((image_paths, label_paths))
dataset = dataset.map(lambda image_path, label_path:
tuple(tf.py_func(input_parser, [image_path, label_path], [tf.float32, tf.float32])))
使用tf.py_func()
将input_parser()
变为一个tensorflow内置函数,第二个参数表示输入数据,第三个参数表示输出数据。
在使用dataset
时,先要创建一个迭代器,然后使用get_next()
获取数据。
iter = dataset.make_initializable_iterator()
el = iter.get_next()
with tf.Session() as sess:
sess.run(iter.initializer)
print(sess.run(el))
如果使用多个print()
时,iter
可以自动进行迭代。
加载数据时的小技巧
对于V-Net而言,当训练网络时,必须要提供一个和输入大小相等的tensor作为标签,这个可以直接加载特征和标签来完成。当为非训练状态时,可以生成一个和原特征大小相同的label进行占位。
if train:
label = read_image(label_path.decode("utf-8"))
else:
label = sitk.Image(image.GetSize(),sitk.sitkUInt32)
label.SetOrigin(image.GetOrigin())
label.SetSpacing(image.GetSpacing())
在SimpleITK中,图像作为物理对象占据一个空间有界区域,通过上述方法生成一个和image相同大小的label。
关于SimpleITk可参:http://insightsoftwareconsortium.github.io/SimpleITK-Notebooks/Python_html/03_Image_Details.html
数据增强的方法
按照作者的说法:医学图像通常比较耗费内存,可以对图像进行0-255的标准化,对于较小的输入image可以进行Padding,还可以从3D图形中随机选择一个区域作为网络输入,还可以对图像添加噪声。
Normalization
resacleFilter = sitk.RescaleIntensityImageFilter()
resacleFilter.SetOutputMaximum(255)
resacleFilter.SetOutputMinimum(0)
image = resacleFilter.Execute(image)
RandomCrop
随机从输入图像中采集一个zone,通常可以用来进行数据增强(一般只用于训练阶段)。
先判断zone和image的大小,如果zone的size小于image的size,就将下标置为0~image_size-zone_size
。这里要注意的一点就是在对label进行randomCrop时,每次必须保证将包含标签的zone提取出来。
while not contain_label:
# get the start crop coordinate in ijk
if size_old[0] <= size_new[0]:
start_i = 0
else:
start_i = np.random.randint(0, size_old[0]-size_new[0])
if size_old[1] <= size_new[1]:
start_j = 0
else:
start_j = np.random.randint(0, size_old[1]-size_new[1])
if size_old[2] <= size_new[2]:
start_k = 0
else:
start_k = np.random.randint(0, size_old[2]-size_new[2])
roiFilter.SetIndex([start_i,start_j,start_k])
label_crop = roiFilter.Execute(label)
statFilter = sitk.StatisticsImageFilter()
statFilter.Execute(label_crop)
# will iterate until a sub volume containing label is extracted
# pixel_count = seg_crop.GetHeight()*seg_crop.GetWidth()*seg_crop.GetDepth()
# if statFilter.GetSum()/pixel_count<self.min_ratio:
if statFilter.GetSum()<self.min_pixel:
contain_label = self.drop(self.drop_ratio) # has some probabilty to contain patch with empty label
else:
contain_label = True
image_crop = roiFilter.Execute(image)
训练
此处作者使用了PReLU,也就是Parametric Leaky Relu,是何凯明提出的一种改进ReLU。表达式:
y = max(0, x) + a * min(0, x)
其中的a是可学习参数,当a为非零较小数时,相当于LeakyReLU;当a为零时,等价于ReLU。