Skip to content

Commit da672b2

Browse files
committed
refactor FlexAttention
1 parent 40bd901 commit da672b2

File tree

6 files changed

+136
-85
lines changed

6 files changed

+136
-85
lines changed

torchtitan/experiments/gpt_oss/infra/expert_parallel.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,9 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Callable
87

9-
import torch
108
import torch.nn as nn
11-
from torch.distributed.tensor import (
12-
DeviceMesh,
13-
distribute_module,
14-
distribute_tensor,
15-
DTensor,
16-
Replicate,
17-
Shard,
18-
)
19-
from torch.distributed.tensor.parallel import ParallelStyle
9+
from torch.distributed.tensor import distribute_tensor, Replicate, Shard
2010
from torchtitan.distributed.expert_parallel import ExpertParallel, TensorParallel
2111

2212
# implementation of Tensor Parallel for the GroupedExperts in MoE

torchtitan/experiments/gpt_oss/infra/parallelize.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@
2525
ExpertParallel,
2626
ReordererSequenceParallel,
2727
)
28-
from torchtitan.models.llama4.infra.parallelize import apply_fsdp
2928
from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp
29+
from torchtitan.models.llama4.infra.parallelize import apply_fsdp
3030
from torchtitan.tools.logging import logger
3131

32-
from .expert_parallel import GptossExpertTensorParallel, GptossTensorParallel
32+
from .expert_parallel import GptossExpertTensorParallel
3333

3434

3535
# for selective op activation checkpointing
@@ -308,10 +308,10 @@ def apply_moe_ep_tp(
308308
elif tp_mesh is None:
309309
experts_mesh = ep_mesh
310310
# input / output sharding on the batch / tokens dim
311-
experts_plan = ExpertParallel()
311+
experts_plan = GptossExpertParallel()
312312
elif etp_enabled:
313313
experts_mesh = ep_tp_mesh
314-
experts_plan = ExpertTensorParallel(tp_mesh=tp_mesh, ep_mesh=ep_mesh)
314+
experts_plan = GptossExpertTensorParallel(tp_mesh=tp_mesh, ep_mesh=ep_mesh)
315315
else:
316316
experts_mesh = ep_mesh
317317
experts_plan = ExpertParallel()

torchtitan/experiments/gpt_oss/model/model.py

Lines changed: 68 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,19 @@
77
# This source code is licensed under the BSD-style license found in the
88
# LICENSE file in the root directory of this source tree.
99

10-
from torchtitan.protocols.model import AttentionMasksType
1110
import torch
1211
from torch import nn
12+
from torch.nn.attention.flex_attention import and_masks, BlockMask
1313
from torchtitan.components.tokenizer import BaseTokenizer
14-
from torchtitan.protocols.train_spec import ModelProtocol
1514
from torchtitan.models.attention import (
1615
create_attention_mask,
1716
FlexAttentionWrapper,
1817
get_causal_mask_mod,
1918
get_document_mask_mod,
20-
ScaledDotProductAttentionWrapper,
19+
get_sliding_window_mask_mod,
2120
)
22-
from torch.nn.attention.flex_attention import and_masks, BlockMask
21+
from torchtitan.protocols.model import AttentionMasksType
22+
from torchtitan.protocols.train_spec import ModelProtocol
2323

2424
from .args import GptOssModelArgs
2525
from .moe import GptOssMoE
@@ -115,14 +115,8 @@ class Attention(nn.Module):
115115
Multi-head attention (MLA) module.
116116
"""
117117

118-
def __init__(
119-
self, model_args: GptOssModelArgs, use_sliding_attention: bool = False
120-
):
118+
def __init__(self, model_args: GptOssModelArgs):
121119
super().__init__()
122-
123-
self.sliding_window_size = (
124-
model_args.sliding_window_size if use_sliding_attention else None
125-
)
126120
self.head_dim = model_args.head_dim
127121
self.n_heads = model_args.n_heads
128122
self.n_kv_heads = model_args.n_kv_heads
@@ -157,7 +151,7 @@ def __init__(
157151
self.inner_attention = FlexAttentionWrapper()
158152
else:
159153
raise ValueError("Gpt-oss model only supports FlexAttention!")
160-
154+
161155
def init_weights(self, init_std: float):
162156
linear_list = [
163157
self.wq,
@@ -172,7 +166,6 @@ def init_weights(self, init_std: float):
172166
nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std)
173167
nn.init.trunc_normal_(self.wo.bias, mean=0.0, std=init_std)
174168

175-
176169
def forward(
177170
self,
178171
x: torch.Tensor,
@@ -208,22 +201,15 @@ def forward(
208201

209202
if self.use_flex_attn:
210203
assert isinstance(attention_masks, BlockMask), attention_masks
211-
output = self.inner_attention(xq, xk, xv, block_mask=attention_masks)
212-
213-
# # FlexAttention
214-
# output, aux_output = self.attn(
215-
# q,
216-
# k,
217-
# v,
218-
# scale=None,
219-
# return_lse=True,
220-
# )
221-
222-
# Apply attention sink rescaling: rescale by σ(lse - w[h])
223-
# This is mathematically equivalent to concatenating learnable sink weights
224-
lse = aux_output.lse
225-
sink_scale = torch.sigmoid(lse - self.sinks.view(1, -1, 1)).unsqueeze(-1)
226-
output = output * sink_scale.to(output.dtype)
204+
output, aux_output = self.inner_attention(
205+
xq, xk, xv, block_mask=attention_masks, scale=None, return_aux=True
206+
)
207+
208+
# Apply attention sink rescaling: rescale by σ(lse - w[h])
209+
# This is mathematically equivalent to concatenating learnable sink weights
210+
lse = aux_output.lse
211+
sink_scale = torch.sigmoid(lse - self.sinks.view(1, -1, 1)).unsqueeze(-1)
212+
output = output * sink_scale.to(output.dtype)
227213

228214
output = output.transpose(1, 2).contiguous() # (B, H, T, D) -> (B, T, H, D)
229215

@@ -234,18 +220,6 @@ def forward(
234220
output = self.wo(output) # (bsz, seqlen, dim)
235221
return output
236222

237-
# TODO: statically init the mask using train.seq_len
238-
def sliding_window_causal(self, seqlen, device):
239-
i = torch.arange(seqlen, device=device)
240-
q_idx = i[:, None]
241-
kv_idx = i[None, :]
242-
243-
causal_mask = q_idx >= kv_idx
244-
if self.sliding_window is None:
245-
return causal_mask
246-
window_mask = q_idx - kv_idx <= self.sliding_window
247-
return causal_mask & window_mask
248-
249223

250224
class TransformerBlock(nn.Module):
251225
"""
@@ -255,10 +229,8 @@ class TransformerBlock(nn.Module):
255229
def __init__(self, layer_id: int, model_args: GptOssModelArgs):
256230

257231
super().__init__()
258-
use_sliding_attention = layer_id % 2 == 0
259-
self.attention = Attention(
260-
model_args, use_sliding_attention=use_sliding_attention
261-
)
232+
self.use_sliding_attention = layer_id % 2 == 0
233+
self.attention = Attention(model_args)
262234
self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps)
263235
self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps)
264236

@@ -270,18 +242,31 @@ def __init__(self, layer_id: int, model_args: GptOssModelArgs):
270242
self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5
271243
self.layer_id = layer_id
272244

273-
def forward(self, x: torch.Tensor, rope_cache: torch.Tensor, attention_masks: AttentionMasksType | None):
245+
def forward(
246+
self,
247+
x: torch.Tensor,
248+
rope_cache: torch.Tensor,
249+
attention_masks: AttentionMasksType | None,
250+
):
274251
"""
275252
Forward pass for the Transformer block.
276253
277254
Args:
278255
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
279256
rope_cache (torch.Tensor): Precomputed cosine and sine frequencies.
257+
attention_masks (AttentionMasksType | None): Either a single BlockMask or a dict of BlockMasks keyed by layer.
280258
281259
Returns:
282260
torch.Tensor: Output tensor with the same shape as the input.
283261
"""
284-
x = x + self.attention(self.attention_norm(x), rope_cache, attention_masks)
262+
# Extract the appropriate mask for this layer
263+
if self.use_sliding_attention:
264+
layer_mask = attention_masks.get("sliding_window_mask", None)
265+
else:
266+
layer_mask = attention_masks.get("basic_mask", None)
267+
assert layer_mask is not None
268+
269+
x = x + self.attention(self.attention_norm(x), rope_cache, layer_mask)
285270
x = x + self.moe(self.ffn_norm(x))
286271
return x
287272

@@ -357,24 +342,54 @@ def get_attention_masks(
357342
tokenizer: BaseTokenizer,
358343
extra_inputs: dict[str, torch.Tensor] | None = None,
359344
) -> AttentionMasksType:
360-
# TODO: implement this function
361-
mask_mods = [get_causal_mask_mod()]
345+
346+
basic_mask_mods = []
347+
sliding_window_mask_mods = [
348+
get_sliding_window_mask_mod(self.model_args.sliding_window_size)
349+
]
362350
match self.model_args.attn_mask_type:
363351
case "causal":
364352
B = 1
353+
basic_mask_mods.append(get_causal_mask_mod())
354+
sliding_window_mask_mods.append(get_causal_mask_mod())
365355
case "block_causal":
366356
B = input_batch.shape[0]
367-
mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id))
357+
basic_mask_mods.append(
358+
get_document_mask_mod(input_batch, tokenizer.eos_id)
359+
)
360+
sliding_window_mask_mods.append(
361+
get_document_mask_mod(input_batch, tokenizer.eos_id)
362+
)
368363
case _:
369364
raise ValueError(
370365
f"Unknown attention mask type: {self.model_args.attn_mask_type}"
371366
)
372-
return create_attention_mask(
373-
and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1]
367+
368+
# create basic attention mask: causal or block_causal
369+
basic_mask = create_attention_mask(
370+
and_masks(*basic_mask_mods),
371+
B,
372+
None,
373+
input_batch.shape[1],
374+
input_batch.shape[1],
375+
)
376+
377+
# create sliding window mask, has to
378+
sliding_window_mask = create_attention_mask(
379+
and_masks(*sliding_window_mask_mods),
380+
B,
381+
None,
382+
input_batch.shape[1],
383+
input_batch.shape[1],
374384
)
375385

386+
return {"basic_mask": basic_mask, "sliding_window_mask": sliding_window_mask}
376387

377-
def forward(self, tokens: torch.Tensor, attention_masks: AttentionMasksType | None = None,):
388+
def forward(
389+
self,
390+
tokens: torch.Tensor,
391+
attention_masks: AttentionMasksType | None = None,
392+
):
378393
"""
379394
Forward pass for the Transformer model.
380395

torchtitan/experiments/gpt_oss/model/moe.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# LICENSE file in the root directory of this source tree.
99

1010
from typing import Callable
11+
1112
import torch
1213
from torch import nn
1314
from torch.distributed.tensor import DTensor
@@ -34,7 +35,7 @@ def wrapper(
3435
x: torch.Tensor,
3536
num_tokens_per_expert: torch.Tensor,
3637
) -> torch.Tensor:
37-
num_local_experts = w1.shape[0]
38+
num_local_experts = mlp1_weight.shape[0]
3839
ep_degree = num_tokens_per_expert.shape[0] // num_local_experts
3940

4041
input_shape, x, permuted_indices, num_tokens_per_expert = _permute(
@@ -57,6 +58,7 @@ def wrapper(
5758

5859
return wrapper
5960

61+
6062
def swiglu(x, alpha: float = 1.702, limit: float = 7.0):
6163
x_glu, x_linear = x[..., ::2], x[..., 1::2]
6264
# Clamp the input values
@@ -66,6 +68,7 @@ def swiglu(x, alpha: float = 1.702, limit: float = 7.0):
6668
# Note we add an extra bias of 1 to the linear layer
6769
return out_glu * (x_linear + 1)
6870

71+
6972
def _run_experts_for_loop(
7073
mlp1_weight: torch.Tensor,
7174
mlp1_bias: torch.Tensor,
@@ -91,10 +94,7 @@ def _run_experts_for_loop(
9194
)
9295
out_experts_splits = []
9396
for expert_idx, x_expert in enumerate(x):
94-
h = (
95-
torch.matmul(x_expert, mlp1_weight[expert_idx])
96-
+ mlp1_bias[expert_idx]
97-
)
97+
h = torch.matmul(x_expert, mlp1_weight[expert_idx]) + mlp1_bias[expert_idx]
9898
h = swiglu(h, limit=swiglu_limit)
9999
h = torch.matmul(h, mlp2_weight[expert_idx]) + mlp2_bias[expert_idx]
100100
out_experts_splits.append(h)
@@ -110,6 +110,7 @@ def _run_experts_for_loop(
110110

111111
return out
112112

113+
113114
def _run_experts_grouped_mm(
114115
mlp1_weight: torch.Tensor,
115116
mlp1_bias: torch.Tensor,
@@ -129,14 +130,6 @@ def _run_experts_grouped_mm(
129130
# fall back to regular bmm between 3D tensors
130131
assert x.dim() == 3
131132

132-
if isinstance(mlp1_weight, DTensor):
133-
mlp1_weight, mlp1_bias, mlp2_weight, mlp2_bias = (
134-
mlp1_weight.to_local(),
135-
mlp1_bias.to_local(),
136-
mlp2_weight.to_local(),
137-
mlp2_bias.to_local(),
138-
)
139-
140133
h = torch._grouped_mm(x.bfloat16(), mlp1_weight.bfloat16(), offs=offsets)
141134
if offsets is not None:
142135
b1 = mlp1_bias.repeat_interleave(num_tokens_per_expert_long, dim=0)
@@ -156,6 +149,7 @@ def _run_experts_grouped_mm(
156149

157150
return h
158151

152+
159153
class GptOssGroupedExperts(nn.Module):
160154
def __init__(
161155
self,
@@ -201,16 +195,33 @@ def forward(
201195
run_experts_fn = indices_padding_wrapper(_run_experts_grouped_mm)
202196
else:
203197
run_experts_fn = _run_experts_grouped_mm
204-
return run_experts_fn(mlp1_weight, mlp1_bias, mlp2_weight, mlp2_bias, self.swiglu_limit, x, num_tokens_per_expert)
198+
return run_experts_fn(
199+
mlp1_weight,
200+
mlp1_bias,
201+
mlp2_weight,
202+
mlp2_bias,
203+
self.swiglu_limit,
204+
x,
205+
num_tokens_per_expert,
206+
)
205207
else:
206-
return _run_experts_for_loop(mlp1_weight, mlp1_bias, mlp2_weight, mlp2_bias, self.swiglu_limit, x, num_tokens_per_expert)
208+
return _run_experts_for_loop(
209+
mlp1_weight,
210+
mlp1_bias,
211+
mlp2_weight,
212+
mlp2_bias,
213+
self.swiglu_limit,
214+
x,
215+
num_tokens_per_expert,
216+
)
207217

208218
def init_weights(self, init_std: float):
209219
nn.init.trunc_normal_(self.mlp1_weight, mean=0.0, std=init_std)
210220
nn.init.trunc_normal_(self.mlp1_bias, mean=0.0, std=init_std)
211221
nn.init.trunc_normal_(self.mlp2_weight, mean=0.0, std=init_std)
212222
nn.init.trunc_normal_(self.mlp2_bias, mean=0.0, std=init_std)
213223

224+
214225
class GptOssMoE(MoE):
215226
"""GptOss MoE implementation that inherits from the base MoE class."""
216227

0 commit comments

Comments
 (0)