D³ETR:用于检测Transformer的解码器蒸馏方法
简介
DETR(Detection Transformer)是一种端到端的目标检测器,利用Transformer架构实现了高精度的目标检测。然而,DETR模型通常参数量大,计算开销高,限制了其在资源受限环境中的应用。知识蒸馏(Knowledge Distillation, KD)是一种有效的模型压缩技术,旨在将大型教师模型的知识迁移到轻量级学生模型中。
然而,传统的知识蒸馏方法在DETR上效果不佳,主要原因在于教师模型和学生模型之间缺乏一致的蒸馏点(Distillation Points)。蒸馏点是指教师模型和学生模型在训练过程中用于对齐的中间表示或预测结果。
D³ETR 的核心贡献
D³ETR提出了一种通用的DETR知识蒸馏范式,通过一致的蒸馏点采样策略,提升了蒸馏效果。其主要创新包括:
1. 引入一致的蒸馏点
D³ETR通过引入一组专门的对象查询(Object Queries)来构建蒸馏点。这些查询在教师模型和学生模型之间共享,确保了蒸馏点的一致性,从而提高了蒸馏的稳定性和有效性。
2. 解耦检测与蒸馏任务
在D³ETR中,检测任务和蒸馏任务被解耦处理。通过引入专门的对象查询,模型可以分别优化检测性能和蒸馏效果,避免了任务间的干扰。
3. 通用到特定的蒸馏点采样策略
D³ETR提出了一种从通用到特定的蒸馏点采样策略。首先,随机采样一组均匀分布的对象查询,粗略扫描整个特征图,提取通用知识;然后,利用教师模型中优化良好的对象查询,提取特定知识。通过结合通用和特定的蒸馏点,学生模型能够更全面地学习教师模型的知识。
实验结果与优势
D³ETR在多个DETR架构上进行了广泛的实验,验证了其有效性和通用性。主要结果包括:
- 在MS COCO 2017数据集上,使用ResNet-18作为主干网络的DAB-DETR模型,应用D³ETR后,mAP提升了5.2%,达到41.4%;使用ResNet-50作为主干网络的DAB-DETR模型,mAP提升了3.5%,达到45.7%。
- 在某些情况下,学生模型的性能甚至超过了教师模型。例如,使用ResNet-50作为学生模型,超过了使用ResNet-101的教师模型2.2%。
- D³ETR适用于多种DETR架构,包括DAB-DETR、Deformable DETR、DINO等,展现了良好的通用性。
数学公式示例
D³ETR中的蒸馏损失函数可以表示为:
\[
\mathcal{L}_{\text{KD}} = \sum_{i=1}^{N} \lambda_i \cdot \mathcal{D}(f_i^{\text{teacher}}, f_i^{\text{student}})
\]
其中:
- $N$:蒸馏点的数量
- $\lambda_i$:第$i$个蒸馏点的权重
- $\mathcal{D}$:损失函数,例如L2损失或KL散度
- $f_i^{\text{teacher}}$ 和 $f_i^{\text{student}}$:教师模型和学生模型在第$i$个蒸馏点的特征表示
通过调整权重$\lambda_i$,可以控制不同蒸馏点对总损失的贡献,从而优化蒸馏效果。
Ref
- 论文链接:https://arxiv.org/abs/2211.09768
- IJCAI 2024 论文:https://www.ijcai.org/proceedings/2024/74