1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
| import torch import torch.nn as nn
class DiTBlock(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4): super().__init__()
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False) self.norm2 = nn.LayerNorm(dim, elementwise_affine=False)
self.attn = nn.MultiheadAttention( embed_dim=dim, num_heads=num_heads, batch_first=True, )
hidden_dim = int(dim * mlp_ratio) self.mlp = nn.Sequential( nn.Linear(dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, dim), )
self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(dim, 6 * dim), )
nn.init.zeros_(self.adaLN_modulation[-1].weight) nn.init.zeros_(self.adaLN_modulation[-1].bias)
def forward(self, x, cond): """ x: [B, N, D] noisy latent patch tokens
cond: [B, D] timestep/class/text 等条件编码 """
shift_msa, scale_msa, gate_msa, \ shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(cond).chunk(6, dim=-1)
shift_msa = shift_msa.unsqueeze(1) scale_msa = scale_msa.unsqueeze(1) gate_msa = gate_msa.unsqueeze(1)
shift_mlp = shift_mlp.unsqueeze(1) scale_mlp = scale_mlp.unsqueeze(1) gate_mlp = gate_mlp.unsqueeze(1)
h = self.norm1(x) h = h * (1 + scale_msa) + shift_msa
h, _ = self.attn(h, h, h)
x = x + gate_msa * h
h = self.norm2(x) h = h * (1 + scale_mlp) + shift_mlp
h = self.mlp(h)
x = x + gate_mlp * h
return x
|