论文标题:DASO: Distribution-Aware Semantics-Oriented Pseudo-label for Imbalanced Semi-Supervised Learning
论文链接:https://arxiv.org/abs/2106.05682
代码链接:https://github.com/ytaek-oh/daso
论文来源:CVPR2022
作者单位:KAIST, South Korea, UC Berkeley / ICSI, CA.
摘要
传统的半监督学习方法和现实世界中应用的情景大有不同,主要体现在由(1)各类别分布不平衡;(2)有标注数据和无标注数据的类别分布不匹配,所造成的伪标签具有严重偏置。针对以上两个问题,本文提出了一个分布感知并面向语义的伪标注方法:Distribution-Aware Semantics-Oriented (DASO) Pseudo-label。简单来说,本文提出了混合使用(1)由传统线性分类器所生成的伪标签和(2)由基于相似性度量的分类器所生成的语义伪标签的方法,缓解伪标签中存在的偏置。此外本文还提出了一个语义对齐损失(semantic alignment loss)来生成平衡的特征表示以减少偏置。
动机
现有的类别不平衡条件下的半监督学习方法大多假设有标注数据和无标注数据服从相同的类别分布。然而无标注数据真实的类别分布并不可知,多数情况下这两个分布不是一致的。通过实验发现,一个基于相似性度量的分类器所生成的语义伪标签的偏置更倾向于尾部类(即少数类,minority classes),而传统的线性分类器生成的伪标签(本文称为线性伪标签,linear pseudo-labels)的偏置更倾向于头部类(即多数类,majority classes)。因此本文希望共同利用这两种特性互补的伪标签来提出一种新的伪标注策略。
图中的FixMatch和USADTM分别代表使用线性伪标签和语义伪标签来进行类不平衡半监督学习。从图2(a)和(b)可以发现,线性伪标签(FixMatch)在多数类的Recall高,在少数类的Recall低,但是少数类的Precision高,这说明模型对头部类(head classes)具有偏置。相反,语义伪标签(USADTM)在少数类的Precision大幅降低,Recall大量提升,说明模型对尾部类(tail classes)具有偏置。在图2(c)中,USADTM在头部类的测试精度有所下降,尾部类的测试精度提升。因此,仅使用语义伪标签只能达到次优的表现。
因此,本文设计的自适应混合两种伪标签的思路是:对于被预测为多数类的线性伪标签,应该混合更多的语义伪标签(减少线性伪标签的比例,增加语义伪标签的比例)。反之亦然。
DASO伪标注框架
概览
本文将DASO伪标注选择模块部署在FixMatch半监督学习框架之上。首先利用编码器得到弱增强无标注样本的特征表示,再分别利用线性分类器和基于相似性度量的分类器生成线性伪标签和语义伪标签。最终的伪标签由以上两种伪标签自适应融合得到。于是,Fixmatch中的无监督损失可以改写为,由融合的伪标签替换了之前的线性伪标签。对于语义对齐损失,语义伪标签和预测结果计算交叉熵损失,其中,是基于相似度度量分类器对特征的输出结果。在后文中,被记为。
平衡的原型生成(Balanced prototype generation)
为了构建基于相似度度量的分类器来生成语义伪标签,本文根据有标注集构建一组类原型。具体地,本文建立了一个memory队列的字典,其中key代表对应的类别,代表类别的固定大小为的一个memory队列。类别的原型通过对中所有特征点取平均得到。
然而类原型的表示也可能是不平衡的(由于类别不平衡的标注数据)。为了避免具有偏执的类原型,本文采用了两种方法来平衡类原型。(1)对所有的类别,固定的大小统一为。(2)采用一个动量更新的编码器来提取特征,其结构和原始的编码器完全相同,但是是的指数移动平均(EMA),即。作者给出的解释是:这种做法通过减缓网络模型参数的更新步幅,使得每个原型在特征空间中的移动更稳定。
线性和语义伪标签生成(Linear and semantic pseudo-label generation)
线性伪标签由线性分类器和softmax激活函数直接生成:。语义伪标签由基于相似度的分类器生成,这个分类器衡量每个query特征点到每个平衡过后(balanced)类原型的相似性:
其中代表余弦相似度,是温度系数超参数。其中偏置于头部类,偏置于尾部类。
分布感知融合(Distribution-aware blending)
本文认为语义伪标签在不同类别应该有不同的融合程度,比如当更倾向于头部类时,应该加入融合的比例。正式地,我们设定了一组分布感知的权重来减少潜在的偏置:
其中是的预测类别,。其中是当前伪标签归一化的类别分布(在之前几轮迭代的积累),是超参数。因此,对于线性伪标签来说,预测为少数类的伪标签被保留,预测为多数类的伪标签将可能恢复为它原始的类别。
分布感知融合是根据模型当前的偏置(伪标签的类别分布)动态调整融合比例的。这使得DASO框架在无标注集类别分布未知的条件下对不同的分布能够灵活实时调整。比如对于相同的预测结果:预测为头部类,如果模型偏置得更大则将被融合地更多。
语义对齐损失(Semantic alignment loss)
为了构建更加平衡的特征表示,本文引入一个新的语义对齐损失来对特征编码器进行正则化:
其中是将特征输入到基于相似度度量的分类器得到。
总损失
DASO是一个通用的框架来应用于其它的半监督学习算法中,DASO总目标函数如下:
实验
数据集
CIFAR-10/100,STL-10,用“LT”代表不平衡的数据集。此外还有Semi-Aves大规模鸟物种长尾分布数据集,其无标注集也是长尾分布,并包含有标注集中不存在的类别。
实验结果
有标注数据和无标注数据类别分布一致的情况下():
有标注数据和无标注数据类别分布不一致的情况下():
DASO和其它半监督学习(SSL)方法结合:
DASO在Semi-Aves数据集(open-set)上的性能表现: