|
22 | 22 | from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp |
23 | 23 | from torchtitan.experiments.llama4.infra.parallelize import ( |
24 | 24 | apply_fsdp, |
25 | | - apply_moe_ep_tp, |
26 | 25 | ) |
27 | 26 |
|
28 | 27 | from torchtitan.tools.logging import logger |
@@ -231,30 +230,30 @@ def apply_non_moe_tp( |
231 | 230 | layer_plan = { |
232 | 231 | "attention_norm": SequenceParallel(), |
233 | 232 | "attention": prepare_module_input( |
234 | | - input_layouts=(Shard(1), Replicate()), |
235 | | - desired_input_layouts=(Replicate(), Replicate()), |
| 233 | + input_layouts=(Shard(1), None), |
| 234 | + desired_input_layouts=(Replicate(), None), |
236 | 235 | ), |
237 | | - "attention.wq": colwise_parallel(use_local_output=False), |
238 | | - "attention.wk": colwise_parallel(use_local_output=False), |
239 | | - "attention.wv": colwise_parallel(use_local_output=False), |
| 236 | + "attention.wq": colwise_parallel(), |
| 237 | + "attention.wk": colwise_parallel(), |
| 238 | + "attention.wv": colwise_parallel(), |
240 | 239 | "attention.wo": rowwise_parallel(output_layouts=Shard(1)), |
241 | 240 | "ffn_norm": SequenceParallel(), |
242 | 241 | } |
243 | 242 |
|
| 243 | + # shard attention.sinks across heads |
| 244 | + # TODO(jianiw): Fix the sink implementation for nn.Parameter |
| 245 | + attn = transformer_block.attention |
| 246 | + attn.register_parameter( |
| 247 | + "sinks", |
| 248 | + nn.Parameter(distribute_tensor(attn.sinks, tp_mesh, [Shard(0)])), |
| 249 | + ) |
| 250 | + |
244 | 251 | parallelize_module( |
245 | 252 | module=transformer_block, |
246 | 253 | device_mesh=tp_mesh, |
247 | 254 | parallelize_plan=layer_plan, |
248 | 255 | ) |
249 | 256 |
|
250 | | - # shard attention.sinks across heads |
251 | | - # TODO(jianiw): Fix the sink implementation |
252 | | - # attn = transformer_block.attention |
253 | | - # attn.register_parameter( |
254 | | - # "sinks", |
255 | | - # nn.Parameter(distribute_tensor(attn.sinks, tp_mesh, [Replicate()])), |
256 | | - # ) |
257 | | - |
258 | 257 | if enable_async_tp: |
259 | 258 | from torch.distributed._symmetric_memory import enable_symm_mem_for_group |
260 | 259 |
|
|
0 commit comments