Skip to content

Commit f7b9d84

Browse files
committed
fix ep
1 parent da672b2 commit f7b9d84

File tree

2 files changed

+31
-27
lines changed

2 files changed

+31
-27
lines changed

torchtitan/experiments/gpt_oss/infra/parallelize.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from torchtitan.models.llama4.infra.parallelize import apply_fsdp
3030
from torchtitan.tools.logging import logger
3131

32-
from .expert_parallel import GptossExpertTensorParallel
32+
from .expert_parallel import GptossExpertTensorParallel, GptossTensorParallel
3333

3434

3535
# for selective op activation checkpointing
@@ -66,23 +66,22 @@ def parallelize_gptoss(
6666
raise NotImplementedError("CP support for FlexAttention is still in progress.")
6767

6868
if parallel_dims.tp_enabled:
69-
if job_config.parallelism.enable_async_tensor_parallel:
70-
raise NotImplementedError(
71-
"Currently, async TP is not tested for gptoss. \
72-
torch.compile is not supported yet, which is required for async TP."
73-
)
69+
if (
70+
job_config.parallelism.enable_async_tensor_parallel
71+
and not model_compile_enabled
72+
):
73+
raise RuntimeError("Async TP requires torch.compile")
7474

7575
enable_float8_linear = "float8" in job_config.model.converters
76-
float8_is_rowwise = job_config.float8.recipe_name in (
76+
float8_is_rowwise = job_config.quantize.linear.float8.recipe_name in (
7777
"rowwise",
7878
"rowwise_with_gw_hp",
7979
)
8080

81+
# For now, float8 all-gather with TP is only supported for tensorwise
82+
# float8 scaling recipes. For rowwise recipes, we use regular TP and
83+
# all-gather happens in high precision.
8184
enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise
82-
if enable_float8_tensorwise_tp:
83-
raise NotImplementedError(
84-
"Currently, float8 tensorwise TP is not tested for gptoss"
85-
)
8685

8786
apply_non_moe_tp(
8887
model,
@@ -222,13 +221,13 @@ def apply_non_moe_tp(
222221
layer_plan = {
223222
"attention_norm": SequenceParallel(),
224223
"attention": prepare_module_input(
225-
input_layouts=(Shard(1), None),
226-
desired_input_layouts=(Replicate(), None),
224+
input_layouts=(Shard(1), Replicate(), None),
225+
desired_input_layouts=(Replicate(), Replicate(), None),
227226
),
228-
"attention.wq": colwise_parallel(),
229-
"attention.wk": colwise_parallel(),
230-
"attention.wv": colwise_parallel(),
231-
"attention.attn": prepare_module_output(
227+
"attention.wq": colwise_parallel(use_local_output=False),
228+
"attention.wk": colwise_parallel(use_local_output=False),
229+
"attention.wv": colwise_parallel(use_local_output=False),
230+
"attention.inner_attention": prepare_module_output(
232231
output_layouts=(Shard(1), Shard(1)),
233232
desired_output_layouts=(Shard(1), Shard(1)),
234233
use_local_output=False,
@@ -304,11 +303,11 @@ def apply_moe_ep_tp(
304303
if ep_mesh is None:
305304
experts_mesh = tp_mesh
306305
# input Replicate, output Partial
307-
experts_plan = TensorParallel()
306+
experts_plan = GptossTensorParallel()
308307
elif tp_mesh is None:
309308
experts_mesh = ep_mesh
310309
# input / output sharding on the batch / tokens dim
311-
experts_plan = GptossExpertParallel()
310+
experts_plan = ExpertParallel()
312311
elif etp_enabled:
313312
experts_mesh = ep_tp_mesh
314313
experts_plan = GptossExpertTensorParallel(tp_mesh=tp_mesh, ep_mesh=ep_mesh)

torchtitan/models/attention.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -188,24 +188,29 @@ def blocked_mask_mod(
188188

189189

190190
def get_sliding_window_mask_mod(window_size: int) -> _mask_mod_signature:
191-
"""Creates a sliding window mask that only attends to tokens within the fix window size.
191+
"""Creates a sliding window mask that only attends to tokens within a fixed window size.
192192
193+
This implements causal sliding window attention where each token can only attend to:
194+
- Itself (current token)
195+
- Up to `window_size - 1` previous tokens
193196
Args:
194-
batch: Input batch tensor with shape [b, s, h, d]
195-
eos_id: End-of-sequence token ID that marks document boundaries
197+
window_size: The maximum number of tokens to attend to (including current token).
198+
Must be >= 1. A window_size of 1 means attend only to self.
196199
197200
Returns:
198-
A mask modifier function that implements document-level masking.
201+
A mask modifier function that implements causal sliding window masking.
199202
"""
200-
# b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor
203+
201204
def sliding_window_mod(
202205
b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor
203206
) -> torch.Tensor:
204-
# causal within window
205-
keep = (kv_idx <= q_idx) & (q_idx - kv_idx <= window_size)
206-
return keep
207+
# Causal mask: can only attend to current or previous tokens
208+
# Window mask: can only attend within the window
209+
# q_idx - kv_idx < window_size ensures we look at most window_size-1 tokens back
210+
return (kv_idx <= q_idx) & (q_idx - kv_idx < window_size)
207211

208212
sliding_window_mod.__name__ = f"sliding_window_mod_window_size_{window_size}"
213+
209214
return sliding_window_mod
210215

211216

0 commit comments

Comments
 (0)