Skip to content

Commit e740303

Browse files
committed
rebase to main
1 parent 21a3578 commit e740303

File tree

2 files changed

+61
-41
lines changed

2 files changed

+61
-41
lines changed

torchtitan/experiments/gpt_oss/model/args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class GptOssModelArgs(BaseModelArgs):
7373
head_dim: int = 64
7474
n_heads: int = 64
7575
n_kv_heads: int = 8
76-
sliding_window: int = 128
76+
sliding_window_size: int = 128
7777
use_flex_attn: bool = True
7878
attn_mask_type: str = "causal"
7979
# yarn

torchtitan/experiments/gpt_oss/model/model.py

Lines changed: 60 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# This source code is licensed under the BSD-style license found in the
88
# LICENSE file in the root directory of this source tree.
99

10+
from torchtitan.protocols.model import AttentionMasksType
1011
import torch
1112
from torch import nn
1213
from torchtitan.models.attention import build_attention
@@ -111,8 +112,8 @@ def __init__(
111112
):
112113
super().__init__()
113114

114-
self.sliding_window = (
115-
model_args.sliding_window if use_sliding_attention else None
115+
self.sliding_window_size = (
116+
model_args.sliding_window_size if use_sliding_attention else None
116117
)
117118
self.head_dim = model_args.head_dim
118119
self.n_heads = model_args.n_heads
@@ -142,27 +143,33 @@ def __init__(
142143
)
143144
self.sinks = nn.Parameter(torch.empty(model_args.n_heads))
144145

145-
self.use_flex_attn = model_args.use_flex_attn
146+
self.use_flex_attn = getattr(model_args, "use_flex_attn", False)
146147

147-
if not self.use_flex_attn:
148-
raise ValueError("Only support FlexAttention in Gpt-oss model")
149-
150-
# Only apply sliding window to every other layer
151-
if use_sliding_attention:
152-
self.attn = build_attention(
153-
use_flex_attn=True,
154-
attn_mask_type="sliding_window",
155-
sliding_window=self.sliding_window,
156-
)
148+
if self.use_flex_attn:
149+
self.inner_attention = FlexAttentionWrapper()
157150
else:
158-
self.attn = build_attention(
159-
use_flex_attn=True, attn_mask_type=model_args.attn_mask_type
160-
)
151+
raise ValueError("Gpt-oss model only supports FlexAttention!")
152+
153+
def init_weights(self, init_std: float):
154+
linear_list = [
155+
self.wq,
156+
self.wk,
157+
self.wv,
158+
]
161159

160+
nn.init.trunc_normal_(self.sinks, mean=0.0, std=init_std)
161+
for linear in linear_list:
162+
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)
163+
nn.init.trunc_normal_(linear.bias, mean=0.0, std=init_std)
164+
nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std)
165+
nn.init.trunc_normal_(self.wo.bias, mean=0.0, std=init_std)
166+
167+
162168
def forward(
163169
self,
164170
x: torch.Tensor,
165171
rope_cache: torch.Tensor,
172+
attention_mask: AttentionMasksType | None,
166173
):
167174
"""
168175
Forward pass for the Multi-Head Latent Attention (MLA) Layer.
@@ -191,14 +198,18 @@ def forward(
191198
k = keys.transpose(1, 2).contiguous()
192199
v = values.transpose(1, 2).contiguous()
193200

194-
# FlexAttention
195-
output, aux_output = self.attn(
196-
q,
197-
k,
198-
v,
199-
scale=None,
200-
return_lse=True,
201-
)
201+
if self.use_flex_attn:
202+
assert isinstance(attention_masks, BlockMask), attention_masks
203+
output = self.inner_attention(xq, xk, xv, block_mask=attention_masks)
204+
205+
# # FlexAttention
206+
# output, aux_output = self.attn(
207+
# q,
208+
# k,
209+
# v,
210+
# scale=None,
211+
# return_lse=True,
212+
# )
202213

203214
# Apply attention sink rescaling: rescale by σ(lse - w[h])
204215
# This is mathematically equivalent to concatenating learnable sink weights
@@ -215,20 +226,6 @@ def forward(
215226
output = self.wo(output) # (bsz, seqlen, dim)
216227
return output
217228

218-
def init_weights(self, init_std: float):
219-
linear_list = [
220-
self.wq,
221-
self.wk,
222-
self.wv,
223-
]
224-
225-
nn.init.trunc_normal_(self.sinks, mean=0.0, std=init_std)
226-
for linear in linear_list:
227-
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)
228-
nn.init.trunc_normal_(linear.bias, mean=0.0, std=init_std)
229-
nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std)
230-
nn.init.trunc_normal_(self.wo.bias, mean=0.0, std=init_std)
231-
232229
# TODO: statically init the mask using train.seq_len
233230
def sliding_window_causal(self, seqlen, device):
234231
i = torch.arange(seqlen, device=device)
@@ -265,7 +262,7 @@ def __init__(self, layer_id: int, model_args: GptOssModelArgs):
265262
self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5
266263
self.layer_id = layer_id
267264

268-
def forward(self, x: torch.Tensor, rope_cache: torch.Tensor):
265+
def forward(self, x: torch.Tensor, rope_cache: torch.Tensor, attention_masks: AttentionMasksType | None):
269266
"""
270267
Forward pass for the Transformer block.
271268
@@ -276,7 +273,7 @@ def forward(self, x: torch.Tensor, rope_cache: torch.Tensor):
276273
Returns:
277274
torch.Tensor: Output tensor with the same shape as the input.
278275
"""
279-
x = x + self.attention(self.attention_norm(x), rope_cache)
276+
x = x + self.attention(self.attention_norm(x), rope_cache, attention_masks)
280277
x = x + self.moe(self.ffn_norm(x))
281278
return x
282279

@@ -346,6 +343,29 @@ def _precompute_rope_cache(self) -> torch.Tensor:
346343
self.model_args.rope_theta,
347344
)
348345

346+
def get_attention_masks(
347+
self,
348+
input_batch: torch.Tensor,
349+
tokenizer: BaseTokenizer,
350+
extra_inputs: dict[str, torch.Tensor] | None = None,
351+
) -> AttentionMasksType:
352+
# TODO: implement this function
353+
mask_mods = [get_causal_mask_mod()]
354+
match self.model_args.attn_mask_type:
355+
case "causal":
356+
B = 1
357+
case "block_causal":
358+
B = input_batch.shape[0]
359+
mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id))
360+
case _:
361+
raise ValueError(
362+
f"Unknown attention mask type: {self.model_args.attn_mask_type}"
363+
)
364+
return create_attention_mask(
365+
and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1]
366+
)
367+
368+
349369
def forward(self, tokens: torch.Tensor):
350370
"""
351371
Forward pass for the Transformer model.

0 commit comments

Comments
 (0)