知识蒸馏技术详解:如何利用软标签压缩模型

知识蒸馏技术利用软标签和KL散度,将大型模型的知识迁移到小型模型,实现高效的模型压缩。

原文标题:知识蒸馏技术原理详解:从软标签到模型压缩的实现机制

原文作者:数据派THU

冷月清谈:

知识蒸馏技术是一种模型压缩方法,核心是将大型教师模型的知识迁移到小型学生模型。它利用软标签,即经过温度参数调整的概率分布,而非传统的独热编码标签,来训练学生模型。

软标签保留了类别间的关系信息,例如,识别马的图像时,教师模型不仅会给“马”高概率,也会给“鹿”和“牛”分配一定概率,反映了它们之间的相似性。

学生模型的训练使用复合损失函数,包含硬标签损失(学生模型预测与真实标签的交叉熵)和软标签损失(学生模型预测与教师模型软标签的KL散度)。KL散度乘以温度参数的平方,以避免梯度消失。

通过这种方式,学生模型可以学习到教师模型对不同类别之间关系的理解,从而在更小的模型规模下获得更好的性能。

怜星夜思:

1、除了文中提到的动物识别,知识蒸馏技术还能应用于哪些其他领域?
2、文章提到了温度参数,这个参数如何选择才能获得最佳的蒸馏效果?
3、除了KL散度,还有哪些其他指标可以用来衡量教师模型和学生模型之间的差异?

原文内容

来源:DeepHub imba

本文约1700字,建议阅读8分钟

这种知识迁移机制使得构建小型高效模型成为可能,为模型压缩技术提供了新的解决方案。‍‍‍


知识蒸馏是一种通过性能与模型规模的权衡来实现模型压缩的技术。其核心思想是将较大规模模型(称为教师模型)中的知识迁移到规模较小的模型(称为学生模型)中。本文将深入探讨知识迁移的具体实现机制。

知识蒸馏原理

知识蒸馏的核心目标是实现从教师模型到学生模型的知识迁移。在实际应用中,无论是大规模语言模型(LLMs)还是其他类型的神经网络模型,都会通过softmax函数输出概率分布。

Softmax输出示例分析

考虑一个输出三类别概率的神经网络模型。假设教师模型输出以下logits值:

教师模型logits: [1.1, 0.2, 0.2]

经过softmax函数转换后得到:

Softmax概率分布: [0.552, 0.224, 0.224]

此时,类别0获得最高概率,成为模型的预测输出。模型同时为类别1和类别2分配了较低的概率值。这种概率分布表明,尽管输入数据最可能属于类别0,但其特征表现出了与类别1和类别2的部分相关性。

低概率信息的利用价值

在传统分类任务中,由于最高概率(0.552)显著高于其他概率值(均为0.224),次高概率通常会被忽略。而知识蒸馏技术的创新之处在于充分利用这些次要概率信息来指导学生模型的训练过程。

分类任务实例分析:

以动物识别任务为例,当教师模型处理一张马的图像时,除了对"马"类别赋予最高概率外,还会为"鹿"和"牛"类别分配一定概率。这种概率分配反映了物种间的特征相似性,如四肢结构和尾部特征。虽然马的体型大小和头部轮廓等特征最终导致"马"类别获得最高概率,但模型捕获到的类别间相似性信息同样具有重要价值。

分析另一组教师模型输出的logits值:

教师模型logits:[2.9, 0.1, 0.23]

应用softmax函数后得到:

Softmax概率分布: [0.885, 0.054, 0.061]

在这个例子中,类别0以0.885的高概率占据主导地位,但其他类别仍保留了有效信息。为了更好地利用这些细粒度信息,我们引入温度参数T=3对分布进行软化处理。软化后的logits值为:

软化后logits: [0.967, 0.033, 0.077]

再次应用softmax函数:

温度调节后的概率分布: [0.554, 0.218, 0.228]

经过软化处理的概率分布在保留主导类别信息的同时,适当提升了其他类别的概率权重。这种被称为软标签的概率分布,相比传统的独热编码标签(如[1, 0, 0]),包含了更丰富的类别间关系信息。

学生模型训练机制

在传统的模型训练中,仅使用独热编码标签(如[1, 0, 0])会导致模型仅关注正确类别的预测。这种训练方式通常采用交叉熵损失函数。而知识蒸馏技术通过引入教师模型的软标签信息,为学生模型提供了更丰富的学习目标。

复合损失函数设计

学生模型的训练目标由两个损失分量构成:

  1. 硬标签损失: 学生模型预测值与真实标签之间的标准交叉熵损失。
  2. 软标签损失: 基于教师模型软标签计算的知识迁移损失。

这种复合损失函数可以用数学形式表示为:

Image

KL散度计算方法

为了度量教师模型软标签与学生模型预测之间的差异,采用Kullback-Leibler (KL) 散度作为度量标准:

Image

其中:

  • pi表示教师模型的软标签概率。
  • qi表示学生模型的预测概率。

数值计算示例

以下示例展示了教师模型和学生模型预测之间的KL散度计算过程:

教师模型软标签: [0.554,0.218,0.228]

学生模型预测值: [0.26,0.32,0.42]

Image

各项计算过程:

求和结果:

Image

最终损失计算方法

为了补偿温度参数带来的影响,需要将KL散度乘以温度参数的平方(T²):

Image

这种补偿机制确保了KL散度不会因温度参数的引入而过度衰减,从而避免反向传播过程中出现梯度消失问题。通过综合考虑硬标签损失和经过温度调节的KL散度,学生模型能够有效利用教师模型提供的知识,实现更高效的参数学习。

总结

与仅使用独热编码标签(如[1, 0, 0])的传统训练方法相比,知识蒸馏技术通过引入教师模型的软标签信息,显著降低了学生模型的学习难度。这种知识迁移机制使得构建小型高效模型成为可能,为模型压缩技术提供了新的解决方案。

作者:Hoyath


编辑:黄继彦‍‍‍‍‍

‍‍‍



关于我们

数据派THU作为数据科学类公众号,背靠清华大学大数据研究中心,分享前沿数据科学与大数据技术创新研究动态、持续传播数据科学知识,努力建设数据人才聚集平台、打造中国大数据最强集团军。



新浪微博:@数据派THU

微信视频号:数据派THU

今日头条:数据派THU

还可以考虑用MMD(Maximum Mean Discrepancy),它可以用来比较两个分布的差异,不需要计算概率密度。

找到最佳温度参数,可以尝试网格搜索或者贝叶斯优化等方法,通过比较不同参数下的学生模型性能,来选择最优值。

如果关注的是模型输出的预测结果,也可以直接用一些常用的评价指标,比如准确率、F1值等等,来看学生模型学到了多少教师模型的知识。

知识蒸馏应该也能用在医学影像分析上吧?比如训练一个轻量级的模型,用于快速诊断疾病,这样可以减轻医生的负担,也能让更多人受益。

我记得看过一些论文,说温度参数越高,软标签的分布就越平滑,低概率的信息就越突出。但温度太高,又会损失一些重要的信息,所以需要找到一个平衡点。

这个技术感觉可以应用于任何需要模型压缩的领域,比如自动驾驶,边缘计算等等,只要有大模型和小模型的需求,就可以用它来迁移知识。

我觉得在自然语言处理领域,知识蒸馏也蛮有用的,比如可以用来训练更小的对话模型,让它继承大型模型的语言理解能力,这样在资源有限的设备上也能流畅运行。

温度参数的选择感觉像个玄学,文章里也没细说,估计得根据具体任务和数据集进行调整,可能需要一些实验和经验。

可以用JS散度(Jensen-Shannon divergence)试试,它跟KL散度类似,但更对称一些。