李飞飞团队提出Grafting技术:无需重新训练,高效编辑Diffusion Transformer架构

李飞飞团队提出Grafting技术,通过编辑预训练DiT模型架构,无需重新训练即可实现模型加速和性能提升。

原文标题:李飞飞团队新作:DiT不训练直接改架构,模型深度减半,质量还提高了

原文作者:机器之心

冷月清谈:

本文介绍了一种名为Grafting(嫁接)的技术,用于在计算资源有限的情况下,通过编辑预训练的Diffusion Transformers(DiT)来探索新的模型架构设计。该方法允许研究者在不从头开始训练模型的情况下,通过替换模型中的算子(如MLP)来创建混合架构,从而在保持模型质量的同时减少计算量。Grafting包含激活蒸馏和轻量级调优两个阶段,前者用于将原始算子的功能迁移至新算子,后者用于减轻集成多个新算子导致的误差传播。研究者通过嫁接技术,用门控卷积、局部注意力和线性注意力取代Softmax注意力,用可变扩展率和卷积变体取代MLP,在保持模型质量的同时,实现了模型深度减半和计算加速。实验结果表明,Grafting技术能够有效地探索新的扩散模型设计,并在图像生成任务中取得良好的效果。

怜星夜思:

1、文章中提到Grafting技术可以通过替换DiT模型中的算子来探索新的架构,那么在实际应用中,如何选择合适的算子进行替换?有没有一些通用的原则或者方法?
2、文章中提到嫁接技术可以将模型深度减少一半,这对实际应用有什么意义?除了减少计算量,还可能带来哪些好处或坏处?
3、文章提到Grafting技术在文本转图像模型中实现了加速,但在一些纹理区域观察到了局部性的失真,这说明了什么问题?有什么方法可以解决这个问题?

原文内容

机器之心报道

编辑:欣东、陈陈

本文介绍了一种名为「嫁接」的技术,用于在小计算预算下通过编辑预训练 Diffusion Transformers(简称 DiTs)来探索新的模型架构设计。这种方法允许研究者在不从头开始训练模型的情况下,通过替换模型中的某些算子(如 MLP)来创建新的混合架构,从而在保持模型质量的同时减少计算量。


模型架构设计在机器学习中扮演着核心角色,与数据、算法、算力和基准测试一样重要。它定义了模型函数、算子选择(如注意力机制、卷积)和配置设定(如模型深度、宽度)等等模型要素。


尽管如此,由于从头训练模型的成本过高 —— 尤其人们难以获得关于架构设计的深刻洞见(即哪些方案有效、哪些无效)。因此,研究新架构仍是一项挑战,对生成模型而言尤为如此。


在本文中,来自斯坦福大学、 Liquid AI 等机构的研究者探索了这一问题,即对预训练模型进行架构编辑来研究新架构。



  • 论文链接:https://arxiv.org/pdf/2506.05340v1

  • 论文主页:https://grafting.stanford.edu/

  • 论文标题: Exploring Diffusion Transformer Designs via Grafting 


具体而言,该研究提出了一种编辑预训练扩散 transformer(DiT)的简单方法,即 Grafting(嫁接),该方法可以在较小的计算预算下实现新的架构。


嫁接过程如下:


(i)激活蒸馏:此阶段通过回归目标(regression objective)蒸馏原始算子的激活特征,将其功能迁移至新算子。该阶段核心在于实现算子间的功能传递。

(ii)轻量级调优:此阶段通过使用有限的数据进行调优,减轻了由于集成多个新算子而导致的误差传播。


此外,架构编辑还涵盖多种策略,如添加、删除和替换算子。



本文还基于 DiT-XL/2 构建了一个测试平台,以研究嫁接对模型质量的影响。


利用该测试平台,本文通过嫁接技术开发了一系列混合设计:用门控卷积、局部注意力和线性注意力取代 Softmax 注意力,用可变扩展率和卷积变体取代 MLP。


值得注意的是,许多混合设计使用不到 2% 的预训练计算资源就实现了良好的质量(FID:2.38–2.64,而 DiT-XL/2 为 2.27)。然后,本文嫁接了一个文本转图像模型 (PixArt-Σ),实现了 1.43 倍的加速,而 GenEval 分数下降不到 2%。


最后,本文展示了一个案例研究,该研究通过嫁接技术将每对序列 Transformer 模块转换为并行模块,从而重构了 DiT-XL/2。这将模型深度减少到原来一半,并获得了比其他同等深度模型更高的质量(FID:2.77)。


总而言之,该研究展示了可以通过预训练 DiT 来探索新的扩散模型设计,其修改范围涵盖从算子替换到架构重构。


嫁接扩散 Transformer 


两阶段嫁接方法


嫁接旨在通过编辑预训练模型的计算图来实现新架构。由于该研究专注于用替代方案替换现有算子,这引出了两个问题:


问题 1:在将新算子集成到计算图之前,应该如何初始化?


对应第一阶段:通过激活蒸馏进行初始化。由于 DiT 的激活是连续且平滑的,这可以被视为一个回归问题:


image.png


问题 2:当多个算子集成到计算图时,如何减轻误差传播?


对应第二阶段:轻量级调优。随着更多算子被替换,初始化误差会不断传播,导致与预训练模型的行为出现偏差。


本文采用端到端微调来缓解阶段 1 的累积误差。微调目标函数如公式 1 所示。


实践中,本文发现,即使替换 DiT-XL/2 中的所有 MHA 或 MLP 层,仅使用 10% 的训练数据也能恢复竞争性能。


image.png


自嫁接基准


在研究新的架构设计之前,该研究引入了自嫁接(self-grafting),这是一种简单的对照设置:将现有算子(如 MHA、MLP)替换为相同类型但权重随机初始化的算子。这样可以保持计算图的结构 —— 包括算子类型和参数数量 —— 但改变了具体的计算过程。自嫁接有三方面作用:(1)评估在不改变架构的情况下嫁接流程本身的效果;(2)为比较不同的替换方案提供一个性能基准;(3)研究影响性能的因素,如数据规模、回归目标和超参数。


激活行为分析以及自嫁接结果


本文首先分析了 DiT-XL/2 层中的 MHA 和 MLP 算子激活行为。在这两种情况下,本文观察到激活值存在较大差异,尤其是在较深的层中(表 1 (i, ii))。



经过分析,本文得出通过选择特定于算子的回归目标,可以实现高质量的初始化。


如表 1 (iii,iv) 所示,回归目标的选择会影响性能。对于 MHA,L1 实现了最佳 FID(2.51),其次是 Huber(2.55)和 L2(2.58)。对于 MLP,L2 表现最佳(2.33),而 L1 表现不佳(2.83);值得注意的是,MLP 的参数量是 MHA 的 2 倍。


这表明高质量的初始化需要量身定制的、激活感知的策略。


研究还发现,使用 10% 的数据进行完全自嫁接可实现接近基线的性能。表明在适度的数据和计算预算下完全自嫁接是可行的。



实验


实验 I:通过嫁接实现混合架构


本节实验围绕这个问题进行:当现有算子被高效的替代方案取代时,我们能否保持模型质量?


为了探究这个问题,本文研究了以下嫁接过程:


1. 待替换算子的类型 ——MHA 或 MLP;

2. 替换算子的类型 —— 例如卷积;

3. 层选择策略 —— 替换所有层中的算子或使用启发式选择;

4. 替换率 —— 全部替换或部分替换。


为了实验,该研究构建了一个测试平台,并提出两种层选择策略:完全替换和交错替换。测试平台详见表 3。



此外,该研究还引入了 Hyena-X 和 Hyena-Y 两种新的高效门控卷积算子,并设计为 MHA 的直接替代品。Figure 3 展示了它们的结构。



MHA 结果。通过嫁接替换 DiT-XL/2 中的 MHA 算子,获得了良好的质量 - 效率权衡。主要发现如下:


在交错嫁接下,较小的感受野表现出惊人的效果。实验发现,在 50% 交错替换比例下,滑动窗口注意力(SWA)、Hyena-X/Y 和 Mamba-2 等替代方案均能保持 FID 分数与基线(2.27)差距在 0.5 以内。尤其值得注意的是,尽管 SWA 和 Hyena 变体的感受野有限(卷积核 K=4 / 窗口 w=4),其 FID 下降幅度却极小。


替换策略:交错替换 vs. 完全替换。将交错替换比例从 50% 提升至 75% 时,性能通常下降,但 SWA 在 75% 交错替换下仍有效(FID=3.09)。100% 替换时,性能急剧恶化(所有 FID > 75),这与局部性分析一致,表明只有部分层是局部且适合嫁接的。


数据规模和层选择的消融实验结果。



MLP 结果显示通过嫁接的方式替换 MLP 算子是有效的。


经过实验,得出要点 1:嫁接对于在较小的计算预算下构建具有良好生成质量的高效混合架构非常有效。交错设计尤其有效。


实验 II:通过嫁接改进文本到图像的扩散 Transformers 


结果。嫁接模型在实时计算速度(wall-clock time)上实现了 1.43 倍的提升,同时生成评估分数(GenEval)仅出现小幅下降(47.78 vs. 49.75)。特定属性的指标(Attribute-specific metrics)基本保持可比,并且定性样本也展现出良好的对齐度和质量。在一些纹理区域观察到了局部性的失真(artifacts),这可能是由于 LoRA 的适应能力以及所使用的合成数据质量不高所致(失败案例详见图 D.3,D.4)




要点 2:在文生图 DiTs 中成功应用嫁接技术,构建的混合架构在实现显著加速的同时,生成质量损失极小。


了解更多内容,请参考原论文。


© THE END 

转载请联系本公众号获得授权

投稿或寻求报道:[email protected]

我觉得还可以从正则化的角度考虑。更浅的模型天然具有更强的正则化效果,可以防止过拟合,提高模型的泛化能力。当然,前提是模型本身具有足够的表达能力。

算子选择确实是个 tricky 的问题。我觉得可以从两个角度考虑:一是算子的计算效率,比如用卷积替代注意力,理论上能降低计算复杂度;二是算子的特性,比如有些算子可能更擅长捕捉局部信息,而有些则更擅长全局信息,要结合任务需求选择。

我理解作者的思路是实验驱动的。他们先用自嫁接(self-grafting)做基准测试,看看不同回归目标对性能的影响,然后才尝试不同的算子替换方案。所以我觉得可以先小范围实验,观察不同算子组合的效果,再逐步扩大范围。

减少模型深度最直接的意义当然是降低计算成本,部署到边缘设备或者移动端会更方便。而且,更浅的模型可能更容易训练,收敛速度更快,调参也更容易。

也有可能是训练数据的问题。文章提到使用了合成数据,如果合成数据在纹理细节上不够逼真,就可能导致模型在真实图像上出现失真。可以考虑增加真实数据的比例,或者使用更先进的合成技术。

局部性失真可能说明新引入的算子在处理纹理细节方面存在不足。或者,嫁接过程中对纹理信息的迁移不够充分。可以尝试调整激活蒸馏的参数,或者引入专门处理纹理的算子。

我觉得可以尝试在损失函数中加入纹理相关的惩罚项,引导模型更加关注纹理细节。例如,可以使用基于图像梯度的损失函数,或者引入预训练的纹理特征提取器。