1. 什么是 One-hot 编码
最直观的理解就是,比如说现在有三个类别 A、B、C,它们对应的标签值分别为 [1, 2, 3],如果对这三个类别使用One-hot编码,得到的结果则是,[[1, 0, 0], [0, 1, 0], [0, 0, 1]],相当于:
- 1 被编码为 1 0 0
- 2 被编码为 0 1 0
- 3 被编码为 0 0 1
2. 为什么要对数据进行 One-hot 编码
在分割任务中,网络模型最后的输出shape为[N, C, H, W] (以pytoch为例, 其中N为batch_size, C为预测的类别数),而我们给的的gt(ground truth)的shape一般为[H, W, 3](彩色图或rgb图)或[H, W](灰度图)。
假设我们现在的分割任务里面有5个目标需要分割,给定的gt是彩色的。则网络模型最后的输出shape为 [N, 5, H, W],这和gt的shape不匹配,在训练的时候它们两者之间不能进行损失值计算。因此,就需要使用One-hot编码对gt进行编码,将其编码为[H, W, 5],最后再对维度进行transpose即可。
编码前和编码后的变化类似图中所示(上图对应编码前,下图对应编码后)。
3.代码实现
3.1 方法一
mask_to_onehot
用来将标签进行one-hot,onehot_to_mask
用来恢复one-hot,在可视化的时候使用。
def mask_to_onehot(mask, palette):
"""
Converts a segmentation mask (H, W, C) to (H, W, K) where the last dim is a one
hot encoding vector, C is usually 1 or 3, and K is the number of class.
"""
semantic_map = []
for colour in palette:
equality = np.equal(mask, colour)
class_map = np.all(equality, axis=-1)
semantic_map.append(class_map)
semantic_map = np.stack(semantic_map, axis=-1).astype(np.float32)
return semantic_map
def onehot_to_mask(mask, palette):
"""
Converts a mask (H, W, K) to (H, W, C)
"""
x = np.argmax(mask, axis=-1)
colour_codes = np.array(palette)
x = np.uint8(colour_codes[x.astype(np.uint8)])
return x
方法一在使用的时候需要先定义好颜色表palette(根据自己的数据集来定义就行了)。下面演示两个例子。
假设gt是灰度图,需要分割两个目标(正常器官和肿瘤)(加上背景就是3分类任务),正常器官的灰度值为128,肿瘤的灰度值为255, 背景的灰度值为0。
palette = [[0], [128], [255]] # 里面值的顺序不是固定的,可以按自己的要求来
# 注意:灰度图的话要确保 gt的 shape = [H, W, 1],该函数实在最后的通道维上进行映射
# 如果加载后的gt的 shape = [H, W],则需要进行通道的扩维
gt_onehot = mask_to_onehot(gt, palette) # one-hot 后 gt的shape=[H, W, 3]
假设gt彩色图,需要分割5个目标(加上背景就是6分类任务),颜色值如下。 和灰度图的处理方法类似。
palette = [[0, 0, 0], [192, 224, 224], [128, 128, 64], [0, 192, 128], [128, 128, 192], [128, 128, 0]]
gt_onehot = mask_to_onehot(gt, palette) # one-hot 后 gt的shape=[H, W, 6]
3.1 方法二
为了以示区别,名字不要起的一样。
def mask2onehot(mask, num_classes):
"""
Converts a segmentation mask (H,W) to (K,H,W) where the last dim is a one
hot encoding vector
"""
_mask = [mask == i for i in range(num_classes)]
return np.array(_mask).astype(np.uint8)
def onehot2mask(mask):
"""
Converts a mask (K, H, W) to (H,W)
"""
_mask = np.argmax(mask, axis=0).astype(np.uint8)
return _mask
用法:如果gt是灰度图,如上面的例子,用起来就比较简单。
# 需要先指定每个类别的颜色值对应的标签
# 注意: 第一类从0开始,而不是从1开始
label2trainid = {0: 0, 128: 1, 255: 2}
gt_copy = gt.copy()
# 这一步相当于把
for k, v in label2trainid.items():
gt_copy[gt == k] = v
gt_with_trainid = gt_copy.astype(np.uint8)
gt_onehot = mask2onehot(gt_with_trainid, 3) # one-hot 后 gt的shape=[3, H, W]
如果gt是彩色图,要先把rgb颜色值映射为标签,再进行one-hot编码,相对来说就比较繁琐了。直接用方法一就行了。
二分类和多分类基本差不多,二分类的标签图像像素值处理成0和1组成的矩阵,多分类(N类)的标签图像处理成N层0和1组成的矩阵,即one-hot编码。二分类最后一层的激活函数activation是sigmoid函数,多分类的则是softmax函数。然后对应的损失函数loss分别是binary_crossentropy和categorical_crossentropy。其他的包括基本原理是相同的。
作者:馨意
链接:https://www.zhihu.com/question/319894290/answer/650175752
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。