77# This source code is licensed under the BSD-style license found in the
88# LICENSE file in the root directory of this source tree.
99
10- from torchtitan .protocols .model import AttentionMasksType
1110import torch
1211from torch import nn
12+ from torch .nn .attention .flex_attention import and_masks , BlockMask
1313from torchtitan .components .tokenizer import BaseTokenizer
14- from torchtitan .protocols .train_spec import ModelProtocol
1514from torchtitan .models .attention import (
1615 create_attention_mask ,
1716 FlexAttentionWrapper ,
1817 get_causal_mask_mod ,
1918 get_document_mask_mod ,
20- ScaledDotProductAttentionWrapper ,
19+ get_sliding_window_mask_mod ,
2120)
22- from torch .nn .attention .flex_attention import and_masks , BlockMask
21+ from torchtitan .protocols .model import AttentionMasksType
22+ from torchtitan .protocols .train_spec import ModelProtocol
2323
2424from .args import GptOssModelArgs
2525from .moe import GptOssMoE
@@ -115,14 +115,8 @@ class Attention(nn.Module):
115115 Multi-head attention (MLA) module.
116116 """
117117
118- def __init__ (
119- self , model_args : GptOssModelArgs , use_sliding_attention : bool = False
120- ):
118+ def __init__ (self , model_args : GptOssModelArgs ):
121119 super ().__init__ ()
122-
123- self .sliding_window_size = (
124- model_args .sliding_window_size if use_sliding_attention else None
125- )
126120 self .head_dim = model_args .head_dim
127121 self .n_heads = model_args .n_heads
128122 self .n_kv_heads = model_args .n_kv_heads
@@ -157,7 +151,7 @@ def __init__(
157151 self .inner_attention = FlexAttentionWrapper ()
158152 else :
159153 raise ValueError ("Gpt-oss model only supports FlexAttention!" )
160-
154+
161155 def init_weights (self , init_std : float ):
162156 linear_list = [
163157 self .wq ,
@@ -172,7 +166,6 @@ def init_weights(self, init_std: float):
172166 nn .init .trunc_normal_ (self .wo .weight , mean = 0.0 , std = init_std )
173167 nn .init .trunc_normal_ (self .wo .bias , mean = 0.0 , std = init_std )
174168
175-
176169 def forward (
177170 self ,
178171 x : torch .Tensor ,
@@ -208,22 +201,15 @@ def forward(
208201
209202 if self .use_flex_attn :
210203 assert isinstance (attention_masks , BlockMask ), attention_masks
211- output = self .inner_attention (xq , xk , xv , block_mask = attention_masks )
212-
213- # # FlexAttention
214- # output, aux_output = self.attn(
215- # q,
216- # k,
217- # v,
218- # scale=None,
219- # return_lse=True,
220- # )
221-
222- # Apply attention sink rescaling: rescale by σ(lse - w[h])
223- # This is mathematically equivalent to concatenating learnable sink weights
224- lse = aux_output .lse
225- sink_scale = torch .sigmoid (lse - self .sinks .view (1 , - 1 , 1 )).unsqueeze (- 1 )
226- output = output * sink_scale .to (output .dtype )
204+ output , aux_output = self .inner_attention (
205+ xq , xk , xv , block_mask = attention_masks , scale = None , return_aux = True
206+ )
207+
208+ # Apply attention sink rescaling: rescale by σ(lse - w[h])
209+ # This is mathematically equivalent to concatenating learnable sink weights
210+ lse = aux_output .lse
211+ sink_scale = torch .sigmoid (lse - self .sinks .view (1 , - 1 , 1 )).unsqueeze (- 1 )
212+ output = output * sink_scale .to (output .dtype )
227213
228214 output = output .transpose (1 , 2 ).contiguous () # (B, H, T, D) -> (B, T, H, D)
229215
@@ -234,18 +220,6 @@ def forward(
234220 output = self .wo (output ) # (bsz, seqlen, dim)
235221 return output
236222
237- # TODO: statically init the mask using train.seq_len
238- def sliding_window_causal (self , seqlen , device ):
239- i = torch .arange (seqlen , device = device )
240- q_idx = i [:, None ]
241- kv_idx = i [None , :]
242-
243- causal_mask = q_idx >= kv_idx
244- if self .sliding_window is None :
245- return causal_mask
246- window_mask = q_idx - kv_idx <= self .sliding_window
247- return causal_mask & window_mask
248-
249223
250224class TransformerBlock (nn .Module ):
251225 """
@@ -255,10 +229,8 @@ class TransformerBlock(nn.Module):
255229 def __init__ (self , layer_id : int , model_args : GptOssModelArgs ):
256230
257231 super ().__init__ ()
258- use_sliding_attention = layer_id % 2 == 0
259- self .attention = Attention (
260- model_args , use_sliding_attention = use_sliding_attention
261- )
232+ self .use_sliding_attention = layer_id % 2 == 0
233+ self .attention = Attention (model_args )
262234 self .attention_norm = nn .RMSNorm (model_args .dim , eps = model_args .norm_eps )
263235 self .ffn_norm = nn .RMSNorm (model_args .dim , eps = model_args .norm_eps )
264236
@@ -270,18 +242,31 @@ def __init__(self, layer_id: int, model_args: GptOssModelArgs):
270242 self .weight_init_std = 0.02 / (2 * (layer_id + 1 )) ** 0.5
271243 self .layer_id = layer_id
272244
273- def forward (self , x : torch .Tensor , rope_cache : torch .Tensor , attention_masks : AttentionMasksType | None ):
245+ def forward (
246+ self ,
247+ x : torch .Tensor ,
248+ rope_cache : torch .Tensor ,
249+ attention_masks : AttentionMasksType | None ,
250+ ):
274251 """
275252 Forward pass for the Transformer block.
276253
277254 Args:
278255 x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
279256 rope_cache (torch.Tensor): Precomputed cosine and sine frequencies.
257+ attention_masks (AttentionMasksType | None): Either a single BlockMask or a dict of BlockMasks keyed by layer.
280258
281259 Returns:
282260 torch.Tensor: Output tensor with the same shape as the input.
283261 """
284- x = x + self .attention (self .attention_norm (x ), rope_cache , attention_masks )
262+ # Extract the appropriate mask for this layer
263+ if self .use_sliding_attention :
264+ layer_mask = attention_masks .get ("sliding_window_mask" , None )
265+ else :
266+ layer_mask = attention_masks .get ("basic_mask" , None )
267+ assert layer_mask is not None
268+
269+ x = x + self .attention (self .attention_norm (x ), rope_cache , layer_mask )
285270 x = x + self .moe (self .ffn_norm (x ))
286271 return x
287272
@@ -357,24 +342,54 @@ def get_attention_masks(
357342 tokenizer : BaseTokenizer ,
358343 extra_inputs : dict [str , torch .Tensor ] | None = None ,
359344 ) -> AttentionMasksType :
360- # TODO: implement this function
361- mask_mods = [get_causal_mask_mod ()]
345+
346+ basic_mask_mods = []
347+ sliding_window_mask_mods = [
348+ get_sliding_window_mask_mod (self .model_args .sliding_window_size )
349+ ]
362350 match self .model_args .attn_mask_type :
363351 case "causal" :
364352 B = 1
353+ basic_mask_mods .append (get_causal_mask_mod ())
354+ sliding_window_mask_mods .append (get_causal_mask_mod ())
365355 case "block_causal" :
366356 B = input_batch .shape [0 ]
367- mask_mods .append (get_document_mask_mod (input_batch , tokenizer .eos_id ))
357+ basic_mask_mods .append (
358+ get_document_mask_mod (input_batch , tokenizer .eos_id )
359+ )
360+ sliding_window_mask_mods .append (
361+ get_document_mask_mod (input_batch , tokenizer .eos_id )
362+ )
368363 case _:
369364 raise ValueError (
370365 f"Unknown attention mask type: { self .model_args .attn_mask_type } "
371366 )
372- return create_attention_mask (
373- and_masks (* mask_mods ), B , None , input_batch .shape [1 ], input_batch .shape [1 ]
367+
368+ # create basic attention mask: causal or block_causal
369+ basic_mask = create_attention_mask (
370+ and_masks (* basic_mask_mods ),
371+ B ,
372+ None ,
373+ input_batch .shape [1 ],
374+ input_batch .shape [1 ],
375+ )
376+
377+ # create sliding window mask, has to
378+ sliding_window_mask = create_attention_mask (
379+ and_masks (* sliding_window_mask_mods ),
380+ B ,
381+ None ,
382+ input_batch .shape [1 ],
383+ input_batch .shape [1 ],
374384 )
375385
386+ return {"basic_mask" : basic_mask , "sliding_window_mask" : sliding_window_mask }
376387
377- def forward (self , tokens : torch .Tensor , attention_masks : AttentionMasksType | None = None ,):
388+ def forward (
389+ self ,
390+ tokens : torch .Tensor ,
391+ attention_masks : AttentionMasksType | None = None ,
392+ ):
378393 """
379394 Forward pass for the Transformer model.
380395
0 commit comments