|
29 | 29 | from torchtitan.models.llama4.infra.parallelize import apply_fsdp |
30 | 30 | from torchtitan.tools.logging import logger |
31 | 31 |
|
32 | | -from .expert_parallel import GptossExpertTensorParallel |
| 32 | +from .expert_parallel import GptossExpertTensorParallel, GptossTensorParallel |
33 | 33 |
|
34 | 34 |
|
35 | 35 | # for selective op activation checkpointing |
@@ -66,23 +66,22 @@ def parallelize_gptoss( |
66 | 66 | raise NotImplementedError("CP support for FlexAttention is still in progress.") |
67 | 67 |
|
68 | 68 | if parallel_dims.tp_enabled: |
69 | | - if job_config.parallelism.enable_async_tensor_parallel: |
70 | | - raise NotImplementedError( |
71 | | - "Currently, async TP is not tested for gptoss. \ |
72 | | - torch.compile is not supported yet, which is required for async TP." |
73 | | - ) |
| 69 | + if ( |
| 70 | + job_config.parallelism.enable_async_tensor_parallel |
| 71 | + and not model_compile_enabled |
| 72 | + ): |
| 73 | + raise RuntimeError("Async TP requires torch.compile") |
74 | 74 |
|
75 | 75 | enable_float8_linear = "float8" in job_config.model.converters |
76 | | - float8_is_rowwise = job_config.float8.recipe_name in ( |
| 76 | + float8_is_rowwise = job_config.quantize.linear.float8.recipe_name in ( |
77 | 77 | "rowwise", |
78 | 78 | "rowwise_with_gw_hp", |
79 | 79 | ) |
80 | 80 |
|
| 81 | + # For now, float8 all-gather with TP is only supported for tensorwise |
| 82 | + # float8 scaling recipes. For rowwise recipes, we use regular TP and |
| 83 | + # all-gather happens in high precision. |
81 | 84 | enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise |
82 | | - if enable_float8_tensorwise_tp: |
83 | | - raise NotImplementedError( |
84 | | - "Currently, float8 tensorwise TP is not tested for gptoss" |
85 | | - ) |
86 | 85 |
|
87 | 86 | apply_non_moe_tp( |
88 | 87 | model, |
@@ -222,13 +221,13 @@ def apply_non_moe_tp( |
222 | 221 | layer_plan = { |
223 | 222 | "attention_norm": SequenceParallel(), |
224 | 223 | "attention": prepare_module_input( |
225 | | - input_layouts=(Shard(1), None), |
226 | | - desired_input_layouts=(Replicate(), None), |
| 224 | + input_layouts=(Shard(1), Replicate(), None), |
| 225 | + desired_input_layouts=(Replicate(), Replicate(), None), |
227 | 226 | ), |
228 | | - "attention.wq": colwise_parallel(), |
229 | | - "attention.wk": colwise_parallel(), |
230 | | - "attention.wv": colwise_parallel(), |
231 | | - "attention.attn": prepare_module_output( |
| 227 | + "attention.wq": colwise_parallel(use_local_output=False), |
| 228 | + "attention.wk": colwise_parallel(use_local_output=False), |
| 229 | + "attention.wv": colwise_parallel(use_local_output=False), |
| 230 | + "attention.inner_attention": prepare_module_output( |
232 | 231 | output_layouts=(Shard(1), Shard(1)), |
233 | 232 | desired_output_layouts=(Shard(1), Shard(1)), |
234 | 233 | use_local_output=False, |
@@ -304,11 +303,11 @@ def apply_moe_ep_tp( |
304 | 303 | if ep_mesh is None: |
305 | 304 | experts_mesh = tp_mesh |
306 | 305 | # input Replicate, output Partial |
307 | | - experts_plan = TensorParallel() |
| 306 | + experts_plan = GptossTensorParallel() |
308 | 307 | elif tp_mesh is None: |
309 | 308 | experts_mesh = ep_mesh |
310 | 309 | # input / output sharding on the batch / tokens dim |
311 | | - experts_plan = GptossExpertParallel() |
| 310 | + experts_plan = ExpertParallel() |
312 | 311 | elif etp_enabled: |
313 | 312 | experts_mesh = ep_tp_mesh |
314 | 313 | experts_plan = GptossExpertTensorParallel(tp_mesh=tp_mesh, ep_mesh=ep_mesh) |
|
0 commit comments