07 DeiT 论文精读:Training data-efficient image transformers distillation through attention
前言在前面的章节中我们已经理解了 ViT 的核心思想图像 → Patch Embedding → Token 序列 → Transformer Encoder → 分类结果ViT 原论文证明了一件非常重要的事情纯 Transformer 架构可以直接用于图像识别任务。但是 ViT 也留下了一个明显问题ViT 很强但它对大规模数据和训练资源的依赖比较明显。原始 ViT 的强性能通常依赖大规模预训练数据和较高训练成本这对于普通实验室和普通研究者并不友好。DeiT 正是围绕这个问题提出的。DeiT 论文全名为 Training>论文中明确指出DeiT 在只使用 ImageNet 的情况下训练了有竞争力的无卷积 Transformer并且引入了适合 Transformer 的 teacher-student 蒸馏策略其中 distillation token 是核心设计。2. DeiT 论文想解决什么问题DeiT 的提出背景非常明确。ViT 原论文证明了纯 Transformer 可以用于图像识别但原始 ViT 的强性能很大程度上依赖大规模预训练数据。例如ViT 原论文中经常使用 ImageNet-21k、JFT-300M 等大规模数据。对于大公司或大规模算力平台来说这种设定可以接受但对于普通研究者来说训练成本过高。所以 DeiT 想回答的问题是如果只使用 ImageNet-1K不使用额外大规模外部数据能不能训练出性能强的 Vision Transformer论文摘要中提到DeiT 的 reference vision transformer 拥有约 86M 参数在不使用外部数据的情况下达到 83.1% ImageNet Top-1 accuracy论文还报告了通过蒸馏后最高可达到 85.2% Top-1 accuracy。因此DeiT 的核心问题不是Transformer 能不能做图像分类这个问题 ViT 已经回答了。DeiT 真正要解决的是Transformer 能不能在有限数据和有限算力条件下训练好这就是 DeiT 中 “data-efficient” 的含义。3. DeiT 和 ViT 的关系DeiT 和 ViT 的关系非常紧密。ViT 的核心结构是Patch EmbeddingClass TokenPosition EmbeddingTransformer EncoderClassification HeadDeiT 并没有推翻这个框架。相反DeiT 基本沿用了 ViT 的主体结构。它的重点不是发明一种全新的视觉 Transformer 架构而是解决 ViT 的训练效率问题。可以这样理解ViT证明图像可以被看作 patch token 序列并输入 Transformer。DeiT证明在不依赖超大规模外部数据的情况下ViT 也可以通过更好的训练策略和蒸馏机制训练得很好。所以DeiT 是 ViT 之后非常关键的一篇工作。它解决的是 ViT 从“能跑通”到“更容易训练、更容易复现、更适合普通实验条件”的问题。4. DeiT 的核心贡献DeiT 的贡献可以概括为三点。4.1 提出一套更高效的 ViT 训练方案DeiT 证明只使用 ImageNet-1K也可以训练出很强的 Vision Transformer。论文强调其方法可以在单台计算机上较短时间内训练出有竞争力的模型这大大降低了 ViT 的使用门槛。这说明 ViT 的性能不仅取决于模型结构也强烈依赖训练策略。DeiT 中的 “data-efficient” 可以从两个层面理解。第一层含义是不依赖超大外部数据。ViT 原论文中大规模预训练是非常重要的。DeiT 则希望在 ImageNet-1K 这样的标准数据规模下训练 Transformer。第二层含义是在有限的数据下提高训练效果也就是说同样只使用 ImageNet-1KDeiT 通过更强的数据增强、正则化、优化策略和蒸馏机制让 ViT 学得更好。所以 DeiT 的目标不是简单地减少数据量而是提高数据使用效率。可以概括为ViT 依赖大规模数据学习视觉规律DeiT 通过训练策略和 teacher supervision 提高 ViT 的数据效率。DeiT 的另一个重要观点是对于 Vision Transformer训练 recipe 和模型结构同样重要。这里的训练 recipe 可以理解为一整套训练配置包括数据增强正则化优化器学习率策略warmuplabel smoothingmixupcutmixrandom erasingstochastic depth知识蒸馏这些细节对 ViT 尤其重要。原因是 ViT 的图像归纳偏置比 CNN 更弱。CNN 天然具有局部连接、权重共享和平移等变性而 ViT 更依赖数据和训练目标自己学习图像结构。因此如果训练策略不够强ViT 在 ImageNet-1K 上可能训练不充分。DeiT 的启发是ViT 的问题不只是结构问题也是训练问题。4.2 引入适合 Transformer 的蒸馏机制传统知识蒸馏通常是让 student 模型学习 teacher 模型的输出分布。DeiT 的创新在于不是只在 loss 上做蒸馏 而是在 Transformer 输入序列中加入一个 distillation token。这个 distillation token 会和 class token、patch token 一起进入 Transformer Encoder通过 self-attention 参与表示学习。这也是论文标题中distillation through attention的含义。要理解 DeiT必须先理解知识蒸馏。知识蒸馏的基本框架是Teacher Model一个已经训练好的强模型Student Model一个需要训练的模型普通监督学习中student 只学习真实标签image → label。而知识蒸馏中student 还要学习 teacher 的输出image → teacher prediction。teacher 的输出通常比 one-hot 标签包含更多信息。例如一张猫的图片真实标签只告诉模型cat 1其他类别 0但是 teacher 的输出可能是cat 0.82tiger 0.08dog 0.04fox 0.02...这种分布表达了类别之间的相似关系。所以蒸馏的意义是真实标签告诉 student 正确答案teacher 输出告诉 student 类别之间的关系和判断倾向。这对 ViT 很有帮助因为 ViT 缺少 CNN 那种强图像先验需要更丰富的监督信号。4.3 证明 ConvNet Teacher 对 ViT Student 特别有效DeiT 论文中指出使用 ConvNet 作为 teacher 对 Transformer student 的蒸馏尤其有效。直观上CNN 具有更强的局部视觉归纳偏置而 ViT 缺少这种先验。因此CNN teacher 可以向 ViT student 传递有用的视觉判断经验。这也从侧面说明ViT 的训练困难并不是结构无效而是它需要更好的监督信号和训练策略来学习视觉规律。5. DeiT 的关键设计Distillation TokenDeiT 最核心的结构设计就是distillation token在标准 ViT 中输入序列是[CLS], patch_1, patch_2, ..., patch_196而在 DeiT 中输入序列变成[CLS], [DIST], patch_1, patch_2, ..., patch_196其中[CLS]class token用于学习真实标签监督 [DIST]distillation token用于学习 teacher 监督这意味着 DeiT 比 ViT 多了一个特殊 token。对于 224×224 输入、patch size 为 16 的模型ViT token 数量 196 patch tokens 1 class token 197 DeiT distilled token 数量 196 patch tokens 1 class token 1 distillation token 198这个 distillation token 不是图像 patch 切出来的而是一个可学习参数和 class token 一样会参与 Transformer Encoder 中的 self-attention。5.1 Class Token 和 Distillation Token 的区别class token 和 distillation token 很像但目标不同。class token 主要用于真实标签分类。它经过 Transformer Encoder 后接分类头输出class head output。然后和真实标签计算分类损失。可以理解为CLS token 负责学习 ground-truth label。distillation token 主要用于学习 teacher。它经过 Transformer Encoder 后接另一个分类头distillation head output。然后和 teacher 输出计算蒸馏损失。可以理解为DIST token 负责学习 teacher prediction。5.2 二者会不会互相影响会。因为它们都在同一个 Transformer Encoder 中。输入序列是[CLS], [DIST], patch_1, patch_2, ..., patch_196在 self-attention 中每个 token 都可以和其他 token 交互。所以CLS token 可以关注 patch tokenDIST token 可以关注 patch tokenCLS token 和 DIST token 之间也可以互相关注。这就是 DeiT 中 “through attention” 的关键。distillation token 不是一个独立分支而是通过 attention 融入整个 Transformer 表示学习过程。6. DeiT 的模型结构解析DeiT 的整体结构可以写成Input Image↓Patch Embedding↓Add Class Token↓Add Distillation Token↓Add Position Embedding↓Transformer Encoder↓取 CLS token 和 DIST token↓Class Head Distillation Head和 ViT 相比主要变化有三个1. 多了 distillation token2. position embedding 长度增加 13. 多了 distillation head。以 DeiT-B/16 为例patch tokens: 196class token: 1distillation token: 1total tokens: 198embedding dim: 768所以 Transformer Encoder 的输入形状为[B, 198, 768]而不是 ViT-B/16 的[B, 197, 768]最后输出时x_cls x[:, 0]x_dist x[:, 1]其中x_clsclass token 输出x_distdistillation token 输出分别接不同的分类头。7. DeiT 的蒸馏损失函数解析DeiT 的训练目标由两部分组成1. 普通分类损失2. 蒸馏损失可以写成Total Loss (1 - α) × Classification Loss α × Distillation Loss其中α控制普通分类损失和蒸馏损失的权重在官方 DeiT 代码中DistillationLoss 会先计算基础分类损失 base_loss如果启用蒸馏则再用 teacher model 对原始输入进行预测然后根据 distillation_type 计算 soft 或 hard distillation loss最后按 alpha 加权组合。普通分类损失通常是Classification Loss CE(class_head_output, ground_truth_label)蒸馏损失根据蒸馏方式不同可以分为 soft distillation 和 hard distillation。7.1 Soft DistillationSoft distillation 学习 teacher 的概率分布。teacher 会输出每个类别的概率例如cat: 0.82dog: 0.08tiger: 0.04...student 的 distillation head 要尽量接近 teacher 的输出分布。通常使用 KL divergence 进行约束。形式上可以理解为Distillation Loss KL(student_distribution, teacher_distribution)在官方 DeiT 代码中soft distillation 使用 F.kl_div对 student 的 distillation 输出和 teacher 输出都进行 temperature scaling并乘以 T*T 进行尺度修正。也就是说soft distillation 学的是teacher 的完整判断分布。它不仅告诉 student 哪个类别最可能还告诉 student 类别之间的相似关系。7.2 Hard DistillationHard distillation 学习 teacher 的预测类别。teacher 输出 logits 后取最大概率对应的类别teacher_label argmax(teacher_output)。然后 student 的 distillation head 学习这个 teacher label。官方 DeiT 代码中hard distillation 使用F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim1))也就是说hard distillation 不是学习完整概率分布而是学习 teacher 给出的硬标签。可以这样理解Soft distillation学习 teacher 的判断分布Hard distillation学习 teacher 的最终答案7.3 teacher model在 DeiT 中teacher 指的是知识蒸馏中的教师模型。DeiT 本身是 student model训练时不仅学习真实标签还会学习 teacher model 的预测结果。DeiT 原论文中默认使用的 teacher 是 RegNetY-16GF这是一个卷积神经网络模型参数量约为 84M。论文中使用与 DeiT 相同的数据和数据增强方式训练该 teacher其 ImageNet Top-1 accuracy 为 82.9%。DeiT 之所以选择 CNN teacher是因为 CNN 具有更强的图像归纳偏置例如局部连接、层次化特征提取和局部纹理建模能力。ViT 的归纳偏置较弱通过 distillation token 向 CNN teacher 学习可以帮助 Transformer student 获得更好的视觉监督信号。