单步生成高质量图像:分布匹配蒸馏(DMD)技术解析

分布匹配蒸馏(DMD)技术将多步扩散模型简化为单步生成,显著提升图像生成速度和效率。

原文标题:分布匹配蒸馏:扩散模型的单步生成优化方法研究

原文作者:数据派THU

冷月清谈:

扩散模型因其生成高质量图像的能力而备受关注,但其多步迭代去噪过程计算成本高昂。分布匹配蒸馏(DMD)技术提供了一种解决方案,它将多步扩散过程简化为单步生成,从而显著提高了图像生成速度。

DMD的核心思想是训练一个单步生成器,使其能够直接从噪声图像生成高质量的目标图像。它通过分布匹配损失函数和对抗生成网络损失来实现这一点,迫使生成器学习逼近真实数据分布。

DMD的实现流程主要分为五个阶段:系统初始化、噪声到图像的生成、高斯噪声注入、双重网络处理和损失计算。其中,双重网络处理阶段使用了两个关键网络:一个预训练的“教师”网络(real_unet)和一个学习生成器分布的网络(fake_unet)。这两个网络协同工作,通过对比它们的输出来量化真实分布和生成分布之间的差异,从而指导生成器的训练。

DMD技术巧妙地利用fake_unet来捕捉生成器分布的动态变化。fake_unet的训练目标是将生成器输出的噪声版本还原为当前生成器输出,而非像传统unet那样与真实图像进行比较。这种设计使fake_unet能够更有效地指导生成器的训练。

怜星夜思:

1、DMD 技术中提到的分布匹配损失函数,除了 KL 散度之外,还有其他更合适的损失函数吗?实际应用中如何选择?
2、文中提到的 fake_unet 的作用是什么?如果直接用 real_unet 进行训练会有什么问题?
3、DMD 技术如何应用于实际场景?例如,在图像编辑、超分辨率等任务中,DMD 的优势和局限性是什么?

原文内容

来源:DeepHub IMBA

本文约2000字,建议阅读8分钟

本文深入探讨了分布匹配蒸馏(DMD)的技术原理和实现机制,着重阐述了其在图像生成领域的应用价值。


扩散模型在生成高质量图像领域具有显著优势,但其迭代去噪过程导致计算开销较大。分布匹配蒸馏(Distribution Matching Distillation,DMD)通过将多步扩散过程精简为单步生成器来解决这一问题。该方法结合分布匹配损失函数和对抗生成网络损失,实现从噪声图像到真实图像的高效映射,为快速图像生成应用提供了新的技术路径。

分布匹配机制

与传统扩散模型不同,单步生成器并不直接学习完整的数据分布,而是通过强制对齐的方式逼近目标分布。这种方法摒弃了逐步近似的过程,直接建立噪声样本到目标分布的映射关系。
在此过程中,蒸馏机制起到关键作用。预训练模型作为教师网络,提供目标分布的高精度中间表征。

DMD 技术实现流程

阶段 0:系统初始化
  1. 单步生成器基于预训练扩散 unet 进行初始化,时间步设定为 T-1;

  2. real_unet 作为固定权重的教师网络,表征真实数据分布;

  3. fake_unet 用于对生成器的数据分布进行建模。
阶段 1:噪声到图像的生成
生成器接收随机噪声图作为输入,通过单步去噪操作生成图像 x,此时生成的图像 x 符合生成器的概率密度分布 p_fake
阶段 2:高斯噪声注入
对生成图像 x 施加高斯噪声,获得噪声图像 xt,在 0.2T 到 0.98T 范围内均匀采样时间步 t(避开极端噪声状态),噪声注入操作促进 p_fake 与 p_real 分布的重叠,为后续分布比较创造条件
阶段 3:双重网络处理
  1. real_unet 生成 pred_real_image,作为清晰图像的参考近似;

  2. fake_unet 生成 pred_fake_image,反映当前时间步的生成器分布特征。
通过对比 pred_real_image 和 pred_fake_image 量化真实分布与生成分布的差异。
阶段 4:损失计算
计算 x 与 x — grad 之间的均方误差(MSE)作为损失度量。其中 x — grad 表示经过梯度校正的输出,用于减小与真实数据分布的偏差。
阶段 5:假分布更新机制
fake_unet 通过 x 和 pred_fake_image 之间的扩散损失进行参数更新。这一过程使 fake unet 能够追踪生成器分布的动态变化。与传统 unet 使用 xt-1_pred 和 xt-1_gt 计算损失不同,这里采用 xt-1_pred 和 x 之间的损失,使 fake UNet 能够将生成器输出的噪声版本(xt)还原为当前生成器输出 x。

核心问题解析

问题 1: 为何 fake_unet 采用 xt-1_pred 和 x0 之间的散度作为损失度量,而非采用 xt-1_pred 和 xt-1_gt 的比较?
选择 xt-1_pred 和 x 之间的散度是基于 fake_unet 的核心功能考虑。其目标是将生成器输出的噪声版本(xt)映射回生成器的当前输出(x)。这种设计确保了 fake_unet 能够准确捕获生成器的动态分布特征,从而提供有效的梯度信息来优化生成器输出。
问题 2:fake_unet 的必要性何在?是否可以直接利用预训练的 real_unet 输出与生成器输出计算 KL 散度?
生成器的设计目标是实现单步完全去噪,而预训练的 real_unet 在相同时间步内仅能实现部分去噪。这种本质差异导致 real_unet 输出无法提供有效的 KL 散度用于生成器训练。相比之下,fake_unet 通过持续学习生成器的动态分布,能够准确approximation当前生成器输出的特征。通过比较 real_unet 和 fake_unet 的输出,可以获得用于优化生成器概率分布的有效梯度方向,从而提升单步图像合成的质量。# 分布匹配损失机制
训练过程中,通过 KL 散度定量评估生成器分布与真实分布之间的差异。
图片
其中 Preal 代表真实数据的概率密度函数,Pfake 表示生成器 Gθ 产生的假分布概率密度函数。
对于高维数据集,直接计算概率密度在计算复杂度上存在显著挑战。例如,对于 32×32 像素的灰度图像,其维度空间为 256¹⁰²⁴,直接计算在实际应用中不可行。
因此,采用分数函数对真实分布和生成分布进行特征表征。
图片
这种方法使得 KL 散度的计算成为可能:Sreal 引导 x 向 Preal 的模态靠近,而 −Sfake 则促使其远离真实分布。
图片
其中 Sreal(x) 为真实数据分布的分数函数,Sfake(x) 为生成数据分布的分数函数,∇θ Gθ(z) 表示生成器输出 x 对参数的梯度。
Sreal(x)−Sfake(x) 表征了真实分数与生成分数的差异。对于生成样本 x,由于其 Sreal 接近零,需要引入扰动以支持扩散模型从 xt 进行去噪。
Sfake 和 Sreal 的定义参考自论文 "Song et al. — Score-based generative modeling through stochastic differential equations"
最终损失函数

技术原理剖析

在时间步 t−1,利用 real_unet 和 fake_unet 的输出构建梯度,引导生成器的当前输出 x 向 real_unet 在 t=0 时刻的输出收敛。随后计算生成器原始输出与梯度校正后输出的均方误差(MSE)。这一校正机制确保 x 能够逐步对齐真实数据分布。
图片
损失函数的代码实现:
该图展示了不同时间步的损失函数变化,详细说明了多步生成器对单步生成器的训练过程。注意: 图中未详细展示 weighting_factor 相关细节,并对底层分布作出了特定假设。
核心思想在于利用 xfake 和 xreal 之间的差异产生的梯度,将生成器输出引导至 real_unet 在 t=0 时刻的目标输出。随着训练进行,生成器输出逐步向真实分布靠近,同时带动 fake_unet 输出的优化。最终,校正后的图像 ∥x−grad∥ 收敛至真实分布。

总结

本文深入探讨了分布匹配蒸馏(DMD)的技术原理和实现机制,着重阐述了其在图像生成领域的应用价值。欢迎学术界同仁就相关技术细节提供建议和讨论,以促进该领域的持续发展。
作者:Om Rastogi
编辑:黄继彦


关于我们

数据派THU作为数据科学类公众号,背靠清华大学大数据研究中心,分享前沿数据科学与大数据技术创新研究动态、持续传播数据科学知识,努力建设数据人才聚集平台、打造中国大数据最强集团军。



新浪微博:@数据派THU

微信视频号:数据派THU

今日头条:数据派THU

DMD 可以用来做一些轻量级的图像编辑任务,比如快速生成一些表情包或者简单的插画。但如果要做一些精细的修改,比如P图,DMD 就有点力不从心了。

fake_unet 是动态更新的,它追踪的是生成器的学习过程。而 real_unet 是静态的,它代表的是最终的学习目标。如果只用 real_unet,就好比只给学生看最终的答案,却不告诉他解题步骤,学生很难学会。

DMD 在图像编辑中可以用于快速生成修改后的图像,例如风格迁移、图像修复等。它的优势在于速度快,但局限性在于对细节的控制能力可能不如传统的迭代方法。在超分辨率任务中,DMD 可以用于快速生成高清图像,但对于纹理细节丰富的图像,其生成质量可能不如基于GAN的方法。

在实际应用中,DMD 的优势在于快速生成,可以用于一些对实时性要求较高的场景,例如游戏、VR/AR 等。但它的局限性在于生成图像的质量和多样性可能不如 GAN,尤其是在处理复杂场景时。

我觉得可以试试交叉熵损失函数,它在分类任务里表现很好,说不定在分布匹配上也能有奇效。当然,具体效果还得看实验结果。

fake_unet 的作用在于模拟生成器的分布,而 real_unet 则代表真实数据分布。直接用 real_unet 训练会导致生成器过度拟合 real_unet 的输出,而不是学习真实的数据分布,最终限制生成器的泛化能力。

除了KL散度和楼上提到的Wasserstein距离和MMD,还可以考虑JS散度。不过JS散度有个缺点,当两个分布完全没有重叠时,梯度会消失,所以使用的时候要注意。

可以理解为 real_unet 是老师,fake_unet 是学生模仿老师的笔记(生成器),最终目的是让学生自己学会写笔记(生成图像)。如果学生直接抄老师的答案,考试的时候就不会了。