Paper:High-Resolution Image Synthesis with Latent Diffusion
Models
传统的 Diffusion
模型都是在像素空间上运行的,训练和推理速度很慢。同时这样会让模型花费更多精力去优化细节,忽略核心的语义生成。本论文提出了
LDM,是Stable 的奠基之作。
Stable Diffusion 是怎么做的?
首先预训练一个自编码器 ,将图像 映射到低维空间 ,而解码器 将其还原成图像。
然后模型在这个低维空间
进行扩散生成。
另外 Stable Diffusion 还引入了
Cross-Attention,将各种模态(比如文本,草图)经过专用的编码器 转换后,作为 和 ,与潜空间特征 进行交互。
训练函数目标如下:
其实和 DDPM
相比,就是把扩散空间从像素空间改变到了潜空间,而且加了多模态条件控制。
整体架构图如下:
image.png
为了方便理解,让 Gemini 生成了一段 Pytorch 的伪代码。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
| import torch import torch.nn as nn
class LatentDiffusionModel(nn.Module): def __init__(self): super().__init__() self.autoencoder = Autoencoder(KL_or_VQ_reg=True) self.unet = UNet(in_channels=latent_dim, out_channels=latent_dim) self.cond_stage_model = TransformerEncoder()
def forward(self, x, y): """ x: 输入图像 y: 条件输入 (如 text prompt) """ z = self.autoencoder.encode(x) t = torch.randint(0, T, (z.shape[0],)) noise = torch.randn_like(z) z_t = self.q_sample(z, t, noise) context = self.cond_stage_model(y) predicted_noise = self.unet(z_t, t, context) loss = F.mse_loss(predicted_noise, noise) return loss
@torch.no_grad() def sample(self, y): z_t = torch.randn((1, latent_dim, h, w)) context = self.cond_stage_model(y) for t in reversed(range(T)): z_t = self.denoise_step(z_t, t, context) return self.autoencoder.decode(z_t)
|
生成效果如何?
image.png