77# This source code is licensed under the BSD-style license found in the
88# LICENSE file in the root directory of this source tree.
99
10+ from torchtitan .protocols .model import AttentionMasksType
1011import torch
1112from torch import nn
1213from torchtitan .models .attention import build_attention
@@ -111,8 +112,8 @@ def __init__(
111112 ):
112113 super ().__init__ ()
113114
114- self .sliding_window = (
115- model_args .sliding_window if use_sliding_attention else None
115+ self .sliding_window_size = (
116+ model_args .sliding_window_size if use_sliding_attention else None
116117 )
117118 self .head_dim = model_args .head_dim
118119 self .n_heads = model_args .n_heads
@@ -142,27 +143,33 @@ def __init__(
142143 )
143144 self .sinks = nn .Parameter (torch .empty (model_args .n_heads ))
144145
145- self .use_flex_attn = model_args . use_flex_attn
146+ self .use_flex_attn = getattr ( model_args , " use_flex_attn" , False )
146147
147- if not self .use_flex_attn :
148- raise ValueError ("Only support FlexAttention in Gpt-oss model" )
149-
150- # Only apply sliding window to every other layer
151- if use_sliding_attention :
152- self .attn = build_attention (
153- use_flex_attn = True ,
154- attn_mask_type = "sliding_window" ,
155- sliding_window = self .sliding_window ,
156- )
148+ if self .use_flex_attn :
149+ self .inner_attention = FlexAttentionWrapper ()
157150 else :
158- self .attn = build_attention (
159- use_flex_attn = True , attn_mask_type = model_args .attn_mask_type
160- )
151+ raise ValueError ("Gpt-oss model only supports FlexAttention!" )
152+
153+ def init_weights (self , init_std : float ):
154+ linear_list = [
155+ self .wq ,
156+ self .wk ,
157+ self .wv ,
158+ ]
161159
160+ nn .init .trunc_normal_ (self .sinks , mean = 0.0 , std = init_std )
161+ for linear in linear_list :
162+ nn .init .trunc_normal_ (linear .weight , mean = 0.0 , std = init_std )
163+ nn .init .trunc_normal_ (linear .bias , mean = 0.0 , std = init_std )
164+ nn .init .trunc_normal_ (self .wo .weight , mean = 0.0 , std = init_std )
165+ nn .init .trunc_normal_ (self .wo .bias , mean = 0.0 , std = init_std )
166+
167+
162168 def forward (
163169 self ,
164170 x : torch .Tensor ,
165171 rope_cache : torch .Tensor ,
172+ attention_mask : AttentionMasksType | None ,
166173 ):
167174 """
168175 Forward pass for the Multi-Head Latent Attention (MLA) Layer.
@@ -191,14 +198,18 @@ def forward(
191198 k = keys .transpose (1 , 2 ).contiguous ()
192199 v = values .transpose (1 , 2 ).contiguous ()
193200
194- # FlexAttention
195- output , aux_output = self .attn (
196- q ,
197- k ,
198- v ,
199- scale = None ,
200- return_lse = True ,
201- )
201+ if self .use_flex_attn :
202+ assert isinstance (attention_masks , BlockMask ), attention_masks
203+ output = self .inner_attention (xq , xk , xv , block_mask = attention_masks )
204+
205+ # # FlexAttention
206+ # output, aux_output = self.attn(
207+ # q,
208+ # k,
209+ # v,
210+ # scale=None,
211+ # return_lse=True,
212+ # )
202213
203214 # Apply attention sink rescaling: rescale by σ(lse - w[h])
204215 # This is mathematically equivalent to concatenating learnable sink weights
@@ -215,20 +226,6 @@ def forward(
215226 output = self .wo (output ) # (bsz, seqlen, dim)
216227 return output
217228
218- def init_weights (self , init_std : float ):
219- linear_list = [
220- self .wq ,
221- self .wk ,
222- self .wv ,
223- ]
224-
225- nn .init .trunc_normal_ (self .sinks , mean = 0.0 , std = init_std )
226- for linear in linear_list :
227- nn .init .trunc_normal_ (linear .weight , mean = 0.0 , std = init_std )
228- nn .init .trunc_normal_ (linear .bias , mean = 0.0 , std = init_std )
229- nn .init .trunc_normal_ (self .wo .weight , mean = 0.0 , std = init_std )
230- nn .init .trunc_normal_ (self .wo .bias , mean = 0.0 , std = init_std )
231-
232229 # TODO: statically init the mask using train.seq_len
233230 def sliding_window_causal (self , seqlen , device ):
234231 i = torch .arange (seqlen , device = device )
@@ -265,7 +262,7 @@ def __init__(self, layer_id: int, model_args: GptOssModelArgs):
265262 self .weight_init_std = 0.02 / (2 * (layer_id + 1 )) ** 0.5
266263 self .layer_id = layer_id
267264
268- def forward (self , x : torch .Tensor , rope_cache : torch .Tensor ):
265+ def forward (self , x : torch .Tensor , rope_cache : torch .Tensor , attention_masks : AttentionMasksType | None ):
269266 """
270267 Forward pass for the Transformer block.
271268
@@ -276,7 +273,7 @@ def forward(self, x: torch.Tensor, rope_cache: torch.Tensor):
276273 Returns:
277274 torch.Tensor: Output tensor with the same shape as the input.
278275 """
279- x = x + self .attention (self .attention_norm (x ), rope_cache )
276+ x = x + self .attention (self .attention_norm (x ), rope_cache , attention_masks )
280277 x = x + self .moe (self .ffn_norm (x ))
281278 return x
282279
@@ -346,6 +343,29 @@ def _precompute_rope_cache(self) -> torch.Tensor:
346343 self .model_args .rope_theta ,
347344 )
348345
346+ def get_attention_masks (
347+ self ,
348+ input_batch : torch .Tensor ,
349+ tokenizer : BaseTokenizer ,
350+ extra_inputs : dict [str , torch .Tensor ] | None = None ,
351+ ) -> AttentionMasksType :
352+ # TODO: implement this function
353+ mask_mods = [get_causal_mask_mod ()]
354+ match self .model_args .attn_mask_type :
355+ case "causal" :
356+ B = 1
357+ case "block_causal" :
358+ B = input_batch .shape [0 ]
359+ mask_mods .append (get_document_mask_mod (input_batch , tokenizer .eos_id ))
360+ case _:
361+ raise ValueError (
362+ f"Unknown attention mask type: { self .model_args .attn_mask_type } "
363+ )
364+ return create_attention_mask (
365+ and_masks (* mask_mods ), B , None , input_batch .shape [1 ], input_batch .shape [1 ]
366+ )
367+
368+
349369 def forward (self , tokens : torch .Tensor ):
350370 """
351371 Forward pass for the Transformer model.
0 commit comments