您现在的位置是:首页 >技术杂谈 >扩散模型加速采样算法DDIM网站首页技术杂谈

扩散模型加速采样算法DDIM

next_travel 2025-03-24 00:01:03
简介扩散模型加速采样算法DDIM

摘要

扩散模型(Diffusion Models, DM)近年来成为生成模型研究的热点,其中去噪扩散概率模型(DDPM)利用马尔科夫链逐步去噪以生成高质量样本。然而,DDIM(去噪扩散隐式模型)通过非马尔科夫链方法优化了采样过程,提高了生成效率。本周学习了DDIM的非马尔科夫链前向扩散过程,并分析了其后验分布与DDPM的对比,展示了DDIM如何通过引入超参数 σ t sigma_t σt实现更灵活更一般的采样方式。此外,讨论了加速采样技巧respacing,它通过重新分配时间步长来减少采样步骤,同时保持生成样本质量。此外,本周还阅读了CatV2TON论文,该方法基于扩散变换器(DiT),提出了时间连接策略和自适应剪辑归一化(AdaCN),在虚拟试穿任务上实现了高效且连贯的视频生成。

abstract

Diffusion Models (DM) have become the focus of generative model research in recent years, in which denoising diffusion probability model (DDPM) uses Markov chain to gradually denoise to generate high-quality samples. However, DDIM (Denoising Diffusion Implicit Model) optimizes the sampling process through non-Markov chain methods and improves the generation efficiency. This week we studied the non-Markov chain forward diffusion process of DDIM and analyzed the posterior distribution compared to DDPM, showing how DDIM can achieve a more flexible and general sampling method by introducing the hyperparameter σ t sigma_t σt. In addition, the accelerated sampling technique respacing is discussed, which reduces sampling steps by redistributing time steps while maintaining the generated sample quality. In addition, I read the CatV2TON paper this week, which is based on a diffusion converter (DiT) and proposes a time-connected strategy and adaptive clip normalization (AdaCN) for efficient and coherent video generation on virtual fitting tasks.

1.扩散模型加速采样算法(DDIM)

回顾:马尔科夫链过程在扩撒模型中起到了核心作用,主要体现在数据的逐步去噪和重建过程中。上上周学习DDPM是一类基于马尔科夫链的生成模型,用于高质量的数据生成,如图像生成。
DDIM相对于DDPM在训练方法上是相同的,主要区别在于上采样过程(使用非马尔科夫链过程)。

1.1 非马尔科夫链的前向扩散过程

公式: q σ ( x 1 : T ∣ x 0 ) : = q σ ( x T ∣ x 0 ) ∏ t = 2 T q σ ( x t − 1 ∣ x t , x 0 ) q_sigma(oldsymbol{x}_{1:T}|oldsymbol{x}_0):=q_sigma(oldsymbol{x}_T|oldsymbol{x}_0)prod_{t=2}^Tq_sigma(oldsymbol{x}_{t-1}|oldsymbol{x}_t,oldsymbol{x}_0) qσ(x1:Tx0):=qσ(xTx0)t=2Tqσ(xt1xt,x0)
q σ ( x t − 1 ∣ x t , x 0 ) = N ( α t − 1 x 0 + 1 − α t − 1 − σ t 2 ⋅ x t − α t x 0 1 − α t , σ t 2 I ) . q_sigma(oldsymbol{x}_{t-1}|oldsymbol{x}_t,oldsymbol{x}_0)=mathcal{N}left(sqrt{alpha_{t-1}}oldsymbol{x}_0+sqrt{1-alpha_{t-1}-sigma_t^2}cdotfrac{oldsymbol{x}_t-sqrt{alpha_t}oldsymbol{x}_0}{sqrt{1-alpha_t}},sigma_t^2oldsymbol{I} ight). qσ(xt1xt,x0)=N(αt1 x0+1αt1σt2 1αt xtαt x0,σt2I).
在马尔科夫链过程中的公式为:
q ( x 1 : T ∣ x 0 ) = ∏ t = 1 T q ( x t ∣ x t − 1 ) q(mathbf{x}_{1:T}|mathbf{x}_0)=prod_{t=1}^Tq(mathbf{x}_t|mathbf{x}_{t-1}) q(x1:Tx0)=t=1Tq(xtxt1)

需要证明在DDIM中的 q ( x T ∣ x 0 ) q(mathbf{x}_{T}|{x}_{0}) q(xTx0)也满足在DDPM算法中的边缘分布相同的分布。(证明结果表明DDIM可以使用和DDPM相同的损失函数Lsample

1.2 对比非马尔科夫链扩散后验分布与DDPM马尔科夫链扩散的后验分布

DDPM马尔科夫链扩散的后验分布如下
q ( x t − 1 ∣ x t , x 0 ) = N ( x t − 1 ; μ ~ t ( x t , x 0 ) , β ~ t I ) , w h e r e μ ~ t ( x t , x 0 ) : = α ˉ t − 1 β t 1 − α ˉ t x 0 + α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t a n d β ~ t : = 1 − α ˉ t − 1 1 − α ˉ t β t egin{aligned} q(mathbf{x}_{t-1}|mathbf{x}_{t},mathbf{x}_{0}) & =mathcal{N}(mathbf{x}_{t-1}; ilde{oldsymbol{mu}}_{t}(mathbf{x}_{t},mathbf{x}_{0}), ilde{eta}_{t}mathbf{I}), \ mathrm{where}quad ilde{mu}_{t}(mathbf{x}_{t},mathbf{x}_{0}) & :=frac{sqrt{ar{alpha}_{t-1}}eta_{t}}{1-ar{alpha}_{t}}mathbf{x}_{0}+frac{sqrt{alpha_{t}}left(1-ar{alpha}_{t-1} ight)}{1-ar{alpha}_{t}}mathbf{x}_{t}quadmathrm{and}quad ilde{eta}_{t}:=frac{1-ar{alpha}_{t-1}}{1-ar{alpha}_{t}}eta_{t} end{aligned} q(xt1xt,x0)whereμ~t(xt,x0)=N(xt1;μ~t(xt,x0),β~tI),:=1αˉtαˉt1 βtx0+1αˉtαt (1αˉt1)xtandβ~t:=1αˉt1αˉt1βt
DDIM非马尔科夫链扩散后验分布如下
q σ ( x t − 1 ∣ x t , x 0 ) = N ( α t − 1 x 0 + 1 − α t − 1 − σ t 2 ⋅ x t − α t x 0 1 − α t , σ t 2 I ) . q_{sigma}(oldsymbol{x}_{t-1}|oldsymbol{x}_{t},oldsymbol{x}_{0})=mathcal{N}left(sqrt{alpha_{t-1}}oldsymbol{x}_{0}+sqrt{1-alpha_{t-1}-sigma_{t}^{2}}cdotfrac{oldsymbol{x}_{t}-sqrt{alpha_{t}}oldsymbol{x}_{0}}{sqrt{1-alpha_{t}}},sigma_{t}^{2}oldsymbol{I} ight). qσ(xt1xt,x0)=N(αt1 x0+1αt1σt2 1αt xtαt x0,σt2I).
上述两个式子的对比,在非马尔科夫链过程中多了一个超参数 σ t 2 sigma_{t}^{2} σt2,其取值的不同也决定着两者后验概率的不同。相当于上采样过程的不同。
在DDIM中xt-1与xt的关系有如下式子:
在这里插入图片描述
(1)当 σ t sigma_{t} σt取值为
σ t = ( 1 − α t − 1 ) / ( 1 − α t ) 1 − α t / α t − 1 sigma_{t}=sqrt{(1-alpha_{t-1})/(1-alpha_{t})}sqrt{1-alpha_{t}/alpha_{t-1}} σt=(1αt1)/(1αt) 1αt/αt1 通过变换发现退化成了基于马尔科夫链的扩散过程。
(2)当 σ t = 0 sigma_{t}=0 σt=0 时,非马尔科夫链扩散过程时是确定性的。也就是DDIM的来源。

1.3 加速采样的技巧respacing

其核心思想时通过调整扩散过程中的时间步长分布,减少采样所需的步骤,同时保证生成的样本质量。
(1)时间步长的重新分配

  1. 在DDPM中,扩散过程通常被划分为固定的时间步长,每一步对应特定的噪声水平。
  2. respacing 通过选择一部分关键时间步长,跳过中间的非关键步骤,从而加速采样。

(2)非均匀时间步长选择
Respacing 允许选择非均匀的时间步长,而不是均匀分布的步长。可以选择在噪声水平较高的区域(早期扩散阶段)使用更多步长,而在噪声水平较低的区域(后期扩散阶段)使用较少步长。

1.4 DDIM采样函数

返回生成的样本和预测的初始数据 x_start

import torch
import torch as th

def ddim_sample(
    self,
    model,
    x,
    t,
    clip_denoised=True,
    denoised_fn=None,
    model_kwargs=None,
    eta=0.0,
):
    """
    Sample x_t[-1] from the model using DDIM.

    Same usage as p_sample().
    """
    out = self.p_mean_variance(
        model,
        x,
        t,
        clip_denoised=clip_denoised,
        denoised_fn=denoised_fn,
        model_kwargs=model_kwargs,
    )
    
    # Usually our model outputs epsilon, but we re-derive it
    # in case we used x_start or x_prev prediction.
    eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
    
    alpha_bar = self._extract_into_tensor(self.alphas_cumprod, t, x.shape)
    alpha_bar_prev = self._extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
    
    sigma = (
        eta
        * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
        * th.sqrt(1 - alpha_bar / alpha_bar_prev)
    )
    
    # Equation 12.
    noise = th.randn_like(x)
    mean_pred = (
        out["pred_xstart"] * th.sqrt(alpha_bar_prev)
        + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
    )
    
    nonzero_mask = (
        (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
    )  # no noise when t == 0
    
    sample = mean_pred + nonzero_mask * sigma * noise
    
    return {"sample": sample, "pred_xstart": out["pred_xstart"]}

def p_mean_variance(
    self,
    model,
    x,
    t,
    clip_denoised=True,
    denoised_fn=None,
    model_kwargs=None,
):
    """
    Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
    the initial x, x_0.
    """
    if model_kwargs is None:
        model_kwargs = {}

    B, C = x.shape[:2]
    assert t.shape == (B,)
    
    model_output = model(x, t, **model_kwargs)
    
    if self.model_var_type in ["learned", "learned_range"]:
        assert model_output.shape == (B, C * 2, *x.shape[2:])
        model_output, model_var_values = th.split(model_output, C, dim=1)
        if self.model_var_type == "learned":
            model_log_variance = model_var_values
            model_variance = th.exp(model_log_variance)
        else:
            min_log = _extract_into_tensor(
                self.posterior_log_variance_clipped, t, x.shape
            )
            max_log = _extract_into_tensor(th.log(self.betas), t, x.shape)
            # The model_var_values is [-1, 1] for [min_var, max_var].
            frac = (model_var_values + 1) / 2
            model_log_variance = frac * max_log + (1 - frac) * min_log
            model_variance = th.exp(model_log_variance)
    else:
        model_variance, model_log_variance = {
            # for fixedlarge, we set the initial (log-)variance like so
            # to get a better decoder log likelihood.
            "fixed_large": (
                th.cat([self.posterior_variance[1].unsqueeze(0), self.betas[1:]], 0),
                th.log(th.cat([self.posterior_variance[1].unsqueeze(0), self.betas[1:]], 0)),
            ),
            "fixed_small": (
                self.posterior_variance,
                self.posterior_log_variance_clipped,
            ),
        }[self.model_var_type]
        model_variance = _extract_into_tensor(model_variance, t, x.shape)
        model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)

    def process_xstart(x):
        if denoised_fn is not None:
            x = denoised_fn(x)
        if clip_denoised:
            return x.clamp(-1, 1)
        return x

    pred_xstart = process_xstart(
        self._predict_xstart_from_eps(x, t, eps=model_output)
    )
    model_mean, _, _ = self.q_posterior_mean_variance(
        pred_xstart, x, t, model_variance, model_log_variance
    )
    
    return {
        "mean": model_mean,
        "variance": model_variance,
        "log_variance": model_log_variance,
        "pred_xstart": pred_xstart,
    }

def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
    return (
        _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
        - pred_xstart
    ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)

def _extract_into_tensor(arr, t, broadcast_shape):
    """
    Extract values from a 1-D numpy array for a batch of indices.
    """
    res = th.from_numpy(arr).to(device=t.device)[t].float()
    while len(res.shape) < len(broadcast_shape):
        res = res[..., None]
    return res.expand(broadcast_shape)

2.论文阅读

本周阅读了一篇2025年一月分发表的文章《
CatV2TON: Taming Diffusion Transformers for Vision-Based Virtual Try-On with Temporal Concatenation》论文链接. 想象一下,无需亲自试穿,仅凭借一张照片或者一段视频,就能精准预览任何心仪服装的上身效果,中山大学和新加坡国立大学提出的基于视觉的虚拟试穿(CatV2TON)技术将其变为现实。

2.1 介绍

在这里插入图片描述CatV2TON是一种简单有效的基于视觉的虚拟试穿 (V2TON) 方法,它使用单个扩散变压器模型支持图像和视频试穿任务。通过在时间上连接服装和人员输入并在混合图像和视频数据集上进行训练,CatV2TON在静态和动态设置中实现了强大的试穿性能。
为了高效地生成长视频,论文还提出了一种基于重叠剪辑的推理策略,该策略使用顺序帧引导和自适应剪辑规范化 (AdaCN) 来保持时间一致性并减少资源需求。还介绍了 ViViD-S,这是一个经过精炼的视频试穿数据集,通过过滤背面帧并应用 3D 蒙版平滑来增强时间一致性。综合实验表明,CatV2TON 在图像和视频试穿任务中均优于现有方法,为跨不同场景的逼真虚拟试穿提供了多功能且可靠的解决方案。

2.2 方法

在这里插入图片描述
CatV2TON使用DiT作为主干,第一个DiT块被复制为Pose Encoder。人和服装条件在时间上连接为试穿条件。整个可训练部分仅由自注意力层和Pose Encoder组成,占总参数的不到1/5。
在这里插入图片描述
基于重叠片段的推理策略说明
(a)一段长视频被分成n个重叠片段,每个片段由重复的帧组成。每个片段的最后k帧用作生成下一个片段的提示帧。
(b) 自适应片段归一化(AdaCN)用于根据提示帧特征和去噪提示帧的平均值和标准差对整个片段进行归一化,确保生成的视频中各个片段之间的平滑连续性。

2.3 结果

在 ViViD 数据集上对连衣裙进行定性比较。
在这里插入图片描述
关于AdaCN的消融实验结果
在这里插入图片描述
当不使用AdaCN进行推理时,试穿结果中的服装部分将出现色差问题,并且通常会随着视频长度的增加而加剧。

2.4 结论

CatV2TON是一个简单而高效的扩散变换器框架,适用于图像和视频虚拟试穿任务。通过时间连接服装和人员输入并使用混合图像视频数据集进行训练,模型仅使用 20% 的主干参数作为可训练组件即可获得高质量的结果。为了支持长时间、时间一致的试穿视频生成,引入了一种基于重叠剪辑的推理策略和自适应剪辑规范化 (AdaCN),在保持时间连续性的同时减少了资源需求。此外论文提出了一个精选的视频试穿数据集 ViViD-S,它是通过过滤后视帧并应用 3D 蒙版平滑来增强蒙版的时间一致性而创建的。大量实验表明,CatV2TON 在定量和定性评估方面均优于基线方法,标志着基于视觉的虚拟试穿研究统一模型向前迈出了重要一步。

3.总结

本周学习了扩散模型加速采样算法DDIM,分析其非马尔科夫链特性及如何优化采样过程,提高效率。通过引入自适应超参数 σ t sigma_t σt,DDIM实现了比DDPM更灵活的生成方式,并结合respacing技巧进一步提升采样速度。阅读了一篇基于扩散模型的视觉虚拟试穿技术CatV2TON。该方法利用DiT主干,并引入时间连接策略,在图像和视频任务上取得显著优势。其提出的重叠剪辑推理策略和AdaCN有效提高了长视频的时间一致性,同时减少了计算成本。

风语者!平时喜欢研究各种技术,目前在从事后端开发工作,热爱生活、热爱工作。