simple diffusion:高分辨率图像的端到端扩散

simple diffusion: End-to-end diffusion for high resolution images

https://arxiv.org/abs/2301.11093

Emiel Hoogeboom *     Jonathan Heek *      Tim Salimans     

(*Equal contribution)

https://github.com/lucidrains/denoising-diffusion-pytorch 3.2k stars

https://huggingface.co/blog/annotated-diffusion ★★★★★

来自Google Brain。作者之一Tim Salimans是On Distillation of Guided Diffusion Models的作者之一,且和DDPM一作有一些共同工作

摘要:目前,在高分辨率图像的像素空间中应用扩散模型是困难的。相反,现有的方法侧重于低维空间中的扩散(潜在扩散),或者具有称为级联的多个超分辨率生成级别。缺点是这些方法增加了扩散框架的复杂性。本文旨在提高高分辨率图像的去噪扩散,同时保持模型尽可能简单。本文围绕研究问题展开:如何在高分辨率图像上训练标准的去噪扩散模型,并仍然获得与这些替代方法相当的性能?四个主要发现是:1)应针对高分辨率图像调整噪声时间表,2)仅缩放架构的特定部分就足够了,3)应在架构中的特定位置添加dropout,4)下采样是避免高分辨率特征图的有效策略。结合这些简单而有效的技术,我们在ImageNet上无需采样修改器的情况下,在扩散模型中实现了最先进的图像生成。

图1:一张穿着毛衣的青蛙、一只猫头鹰弹钢琴的dslr照片,生动的幻想艺术,以及两个机器人在背景中与纽约下棋。除了冻结的文本编码器,simple diffusion是端到端训练的,图像是在全像素空间中生成的Figure 1: A dslr photo of a frog wearing a sweater, An owl playing the piano, vivid, fantasy art, and two robots playing chess with New York in the background. Except for the frozen text encoder, simple diffusion is trained end-to-end and images are generated in full pixel space

1    引言

基于分数的扩散模型在数据生成中越来越流行。从本质上讲,这个想法很简单:预先定义一个扩散过程,该过程通过添加随机噪声逐渐破坏信息。然后,相反的方向定义了去噪过程,该过程用神经网络近似。

扩散模型已证明对图像、音频和视频生成非常有效。然而,对于更高分辨率,文献通常对低维潜在空间(潜在扩散)进行操作(Rombach等人,2022)或将生成过程划分为多个子问题,例如通过超分辨率(级联扩散)(Ho等人,2022年)或去噪专家的混合(Balaji等人,2022。缺点是这些方法引入了额外的复杂性,并且通常不支持单个端到端训练设置。

在本文中,我们的目标是在保持模型尽可能简单的同时,改进标准去噪扩散以获得更高的分辨率。我们的四个主要发现是:1)应针对较大的图像调整噪声计划,随着分辨率的增加,增加更多的噪声。2) 在16×16分辨率上扩展U-Net架构足以提高性能。更进一步的是U-ViT架构,一个带有Transformer主干的U-Net。3) 为了提高性能,应该添加dropout,但不能在最高分辨率的特征图上添加。最后4)对于更高的分辨率,可以降采样而不会降低性能。最重要的是,这些结果仅使用单个模型和端到端训练设置即可获得。在使用现有的蒸馏技术(现在只需应用于单个阶段)后,该模型可以在0.4秒内生成图像。

图2:simple diffusion生成的图像。重要的是,每个图像都是通过单个扩散模型在整个图像空间中生成的,没有任何级联(超分辨率)或专家的混合。样本取自U-Net模型,引导比例为4。Figure 2: Generated images with simple diffusion. Importantly, each image is generated in full image space by a single diffusion model without any cascades (super-resolution) or mixtures of experts. Samples are drawn from the U-Net model with guidance scale 4.

2    背景:扩散模型

扩散模型通过学习破坏过程的逆过程来生成数据。通常,随着时间的推移,扩散过程逐渐增加高斯噪声。可以方便地直接在边界q({\boldsymbol{z}}_t | {\boldsymbol{x}})中表示该过程,该边界由下式给出:

 q({\boldsymbol{z}}_t | {\boldsymbol{x}}) = \mathcal{N}({\boldsymbol{z}}_t | \alpha_t {\boldsymbol{x}}, \sigma_t^2 {\mathbf{I}})    (1)

其中,\alpha_t, \sigma_t \in (0, 1)是确定在时间步t上有多少信号被破坏的超参数,时间步可以是连续的,例如t \in [0, 1]。这里,\alpha_t减小,\sigma_t增大,均大于零。我们考虑一个方差保持过程,它将\alpha_t, \sigma_t 之间的关系固定为\alpha_t^2 = 1 - \sigma_t^2。假设扩散过程是马尔可夫过程,转移分布( transition distributions)由下式给出:

q({\boldsymbol{z}}_t | {\boldsymbol{z}}_s) = \mathcal{N}({\boldsymbol{z}}_t | \alpha_{ts} {\boldsymbol{z}}_s, \sigma_{ts}^2 {\mathbf{I}})    (2)

其中\alpha_{ts} = \alpha_t / \alpha_s\sigma_{ts}^2 = \sigma_t^2 - \alpha_{t|s}^2 \sigma_s^2t > s

噪声计划Noise schedule    一种常用的噪声计划是\alpha余弦计划,其中\alpha_t = \cos(\pi t / 2),在方差保持假设下,\sigma_t = \sin(\pi t / 2)。(Kingma等人,2021)的一个重要发现是,重要的是信噪比\alpha_t / \sigma_t,即1 / \tan(\pi t / 2),或者在对数空间中\log \frac{\alpha_t}{\sigma_t} = -\log \tan(\pi t / 2)

去噪    以单个数据点{\boldsymbol{x}}为条件,去噪过程可以写成:

q({\boldsymbol{z}}_s | {\boldsymbol{z}}_t, {\boldsymbol{x}}) = \mathcal{N}({\boldsymbol{z}}_t | {\boldsymbol{\mu}}_{t \to s}, \sigma_{t \to s}^2 {\mathbf{I}})    (3)

其中{\boldsymbol{\mu}}_{t \to s} = \frac{\alpha_{ts} \sigma_s^2}{\sigma_t^2} {\boldsymbol{z}}_t + \frac{\alpha_s \sigma_{ts}^2}{\sigma_t^2} {\boldsymbol{x}}\sigma_{t \to s} = \frac{\sigma_{ts}^2 \sigma_{s}^2}{\sigma_t^2}

文献中的一个重要且令人惊讶的结果是,当{\boldsymbol{x}}由神经网络\hat{{\boldsymbol{x}}} = f_\theta({\boldsymbol{z}}_t)逼近时,可以定义习得分布 p({\boldsymbol{z}}_s | {\boldsymbol{z}}_t) = q({\boldsymbol{z}}_s | {\boldsymbol{z}}_t, {\boldsymbol{x}} = \hat{{\boldsymbol{x}}}),而对s \to t不会失去一般性(without loss of generality as s to t.)。这也能公式是因为,对于s \to t,所有数据点q({\boldsymbol{z}}_s | {\boldsymbol{z}}_t)(通常未知)的真实去噪分布将等于q({\boldsymbol{z}}_s | {\boldsymbol{z}}_t, {\boldsymbol{x}} = \mathbb{E}[{\boldsymbol{x}} | {\boldsymbol{z}}_t])(Song等人,2021)。

参数化    网络不需要直接近似\hat{{\boldsymbol{x}}} ,实验发现,其他预测产生更高的视觉质量。研究边界q({\boldsymbol{z}}_t | {\boldsymbol{x}})的重参数化,即{\boldsymbol{z}}_t = \alpha_t {\boldsymbol{x}} + \sigma_t {\boldsymbol{ \epsilon }}_t,其中{\boldsymbol{ \epsilon }}_t \sim \mathcal{N}(0, 1),例如,我们可以选择神经网络预测的epsilon参数化,这里,神经网络预测值为\hat{{\boldsymbol{\epsilon}}}_t。要获得\hat{{\boldsymbol{x}}},需要计算\hat{{\boldsymbol{x}}} = {\boldsymbol{z}}_t / \alpha_t - \sigma_t \hat{{\boldsymbol{\epsilon}}}_t / \alpha_t。epsilon参数化的问题是它在t=1附近给出了不稳定的采样。没有这个问题的另一种参数化称为v预测,在(Salimans&Ho,2022)中提出,其定义为\hat{{\boldsymbol{v}}}_t = \alpha_t \hat{{\boldsymbol{\epsilon}}}_t - \sigma_t \hat{{\boldsymbol{x}}}

注意,给定{\boldsymbol{z}}_t,我们可以获得\hat{{\boldsymbol{x}}}\hat{{\boldsymbol{\epsilon}}}_t,通过恒等式\sigma_t {\boldsymbol{z}}_t + \alpha_t \hat{{\boldsymbol{v}}}_t = (\sigma_t^2 + \alpha_t^2) \hat{{\boldsymbol{\epsilon}}}_t = \hat{{\boldsymbol{\epsilon}}}_t \alpha_t {\boldsymbol{z}}_t - \sigma_t \hat{{\boldsymbol{v}}}_t = (\alpha_t^2 + \sigma_t^2) \hat{{\boldsymbol{x}}} = \hat{{\boldsymbol{x}}}。在最初的实验中,我们发现v预测可以更可靠地训练,特别是对于更大的分辨率,因此我们在本文中使用了这种参数化。

优化    为了训练模型,我们使用来自(Ho等人,2020)的标准epsilon损失。激励这种损失选择的一种方法是,使用变分推理,可以得出模型对数似然的下限(在连续时间内),如(Kingma等人,2021)所述:

 \log p({\boldsymbol{x}}) = \log \mathbb{E}_{q} \frac{p({\boldsymbol{x}}, {\boldsymbol{z}}_0, \ldots, {\boldsymbol{z}}_1)}{q({\boldsymbol{x}}, {\boldsymbol{z}}_0, \ldots, {\boldsymbol{z}}_1)} \geq \mathbb{E}_{q} \frac{p({\boldsymbol{x}}, {\boldsymbol{z}}_0, \ldots, {\boldsymbol{z}}_1)}{q({\boldsymbol{x}}, {\boldsymbol{z}}_0, \ldots, {\boldsymbol{z}}_1)}

= \mathcal{L}_x + \mathcal{L}_T - \mathbb{E}_{t \sim \mathcal{U}(0, 1)} \Big{[} w(t) ||{\boldsymbol{\epsilon}}_t - \hat{{\boldsymbol{\epsilon}}}_t ||^2 \Big{]}

其中,对于well-defined的过程\mathcal{L}_x = -\log p({\boldsymbol{x}} | {\boldsymbol{z}}_0) \approx 0,对于离散{\boldsymbol{x}}\mathcal{L}_T = -\mathrm{KL}(q({\boldsymbol{z}}_T | {\boldsymbol{x}}) | p({\boldsymbol{z}}_T)) \approx 0,其中w(t)是一个加权函数,若方程为真,则需要w(t) = - \frac{\mathrm{d}}{\mathrm{d}t}\log \mathrm{SNR}(t),其中\mathrm{SNR}(t) = \alpha_t^2 / \sigma_t^2。在实践中,我们通常使用{\boldsymbol{\epsilon}}_t上的未加权损失(意味着w(t) = 1),这在(Ho等人,2020)中被发现具有优异的样本质量。更多有用的背景信息请参见附录A。

where for a well-defined process \mathcal{L}_x = -\log p(\vx | \vz_0) \approx 0 for discrete \vx, \mathcal{L}_T = -\mathrm{KL}(q(\vz_T | \vx) | p(\vz_T)) \approx 0, and where w(t) is a weighting function which for the equation to be true needs to be w(t) = - \frac{\mathrm{d}}{\mathrm{d}t}\log \mathrm{SNR}(t) where \mathrm{SNR}(t) = \alpha_t^2 / \sigma_t^2.

3    方法:简单扩散

在本节中,我们将介绍几种改进,使去噪扩散能够在高分辨率下正常工作。

3.1    调整噪声计划

其中一个修改是通常用于扩散模型的噪声计划。最常见的计划是\alpha余弦计划,在方差保持假设下,其等于\frac{\sigma_t}{\alpha_t} = \tan(\pi t / 2)(在本分析中忽略t=0t=1周围的边界)(Nichol&Dhariwal,2021)。该计划最初是为了提高CIFAR10的性能而提出的,其分辨率为32×32,ImageNet为64×64。

然而,对于高分辨率,没有添加足够的噪声。例如,请查看图3的最上面一行,对于标准余弦计划,在很长的时间范围内,图像的全局结构已经在很大程度上定义了。这是有问题的,因为生成去噪过程只有很小的时间窗口来决定图像的全局结构。我们认为,对于更高的分辨率,可以以可预测的方式更改此计划,以保持良好的视觉样本质量。

图3:512×512图像上的标准和移位扩散噪声,通过平均池化以64×64的分辨率进行可视化。顶行显示了传统的余弦计划,底行显示了我们提出的移位计划。Figure 3: The standard and shifted diffusion noise on an image of 512 × 512, that is visualized by average pooling to a resolution of 64 × 64. The top row shows a conventional cosine schedule, the bottom row shows our proposed shifted schedule.

为了更详细地说明这一需要,让我们研究一个128×128的问题。给定输入图像{\boldsymbol{x}},像素i的扩散分布由q(z_t^{(i)} | {\boldsymbol{x}}) = \mathcal{N}(z_t^{(i)} | \alpha_t x_i, \sigma_t)给出。通常,扩散模型使用通过下采样以对较低分辨率特征图进行操作的网络架构,在我们的情况下,使用平均池化。

假设我们平均池化{\boldsymbol{z}}_t,其中我们让索引1、2、3、4表示2×2正方形中正在合并的像素。这个新像素是z^{64 \times 64}_t = (z_t^{(1)} + z_t^{(2)} + z_t^{(3)} + z_t^{(4)}) / 4。回想一下,对于独立随机变量的方差是加法,这意味着\mathrm{Var}[X_1 + X_2] = \mathrm{Var}[X_1] + \mathrm{Var}[X_2],并且对于常数a\mathrm{Var}[aX] = a^2\mathrm{Var}[X]。设x^{64 \times 64}表示平均池化的输入图像的第一个像素,我们发现z^{64 \times 64}_t \sim \mathcal{N}(\alpha_t x^{64 \times 64}, \sigma_t / 2)。分辨率较低的像素z^{64 \times 64}_t只有一半的噪声量。我们假设,随着分辨率的增加,这是有问题的,因为在较低分辨率上花费的扩散时间要少得多,这是一个生成全局一致性的阶段。

(We hypothesize that as resolutions increase this is problematic, as much fewer diffusion time is spent on the lower resolution, a stage at which the global consistency is generated.)

我们可以进一步得出,在这种较低分辨率下,{\alpha_t}{\sigma_t}之比是2倍高(is twice as high),这意味着信噪比是2^2倍高。因此,\mathrm{SNR}^{64 \times 64}(t) = \mathrm{SNR}^{128\times 128}(t) \cdot 2^2,或者,更一般地有:

\mathrm{SNR}^{d / s \times d / s}(t) = \mathrm{SNR}^{d \times d}(t) \cdot s^2    (4)

总之,在大小为s×s的窗口上进行平均后,{\alpha_t}{\sigma_t}的比率增加至s倍(因此SNR增加至s^2倍)。因此,我们认为,可以根据一些参考分辨率(例如32×32或64×64)来定义噪声计划,这些计划最初设计并测试是成功的。在我们的方法中,首先选择参考分辨率,例如64×64(我们将从经验中看到的合理选择)。

在参考分辨率下,我们定义了噪声计划\mathrm{SNR}^{64 \times 64}(t) = 1 / \tan(\pi t / 2)^2,这反过来定义了全分辨率d×d下的期望SNR:

\mathrm{SNR}_{\mathrm{shift}\, 64}^{d \times d}(t) = \mathrm{SNR}^{64 \times 64}(t) \cdot (64 / d)^2    (5)

信噪比简单地乘以(64 / d)^2,这对于我们的d>64设置,降低了高分辨率下的信噪比。在对数空间中,这意味着2 \cdot \log (64 / d)的简单移位(见图6)。例如,对于128×128的图像和64的参考分辨率,噪声计划的公式为:

 \log \mathrm{SNR}_{\mathrm{shift} \, 64}^{128 \times 128}(t) = - 2 \log \tan (\pi t / 2) + 2 \log (64 / 128)

回想一下,在方差保持过程下,扩散参数可以计算为\alpha^2_t = \mathrm{sigmoid}(\log \mathrm{SNR}(t))\sigma^2_t = \mathrm{sigmoid}(-\log \mathrm{SNR}(t))

最后,可能值得研究并行和互补的工作(Chen,2023),该工作还分析了更高分辨率图像的调整后的噪声时间表,并描述了一些其他改进。

图6:原始和移位余弦表的对数信噪比。Figure 6: Log signal to noise ratio for the original and shifted cosine schedule.

内插计划表 Interpolating schedules    改变计划表的一个潜在缺点是,由于每像素噪声的增加,高频细节现在在扩散过程中生成得更晚。然而,我们假设,当对已经生成的全局/低频特征进行调节时,高频细节是弱相关的。因此,应该可以在几个扩散步骤中生成高频细节。可替换地,可以内插不同的移位计划表,例如,对于512的分辨率,可以通过从移位32开始并在对数空间内插移位256来包括更高的频率细节。\log \mathrm{SNR}_{\mathrm{interpolate} (32 \to 256)}(t)的计划表等于:

t \log \mathrm{SNR}_{\mathrm{shift} \, 256}^{512 \times 512}(t) + (1 - t) \log \mathrm{SNR}_{\mathrm{shift} \, 32}^{512 \times 512}(t)    (6)

其在低频、中频和高频细节上具有更相等的权重。

A potential downside of shifting the schedule is that high frequency details are now generated much later in the diffusion process due to the increased per-pixel noise. However, we postulate that high-frequency details are weakly correlated when conditioning on the global/low-frequency features that are already generated. It should therefore be possible to generate the high-frequency details in few diffusion steps. Alternatively, one can interpolate different shift schedules, for example for a resolution of 512 one could include higher frequency details by starting at shift 32 and interpolating in log-space to shift 256.

3.2    多尺度训练损失

在上一节中,我们认为,在对高分辨率图像进行训练时,应调整扩散模型的噪声计划,以便在基本分辨率下保持信噪比不变。然而,即使在以这种方式调整噪声调度时,越来越高分辨率的图像上的训练损失也主要由高频细节控制。为了纠正这一点,我们建议用多尺度版本来替换标准训练损失,多尺度版本包括使用较低分辨率增加的加权因子来评估我们在下采样分辨率下的标准训练损失。我们发现,多尺度损失使得收敛更快,尤其是在分辨率大于256×256时。我们在d×d分辨率下的原始训练损失可以写为:

L^{d \times d}_{\theta}({\boldsymbol{x}}) = \frac{1}{d^{2}}\mathbb{E}_{{\boldsymbol{ \epsilon }},t} \lVert \text{D}^{d \times d}[{\boldsymbol{ \epsilon }}] - \text{D}^{d \times d}[\hat{{\boldsymbol{ \epsilon }}}_{\theta}(\alpha_t {\boldsymbol{x}}+\sigma_t {\boldsymbol{ \epsilon }}, t)] \rVert_{2}^{2}

To correct for this we propose replacing the standard training loss by a multiscale version that consists of evaluating our standard training loss at downsampled resolutions with a weighting factor that increases for the lower resolutions.

其中\text{D}^{d \times d}表示降采样到d \times d分辨率。如果该分辨率等同于我们模型\hat{{\boldsymbol{\epsilon}}}_\theta 和数据{\boldsymbol{x}}的原始分辨率,则下采样不做任何事,可以从该方程中删除。否则,\text{D}^{d \times d}[{\hat{{\boldsymbol{\epsilon}}}}_{\theta}]可以被视为非原始分辨率d \times d数据的经调整的去噪模型。由于对图像进行下采样是一种线性操作,我们得到\text{D}^{d \times d}[\mathbb{E}({\boldsymbol{ \epsilon }}|{\boldsymbol{x}})] = \mathbb{E}(\text{D}^{d \times d}[{\boldsymbol{ \epsilon }}]|{\boldsymbol{x}}),这种构建低分辨率模型的方式确实与我们的原始模型一致。

然后,我们提出针对包括多分辨率的多尺度训练损失来训练我们的高分辨率模型。例如,对于分辨率32, 64, \ldots, d,损失将为:\tilde{L}^{d \times d}_{\theta}({\boldsymbol{x}}) =\sum\nolimits_{s \in \{32, 64, 128, \ldots, d\}}\frac{1}{s} L^{s \times s}_{\theta}({\boldsymbol{x}})

也就是说,我们针对从基本分辨率(在本例中为32×32)开始并始终包括最终分辨率d×d的分辨率的训练损失的加权和进行训练。

3.3    扩展架构

另一个问题是如何扩展架构。每次分辨率加倍时,典型的模型架构将通道减半,使得每次操作的FLOPs相同,但特征数量加倍。每次分辨率加倍时,计算强度(flops / features)也会减半。低计算强度会导致加速器的利用率低下,而大量激活会导致内存不足。因此,我们更喜欢在分辨率较低的特征图上进行缩放。我们的假设是,主要在特定分辨率(即16×16分辨率)上进行缩放,足以在我们考虑的网络大小范围内提高性能。通常,低分辨率操作具有相对较小的特征图。为了说明这一点,请考虑以下示例:

1024 \text{ (batch) } \times 16 \times 16 \times 1024 \text{ (channel)} \cdot 2 \text{ bytes} / \text{dim}

对于具有128个通道的256×256特征图,如果以16位浮点格式存储,则特征图的成本为16GB。

参数的内存占用较小:卷积内核的典型大小为3^2 \times 128^2  \text{ dimensions} \cdot 4 \text{ bytes} / \text{dims} \cdot 5 \text{ replications } = 2.8,对于1024通道则是180MB,其中\cdot 5 \text{ replications }用于梯度、优化器状态和指数移动平均值。问题是,在分辨率为16×16的情况下,特征图的大小在16^2处是可管理的,参数所需的空间也是可管理的。

总结表1中的包络(back-of-the-envelope)计算,可以看到对于相同的内存约束,在16×16时可以容纳16GB/0.7GB≈23层,而在256×256时只能容纳1层。

选择这种解决方案的其他原因是,自注意开始在扩散模型文献中的许多现有作品中使用(Ho等人,2020年;Nichol&Dhariwal,2021)。此外,正是16×16的分辨率使用于分类的视觉Transformer能够成功运行(Dosovitskiy等人,2021)。虽然这可能不是缩放架构的理想方式,但我们将根据经验证明,缩放16×16级别效果良好。

细心的ML从业者可能已经意识到,当天真地使用多个设备时,参数会被复制(在JAX和Flax中是典型的)或存储在第一个设备(PyTorch)上。这两种情况都会导致这样一种情况,即每个设备对于特征图的内存需求随着所需的1/设备而减少,但参数需求不受影响,需要大量内存(Both cases result in a situation where the memory requirements per device for the feature maps decreases with $1 / \text{devices}$ as desired, but the parameter requirement is unaffected and requires a lot of memory.)。我们主要以低分辨率缩放,其中激活相对较小,但参数矩阵较大O(\text{features}^2)。我们发现,分割权重可以使我们扩展到更大的模型,而不需要更复杂的并行化方法,如模型并行。

避免高分辨率特征图    高分辨率特征图占用内存。如果FLOP的数量保持不变,内存仍会随着分辨率线性缩放

在实践中,不可能在不牺牲加速器利用率的情况下将通道减小到一定大小以上。现代加速器在计算和内存带宽之间具有非常高的比率(Modern accelerators have a very high ratio between compute and memory bandwidth)。因此,低通道计数可能会导致操作内存受限,导致加速器大部分处于空闲状态,并且比预期的挂钟性能更差。

为了避免在最高分辨率上进行计算,我们立即下采样图像作为神经网络的第一步,上采样作为最后一步。令人惊讶的是,尽管这样神经网络在计算和内存方面更便宜,但我们从经验上发现,它们也能获得更好的性能。我们有两种方法可供选择。

一种方法是使用可逆的线性5/3小波(如JPEG2000中所使用的)将图像转换为较低分辨率的频率响应,如图7所示。这里,为了可视化,不同的特征响应在空间上被拼接。在网络中,响应在通道维度上拼接。当应用多个DWT级别时(这里有两个),则响应的分辨率不同。这是通过找到最低分辨率(图中128^2)并对高分辨率特征图的像素进行reshape来解决的,在256^2的情况下,它们被reshape为{128^2}×4,作为典型的空间到深度操作。DWT实施指南可在此处找到:http://trueharmoniccolours.co.uk/Blog/?p=14

如果上述情况看起来很复杂,那么如果愿意支付一点性能损失,也存在一个更简单的解决方案。作为第一层,可以使用步长为d的d×d卷积层,以及形状相同的转置卷积层作为最后一层。这相当于Transformer文献中所谓的patching。根据经验,我们发现这一点表现相似,尽管稍差一些。

3.4    dropout

在用于扩散的通常架构中,在所有分辨率下,全局dropout超参数用于残差块。在CDM(Ho等人,2022)中,dropout被用在较低分辨率上,以生成图像。对于条件输入的高分辨率图像,不使用dropout。然而,对数据执行各种其他形式的增强。这表明正则化很重要,即使对于高分辨率的模型也是如此。然而,正如我们将根据经验证明的那样,在所有残差块中添加dropout的朴素方法并不能给出期望的结果。

由于我们的网络设计仅在较低分辨率下缩放网络大小,因此我们假设仅添加dropout和较低分辨率就足够了(we hypothesize that it should be sufficient to only add dropout add the lower resolutions)。这避免了对高分辨率层进行正则化,而这在内存方面是昂贵的,同时仍然使用dropout正则化,而dropout正则化对在低分辨率图像上训练的模型是成功的。

3.5    U-ViT架构

将上述对架构的改变再进一步,如果架构已经在该分辨率下使用了自注意,则可以用MLP块替换卷积层。这将(Peebles&Xie,2022)引入的用于扩散的Transformer与U-Nets连接起来,用Transformer替换其主干。因此,这种相对较小的变化意味着我们现在使用这些分辨率的Transformer块。主要的好处是,自注意和MLP块的组合具有较高的加速器利用率,因此大型模型训练速度更快。有关此架构的更多详细信息,请参见附录B。本质上,这种U-Vision Transformer(U-ViT)架构可以看作是一个小型卷积U-Net,它通过多个级别将样本降到16×16分辨率。这里使用了大型Transformer。此后,再次通过卷积U-Net进行上采样。

3.6    文本到图像生成

作为概念的证明,我们还训练了一个以文本数据为条件的简单扩散模型。跟随(Saharia等人,2022)我们使用T5 XXL(Raffel等人,2020)文本编码器作为条件。有关更多详细信息,请参见附录B。我们训练了三个模型:一个是分辨率为256×256的图像,用于与文献中的模型进行直接比较,一个是512×512,另一个是384×640。对于最后一种非方形分辨率,如果图像的宽度小于其高度,则在预处理过程中旋转图像,并将“肖像模式”标志设置为真。As a result, this model can generate natively in a 5:3 aspect ratio for both landscape and portrait orientation.

4    相关工作

基于分数的扩散模型(Sohl Dickstein等人,2015;Song&Ermon,2019;Ho等人,2020)是一种预先定义随机破坏过程的生成模型。在神经网络的帮助下,通过近似逆过程来学习生成过程。

扩散模型已成功应用于图像生成(Ho等人,2020年;2022年)、语音生成(Chen等人,2020;Kong等人,2021)、视频生成(Singer等人,2022;Saharia等人,2022.)。

对于复杂数据(例如ImageNet)的高分辨率(例如512^2、256^2、128^2)的扩散模型通常不会直接学习。相反,文献中的方法通过超分辨率将生成过程划分为子问题(Ho等人,2022年),或去噪模型的混合(Feng等人,2022;Balaji等人,2022)。或者,其他方法将高分辨率数据向下投影到较低维度的潜在空间(Rombach等人,2022)。这些技术也可以结合起来,进一步细分(sub-divide)生成问题。尽管这种细分通常使优化更容易,但缺点是工程复杂性增加:与处理单个模型不同,需要训练和跟踪多个模型。在本文中,我们表明,仅对原始(现代)公式进行少量修改,就可以训练分辨率高达512×512的单个去噪扩散模型(Ho等人,2020)。

5    实验

5.1    所提修改的影响

噪声计划表    在本实验中,研究了噪声计划如何影响生成的图像质量,并根据FID50K score对训练集和验证集进行了评估。回想一下,我们的假设是余弦表没有添加足够的噪声,但可以通过使用图像分辨率和噪声分辨率之间的比率“移动(shifting)”其log SNR曲线来调整。在这些实验中,噪声分辨率从原始图像分辨率(对应于传统的余弦表)以2的比例一直变化到32倍。

如表2所示,对于分辨率为128×128和256×256的ImageNet,改变噪声计划可显著提高性能。这一差异在较高分辨率下尤为明显,在训练数据上,原始余弦计算的FID是7.65,而shifted计划的则是3.76。请注意,移向64和32的性能差异相对较小,尽管对于32 shift稍好一些。考虑到差异很小,而且shift 64计划在早期迭代中表现稍好,我们通常建议使用shift 64计划。

dropout    ImageNet数据集大约有100万张图像。如之前的工作所述,重要的是要正则网络以避免过拟合(Ho等人,2022;Dhariwal&Nichol,2021)。尽管dropout已成功应用于分辨率为64×64的网络,但对于高分辨率运行的模型,它通常被禁用。在这个实验中,我们仅在网络层的子集上启用dropout:仅在给定的“起始分辨率”超参数以下的分辨率下启用dropout。例如,如果起始分辨率为32,则dropout应用于以32×32、16×16和8×8分辨率操作的模块。

回想一下我们的假设,它应该足以使在低分辨率特征图上运行的网络模块正规化。如表3所示,这一假设成立。在128×128的图像上进行的这项实验中,添加分辨率64、32、16的dropout都表现得比较好。尽管从16×16中添加dropout的效果稍差,但我们在剩余实验中使用了这个设置,因为它在早期迭代中收敛更快。

该实验还显示了两种根本不起作用的设置,应该避免:要么添加不dropout,要么从与数据相同的分辨率开始添加dropout。这可以解释为什么高分辨率扩散的dropout迄今尚未被广泛使用:通常dropout被设置为所有分辨率下所有特征图的全局参数,但该实验表明这种正则化过于激进。

架构扩展    在本节中,我们研究了增加16×16网络模块数量的效果。在UNets中,块数超参数通常指“下行”路径上的块数。在许多实现中,“up”块使用一个额外的块。当表格读取“2+3”块时,这意味着2个向下块和3个向上块,在文献中称为2个块。

通常,增加模块数量可以提高性能,如表4所示。对此,一个有趣的例外是从8块增加到12块,eval FID略有下降。我们认为,这可能表明,随着网络的扩展,网络应该更加正则化。稍后将观察到这种效应在更大的U-ViT架构中被放大。

避免更高分辨率的特征图    最后,我们想研究下采样技术对避免高分辨率特征图的影响。对于这个实验,我们首先有一个分辨率为512的图像的标准U-Net。然后,当我们使用传统层或DWT进行下采样(或降到256或128)时。在本研究中,通过将跳过的高分辨率块分布在低分辨率块上(distributing the high resolution blocks that are skipped over the lower resolution blocks.),使块的总数保持不变。更多详情请参见附录B。

回想一下,我们的假设是,下采样不应在样本质量上代价太大,同时大大提高了模型的速度。令人惊讶的是,除了更快之外,使用下采样策略的模型实际上获得了更好的样本质量。看来,如此高分辨率的下采样使得网络能够更好地优化样本质量。最重要的是,它允许训练时无需过于庞大的特征图,而不会降低性能。

多尺度损失    在最后的实验中,我们测试了标准损失和多尺度损失之间的差异,这更加强调了图像中的低频率。对于分辨率256和512,我们报告了在启用或禁用多尺度损失的情况下训练的模型的FID分数中的样本质量。如图6所示,对于256,多尺度损失似乎没有太大影响,表现稍差。然而,对于较大的512分辨率,多尺度损失会产生影响并降低FID分数。

5.2    与文献的比较

在本节中,将简单扩散与文献中的现有方法进行比较。虽然生成漂亮的图像非常有用,但我们特别选择只与没有引导的方法(或其他采样修改,如拒绝采样)进行比较,以查看模型的拟合程度。这些采样修改可能会在视觉质量指标上产生夸大的分数(Ho&Salimans,2022)。

有趣的是,更大的U-ViT模型在训练FID和 Inception得分(IS)上表现非常好,优于文献中的所有现有方法(表7)。然而,U-Net模型在eval FID上表现更好。我们认为这是我们之前在表4中观察到的效果的外推,其中增加架构大小不一定会导致更好的评估FID。有关模型的示例,请参见图2和图10。总之,在所有其他类型的方法中,简单扩散实现了类条件(class-conditional )ImageNet生成的SOTA FID分数,而无需采样修改。我们认为这是一个非常有希望的结果:通过调整扩散计划和修改损失,简单扩散是一个单阶段模型,其分辨率高达512×512,性能高。更多结果见附录C。

文本到图像    在这个实验中,我们训练了一个文本到图像模型(Saharia等人,2022)。除了自注意和mlp块之外,该网络在T5 XXL文本嵌入上运行的Transformer中也具有交叉注意。如表8所示,simple diffusion比最近的一些文本到图像模型(如DALLE-2)要好一点,尽管它仍然落后于Imagen。重要的是,我们的模型是第一个可以仅使用端到端训练的单个扩散模型来生成这种质量的图像的模型。

6    结论

总之,我们介绍了对原始去噪扩散公式的几个简单修改,这些修改适用于高分辨率图像。在没有采样修改器(sampling modifiers)的情况下,simple diffusion在ImageNet上实现了FID分数的最先进性能,并且可以在端到端设置中轻松训练。此外,据我们所知,这是第一个可以生成具有如此高视觉质量的图像的单阶段文本到图像模型。

附录

A    扩散模型的其他背景信息

本节更详细地总结了去噪扩散的相关背景信息。首先,了解现代去噪扩散模型(Ho等人,2020)是如何使用(Kingma等人,2021)中的公式进行训练的,这可能会有所帮助。首先,我们定义信号是如何被破坏(扩散)的,这是与采样{\boldsymbol{z}}_t \sim q({\boldsymbol{z}}_t | {\boldsymbol{x}})等效的算法:

对于具体的优化设置,我们通常使用(v-prediction,epsilon loss),损失可按以下定义计算。这在算法等价于(Ho Et al.,2020;Kingma Et al.(2021)所提的\mathbb{E}_{t \sim \mathcal{U}(0, 1), {\boldsymbol{z}}_t \sim q({\boldsymbol{z}}_t | {\boldsymbol{x}})}||f({\boldsymbol{z}}_t, t) - \epsilon_t||^2

在条件(例如文本嵌入的ImageNet类别编号)的情况下,这些被添加为uvit调用的输入,但不会以其他方式影响扩散过程。在10%的时间里,条件设置被取消,这样模型就可以在无分类器的引导下使用。

标准余弦logsnr计划(考虑边界)可以定义为:

然后,可以将shifted的计划表定义为:

内插计划表如下:

需要注意的是,最小和最大logsnr超参数会随着整个计划发生shifted,因此当使用这些端点来定义架构中的嵌入时,需要注意。

采样    在本工作中,除非另有说明,否则我们使用标准ddpm采样器。以下在算法上等价于采样的生成过程(generative process of sampling){\boldsymbol{z}}_T \sim \mathcal{N}(0, \mathbf{I}),以及重复采样{\boldsymbol{z}}_{s} \sim p({\boldsymbol{z}}_s | {\boldsymbol{z}}_t)

其中noise_param被设置为0.2,但MSCOCO FID评估除外,其中它被设置为1.0。

一个重要但不经常讨论的细节是,在采样期间,在x空间中裁剪预测是有帮助的,下面给出了静态裁剪的示例,动态裁剪参见(Saharia等人,2022):

无分类器指导    在无分类器指导中(Ho&Salimans,2022),人们在训练期间偶尔会丢弃条件信号(通常约10%的时间)。这允许人们还可以训练模型p({\boldsymbol{x}}),除了通常训练的模型p({\boldsymbol{x}} | \text{cond})。然后,这些模型的epsilon预测可以与引导尺度重新组合(The epsilon predictions of these models can then be recombined with a guidance scale.)。对于η>0

\hat{{\boldsymbol{ \epsilon }}}({\boldsymbol{x}}) = (1 + \eta) \hat{{\boldsymbol{ \epsilon }}}({\boldsymbol{x}} , \text{cond}) - \eta \hat{{\boldsymbol{ \epsilon }}}({\boldsymbol{x}})    (7)

人们可以用\hat{\boldsymbol{v}}\hat{\boldsymbol{x}}代替\hat{\boldsymbol{ \epsilon }},由于线性和项抵消,结果是等价的。请注意,我们将按照文献中经常使用的方法将引导尺度(guidance scale)报告为(1 + \eta),不要被报告η本身所混淆。

蒸馏    与许多扩散模型一样,也可以蒸馏简单扩散以减少采样步骤和神经网络评估的数量(Meng等人,2022)以减少采样步数。对于蒸馏的U-ViT模型,在TPUv4上生成单个图像需要0.42秒。类似地,生成一批量的8幅图像需要2.00秒。Like many diffusion models, simple diffusion can also be distilled to reduce the number of sampling steps and neural net evaluations (Meng et al., 2022) to reduce the number of sampling steps. For a distilled U-ViT model, generating a single image takes 0.42 seconds on a TPUv4. Similarly, generating a batch of 8 images takes 2.00 seconds.

B    实验细节

本节给出了实验的具体细节。首先,U-Net实验的标准优化器设置。

B.1    U-Net设置

为了保持残差块的数量相同,通过下采样跳连的高分辨率块被添加到较低分辨率级别。在没有下采样的情况下,该架构使用(To keep the number of residual blocks the same, high resolution blocks that are skipped by down-sampling are added to the lower resolution levels. With no downsampling, the architecture uses:):

channel_multiplier =[1 , 1 , 1 , 2 , 4 , 8 , 8] , num_res_blocks =[1 , 1 , 2 , 2 , 4 , 12 , 4] ,

在2×下采样的情况下,架构使用:

channel_multiplier =[1 , 2 , 2 , 4 , 8 , 8] , num_res_blocks =[2 , 2 , 2 , 4 , 12 , 4] ,

在4倍下采样的情况下,架构使用:

channel_multiplier =[2 , 3 , 4 , 8 , 8] , num_res_blocks =[3 , 3 , 4 , 12 , 4]

B.2    U-ViT设置

U-ViT是与U-Net非常相似的架构(见图8)。两个主要区别是:1)当模块具有自注意时,它使用MLP块而不是卷积层,使它们的组合成为Transformer块。2)中间的Transformer块不使用跳接,只使用残差连接。

U-ViT ImageNet的默认优化设置为:

所有分辨率128、256和512的架构设置几乎相同。

其中,patching类型为,128为“none”,256为“dwt_1”,512为“dwt2”。还要注意,损失是根据v而不是epsilon计算的。这可能不是很重要:在小型实验中,我们只观察到两者之间的细微性能差异。还要注意,批量大小更大(2048),这会显著影响FID和IS性能。文本到图像模型被训练了700K个步骤。

B.2.1    U-VIT模块的伪代码

Transformer块由一个自注意和mlp块组成。这些定义如人们所期望的那样,为了完整性,在伪代码中给出如下:

另一个重要的块是标准ResBlock,伪代码如下:

编者:这个换行,建议还是看原文

给定这些构建块,可以定义U-ViT架构:

可以看到,它与UNet非常相似,中间部分现在是一个Transformer,它没有卷积层,而是只有残差连接的mlp块。

C    其他实验

Guidance scale    在表9中,我们显示了引导对ImageNet模型的影响。对于相对较小的引导级别(small levels of guidance),样本立即在IS上获得很多,代价是特别是评估FID。此外,图9显示了文本到图像模型的Clip与MSCOCO FID30K的得分。继其他如(Saharia等人,2022)之后,通过调节来自MSCOCO验证集的30K随机采样文本来对图像进行采样,并将整个验证集作为参考进行计算。



最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 206,723评论 6 481
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 88,485评论 2 382
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 152,998评论 0 344
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 55,323评论 1 279
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 64,355评论 5 374
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,079评论 1 285
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,389评论 3 400
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,019评论 0 259
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 43,519评论 1 300
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,971评论 2 325
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,100评论 1 333
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,738评论 4 324
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,293评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,289评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,517评论 1 262
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,547评论 2 354
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,834评论 2 345

推荐阅读更多精彩内容