DeiT:数据高效的图像 Transformer
https://github.com/facebookresearch/deit https://github.com/FrancescoSaverioZuppichini/DeiT
摘要
DeiT(Data-efficient Image Transformers)是一种用于图像分类任务的神经网络模型,基于 Transformer 架构。它旨在通过数据高效的训练策略和知识蒸馏技术,在参数较少的情况下实现高效的图像分类。与传统的卷积神经网络(CNN)相比,DeiT 采用了 Transformer 的注意力机制,能够更好地捕捉图像中的全局关系。通过引入 Distillation Token 和采用硬蒸馏(Hard Distillation)方法,DeiT 在仅使用 ImageNet 数据集的情况下,达到了与最先进的 CNN 相当的性能,同时显著减少了训练所需的计算资源。
1. 引言
Transformer 架构在自然语言处理领域取得了巨大成功,但将其应用于图像分类任务时,如 ViT(Vision Transformers),通常需要在大规模数据集(如 JFT-300M,包含 3 亿张图片)上进行预训练,才能达到与 CNN 相当的性能。这不仅限制了 ViT 方法的广泛应用,还增加了训练成本。DeiT 通过采用数据高效的训练策略和知识蒸馏技术,仅使用 ImageNet 数据集,就能训练出性能优异的 Transformer 模型。
2. 知识蒸馏
知识蒸馏是一种模型训练技术,通过将一个大型教师模型的知识传递给小型学生模型,以提升学生模型的性能。在 DeiT 中,教师模型是一个性能良好的分类器,学生模型则是一个参数较少的 Transformer 模型。蒸馏过程包括两个阶段:训练教师模型和将教师模型的知识蒸馏到学生模型中。
2.1 硬蒸馏(Hard Distillation)
硬蒸馏使用教师模型的硬标签(即类别标签)作为目标,训练学生模型。具体来说,学生模型的输出 \( Z_s \) 与教师模型的输出 \( Z_t \) 之间的交叉熵损失被加入到总损失函数中: [ \mathcal{L}{\text{global}}^{\text{hardDistill}} = \frac{1}{2} \mathcal{L}(\psi(Z_s), y_t) ] 其中,}}(\psi(Z_s), y) + \frac{1}{2} \mathcal{L}_{\text{CE}\( y \) 是真实标签,\( y_t = \text{argmax}_c Z_t(c) \) 是教师模型的预测标签,\( \psi \) 是 softmax 函数。
2.2 软蒸馏(Soft Distillation)
软蒸馏使用教师模型的软标签(即类别概率分布)作为目标,训练学生模型。具体来说,学生模型的输出 \( Z_s \) 与教师模型的输出 \( Z_t \) 之间的 KL 散度被加入到总损失函数中: [ \mathcal{L}{\text{global}} = (1 - \lambda) \mathcal{L}(\psi(Z_s / \tau), \psi(Z_t / \tau)) ] 其中,}}(\psi(Z_s), y) + \lambda \tau^2 \text{KL\( \lambda \) 和 \( \tau \) 是超参数,控制蒸馏损失的权重和温度。
3. DeiT 的训练策略
DeiT 的训练策略包括优化器选择、数据增强和正则化等。
3.1 优化器
DeiT 使用 AdamW 优化器,这是一种带权重衰减的 Adam 优化器。实验表明,AdamW 比 SGD 优化器表现更好。
3.2 数据增强
DeiT 使用了多种数据增强技术,包括 Rand-Augment、Mixup 和 CutMix。这些技术通过随机选择和组合图像,增加了数据的多样性,有助于模型更好地泛化。
3.3 正则化
DeiT 使用了随机擦除(Random Erasing)和随机深度(Stochastic Depth)等正则化技术。这些技术有助于防止模型过拟合,尤其是在使用较深的模型时。
4. Distillation Token
Distillation Token 是 DeiT 的一个关键创新。它与 ViT 中的 Class Token 一起加入到 Transformer 中,通过自注意力机制与其他嵌入一起计算。Distillation Token 的输出目标是教师模型的预测标签,这使得学生模型能够从教师模型中学习。在训练时,Distillation Token 的输出用于计算蒸馏损失,而在测试时,Distillation Token 的输出与 Class Token 的输出平均,以提高分类性能。
5. 实验结果
DeiT 在 ImageNet 数据集上的实验结果表明,仅使用 ImageNet 数据集进行训练,DeiT 就能达到与最先进的 CNN 相当的性能。具体来说,DeiT-B(Base 版本)在 224×224 分辨率下达到了 83.1% 的 Top-1 准确率,在 384×384 分辨率下达到了 84.2% 的 Top-1 准确率。这些结果表明,DeiT 在数据效率和性能之间取得了良好的平衡。
6. 结论
DeiT 通过采用数据高效的训练策略和知识蒸馏技术,显著提高了 Transformer 模型在图像分类任务中的性能和数据效率。Distillation Token 和硬蒸馏方法的引入,使得 DeiT 能够从教师模型中学习到更多的知识,从而在仅使用 ImageNet 数据集的情况下,达到与最先进的 CNN 相当的性能。DeiT 的训练策略和蒸馏方法为 Transformer 在计算机视觉领域的应用提供了新的思路。
Ref
- DeiT: Data-efficient Image Transformers
- DeiT: Data-efficient Image Transformers - Implementation
- Deit:知识蒸馏与vit的结合 学习笔记(附代码)-CSDN博客
- Training data-efficient image transformers & distillation through attention-CSDN博客
- DeiT:注意力Attention也能蒸馏
- Transformer学习(四)---DeiT https://blog.csdn.net/PLANTTHESON/article/details/135547163 https://blog.csdn.net/yideqianfenzhiyi/article/details/113444303 https://zhuanlan.zhihu.com/p/543730203 https://zhuanlan.zhihu.com/p/443710545 https://blog.csdn.net/shizheng_Li/article/details/146441742 https://zhuanlan.zhihu.com/p/471384477 https://blog.csdn.net/abc13526222160/article/details/132339050