1111
1212
1313@dataclass
14- class GPTConfig :
15- vocab_size : int = 50_257
14+ class BlockGPTConfig :
15+ vocab_size : int = 50_258
1616 bos_id : int = 50_256
17+ mask_id : int = 50_257
1718 num_layers : int = 12
1819 num_heads : int = 6
1920 model_dim : int = 768
2021 max_seq_len : int = 131_072 # 2**17
2122 head_dim : int = 128
2223 intermediate_dim : int | None = None
24+ diffusion_block_size : int = 16
25+ t_lower : float = 0.3
26+ t_upper : float = 0.8
2327
2428
2529def norm (x ):
@@ -63,7 +67,7 @@ def forward(self, x_BTHD: Tensor, pos_id: Tensor):
6367
6468
6569class CausalSelfAttention (nn .Module ):
66- def __init__ (self , config : GPTConfig ):
70+ def __init__ (self , config : BlockGPTConfig ):
6771 super ().__init__ ()
6872 self .config = config
6973
@@ -107,7 +111,7 @@ def forward(self, x: Tensor, v_residual: Tensor | None, pos_id: Tensor, block_ma
107111
108112
109113class MLP (nn .Module ):
110- def __init__ (self , config : GPTConfig ):
114+ def __init__ (self , config : BlockGPTConfig ):
111115 super ().__init__ ()
112116 intermediate_dim = config .intermediate_dim or 4 * config .model_dim
113117 self .in_proj = CastedLinear (config .model_dim , intermediate_dim )
@@ -121,7 +125,7 @@ def forward(self, x: Tensor):
121125
122126
123127class Block (nn .Module ):
124- def __init__ (self , config : GPTConfig ):
128+ def __init__ (self , config : BlockGPTConfig ):
125129 super ().__init__ ()
126130 self .attn = CausalSelfAttention (config )
127131 self .mlp = MLP (config )
@@ -135,8 +139,8 @@ def forward(self, x: Tensor, v_residual: Tensor, x0: Tensor, pos_id: Tensor, blo
135139 return x , v_residual
136140
137141
138- class GPT (nn .Module ):
139- def __init__ (self , config : GPTConfig ):
142+ class BlockGPT (nn .Module ):
143+ def __init__ (self , config : BlockGPTConfig ):
140144 super ().__init__ ()
141145 self .config = config
142146
@@ -149,25 +153,53 @@ def __init__(self, config: GPTConfig):
149153 assert len (self .blocks ) % 2 == 0
150154 self .skip_w = nn .Parameter (torch .ones (len (self .blocks ) // 2 ))
151155
152- def create_blockmask (self , input_seq : Tensor ):
153- docs = (input_seq == self .config .bos_token_id ).cumsum (0 )
156+ def create_blockmask (self , doc_id : Tensor , pos_id : Tensor ):
157+ """BlockMask for attn rules from https://arxiv.org/pdf/2503.09573 section 3.1"""
158+ L = len (doc_id )
154159
155- def document_causal_mask (b , h , q_idx , kv_idx ):
156- causal_mask = q_idx >= kv_idx
157- document_mask = docs [q_idx ] == docs [kv_idx ]
158- return causal_mask & document_mask # & window_mask
160+ block_id = pos_id // self .config .diffusion_block_size + doc_id * L
161+ block_id = torch .cumsum (block_id != block_id .roll (1 , 0 ), 0 ) - 1
159162
160- S = len ( input_seq )
161- return create_block_mask ( document_causal_mask , None , None , S , S , device = "cuda" , _compile = True )
163+ block_id , doc_id = block_id . repeat ( 2 ), doc_id . repeat ( 2 )
164+ noisy = torch . arange ( 2 * L , device = doc_id . device ) < L
162165
163- def forward (self , input_seq : Tensor , target_seq : Tensor ):
166+ def block_diffusion_mask (b , h , q , kv ):
167+ # mask from section 3.1 of https://arxiv.org/pdf/2503.09573
168+ blk_q , blk_kv = block_id [q ], block_id [kv ]
169+
170+ bd = (blk_q == blk_kv ) & (noisy [q ] == noisy [kv ]) # Block Diagonal
171+ obc = (blk_q > blk_kv ) & noisy [q ] & (~ noisy [kv ]) # Offset Block Causal
172+ bc = (blk_q >= blk_kv ) & (~ noisy [q ]) & (~ noisy [kv ]) # Block Causal
173+
174+ same_doc = doc_id [q ] == doc_id [kv ]
175+ return same_doc & (bd | obc | bc )
176+
177+ S = 2 * L
178+ return create_block_mask (block_diffusion_mask , None , None , S , S )
179+
180+ def forward (self , input_seq : Tensor ):
164181 assert input_seq .ndim == 1
165182
166- x = x0 = norm (self .embed (input_seq )[None ])
183+ # construct attention rules & block mask
184+ doc_id = (input_seq == self .config .bos_id ).cumsum (0 )
185+ p = torch .arange (input_seq .size (0 ), device = input_seq .device )
186+ pos_id = p - torch .where (input_seq == self .config .bos_id , p , - 1 ).cummax (0 ).values
187+ block_mask = self .create_blockmask (doc_id , pos_id )
188+
189+ # Apply noise to sequence
190+ noise_range = (self .config .t_lower , self .config .t_upper ) if self .training else (0.0 , 1.0 )
191+ rand = torch .rand_like (input_seq , dtype = torch .float32 )
192+ t = torch .empty_like (rand ).uniform_ (* noise_range )[doc_id ]
193+ noisy_seq = input_seq .masked_fill (rand >= (1 - t ), self .config .mask_id )
194+
195+ # Concat noisy + clean into seq and repeat pos_ids
196+ seq = torch .cat ([noisy_seq , input_seq ], dim = 0 )
197+ pos_id = pos_id .repeat (2 )
198+
199+ # Embedding & U-net backbone forward
200+ x = x0 = norm (self .embed (seq )[None ])
167201 v_residual = None
168202
169- # U-net design
170- block_mask = self .create_blockmask (input_seq )
171203 skip_conns , n = [], len (self .skip_w )
172204 for i , block in enumerate (self .blocks ):
173205 if i >= n :
@@ -176,11 +208,16 @@ def forward(self, input_seq: Tensor, target_seq: Tensor):
176208 if i < n :
177209 skip_conns .append (x )
178210
211+ x = x [:, :input_seq .size (0 )] # Get logits for noisy tokens only
179212 x = norm (x )
180213 logits = self .lm_head (x ).float ()
181214
182- # tanh softcapping
183- logits = 30 * torch .sigmoid (logits / (7.5 * x .size (- 1 )** 0.5 ))
184- loss = F .cross_entropy (logits .view (- 1 , logits .size (- 1 )), target_seq , reduction = 'sum' if self .training else 'mean' )
185-
186- return loss
215+ # Get loss for masked tokens
216+ mask = (noisy_seq == self .config .mask_id )
217+ targets = torch .where (mask , input_seq , torch .full_like (input_seq , - 100 ))
218+ losses = F .cross_entropy (logits .view (- 1 , logits .size (- 1 )), targets .view (- 1 ), reduction = 'none' )
219+ if self .training :
220+ weights = (1.0 / (t + 1e-4 )).type_as (logits )
221+ return (losses * weights * mask ).sum () / mask .sum ()
222+ else :
223+ return (losses * mask ).sum () / mask .sum ()
0 commit comments