Skip to content

Conversation

@nathan-az
Copy link
Collaborator

@nathan-az nathan-az commented Feb 20, 2025

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Changelog

Essentially implements hybrid sharded data parallel as is done in torchtitan using the ParallelDims class. This changes no public APIs but allows explicit setting of dp_replica and dp_shard in the distributed full finetuning recipe. Not including these defaults to the previous behaviour, calculating dp_shard from the world size and tensor_parallel_dim.

This pattern of setting up the device mesh preps torchtune for potential future schemes such as context/pipeline parallel.

Gradient Accumulation Optimisation

One thing of note - it appears from runtime profiling that by default, gradient reduce-scatter may be occurring across all devices in the DP group, not just within the replica.

image

The above runs each had 4 gradient accumulation steps. My expectation is that bwd_time should remain constant, but optim to bear the increased cost (but occur once per optimiser step, not per backward step).

Minimising reduce-alls

I have updated this PR to add the option for users to minimise the number of reduce-alls. Currently this is behind a flag, but once it is validated it may be good to have it as default behaviour in the HSDP case. If someone more familiar with distribution can validate whether the way I have done this is correct (or even suggest a more optimal way) that would be great.

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • [ x ] run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • [ x ] add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 20, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2415

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 20, 2025
@nathan-az nathan-az marked this pull request as draft February 20, 2025 10:55
@nathan-az
Copy link
Collaborator Author

nathan-az commented Feb 20, 2025

After staring at the FSDP 2 source for a while, one thing that caught my eye is that is_last_backward defaults to True and appears to trigger additional hooks.

We might be able to alternate this between True and False around the optimizer steps using the public set_is_last_backward. I'm not sure how straightforward it is to iterate through our layers/modules to find those that were sharded to set this, but it may be promising.

Additionally set_requires_all_reduce looks interesting.

@nathan-az
Copy link
Collaborator Author

image

Above is 4 nodes with gradient_accumulation_steps=4 vs 1 node with gradient_accumulation_steps=16 so effective batch size is equal. After deferring the all-reduce with the FSDP2 APIs, loss still looks healthy. We see periodicity of 4 in the bwd_time which is expected. The scaling of TPS per device is also much better.

@nathan-az nathan-az marked this pull request as ready for review February 21, 2025 20:38
Copy link
Member

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can confirm that eventually we want to add PP & CP, but for now can we limit the ParallelDims class to only those we support right now?

Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
…ported parallel schemes.

Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
@nathan-az
Copy link
Collaborator Author

nathan-az commented Mar 13, 2025

I can confirm that eventually we want to add PP & CP, but for now can we limit the ParallelDims class to only those we support right now?

Done! Image below shows results and throughput uplift on 2 nodes of 8x H100s for LLaMA-3.1 8B using 8 gradient accumulation steps. With larger models or higher accumulation steps, uplift should be even larger.

image

The runs are seeded so I'm not 100% sure why grad norms and losses aren't identical. I assume this is just due to not running with cuda deterministic mode.

Let me know if you have any other feedback.

@nathan-az nathan-az changed the title (WIP/RFC) Hybrid Sharding Hybrid Sharding in Full Distributed FT Mar 15, 2025
@nathan-az
Copy link
Collaborator Author

@joecummings I've removed the WIP title. Let me know if there's anything required here :)

@joecummings
Copy link
Member

@joecummings I've removed the WIP title. Let me know if there's anything required here :)

Reviewing today!

@nathan-az
Copy link
Collaborator Author

Thanks @joecummings !

I also ran some tests with deterministic mode, but still see slight divergence from the early steps.

image

Curious if you have any intuition as to why this could be the case, as it's not obvious to me why delaying the collective should cause any change. Secondly wondering if this should be cause for concern.

Full config below in case that's useful! This was with the full_finetune_distributed from this branch, alternating only cudnn_deterministic_mode.

batch_size: 4
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_files:
    filename_format: model-{}-of-{}.safetensors
    max_filename: '00004'
  model_type: LLAMA3
  output_dir: ${output_dir}
  recipe_checkpoint: null
  checkpoint_dir: inputs/model
clip_grad_norm: 1.0
compile: true
cudnn_deterministic_mode: true
custom_sharded_layers: []
dataset:
  _component_: torchtune.datasets.chat_dataset
  source: parquet
  conversation_column: messages
  conversation_style: openai
  split: train
  packed: true
  train_on_input: true
  data_dir: inputs/dataset
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: true
epochs: 1
fsdp_cpu_offload: false
gradient_accumulation_steps: 8
log_every_n_steps: 1
log_peak_memory_stats: true
loss:
  _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: 20
metric_logger:
  _component_: torchtune.training.metric_logging.MLFlowLogger
minimize_all_reduces: false
optimizer:
  _component_: torchao.prototype.low_bit_optim.AdamW8bit
  lr: 5.0e-05
lr_scheduler:
  _component_: custom_schedulers.get_constant_schedule_with_warmup
  num_warmup_steps: 10
optimizer_in_bwd: false
output_dir: outputs
resume_from_checkpoint: false
seed: 100
shuffle: true
tokenizer:
  max_seq_len: 4096
  path: inputs/model/original/tokenizer.model
  _component_: torchtune.models.llama3.llama3_tokenizer
model:
  _component_: torchtune.models.llama3_1.llama3_1_8b
dp_shard: 8
dp_replicate: 2

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@joecummings
Copy link
Member

@nathan-az This looks good! Do you mind if I push some changes just to make the API slightly more in line with the function-based getters we use in tune?

@nathan-az
Copy link
Collaborator Author

@joecummings I don't mind at all :)

@nathan-az
Copy link
Collaborator Author

@joecummings I noticed that since my work, torchtitan added a nice property non_data_parallel_size.

I won't add any further commits so I don't cause you any conflicts, but I think it's worth including in torchtune.

No pressure - I can add it in a separate PR in future (and use it to fix the TPS reporting, which is currently bugged in the TP case)

@joecummings
Copy link
Member

Testing block of non-FSDP minimizing all reduces:

(joe-torchtune) [jrcummings@devvm4767.pnb0 ~/projects/joe-torchtune (hsdp)]$ CUDA_VISIBLE_DEVICES=3,4,5,6 tune run --nproc-per-node 4 full_finetune_distributed --config llama3_1/8B_full metric_logger._component_=torchtune.training.metric_logging.WandBLogger metric_logger.project=hsdp metric_logger.name=minimize-all-reduces-tp-4 tensor_parallel_dim=4 batch_size=16 seed=26 max_steps_per_epoch=10 gradient_accumulation_steps=1 minimize_all_reduces=True
Running with torchrun...
W0331 07:26:38.914000 3474698 site-packages/torch/distributed/run.py:793]
W0331 07:26:38.914000 3474698 site-packages/torch/distributed/run.py:793] *****************************************
W0331 07:26:38.914000 3474698 site-packages/torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0331 07:26:38.914000 3474698 site-packages/torch/distributed/run.py:793] *****************************************

model:
  _component_: torchtune.models.llama3_1.llama3_1_8b
optimizer:
  _component_: torch.optim.AdamW
  fused: true
  lr: 2.0e-05
optimizer_in_bwd: false
output_dir: /tmp/torchtune/llama3_1_8B/full
profiler:
  _component_: torchtune.training.setup_torch_profiler
  active_steps: 2
  cpu: true
  cuda: true
  enabled: false
  num_cycles: 1
  output_dir: /tmp/torchtune/llama3_1_8B/full/profiling_outputs
  profile_memory: false
  record_shapes: true
  wait_steps: 5
  warmup_steps: 3
  with_flops: false
  with_stack: false
resume_from_checkpoint: false
seed: 26
shuffle: true
tensor_parallel_dim: 4
tensor_parallel_plan:
  _component_: torchtune.models.llama3.base_llama_tp_plan
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  max_seq_len: null
  path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model


  _component_: torchtune.models.llama3_1.llama3_1_8b
optimizer:
  _component_: torch.optim.AdamW
  fused: true
  lr: 2.0e-05
optimizer_in_bwd: false
output_dir: /tmp/torchtune/llama3_1_8B/full
profiler:
  _component_: torchtune.training.setup_torch_profiler
  active_steps: 2
  cpu: true
  cuda: true
  enabled: false
  num_cycles: 1
  output_dir: /tmp/torchtune/llama3_1_8B/full/profiling_outputs
  profile_memory: false
  record_shapes: true
  wait_steps: 5
  warmup_steps: 3
  with_flops: false
  with_stack: false
resume_from_checkpoint: false
seed: 26
shuffle: true
tensor_parallel_dim: 4
tensor_parallel_plan:
  _component_: torchtune.models.llama3.base_llama_tp_plan
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  max_seq_len: null
  path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model


  _component_: torchtune.models.llama3_1.llama3_1_8b
optimizer:
  _component_: torch.optim.AdamW
  fused: true
  lr: 2.0e-05
optimizer_in_bwd: false
output_dir: /tmp/torchtune/llama3_1_8B/full
profiler:
  _component_: torchtune.training.setup_torch_profiler
  active_steps: 2
  cpu: true
  cuda: true
  enabled: false
  num_cycles: 1
  output_dir: /tmp/torchtune/llama3_1_8B/full/profiling_outputs
  profile_memory: false
  record_shapes: true
  wait_steps: 5
  warmup_steps: 3
  with_flops: false
  with_stack: false
resume_from_checkpoint: false
seed: 26
shuffle: true
tensor_parallel_dim: 4
tensor_parallel_plan:
  _component_: torchtune.models.llama3.base_llama_tp_plan
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  max_seq_len: null
  path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model


  _component_: torchtune.models.llama3_1.llama3_1_8b
optimizer:
  _component_: torch.optim.AdamW
  fused: true
  lr: 2.0e-05
optimizer_in_bwd: false
output_dir: /tmp/torchtune/llama3_1_8B/full
profiler:
  _component_: torchtune.training.setup_torch_profiler
  active_steps: 2
  cpu: true
  cuda: true
  enabled: false
  num_cycles: 1
  output_dir: /tmp/torchtune/llama3_1_8B/full/profiling_outputs
  profile_memory: false
  record_shapes: true
  wait_steps: 5
  warmup_steps: 3
  with_flops: false
  with_stack: false
resume_from_checkpoint: false
seed: 26
shuffle: true
tensor_parallel_dim: 4
tensor_parallel_plan:
  _component_: torchtune.models.llama3.base_llama_tp_plan
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  max_seq_len: null
  path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model

[rank3]: Traceback (most recent call last):
[rank3]:   File "/home/jrcummings/projects/joe-torchtune/recipes/full_finetune_distributed.py", line 977, in <module>
[rank3]:     sys.exit(recipe_main())
[rank3]:              ^^^^^^^^^^^^^
[rank3]:   File "/home/jrcummings/projects/joe-torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank3]:     sys.exit(recipe_main(conf))
[rank3]:              ^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/jrcummings/projects/joe-torchtune/recipes/full_finetune_distributed.py", line 970, in recipe_main
[rank3]:     recipe = FullFinetuneRecipeDistributed(cfg=cfg)
[rank3]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/jrcummings/projects/joe-torchtune/recipes/full_finetune_distributed.py", line 184, in __init__
[rank3]:     raise ValueError("``minimize_all_reduces`` is not supported without FSDP")
[rank3]: ValueError: ``minimize_all_reduces`` is not supported without FSDP
[rank1]: Traceback (most recent call last):
[rank1]:   File "/home/jrcummings/projects/joe-torchtune/recipes/full_finetune_distributed.py", line 977, in <module>
[rank1]:     sys.exit(recipe_main())
[rank1]:              ^^^^^^^^^^^^^
[rank1]:   File "/home/jrcummings/projects/joe-torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank1]:     sys.exit(recipe_main(conf))
[rank1]:              ^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jrcummings/projects/joe-torchtune/recipes/full_finetune_distributed.py", line 970, in recipe_main
[rank1]:     recipe = FullFinetuneRecipeDistributed(cfg=cfg)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jrcummings/projects/joe-torchtune/recipes/full_finetune_distributed.py", line 184, in __init__
[rank1]:     raise ValueError("``minimize_all_reduces`` is not supported without FSDP")
[rank1]: ValueError: ``minimize_all_reduces`` is not supported without FSDP
[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/jrcummings/projects/joe-torchtune/recipes/full_finetune_distributed.py", line 977, in <module>
[rank0]:     sys.exit(recipe_main())
[rank0]:              ^^^^^^^^^^^^^
[rank0]:   File "/home/jrcummings/projects/joe-torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank0]:     sys.exit(recipe_main(conf))
[rank0]:              ^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/jrcummings/projects/joe-torchtune/recipes/full_finetune_distributed.py", line 970, in recipe_main
[rank0]:     recipe = FullFinetuneRecipeDistributed(cfg=cfg)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/jrcummings/projects/joe-torchtune/recipes/full_finetune_distributed.py", line 184, in __init__
[rank0]:     raise ValueError("``minimize_all_reduces`` is not supported without FSDP")
[rank0]: ValueError: ``minimize_all_reduces`` is not supported without FSDP
[rank2]: Traceback (most recent call last):
[rank2]:   File "/home/jrcummings/projects/joe-torchtune/recipes/full_finetune_distributed.py", line 977, in <module>
[rank2]:     sys.exit(recipe_main())
[rank2]:              ^^^^^^^^^^^^^
[rank2]:   File "/home/jrcummings/projects/joe-torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank2]:     sys.exit(recipe_main(conf))
[rank2]:              ^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jrcummings/projects/joe-torchtune/recipes/full_finetune_distributed.py", line 970, in recipe_main
[rank2]:     recipe = FullFinetuneRecipeDistributed(cfg=cfg)
[rank2]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jrcummings/projects/joe-torchtune/recipes/full_finetune_distributed.py", line 184, in __init__
[rank2]:     raise ValueError("``minimize_all_reduces`` is not supported without FSDP")
[rank2]: ValueError: ``minimize_all_reduces`` is not supported without FSDP
[rank0]:[W331 07:26:48.502917420 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present,  but this warning has only been added since PyTorch 2.4 (function operator())
W0331 07:26:49.145000 3474698 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 3475945 closing signal SIGTERM
W0331 07:26:49.146000 3474698 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 3475946 closing signal SIGTERM
W0331 07:26:49.147000 3474698 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 3475947 closing signal SIGTERM
E0331 07:26:49.863000 3474698 site-packages/torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: 1) local_rank: 3 (pid: 3475948) of binary: /home/jrcummings/.conda/envs/joe-torchtune/bin/python
Traceback (most recent call last):
  File "/home/jrcummings/.conda/envs/joe-torchtune/bin/tune", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/home/jrcummings/projects/joe-torchtune/torchtune/_cli/tune.py", line 52, in main
    parser.run(args)
  File "/home/jrcummings/projects/joe-torchtune/torchtune/_cli/tune.py", line 46, in run
    args.func(args)
  File "/home/jrcummings/projects/joe-torchtune/torchtune/_cli/run.py", line 212, in _run_cmd
    self._run_distributed(args, is_builtin=is_builtin)
  File "/home/jrcummings/.conda/envs/joe-torchtune/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/home/jrcummings/projects/joe-torchtune/torchtune/_cli/run.py", line 101, in _run_distributed
    run(args)
  File "/home/jrcummings/.conda/envs/joe-torchtune/lib/python3.11/site-packages/torch/distributed/run.py", line 910, in run
    elastic_launch(
  File "/home/jrcummings/.conda/envs/joe-torchtune/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 138, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jrcummings/.conda/envs/joe-torchtune/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 269, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
/home/jrcummings/projects/joe-torchtune/recipes/full_finetune_distributed.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2025-03-31_07:26:49
  host      : devvm4767.pnb0.facebook.com
  rank      : 3 (local_rank: 3)
  exitcode  : 1 (pid: 3475948)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

@joecummings
Copy link
Member

Unfortunately in some quick testing, I saw a memory spike when using grad acc > 1, HSDP, and minimize all reduces. In order to quickly merge this PR, I separated out the ParallelDims work w/ the minimize all reduces work. I'll put up another PR soon that we can more thoroughly test
Screenshot 2025-03-31 at 12 59 45 PM

@joecummings
Copy link
Member

Screenshot 2025-03-31 at 1 22 53 PM

Loss between HSDP2x2 and FSDP4 are expected not to match b/c data is only replicated across 2 machines in 2x2 versus 4 in FSDP4 so there's simply not the same data going through the forward at each step. You can see in the screenshot below that running fsdp on 2 devices yields the same loss.

Screenshot 2025-03-31 at 1 25 48 PM

Copy link
Member

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ;)

Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nathan-az Thanks for adding this functionality!

@joecummings joecummings merged commit d3ab3b7 into meta-pytorch:main Mar 31, 2025
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants