为什么蒸馏后的扩散模型反而更好量化?
最近在量化文生图模型时,我碰到一个很反我直觉的现象:在相近的量化配置下,蒸馏模型比 base model 画质下降更不明显。
我原本觉得蒸馏把原本几十步的采样压缩成几步,看起来像是把同样的事情做得更激进了;如果一步里要完成更多事情,内部表示似乎应该更“sharp”,也就更容易出现 outlier,更难被低精度近似。
这个直觉实际上是把两件事混在了一起:
- 采样轨迹是不是更短
- 权重和激活是不是更容易被量化近似
蒸馏做了什么
如果先看普通的 diffusion training,模型做的事情其实很朴素:给它在第 $t$ 个时间步上的带噪样本 $x_t$($x_0$ 是原始图像或原始 latent)和当前的 timestep,让它学会预测噪声 $\epsilon$、$v$,或者别的等价目标。也就是说,base model 学的是单步怎么去噪。
最常见的写法里:
\[x_t = \alpha_t x_0 + \sigma_t \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)\]这里的 $\mathcal{N}(0, I)$ 表示标准高斯噪声,$I$ 是单位矩阵。
如果采用这类参数化,通常会有 $\alpha_t^2 + \sigma_t^2 = 1$。
也就是说,$x_t$ 可以看成“原始信号 $x_0$”和“噪声 $\epsilon$”在第 $t$ 步的混合结果。
$v$ 则通常是 $x_0$ 和 $\epsilon$ 的另一种参数化。一种常见写法是:
\[v_t = \alpha_t \epsilon - \sigma_t x_0\]所以不管模型预测的是噪声 $\epsilon$、原图 $x_0$,还是 $v$,本质上都是在学同一件事的不同表达。
而蒸馏多出来的关键角色是 teacher。
训练时,teacher 先按原来的多步采样过程跑出结果,student 再去学这个结果。可以很粗糙地理解成:
teacher负责告诉student,原来走很多小步之后,大概会走到哪里student负责学习,能不能用更少的步数直接走到一个差不多的位置
所以从流程上看,蒸馏不是把原模型简单“加速”了一下,而是把一部分原本属于推理阶段的多步过程,提前搬到了训练阶段里。
最后训练出来的 student,学的也不再是“在很多 timestep 上都做局部修正”,而是“在少数几个采样点上,直接逼近 teacher 多步迭代后的效果”。
这么理解的话,蒸馏至少做了几件比较明确的事情:
- 把监督信号从真实构造的 target,部分换成了
teacher的输出 - 把原来很长的采样轨迹,压缩成了少数几个更关键的 sampling nodes
- 让
student重点优化少步数场景下的结果,而不是覆盖整条长轨迹
这样一来,后面真正需要问的问题就不是:蒸馏之后是不是一定更 sharp。
而是:当模型学的目标从“每一步都做一点局部修正”变成“用少数几步逼近原来很多步的结果”时,它的内部统计会更尖锐,还是反而更规整。
“少步数 = 更 Sharp”?
我最开始是这么想的:
base model用很多步慢慢修- 蒸馏模型用很少几步直接到位
- 那蒸馏模型每一步承担的”变化量”更大
- 所以它内部的激活和参数应该更极端,也就更难量化
这个直觉不能说完全没道理,但它其实偷换了一个概念。
这里的 “sharp” 可能在说两件完全不同的事:
- 生成行为上更果断:更少的步数去噪成近似的图像
- 数值统计上更尖锐:更大的 dynamic range、更多 outlier、更高的量化敏感性
前者经常成立,但前者并不能直接推出后者。
少步数最多只能说明:student 学会了一条更短的去噪路径。它不一定意味着:
- 权重分布更散
- activation range 更大
- 某些通道更依赖极端值
- 对低精度噪声更敏感
所以蒸馏并不意味着模型会更好量化还是更难量化。
ai 给出的一个合理的解释
1. 蒸馏学到的更像是 teacher 诱导出来的一类稳定映射
蒸馏时,student 学的首先不是直接从数据分布里自己摸索所有去噪路径,而是在少数几个采样点上,尽量逼近 teacher 已经给出来的一类输出映射。
这和从头训练一个 base model 不太一样。base model 需要自己去覆盖整条采样轨迹;而蒸馏模型更多是在少数几个采样点上,尽量逼近 teacher 已经算出来的结果。
如果 teacher 给出的目标本身比较稳定,那么 student 学出来的表示就更有可能更规整一些。这里想表达的不是“它一定会更集中”,而是蒸馏这件事本身会把训练目标变得更受约束。
从部署视角看,这种“更受约束”的训练目标,可能会带来几件量化友好的事情:
- 参数和激活分布更集中
- 对少数极端通道的依赖下降
- 对小数值扰动更不敏感
- calibration 样本更容易覆盖真实使用区间
2. 蒸馏把长采样链路里的“过渡态”折叠掉了
典型的 base model 会在 20 到 50 步里逐步修正噪声,内部会经过很多中间态。对非量化模型来说,这些中间态没问题;但对量化模型来说,它们往往意味着更复杂的统计变化:
- 不同步的 activation range 差异更大
- 某些 step 对高幅值通道特别敏感
- layer 的分布会随
timestep明显漂移
蒸馏模型把原本长链路中的一部分中间过程折叠掉之后,未必会让每一步都更“剧烈”,也可能只是减少了很多原本必须经过的中间态。这样一来,不同 step 之间的统计变化也更容易变小。
3. 量化误差在 base model 里更容易累计
base model 的优势在于可以多次微调,但代价是量化误差也会被反复注入。哪怕每一步只多出一点 bias,累积二三十步之后,最后的 latent 偏移也可能已经足够影响画质。
蒸馏模型虽然单步任务更重,但总共只走少数几步,误差注入次数本身就少得多。于是最终结果经常是:
- 单步误差不一定更小
- 但总误差更不容易滚大
这里更像是在说整条采样链对量化误差更不敏感,而不是单次网络调用本身就一定更容易量化。
| 模型类型 | 步数 | 量化误差累计 |
|---|---|---|
base model | 20 到 50 步 | 每步很小的 bias,累计多次 |
| 蒸馏模型 | 1 到 4 步 | 只累计少数几次 |
所以在最终图像质量上,蒸馏模型往往会显得更抗量化。
4. 蒸馏模型通常只需要在更窄的操作域里表现稳定
很多蒸馏模型并不是在“所有可能配置”上都追求最优,而是在一个更明确的部署域里训练出来的,比如:
- 固定或很少的采样步数
- 有限的
CFG区间 - 特定分辨率或模型版本
- 更少的 scheduler 变化
操作域越窄,模型越不需要为很多极端 case 预留表达能力。对量化来说,这意味着:
- 激活支持集更小
- 极端分布更少出现
- calibration 数据更容易真正覆盖部署分布
这也是为什么很多蒸馏模型在“目标场景里更稳”,但并不代表它们在更宽广的设置下也同样稳。
结尾
所以,真正需要警惕的不是“蒸馏模型是不是更 sharp”,而是我们是不是把两种不同的 sharpness 混在了一起:
- 一种是生成轨迹上的“更快、更果断”
- 一种是数值表示上的“更尖、更难量化”
前者确实常常成立,但它并不自动推出后者。
更准确地说,“蒸馏模型量化后往往更稳”应该被理解成一个部署层面的现象:它不一定说明蒸馏模型本身天然更容易量化,却很可能说明它们在目标采样路径上更规整、更少犯错,也更不容易把量化噪声一步步放大。