Post

为什么蒸馏后的扩散模型反而更好量化?

为什么蒸馏后的扩散模型反而更好量化?

最近在量化文生图模型时,我碰到一个很反我直觉的现象:在相近的量化配置下,蒸馏模型比 base model 画质下降更不明显。

我原本觉得蒸馏把原本几十步的采样压缩成几步,看起来像是把同样的事情做得更激进了;如果一步里要完成更多事情,内部表示似乎应该更“sharp”,也就更容易出现 outlier,更难被低精度近似。

这个直觉实际上是把两件事混在了一起:

  • 采样轨迹是不是更短
  • 权重和激活是不是更容易被量化近似

蒸馏做了什么

如果先看普通的 diffusion training,模型做的事情其实很朴素:给它在第 $t$ 个时间步上的带噪样本 $x_t$($x_0$ 是原始图像或原始 latent)和当前的 timestep,让它学会预测噪声 $\epsilon$、$v$,或者别的等价目标。也就是说,base model 学的是单步怎么去噪

最常见的写法里:

\[x_t = \alpha_t x_0 + \sigma_t \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)\]

这里的 $\mathcal{N}(0, I)$ 表示标准高斯噪声,$I$ 是单位矩阵。

如果采用这类参数化,通常会有 $\alpha_t^2 + \sigma_t^2 = 1$。

也就是说,$x_t$ 可以看成“原始信号 $x_0$”和“噪声 $\epsilon$”在第 $t$ 步的混合结果。

$v$ 则通常是 $x_0$ 和 $\epsilon$ 的另一种参数化。一种常见写法是:

\[v_t = \alpha_t \epsilon - \sigma_t x_0\]

所以不管模型预测的是噪声 $\epsilon$、原图 $x_0$,还是 $v$,本质上都是在学同一件事的不同表达。

而蒸馏多出来的关键角色是 teacher

训练时,teacher 先按原来的多步采样过程跑出结果,student 再去学这个结果。可以很粗糙地理解成:

  • teacher 负责告诉student,原来走很多小步之后,大概会走到哪里
  • student 负责学习,能不能用更少的步数直接走到一个差不多的位置

所以从流程上看,蒸馏不是把原模型简单“加速”了一下,而是把一部分原本属于推理阶段的多步过程,提前搬到了训练阶段里。

最后训练出来的 student,学的也不再是“在很多 timestep 上都做局部修正”,而是“在少数几个采样点上,直接逼近 teacher 多步迭代后的效果”。

这么理解的话,蒸馏至少做了几件比较明确的事情:

  • 把监督信号从真实构造的 target,部分换成了 teacher 的输出
  • 把原来很长的采样轨迹,压缩成了少数几个更关键的 sampling nodes
  • student 重点优化少步数场景下的结果,而不是覆盖整条长轨迹

这样一来,后面真正需要问的问题就不是:蒸馏之后是不是一定更 sharp。

而是:当模型学的目标从“每一步都做一点局部修正”变成“用少数几步逼近原来很多步的结果”时,它的内部统计会更尖锐,还是反而更规整。

“少步数 = 更 Sharp”?

我最开始是这么想的:

  • base model 用很多步慢慢修
  • 蒸馏模型用很少几步直接到位
  • 那蒸馏模型每一步承担的”变化量”更大
  • 所以它内部的激活和参数应该更极端,也就更难量化

这个直觉不能说完全没道理,但它其实偷换了一个概念。

这里的 “sharp” 可能在说两件完全不同的事:

  • 生成行为上更果断:更少的步数去噪成近似的图像
  • 数值统计上更尖锐:更大的 dynamic range、更多 outlier、更高的量化敏感性

前者经常成立,但前者并不能直接推出后者。

少步数最多只能说明:student 学会了一条更短的去噪路径。它不一定意味着:

  • 权重分布更散
  • activation range 更大
  • 某些通道更依赖极端值
  • 对低精度噪声更敏感

所以蒸馏并不意味着模型会更好量化还是更难量化。

ai 给出的一个合理的解释

1. 蒸馏学到的更像是 teacher 诱导出来的一类稳定映射

蒸馏时,student 学的首先不是直接从数据分布里自己摸索所有去噪路径,而是在少数几个采样点上,尽量逼近 teacher 已经给出来的一类输出映射。

这和从头训练一个 base model 不太一样。base model 需要自己去覆盖整条采样轨迹;而蒸馏模型更多是在少数几个采样点上,尽量逼近 teacher 已经算出来的结果。

如果 teacher 给出的目标本身比较稳定,那么 student 学出来的表示就更有可能更规整一些。这里想表达的不是“它一定会更集中”,而是蒸馏这件事本身会把训练目标变得更受约束。

从部署视角看,这种“更受约束”的训练目标,可能会带来几件量化友好的事情:

  • 参数和激活分布更集中
  • 对少数极端通道的依赖下降
  • 对小数值扰动更不敏感
  • calibration 样本更容易覆盖真实使用区间

2. 蒸馏把长采样链路里的“过渡态”折叠掉了

典型的 base model 会在 20 到 50 步里逐步修正噪声,内部会经过很多中间态。对非量化模型来说,这些中间态没问题;但对量化模型来说,它们往往意味着更复杂的统计变化:

  • 不同步的 activation range 差异更大
  • 某些 step 对高幅值通道特别敏感
  • layer 的分布会随 timestep 明显漂移

蒸馏模型把原本长链路中的一部分中间过程折叠掉之后,未必会让每一步都更“剧烈”,也可能只是减少了很多原本必须经过的中间态。这样一来,不同 step 之间的统计变化也更容易变小。

3. 量化误差在 base model 里更容易累计

base model 的优势在于可以多次微调,但代价是量化误差也会被反复注入。哪怕每一步只多出一点 bias,累积二三十步之后,最后的 latent 偏移也可能已经足够影响画质。

蒸馏模型虽然单步任务更重,但总共只走少数几步,误差注入次数本身就少得多。于是最终结果经常是:

  • 单步误差不一定更小
  • 但总误差更不容易滚大

这里更像是在说整条采样链对量化误差更不敏感,而不是单次网络调用本身就一定更容易量化。

模型类型步数量化误差累计
base model20 到 50 步每步很小的 bias,累计多次
蒸馏模型1 到 4 步只累计少数几次

所以在最终图像质量上,蒸馏模型往往会显得更抗量化。

4. 蒸馏模型通常只需要在更窄的操作域里表现稳定

很多蒸馏模型并不是在“所有可能配置”上都追求最优,而是在一个更明确的部署域里训练出来的,比如:

  • 固定或很少的采样步数
  • 有限的 CFG 区间
  • 特定分辨率或模型版本
  • 更少的 scheduler 变化

操作域越窄,模型越不需要为很多极端 case 预留表达能力。对量化来说,这意味着:

  • 激活支持集更小
  • 极端分布更少出现
  • calibration 数据更容易真正覆盖部署分布

这也是为什么很多蒸馏模型在“目标场景里更稳”,但并不代表它们在更宽广的设置下也同样稳。

结尾

所以,真正需要警惕的不是“蒸馏模型是不是更 sharp”,而是我们是不是把两种不同的 sharpness 混在了一起:

  • 一种是生成轨迹上的“更快、更果断”
  • 一种是数值表示上的“更尖、更难量化”

前者确实常常成立,但它并不自动推出后者。

更准确地说,“蒸馏模型量化后往往更稳”应该被理解成一个部署层面的现象:它不一定说明蒸馏模型本身天然更容易量化,却很可能说明它们在目标采样路径上更规整、更少犯错,也更不容易把量化噪声一步步放大。

This post is licensed under CC BY 4.0 by the author.