DiT

Paper:Scalable Diffusion Models with Transformers

标准 DiT 其实就是在 Transformer 的基础上注入了条件控制和时间步。

image.png

核心区别

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 标准 Transformer
x = x + Attention(LN(x))
x = x + MLP(LN(x))

cond_params = Linear(SiLU(cond))


# DiT adaLN-Zero block
x = x + gate_attn(cond) * Attention(
LN(x) * (1 + scale_attn(cond)) + shift_attn(cond)
)

x = x + gate_mlp(cond) * MLP(
LN(x) * (1 + scale_mlp(cond)) + shift_mlp(cond)
)

一句话总结:

  • 标准 Transformer:
    • 先 LayerNorm,然后 Attention,然后 Add;
    • 然后 LayerNorm,然后 MLP,然后 Add;
  • DiT:
    • 先 LayerNorm,然后缩放,然后 Attention,然后门控,然后 Add;
    • 然后 LayerNorm,然后缩放,然后 MLP,然后门控,然后 Add;

伪代码

标准 Pre-Norm Transformer Block

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
import torch.nn as nn


class TransformerBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4):
super().__init__()

# 标准 LayerNorm:gamma/beta 是固定可学习参数
# shape: [dim]
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)

self.attn = nn.MultiheadAttention(
embed_dim=dim,
num_heads=num_heads,
batch_first=True,
)

hidden_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)

def forward(self, x):
"""
x: [B, N, D]
B = batch size
N = token 数,例如文本长度或图像 patch 数
D = hidden dim
"""

# Pre-Norm Transformer:
# 先 LN,再 Attention,然后 residual add
h = self.norm1(x)
h, _ = self.attn(h, h, h)
x = x + h

# 再 LN,再 MLP,然后 residual add
h = self.norm2(x)
h = self.mlp(h)
x = x + h

return x

DiT 的 adaLN-Zero Block

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
import torch.nn as nn


class DiTBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4):
super().__init__()

# 注意:这里常见做法是 elementwise_affine=False
# 也就是 LayerNorm 本身不再使用固定的 gamma/beta
# 因为 scale/shift 会由条件向量动态生成
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False)

self.attn = nn.MultiheadAttention(
embed_dim=dim,
num_heads=num_heads,
batch_first=True,
)

hidden_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)

# 不同点 1:
# 由 condition 生成 6 组参数:
# shift_msa, scale_msa, gate_msa
# shift_mlp, scale_mlp, gate_mlp
#
# 每一组 shape 都是 [B, D]
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(dim, 6 * dim),
)

# adaLN-Zero 的常见初始化:
# 最后一层初始化为 0,使得 block 初始时近似恒等映射
nn.init.zeros_(self.adaLN_modulation[-1].weight)
nn.init.zeros_(self.adaLN_modulation[-1].bias)

def forward(self, x, cond):
"""
x: [B, N, D]
noisy latent patch tokens

cond: [B, D]
timestep/class/text 等条件编码
"""

# 从条件向量生成动态调制参数
shift_msa, scale_msa, gate_msa, \
shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(cond).chunk(6, dim=-1)

# 为了和 x: [B, N, D] 广播,需要变成 [B, 1, D]
shift_msa = shift_msa.unsqueeze(1)
scale_msa = scale_msa.unsqueeze(1)
gate_msa = gate_msa.unsqueeze(1)

shift_mlp = shift_mlp.unsqueeze(1)
scale_mlp = scale_mlp.unsqueeze(1)
gate_mlp = gate_mlp.unsqueeze(1)

# 不同点 2:
# 普通 Transformer 是 self.norm1(x)
# DiT 是 modulate(self.norm1(x), shift, scale)
#
# 即:
# h = LN(x) * (1 + scale) + shift
h = self.norm1(x)
h = h * (1 + scale_msa) + shift_msa

h, _ = self.attn(h, h, h)

# 不同点 3:
# residual 分支前乘一个 gate
# gate 也是由 condition 生成的
x = x + gate_msa * h

h = self.norm2(x)
h = h * (1 + scale_mlp) + shift_mlp

h = self.mlp(h)

# MLP 分支也有 condition-dependent gate
x = x + gate_mlp * h

return x

DiT
https://d4wnnn.github.io/2026/05/23/Notion/DiT/
作者
D4wn
发布于
2026年5月23日
许可协议