名称:Sub-Image Anomaly Detection with Deep Pyramid Correspondences
SPADE是一种通过特征对比的方法进行异常检测的算法,主要核心是通过K近邻进行检索(kNN)。基于KNN的异常检测一般只能区分整体特性,无法精确得到缺陷的位置。本文提出了一种利用KNN和多尺度特征的方法来进行异常的缺陷检测与定位。
- 优点
不需要训练,只要有正样本图像就行。 - 缺点
需要存储所有训练集的特征,对于内存的需求很高。
SPADE整个过程分为3部分:图像深度特征提取、K近邻正常图像检索和特征金字塔像素对齐。
1.图像深度特征提取
就是使用一个在imagenet上预训练过的模型进行特征提取,论文中使用的是pytorch框架自带的wide_resnet50_2,对layer1,layer2,layer3,avepool层的结果进行了输出。
这个步骤会把训练集中的所有图像都做一遍特征提取,然后分别把各个层提取出来的特征存储起来,等需要用的时候再全部载入内存(所以如果数据集很大的话,对内存的需求就很高)。
2.K近邻正常图像检索
这个步骤是在整图层面上判定这个图像有没有异常,但不会告诉你异常具体在那个位置。主要是使用上一步avepool层的输出特征,分别测试图像的avepool层的输出特征分别和训练集中avepool层的输出特征计算欧式近距离,然后再取距离最近的K个图像,作为训练集中与测试图像最接近的K个图像。代码中使用topk进行选取,参数largest=False代表降序排列。
# calculate distance matrix
dist_matrix = calc_dist_matrix(torch.flatten(test_outputs['avgpool'], 1),
torch.flatten(train_outputs['avgpool'], 1))
# select K nearest neighbor and take average
topk_values, topk_indexes = torch.topk(dist_matrix, k=args.top_k, dim=1, largest=False)
3.特征金字塔像素对齐
这个步骤是用来确定异常在图像的具体哪个位置,以测试图像的layer1特征为例,layer1上的各个像素都会与步骤2中筛选出来的K个图像的layer1上的像素做欧式距离计算,然后输出2者之间最短的距离,遍历整张特征图就能得到测试图像的layer1上的特征与筛选出来的K个图像的layer1上的特征在像素层面的最短距离,然后layer2,layer3特分别做相同的计算,由于layer1,layer2,layer3他们的特征图尺寸不一样,所以会将他们reszie到一样的尺寸再通道拼接在一起,之后在通道层面求平均,就能得到mask图,作者还对这个mas图做了一个高斯滤波用于平滑图像。代码里面除以100,主要是用来分段计算,所有像素的特征一起计算欧氏距离,内存容易溢出。
# construct a gallery of features at all pixel locations of the K nearest neighbors
topk_feat_map = train_outputs[layer_name][topk_indexes[t_idx]]
test_feat_map = test_outputs[layer_name][t_idx:t_idx + 1]
feat_gallery = topk_feat_map.transpose(3, 1).flatten(0, 2).unsqueeze(-1).unsqueeze(-1)
# calculate distance matrix
dist_matrix_list = []
for d_idx in range(feat_gallery.shape[0] // 100):
dist_matrix = torch.pairwise_distance(feat_gallery[d_idx * 100:d_idx * 100 + 100], test_feat_map)
dist_matrix_list.append(dist_matrix)
dist_matrix = torch.cat(dist_matrix_list, 0)
# k nearest features from the gallery (k=1)
score_map = torch.min(dist_matrix, dim=0)[0]
score_map = F.interpolate(score_map.unsqueeze(0).unsqueeze(0), size=224,
mode='bilinear', align_corners=False)
score_maps.append(score_map)
# average distance between the features
score_map = torch.mean(torch.cat(score_maps, 0), dim=0)
# apply gaussian smoothing on the score map
score_map = gaussian_filter(score_map.squeeze().cpu().detach().numpy(), sigma=4)