知识蒸馏

知识蒸馏

Hinton 2015年正式在论文Distilling the knowledge in a neural network中提出了知识蒸馏的概念,这算是知识蒸馏的开山之作,也是绝对的经典之作。将已经训练好的模型包括的知识,蒸馏提取到另一个模型中,目的是将一个大模型或者多个模型集成学到的知识迁移到另一个轻量化模型。简而言之,就是模型压缩的一种方法,一种基于“教师-学生网络思想”的训练方法。

理论依据

名词解释

  • Teacher:大而笨重的模型
  • Student:小而紧凑的模型
  • transfer set:用于小模型训练的数据,也是获得Teacher模型soft target输出的输入数据集
  • hard target:样本原始标签
  • soft target:Teacher模型输出的预测结果
  • temperature:softmax函数中的超参数
  • knowledge:可以理解为从输入向量到输出向量学习到的映射

符号定义

  • $z$:Logits,模型去除输出层的输出。对于一般的分类问题,比如图片分类,输入一张图片后,经过深度神经网络各种非线性变换,在网络最后的Softmax层之前,会得到这张图片属于各个类别的大小数值$z_i$,某个类别的$z_i$数值越大,模型认为输入图片属于这个类别的可能性就越大。那么什么是Logits?这些汇总了网络内部各种信息后,得出的属于各个类别的汇总分值$z_i$就是Logits,i代表第i 个类别,$z_i$代表属于第i类的可能性。因为Logits不是概率值,所以一般在Logits数值上会用Softmax函数进行变换,得出的概率值作为最终分类结果概率。Softmax一方面把Logits数值在各类别之间进行概率归一,使得各个类别归属数值满足概率分布;另外一方面,它会放大Logits数值之间的差异,使得Logits得分两极分化,Logits得分高的得到的概率值更偏大一些,而较低的Logits数值,得到的概率值则更小。
  • $p$ :probality,每个类的概率

Teacher-Student Model

将复杂且大的模型作为Teacher,Student模型结构较为简单,用Teacher来辅助Student模型的训练,Teacher学习能力强,可以将它学到的知识迁移给学习能力相对弱的Student模型,以此来增强Student模型的泛化能力.

需要注意的是,这里蒸馏的目的是小网络的概率分布趋近于大网络,而非单纯的正确率趋近于大网络

下图为知识蒸馏的基本架构,主要由知识、蒸馏算法和teacher-student三个部分组成

知识

分为response-based knowledge, feature-based knowledge and relation-based knowledge三类

response-based knowledge

Response-Based Knowledge一般是指teacher模型最后一层的响应,即logits。主要思想就是让student直接模仿teacher最后的预测。

feature-based knowledge

所谓的基于特征的知识其实就是中间层的输出

relation-based knowledge

Relation-Based Knowledge是指不同层或样本之间的关系,这里的对象我们可以自由构造,比如两个层之间的关系、三个层之间的关系、两个样本之间的关系、三个样本之间的关系等等。具体的,比如我们面对的是一个文本相关性任务,可以让student去学习teacher输出的不同样本对的相关性,即teacher认为两个文本的相关性是多少,则student就去学习尽量做到同样的相关性。

Distillation Schemes

知识蒸馏一般可以分为离线蒸馏、在线蒸馏和自蒸馏三种架构

Offline Distillation

工业界用的比较多的是离线蒸馏,学生向预先训练好的老师学习,简单易于实现。离线蒸馏的主要问题是大的teacher和小的student之间存在着model capacity gap,可能小的student就没有办法学得特别好,因为可能能力确实有限。

Online Distillation

师生同时更新,整个框架端到端可训练(如相互学习)

Self-Distillation

指不通过新增一个大模型的方式找到一个teacher模型。因能学到增益信息而收益

Teacher-Student Architecture

Distillation Algorithms

  • Adversarial Distillation 对抗蒸馏
  • Multi-Teacher Distillation 多教师蒸馏
  • Cross-Modal Distillation 跨模态蒸馏
  • Graph-Based Distillation 基于图的蒸馏
  • Attention-Based Distillation 基于注意力的蒸馏
  • Data-Free Distillation 无数据蒸馏
  • Quantized Distillation 量化蒸馏
  • Lifelong Distillation 终身蒸馏
  • NAS-Based Distillation 基于NAS的蒸馏
打赏
  • 版权声明: 本博客所有文章除特别声明外,著作权归作者所有。转载请注明出处!
  • Copyrights © 2021-2024 John Doe
  • 访问人数: | 浏览次数:

让我给大家分享喜悦吧!

微信