-
Notifications
You must be signed in to change notification settings - Fork 678
Hybrid Sharding in Full Distributed FT #2415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2415
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
After staring at the FSDP 2 source for a while, one thing that caught my eye is that We might be able to alternate this between Additionally |
joecummings
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can confirm that eventually we want to add PP & CP, but for now can we limit the ParallelDims class to only those we support right now?
Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
…ported parallel schemes. Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
Done! Image below shows results and throughput uplift on 2 nodes of 8x H100s for LLaMA-3.1 8B using 8 gradient accumulation steps. With larger models or higher accumulation steps, uplift should be even larger. The runs are seeded so I'm not 100% sure why grad norms and losses aren't identical. I assume this is just due to not running with cuda deterministic mode. Let me know if you have any other feedback. |
|
@joecummings I've removed the WIP title. Let me know if there's anything required here :) |
Reviewing today! |
|
Thanks @joecummings ! I also ran some tests with deterministic mode, but still see slight divergence from the early steps.
Curious if you have any intuition as to why this could be the case, as it's not obvious to me why delaying the collective should cause any change. Secondly wondering if this should be cause for concern. Full config below in case that's useful! This was with the batch_size: 4
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: '00004'
model_type: LLAMA3
output_dir: ${output_dir}
recipe_checkpoint: null
checkpoint_dir: inputs/model
clip_grad_norm: 1.0
compile: true
cudnn_deterministic_mode: true
custom_sharded_layers: []
dataset:
_component_: torchtune.datasets.chat_dataset
source: parquet
conversation_column: messages
conversation_style: openai
split: train
packed: true
train_on_input: true
data_dir: inputs/dataset
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: true
epochs: 1
fsdp_cpu_offload: false
gradient_accumulation_steps: 8
log_every_n_steps: 1
log_peak_memory_stats: true
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: 20
metric_logger:
_component_: torchtune.training.metric_logging.MLFlowLogger
minimize_all_reduces: false
optimizer:
_component_: torchao.prototype.low_bit_optim.AdamW8bit
lr: 5.0e-05
lr_scheduler:
_component_: custom_schedulers.get_constant_schedule_with_warmup
num_warmup_steps: 10
optimizer_in_bwd: false
output_dir: outputs
resume_from_checkpoint: false
seed: 100
shuffle: true
tokenizer:
max_seq_len: 4096
path: inputs/model/original/tokenizer.model
_component_: torchtune.models.llama3.llama3_tokenizer
model:
_component_: torchtune.models.llama3_1.llama3_1_8b
dp_shard: 8
dp_replicate: 2 |
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
|
@nathan-az This looks good! Do you mind if I push some changes just to make the API slightly more in line with the function-based getters we use in tune? |
|
@joecummings I don't mind at all :) |
|
@joecummings I noticed that since my work, I won't add any further commits so I don't cause you any conflicts, but I think it's worth including in No pressure - I can add it in a separate PR in future (and use it to fix the TPS reporting, which is currently bugged in the TP case) |
|
Testing block of non-FSDP minimizing all reduces: |
joecummings
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM ;)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nathan-az Thanks for adding this functionality!






Context
What is the purpose of this PR? Is it to
Changelog
Essentially implements hybrid sharded data parallel as is done in torchtitan using the
ParallelDimsclass. This changes no public APIs but allows explicit setting ofdp_replicaanddp_shardin the distributed full finetuning recipe. Not including these defaults to the previous behaviour, calculatingdp_shardfrom the world size andtensor_parallel_dim.This pattern of setting up the device mesh preps
torchtunefor potential future schemes such as context/pipeline parallel.Gradient Accumulation Optimisation
One thing of note - it appears from runtime profiling that by default, gradient reduce-scatter may be occurring across all devices in the DP group, not just within the replica.
The above runs each had 4 gradient accumulation steps. My expectation is that
bwd_timeshould remain constant, butoptimto bear the increased cost (but occur once per optimiser step, not per backward step).Minimising reduce-alls
I have updated this PR to add the option for users to minimise the number of reduce-alls. Currently this is behind a flag, but once it is validated it may be good to have it as default behaviour in the HSDP case. If someone more familiar with distribution can validate whether the way I have done this is correct (or even suggest a more optimal way) that would be great.
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install)pytest testspytest tests -m integration_testUX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example