- 论文:Queried Unlabeled Data Improves and Robustifies Class-Incremental Learning
- 期刊:TMLR 2022
- 作者:德州大学奥斯汀分校等
本文在类别增量(CIL)场景的简单模型 LwF 基础上做了改进,并使用了三个机制,提升了模型的效果:无标签查询数据(QUD)、辅助分类器平衡训练、对抗样本训练。本质上持续学习的重演方法和正则化方法。本文用的几个机制其实是独立的、平行的,作者将其堆叠到持续学习场景中,有点缝合怪行为。
无标签查询数据(QUD)
持续学习重演方法最大的问题是受记忆容量限制,重演数据量不够导致的训练样本不均衡的问题。本文的特色是利用了外部数据库(例如 Google 图片)的数据帮助防止遗忘。查询数据(query data)是一个数据库概念,是指从数据库中按照一定的查询条件抽取的一些数据。
具体来说,在
获得了大量无标签查询数据后,最常用的是知识蒸馏或知识迁移方法,通过在损失函数中加入以下正则项:
- 知识蒸馏(KD):
, 为无标签查询数据, 为分类器最后输出的结果(概率值),此正则项让模型在查询数据(代表了旧任务数据)上预测结果尽量向旧模型靠近; - 知识迁移(KT):
。它与 KD 的差别在只是让共有网络 输出结果靠近,整个损失函数不会更新旧任务的输出头。
这里与 LwF 类似,因为都用了知识蒸馏,但是不一样。LwF 手中只有等当前任务的数据,将其当作旧任务数据作蒸馏。这样不适合 CIL 场景,因为旧模型还没有新任务类别的输出头,无法完成新任务。
关于此方法,我认为可能存在的问题:每次查询的数据是不能存下来的,而每个旧任务 anchor (即查询的依据)是固定的,所以会重复查询相同的数据,查询量也是线性增长的,其实将重演记忆的空间代价转化成了查询的时间代价。
辅助分类器平衡训练
这个机制也是为了解决重演的训练样本不均衡的问题。先不管查询数据,将 anchor 看作重演数据,最简单的重演方法是将新数据和重演数据混合,随机采样 batch 拿来训练。这些重演的 anchor (旧类别)占新数据(新类别)的比例是悬殊的。而有其他的采样方式可以使采样的 batch 类别是均衡的,称为 class-balanced batch,具体见论文中引述的工作。
然后,并不是直接使用 class-balanced batch,随机采样的 random batch 也要用,但单独给它开一个分类头。训练时二者同等重要。这样做的目的是防止 class-balanced batch 过分突出不平衡的那部分少量的数据使其过拟合(起到了隐式的正则化的作用)。损失函数:
但在测试阶段并不参与到分类结果中,即测试阶段的输出只用 class-balanced batch 对应的分类头,称为主分类器(primary classifier),random batch 对应的分类头称为辅助分类器(auxiliary classifier)。个人认为这样做训练与测试阶段不一致,合理性有待讨论,但实际上很多论文都有过这种现象,例如上次的 CAT。
将此机制结合到 QUD 机制,得到了本文的 CIL-QUD 模型:
最终的损失函数:
对抗样本训练
对抗训练是用来提高模型鲁棒性的一种手段,通过设计在原数据
加到上面的损失函数中,得到本文的 Robust 版模型 RCIL-QUD :
其中查询数据构造的正则项 $ {}^{}(, )$ (注意这里把
- Robust 版知识蒸馏(RKD):
- Robust 版知识迁移(RFT):
最后又额外引入了一个正则项,能让无标签数据在增强鲁棒性上发挥更大作用(无标签数据已经用在了
道理很简单,无论有无标签,对抗训练都希望扰动之后预测值不变:对有标签数据,不变的是已知的预测标签