Diffusion Model 中的 U-Net 理解

前文我们学习了一些传统 Diffusion Model 的模型,比如 DDPM,LDM,也学习了 U-Net 的细节,但是它们是如何应用 U-Net 的呢。之前学习 Diffusion 一直处在理论或者宏观的角度,接下来需要结合代码理解细节。

为了适配扩散任务,现代的 Diffusion Model 在原始 U-Net 上引入了时间步嵌入。

Diffusion U-Net 的伪代码骨架

值得注意的是,现代的 Diffusion 往往会使用正余弦编码映射作为时间步,这里用标量进行了简化。

上采样与下采样:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class Downsample(nn.Module):
def __init__(self, channels):
super().__init__()
# 使用 stride=2 的卷积替代 MaxPool
self.conv = nn.Conv2d(channels, channels, kernel_size=3, stride=2, padding=1)

def forward(self, x):
return self.conv(x)

class Upsample(nn.Module):
def __init__(self, channels):
super().__init__()
# 先插值放大尺寸,再通过卷积消除锯齿
self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)

def forward(self, x):
return self.conv(self.upsample(x))

ResNetBlock:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResnetBlock(nn.Module):
def __init__(self, in_ch, out_ch, time_emb_dim):
super().__init__()
# 时间步的线性映射:将 t_emb 映射到当前层的通道数
self.time_mlp = nn.Linear(time_emb_dim, out_ch)

# 主卷积路径
self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)

# 如果输入输出通道不同,残差边需要一个 1x1 卷积对齐
self.shortcut = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()

def forward(self, x, t_emb):
# 1. 第一层卷积
h = self.conv1(F.silu(x))

# 2. 注入时间信息:将 t_emb 变换后加到特征图上
# t_emb shape: [B, time_emb_dim] -> [B, out_ch] -> [B, out_ch, 1, 1] (广播机制)
time_feat = self.time_mlp(F.silu(t_emb))
h = h + time_feat[:, :, None, None]

# 3. 第二层卷积
h = self.conv2(F.silu(h))

# 4. 残差相加
return h + self.shortcut(x)

完整的 U-Net 伪代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
class DiffusionUNet(nn.Module):
def __init__(self, in_ch=3, model_ch=128, ch_mult=[1, 2, 4, 8]):
super().__init__()

# 1. 时间步编码 (Sinusoidal + MLP)
time_dim = model_ch * 4
self.time_embed = nn.Sequential(
# 这里假设已完成 Sinusoidal 编码,输入 dim=model_ch
nn.Linear(model_ch, time_dim),
nn.SiLU(),
nn.Linear(time_dim, time_dim),
)

# 2. Downsampling 阶段
self.downs = nn.ModuleList([])
cur_ch = model_ch
self.input_conv = nn.Conv2d(in_ch, model_ch, 3, padding=1)

for mult in ch_mult:
out_ch = model_ch * mult
# 每个阶段包含:ResNet 块 + 下采样
self.downs.append(ResnetBlock(cur_ch, out_ch, time_dim))
self.downs.append(Downsample(out_ch))
cur_ch = out_ch

# 3. Middle 阶段
self.middle = ResnetBlock(cur_ch, cur_ch, time_dim)

# 4. Upsampling 阶段
self.ups = nn.ModuleList([])
for mult in reversed(ch_mult):
out_ch = model_ch * mult
# 注意:Upsample 的输入是 (当前层 + Skip层) 的通道总和
self.ups.append(ResnetBlock(cur_ch + out_ch, out_ch, time_dim))
self.ups.append(Upsample(out_ch))
cur_ch = out_ch

# 5. 输出层
self.final_conv = nn.Conv2d(model_ch, in_ch, 3, padding=1)

def forward(self, x, t):
# t 预处理 (正余弦编码在此省略,假设已转为向量)
t_emb = self.time_embed(t)

# Encoder
x = self.input_conv(x)
hs = [x] # 存储中间特征,用于 Skip Connection
for layer in self.downs:
if isinstance(layer, ResnetBlock):
x = layer(x, t_emb)
else: # Downsample
x = layer(x)
hs.append(x)

# Middle
x = self.middle(x, t_emb)

# Decoder
for layer in self.ups:
if isinstance(layer, ResnetBlock):
# 弹出 Encoder 对应的特征图进行拼接
skip_x = hs.pop()
x = torch.cat([x, skip_x], dim=1)
x = layer(x, t_emb)
else: # Upsample
x = layer(x)

return self.final_conv(x)

注意区别:

原始 U-Net

  • Encoder: ConvBlock → MaxPool
  • Decoder: UpConv → Concat → ConvBlock

Diffusion U-Net

  • Encoder: ResNetBlock(t) → Downsample
  • Decoder: Concat → ResNetBlock(t) → Upsample

如何理解:

  • 对于原始 U-Net,常用于语义分割,比如医学图像分割,非常关心像素级位置精度。所以先上采样把特征图放大,然后和高分辨率特征拼接,然后卷积。
  • 对于 Diffusion U-Net,核心任务是预测噪声,所以现在小画布上判断整体结构和噪声趋势,再逐步放大。

如何加入时间步特征?

通常做法是把 变成正余弦编码,然后经过 MLP 映射成 Time Embedding,然后:

  • feature map: [B, C, H, W]
  • time embedding: [B, D]

[B, D] 不能直接和 [B, C, H, W] 相加,所以会先经过线性层:

text
1
2
3
4
5
time embedding: [B, D]
↓ Linear
[B, C]
↓ reshape
[B, C, 1, 1]

然后广播加到特征图上:

text
1
2
[B, C, H, W] + [B, C, 1, 1]
→ [B, C, H, W]

其实也就是平摊成 Token 后,每个 Token 加一个相同的 Embedding。

如何结合文本特征?

文本通常不是直接输入 U-Net,而是先经过文本编码器,比如 CLIP text encoder:

text
1
2
3
"a cat sitting on a chair"
↓ text encoder
text tokens: [B, N, D]

然后 U-Net 在很多层里通过 cross-attention 融合文本信息。

具体的先把空间位置摊平成 token:

text
1
2
[B, C, H, W]
→ [B, H×W, C]

这表示图像里每个空间位置都是一个 image token。

然后做 cross-attention:

text
1
2
3
Q 来自 image feature
K 来自 text feature
V 来自 text feature

Diffusion Model 中的 U-Net 理解
https://d4wnnn.github.io/2026/03/14/Notion/Diffusion Model 中的 U-Net 理解/
作者
D4wn
发布于
2026年3月14日
许可协议