Diffusion Model 中的 U-Net 理解

前文我们学习了一些传统 Diffusion Model 的模型,比如 DDPM,LDM,也学习了 U-Net 的细节,但是它们是如何应用 U-Net 的呢。之前学习 Diffusion 一直处在理论或者宏观的角度,接下来需要结合代码理解细节。

为了适配扩散任务,现代的 Diffusion Model 在原始 U-Net 上引入了时间步嵌入。

Diffusion U-Net 的伪代码骨架

值得注意的是,现代的 Diffusion 往往会使用正余弦编码映射作为时间步,这里用标量进行了简化。

上采样与下采样:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class Downsample(nn.Module):
def __init__(self, channels):
super().__init__()
# 使用 stride=2 的卷积替代 MaxPool
self.conv = nn.Conv2d(channels, channels, kernel_size=3, stride=2, padding=1)

def forward(self, x):
return self.conv(x)

class Upsample(nn.Module):
def __init__(self, channels):
super().__init__()
# 先插值放大尺寸,再通过卷积消除锯齿
self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)

def forward(self, x):
return self.conv(self.upsample(x))

ResNetBlock:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResnetBlock(nn.Module):
def __init__(self, in_ch, out_ch, time_emb_dim):
super().__init__()
# 时间步的线性映射:将 t_emb 映射到当前层的通道数
self.time_mlp = nn.Linear(time_emb_dim, out_ch)

# 主卷积路径
self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)

# 如果输入输出通道不同,残差边需要一个 1x1 卷积对齐
self.shortcut = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()

def forward(self, x, t_emb):
# 1. 第一层卷积
h = self.conv1(F.silu(x))

# 2. 注入时间信息:将 t_emb 变换后加到特征图上
# t_emb shape: [B, time_emb_dim] -> [B, out_ch] -> [B, out_ch, 1, 1] (广播机制)
time_feat = self.time_mlp(F.silu(t_emb))
h = h + time_feat[:, :, None, None]

# 3. 第二层卷积
h = self.conv2(F.silu(h))

# 4. 残差相加
return h + self.shortcut(x)

完整的 U-Net 伪代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
class DiffusionUNet(nn.Module):
def __init__(self, in_ch=3, model_ch=128, ch_mult=[1, 2, 4, 8]):
super().__init__()

# 1. 时间步编码 (Sinusoidal + MLP)
time_dim = model_ch * 4
self.time_embed = nn.Sequential(
# 这里假设已完成 Sinusoidal 编码,输入 dim=model_ch
nn.Linear(model_ch, time_dim),
nn.SiLU(),
nn.Linear(time_dim, time_dim),
)

# 2. Downsampling 阶段
self.downs = nn.ModuleList([])
cur_ch = model_ch
self.input_conv = nn.Conv2d(in_ch, model_ch, 3, padding=1)

for mult in ch_mult:
out_ch = model_ch * mult
# 每个阶段包含:ResNet 块 + 下采样
self.downs.append(ResnetBlock(cur_ch, out_ch, time_dim))
self.downs.append(Downsample(out_ch))
cur_ch = out_ch

# 3. Middle 阶段
self.middle = ResnetBlock(cur_ch, cur_ch, time_dim)

# 4. Upsampling 阶段
self.ups = nn.ModuleList([])
for mult in reversed(ch_mult):
out_ch = model_ch * mult
# 注意:Upsample 的输入是 (当前层 + Skip层) 的通道总和
self.ups.append(ResnetBlock(cur_ch + out_ch, out_ch, time_dim))
self.ups.append(Upsample(out_ch))
cur_ch = out_ch

# 5. 输出层
self.final_conv = nn.Conv2d(model_ch, in_ch, 3, padding=1)

def forward(self, x, t):
# t 预处理 (正余弦编码在此省略,假设已转为向量)
t_emb = self.time_embed(t)

# Encoder
x = self.input_conv(x)
hs = [x] # 存储中间特征,用于 Skip Connection
for layer in self.downs:
if isinstance(layer, ResnetBlock):
x = layer(x, t_emb)
else: # Downsample
x = layer(x)
hs.append(x)

# Middle
x = self.middle(x, t_emb)

# Decoder
for layer in self.ups:
if isinstance(layer, ResnetBlock):
# 弹出 Encoder 对应的特征图进行拼接
skip_x = hs.pop()
x = torch.cat([x, skip_x], dim=1)
x = layer(x, t_emb)
else: # Upsample
x = layer(x)

return self.final_conv(x)

注意区别:

原始 U-Net

  • Encoder: ConvBlock → MaxPool
  • Decoder: UpConv → Concat → ConvBlock

Diffusion U-Net

  • Encoder: ResNetBlock(t) → Downsample
  • Decoder: Concat → ResNetBlock(t) → Upsample

原始的 U-Net 主要矛盾是为了恢复像素级别的位置准确性,计算重心在 Upsample 后的高分辨率处理。而 Diffusion U-Net 的主要矛盾是在各个尺度上预测噪声,计算重心是在 Upsample 前的低分辨率处理。(emmm,理解其实有点困难,后面再说吧。)


Diffusion Model 中的 U-Net 理解
https://d4wnnn.github.io/2026/03/14/Notion/Diffusion Model 中的 U-Net 理解/
作者
D4wn
发布于
2026年3月14日
许可协议