Skip to content

Commit

Permalink
Update on "[BE][5/n] simplify pp vs. non-pp set up"
Browse files Browse the repository at this point in the history
This PR refactors the PP vs. non-PP setup in `train.py`:
- moves `build_pipeline_schedule ` into `pipeline_llama` which reduces the interface for PP in `train.py`
- refactors the set up flow, so that we only have two main if-else for PP vs. non-PP, one in setup phase, the other in training phase.
- I think it's already clear to read or copy-paste, and it's not necessary to create separate sub-functions to hold the code.

This PR also removes unnecessary module returns in `parallelize_llama`, as we are modifying module in-place. Note that torch.compile and AC require returning and reassigning the module. But since we are doing per-block compile and AC, we achieve that in-place for the whole model by
```
transformer_block = compile/AC(transformer_block)
model.layers.register_module(layer_id, transformer_block)
``` 

[ghstack-poisoned]
  • Loading branch information
tianyu-l committed Aug 8, 2024
1 parent f58ca70 commit ff53569
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def main(job_config: JobConfig):
f"{color.red}size: {model_param_count:,} total parameters{color.reset}"
)

# loss function to be shared by Pipeline Parallel and spmd training
# loss function to be shared by Pipeline Parallel and SPMD training
def loss_fn(pred, labels):
return torch.nn.functional.cross_entropy(
pred.flatten(0, 1), labels.flatten(0, 1)
Expand All @@ -150,7 +150,7 @@ def loss_fn(pred, labels):
# We need to iterate through model_parts to apply SPMD parallelisms, compilation,
# optimizer, and checkpointing
for m in model_parts:
# apply spmd-style PT-D techniques
# apply SPMD-style PT-D techniques
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config)

# In PP, we cannot call init_weights directly because some layers are missing.
Expand Down Expand Up @@ -269,7 +269,7 @@ def loss_fn(pred, labels):
optimizers.zero_grad()

if parallel_dims.pp_enabled:
# pipeline parallel forward / backward inside step() call
# Pipeline Parallel forward / backward inside step() call
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1

with train_context():
Expand Down

0 comments on commit ff53569

Please sign in to comment.