Skip to content

Commit dcda272

Browse files
committed
Convert autoregressive GPT model into Block Diffusion model
1 parent b983087 commit dcda272

File tree

2 files changed

+69
-36
lines changed

2 files changed

+69
-36
lines changed

model.py

Lines changed: 61 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,19 @@
1111

1212

1313
@dataclass
14-
class GPTConfig:
15-
vocab_size: int = 50_257
14+
class BlockGPTConfig:
15+
vocab_size: int = 50_258
1616
bos_id: int = 50_256
17+
mask_id: int = 50_257
1718
num_layers: int = 12
1819
num_heads: int = 6
1920
model_dim: int = 768
2021
max_seq_len: int = 131_072 # 2**17
2122
head_dim: int = 128
2223
intermediate_dim: int | None = None
24+
diffusion_block_size: int = 16
25+
t_lower: float = 0.3
26+
t_upper: float = 0.8
2327

2428

2529
def norm(x):
@@ -63,7 +67,7 @@ def forward(self, x_BTHD: Tensor, pos_id: Tensor):
6367

6468

6569
class CausalSelfAttention(nn.Module):
66-
def __init__(self, config: GPTConfig):
70+
def __init__(self, config: BlockGPTConfig):
6771
super().__init__()
6872
self.config = config
6973

@@ -107,7 +111,7 @@ def forward(self, x: Tensor, v_residual: Tensor | None, pos_id: Tensor, block_ma
107111

108112

109113
class MLP(nn.Module):
110-
def __init__(self, config: GPTConfig):
114+
def __init__(self, config: BlockGPTConfig):
111115
super().__init__()
112116
intermediate_dim = config.intermediate_dim or 4 * config.model_dim
113117
self.in_proj = CastedLinear(config.model_dim, intermediate_dim)
@@ -121,7 +125,7 @@ def forward(self, x: Tensor):
121125

122126

123127
class Block(nn.Module):
124-
def __init__(self, config: GPTConfig):
128+
def __init__(self, config: BlockGPTConfig):
125129
super().__init__()
126130
self.attn = CausalSelfAttention(config)
127131
self.mlp = MLP(config)
@@ -135,8 +139,8 @@ def forward(self, x: Tensor, v_residual: Tensor, x0: Tensor, pos_id: Tensor, blo
135139
return x, v_residual
136140

137141

138-
class GPT(nn.Module):
139-
def __init__(self, config: GPTConfig):
142+
class BlockGPT(nn.Module):
143+
def __init__(self, config: BlockGPTConfig):
140144
super().__init__()
141145
self.config = config
142146

@@ -149,25 +153,53 @@ def __init__(self, config: GPTConfig):
149153
assert len(self.blocks) % 2 == 0
150154
self.skip_w = nn.Parameter(torch.ones(len(self.blocks) // 2))
151155

152-
def create_blockmask(self, input_seq: Tensor):
153-
docs = (input_seq == self.config.bos_token_id).cumsum(0)
156+
def create_blockmask(self, doc_id: Tensor, pos_id: Tensor):
157+
"""BlockMask for attn rules from https://arxiv.org/pdf/2503.09573 section 3.1"""
158+
L = len(doc_id)
154159

155-
def document_causal_mask(b, h, q_idx, kv_idx):
156-
causal_mask = q_idx >= kv_idx
157-
document_mask = docs[q_idx] == docs[kv_idx]
158-
return causal_mask & document_mask # & window_mask
160+
block_id = pos_id // self.config.diffusion_block_size + doc_id * L
161+
block_id = torch.cumsum(block_id != block_id.roll(1, 0), 0) - 1
159162

160-
S = len(input_seq)
161-
return create_block_mask(document_causal_mask, None, None, S, S, device="cuda", _compile=True)
163+
block_id, doc_id = block_id.repeat(2), doc_id.repeat(2)
164+
noisy = torch.arange(2 * L, device=doc_id.device) < L
162165

163-
def forward(self, input_seq: Tensor, target_seq: Tensor):
166+
def block_diffusion_mask(b, h, q, kv):
167+
# mask from section 3.1 of https://arxiv.org/pdf/2503.09573
168+
blk_q, blk_kv = block_id[q], block_id[kv]
169+
170+
bd = (blk_q == blk_kv) & (noisy[q] == noisy[kv]) # Block Diagonal
171+
obc = (blk_q > blk_kv) & noisy[q] & (~noisy[kv]) # Offset Block Causal
172+
bc = (blk_q >= blk_kv) & (~noisy[q]) & (~noisy[kv]) # Block Causal
173+
174+
same_doc = doc_id[q] == doc_id[kv]
175+
return same_doc & (bd | obc | bc)
176+
177+
S = 2 * L
178+
return create_block_mask(block_diffusion_mask, None, None, S, S)
179+
180+
def forward(self, input_seq: Tensor):
164181
assert input_seq.ndim == 1
165182

166-
x = x0 = norm(self.embed(input_seq)[None])
183+
# construct attention rules & block mask
184+
doc_id = (input_seq == self.config.bos_id).cumsum(0)
185+
p = torch.arange(input_seq.size(0), device=input_seq.device)
186+
pos_id = p - torch.where(input_seq == self.config.bos_id, p, -1).cummax(0).values
187+
block_mask = self.create_blockmask(doc_id, pos_id)
188+
189+
# Apply noise to sequence
190+
noise_range = (self.config.t_lower, self.config.t_upper) if self.training else (0.0, 1.0)
191+
rand = torch.rand_like(input_seq, dtype=torch.float32)
192+
t = torch.empty_like(rand).uniform_(*noise_range)[doc_id]
193+
noisy_seq = input_seq.masked_fill(rand >= (1 - t), self.config.mask_id)
194+
195+
# Concat noisy + clean into seq and repeat pos_ids
196+
seq = torch.cat([noisy_seq, input_seq], dim=0)
197+
pos_id = pos_id.repeat(2)
198+
199+
# Embedding & U-net backbone forward
200+
x = x0 = norm(self.embed(seq)[None])
167201
v_residual = None
168202

169-
# U-net design
170-
block_mask = self.create_blockmask(input_seq)
171203
skip_conns, n = [], len(self.skip_w)
172204
for i, block in enumerate(self.blocks):
173205
if i >= n:
@@ -176,11 +208,16 @@ def forward(self, input_seq: Tensor, target_seq: Tensor):
176208
if i < n:
177209
skip_conns.append(x)
178210

211+
x = x[:, :input_seq.size(0)] # Get logits for noisy tokens only
179212
x = norm(x)
180213
logits = self.lm_head(x).float()
181214

182-
# tanh softcapping
183-
logits = 30 * torch.sigmoid(logits / (7.5 * x.size(-1)**0.5))
184-
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction='sum' if self.training else 'mean')
185-
186-
return loss
215+
# Get loss for masked tokens
216+
mask = (noisy_seq == self.config.mask_id)
217+
targets = torch.where(mask, input_seq, torch.full_like(input_seq, -100))
218+
losses = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none')
219+
if self.training:
220+
weights = (1.0 / (t + 1e-4)).type_as(logits)
221+
return (losses * weights * mask).sum() / mask.sum()
222+
else:
223+
return (losses * mask).sum() / mask.sum()

train.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
import torch.distributed as dist
1111

1212
from muon import Muon
13-
14-
from model import GPT, GPTConfig
13+
from model import BlockGPT, BlockGPTConfig
1514

1615

1716
code = "\n".join([
@@ -73,17 +72,14 @@ def distributed_data_generator(
7372
tokens, pos = _load_data_shard(next(file_iter)), 0
7473

7574
while True:
76-
if pos + batch_size + 1 >= len(tokens):
75+
if pos + batch_size >= len(tokens):
7776
tokens, pos = _load_data_shard(next(file_iter)), 0
7877
buf = tokens[pos + rank * local_bs :][: local_bs + 1]
7978
inputs = buf[:-1].to(
8079
device="cuda", dtype=torch.int32, non_blocking=True
8180
)
82-
targets = buf[1:].to(
83-
device="cuda", dtype=torch.int64, non_blocking=True
84-
)
8581
pos += batch_size
86-
yield inputs, targets
82+
yield inputs
8783

8884

8985
# -----------------------------------------------------------------------------
@@ -96,16 +92,16 @@ def evaluate(model, loader, steps):
9692
total = 0.0
9793
with torch.no_grad():
9894
for _ in range(steps):
99-
x, y = next(loader)
100-
total += model(x, y)
95+
x = next(loader)
96+
total += model(x)
10197
return total / steps
10298

10399

104100
def train_step(model, loader, step, optimizers, optimizer2, accum_steps):
105101
# forward/backward accumulation
106102
for _ in range(accum_steps):
107-
x, y = next(loader)
108-
loss = model(x, y)
103+
x = next(loader)
104+
loss = model(x)
109105
loss.backward()
110106

111107
# gradient all‐reduce across ranks
@@ -167,7 +163,7 @@ def print0(s: str, console: bool = False):
167163
print0("=" * 100)
168164

169165

170-
model = GPT(GPTConfig()).cuda()
166+
model = BlockGPT(BlockGPTConfig()).cuda()
171167

172168
for m in model.modules():
173169
if isinstance(m, torch.nn.Embedding):

0 commit comments

Comments
 (0)