From 4693748f136f2d0f88146e48b48f1769868584ed Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Fri, 8 Jul 2022 15:14:37 -0700 Subject: [PATCH 1/2] Update FSDP advanced training - Include in index.rst - Include in what's new - Limit width to 80 chars --- index.rst | 11 +- .../FSDP_adavnced_tutorial.rst | 200 +++++++++++++----- 2 files changed, 161 insertions(+), 50 deletions(-) diff --git a/index.rst b/index.rst index 90264d62acd..f1448183ba0 100644 --- a/index.rst +++ b/index.rst @@ -5,6 +5,7 @@ What's new in PyTorch tutorials? * `Introduction to TorchRec `__ * `Getting Started with Fully Sharded Data Parallel (FSDP) `__ +* `Advanced model training with Fully Sharded Data Parallel (FSDP) `__ * `Grokking PyTorch Intel CPU Performance from First Principles `__ * `Customize Process Group Backends Using Cpp Extensions `__ * `Forward-mode Automatic Differentiation `__ (added functorch API capabilities) @@ -579,10 +580,17 @@ What's new in PyTorch tutorials? .. customcarditem:: :header: Getting Started with Fully Sharded Data Parallel(FSDP) :card_description: Learn how to train models with Fully Sharded Data Parallel package. - :image: _static/img/thumbnails/cropped/Getting Started with FSDP.png + :image: _static/img/thumbnails/cropped/Getting-Started-with-FSDP.png :link: intermediate/FSDP_tutorial.html :tags: Parallel-and-Distributed-Training +.. customcarditem:: + :header: Advanced Model Training with Fully Sharded Data Parallel (FSDP) + :card_description: Explore advanced model training with Fully Sharded Data Parallel package. + :image: _static/img/thumbnails/cropped/Getting-Started-with-FSDP.png + :link: intermediate/FSDP_adavnced_tutorial.html + :tags: Parallel-and-Distributed-Training + .. Mobile .. customcarditem:: @@ -857,6 +865,7 @@ Additional Resources intermediate/ddp_tutorial intermediate/dist_tuto intermediate/FSDP_tutorial + intermediate/FSDP_adavnced_tutorial intermediate/process_group_cpp_extension_tutorial intermediate/rpc_tutorial intermediate/rpc_param_server_tutorial diff --git a/intermediate_source/FSDP_adavnced_tutorial.rst b/intermediate_source/FSDP_adavnced_tutorial.rst index 1adbf972240..5197da16f10 100644 --- a/intermediate_source/FSDP_adavnced_tutorial.rst +++ b/intermediate_source/FSDP_adavnced_tutorial.rst @@ -1,19 +1,34 @@ -Advanced Fully Sharded Data Parallel(FSDP) Tutorial -===================================================== +Advanced Model Training with Fully Sharded Data Parallel (FSDP) +=============================================================== -**Author**: `Hamid Shojanazeri `__, `Less Wright `__, `Rohan Varma `__, `Yanli Zhao `__ +**Author**: `Hamid Shojanazeri `__, `Less +Wright `__, `Rohan Varma +`__, `Yanli Zhao +`__ -This tutorial introduces more advanced features of Fully Sharded Data Parallel (FSDP) as part of the PyTorch 1.12 release. To get familiar with FSDP, please refer to the `FSDP getting started tutorial `__. +This tutorial introduces more advanced features of Fully Sharded Data Parallel +(FSDP) as part of the PyTorch 1.12 release. To get familiar with FSDP, please +refer to the `FSDP getting started tutorial +`__. -In this tutorial, we fine-tune a HuggingFace (HF) T5 model with FSDP for text summarization as a working example. +In this tutorial, we fine-tune a HuggingFace (HF) T5 model with FSDP for text +summarization as a working example. -The example uses Wikihow and for simplicity, we will showcase the training on a single node, P4dn instance with 8 A100 GPUs. We will soon have a blog post on large scale FSDP training on a multi-node cluster, please stay tuned for that on the PyTorch medium channel. +The example uses Wikihow and for simplicity, we will showcase the training on a +single node, P4dn instance with 8 A100 GPUs. We will soon have a blog post on +large scale FSDP training on a multi-node cluster, please stay tuned for that on +the PyTorch medium channel. -FSDP is a production ready package with focus on ease of use, performance, and long-term support. -One of the main benefits of FSDP is reducing the memory footprint on each GPU. This enables training of larger models with lower total memory vs DDP, and leverages the overlap of computation and communication to train models efficiently. -This reduced memory pressure can be leveraged to either train larger models or increase batch size, potentially helping overall training throughput. -You can read more about PyTorch FSDP `here `__. +FSDP is a production ready package with focus on ease of use, performance, and +long-term support. One of the main benefits of FSDP is reducing the memory +footprint on each GPU. This enables training of larger models with lower total +memory vs DDP, and leverages the overlap of computation and communication to +train models efficiently. +This reduced memory pressure can be leveraged to either train larger models or +increase batch size, potentially helping overall training throughput. You can +read more about PyTorch FSDP `here +`__. FSDP Features in This Tutorial @@ -38,29 +53,38 @@ At a high level FDSP works as follow: *In forward pass* -* Run `all_gather` to collect all shards from all ranks to recover the full parameter for this FSDP unit -* Run forward computation +* Run `all_gather` to collect all shards from all ranks to recover the full + parameter for this FSDP unit Run forward computation * Discard non-owned parameter shards it has just collected to free memory *In backward pass* -* Run `all_gather` to collect all shards from all ranks to recover the full parameter in this FSDP unit -* Run backward computation +* Run `all_gather` to collect all shards from all ranks to recover the full + parameter in this FSDP unit Run backward computation * Discard non-owned parameters to free memory. * Run reduce_scatter to sync gradients Fine-tuning HF T5 ----------------- -HF T5 pre-trained models are available in four different sizes, ranging from small with 60 Million parameters to XXL with 11 Billion parameters. In this tutorial, we demonstrate the fine-tuning of a T5 3B with FSDP for text summarization using WikiHow dataset. -The main focus of this tutorial is to highlight different available features in FSDP that are helpful for training large scale model above 3B parameters. Also, we cover specific features for Transformer based models. The code for this tutorial is available in `Pytorch Examples `__. +HF T5 pre-trained models are available in four different sizes, ranging from +small with 60 Million parameters to XXL with 11 Billion parameters. In this +tutorial, we demonstrate the fine-tuning of a T5 3B with FSDP for text +summarization using WikiHow dataset. The main focus of this tutorial is to +highlight different available features in FSDP that are helpful for training +large scale model above 3B parameters. Also, we cover specific features for +Transformer based models. The code for this tutorial is available in `Pytorch +Examples +`__. *Setup* 1.1 Install PyTorch Nightlies -We will install PyTorch nightlies, as some of the features such as activation checkpointing is available in nightlies and will be added in next PyTorch release after 1.12. +We will install PyTorch nightlies, as some of the features such as activation +checkpointing is available in nightlies and will be added in next PyTorch +release after 1.12. .. code-block:: bash @@ -68,16 +92,24 @@ We will install PyTorch nightlies, as some of the features such as activation ch 1.2 Dataset Setup -Please create a `data` folder, download the WikiHow dataset from `wikihowAll.csv `__ and `wikihowSep.cs `__, and place them in the `data` folder. -We will use the wikihow dataset from `summarization_dataset `__. +Please create a `data` folder, download the WikiHow dataset from `wikihowAll.csv +`__ and +`wikihowSep.cs `__, +and place them in the `data` folder. We will use the wikihow dataset from +`summarization_dataset +`__. -Next, we add the following code snippets to a Python script “T5_training.py”. Note - The full source code for this tutorial is available in `PyTorch examples `__. +Next, we add the following code snippets to a Python script “T5_training.py”. + +.. note:: + The full source code for this tutorial is available in `PyTorch examples + `__. 1.3 Import necessary packages: .. code-block:: python - import os + import os import argparse import torch import torch.nn as nn @@ -123,8 +155,11 @@ Next, we add the following code snippets to a Python script “T5_training.py” from datetime import datetime 1.4 Distributed training setup. -Here we use two helper functions to initialize the processes for distributed training, and then to clean up after training completion. -In this tutorial, we are going to use torch elastic, using `torchrun `__ , which will set the worker `RANK` and `WORLD_SIZE` automatically. +Here we use two helper functions to initialize the processes for distributed +training, and then to clean up after training completion. In this tutorial, we +are going to use torch elastic, using `torchrun +`__ , which will set the +worker `RANK` and `WORLD_SIZE` automatically. .. code-block:: python @@ -144,7 +179,8 @@ In this tutorial, we are going to use torch elastic, using `torchrun current date and time of run = {date_of_run}") return date_of_run - + def format_metrics_to_gb(item): """quick function to format numbers to gigabyte and round to 4 digit precision""" metric_num = item / g_gigabyte @@ -303,7 +339,7 @@ We also, add couple of helper functions here for date and formatting memory metr mp_policy = bfSixteen else: mp_policy = None # defaults to fp32 - + # model is on CPU before input to FSDP model = FSDP(model, auto_wrap_policy=t5_auto_wrap_policy, @@ -354,7 +390,7 @@ We also, add couple of helper functions here for date and formatting memory metr format_metrics_to_gb(torch.cuda.memory_reserved()) ) print(f"completed save and stats zone...") - + if args.save_model and curr_val_loss < best_val_loss: # save @@ -432,14 +468,26 @@ To run the the training using torchrun: torchrun --nnodes 1 --nproc_per_node 4 T5_training.py .. _transformer_wrapping_policy: + Transformer Wrapping Policy --------------------------- -As discussed in the `previous tutorial `__, auto_wrap_policy is one of the FSDP features that make it easy to automatically shard a given model and put the model, optimizer and gradient shards into distinct FSDP units. -For some architectures such as Transformer encoder-decoders, some parts of the model such as embedding table is being shared with both encoder and decoder. -In this case, we need to place the embedding table in the outer FSDP unit so that it could be accessed from both encoder and decoder. In addition, by registering the layer class for a transformer, the sharding plan can be made much more communication efficient. In PyTorch 1.12, FSDP added this support and now we have a wrapping policy for transfomers. +As discussed in the `previous tutorial +`__, +auto_wrap_policy is one of the FSDP features that make it easy to automatically +shard a given model and put the model, optimizer and gradient shards into +distinct FSDP units. + +For some architectures such as Transformer encoder-decoders, some parts of the +model such as embedding table is being shared with both encoder and decoder. In +this case, we need to place the embedding table in the outer FSDP unit so that +it could be accessed from both encoder and decoder. In addition, by registering +the layer class for a transformer, the sharding plan can be made much more +communication efficient. In PyTorch 1.12, FSDP added this support and now we +have a wrapping policy for transfomers. -It can be created as follows, where the T5Block represents the T5 transformer layer class (holding MHSA and FFN). +It can be created as follows, where the T5Block represents the T5 transformer +layer class (holding MHSA and FFN). .. code-block:: python @@ -456,12 +504,17 @@ It can be created as follows, where the T5Block represents the T5 transformer la model = FSDP(model, fsdp_auto_wrap_policy=t5_auto_wrap_policy) -To see the wrapped model, you can easily print the model and visually inspect the sharding and FSDP units as well. +To see the wrapped model, you can easily print the model and visually inspect +the sharding and FSDP units as well. Mixed Precision --------------- -FSDP supports flexible mixed precision training allowing for arbitrary reduced precision types (such as fp16 or bfloat16). Currently BFloat16 is only available on Ampere GPUs, so you need to confirm native support before you use it. On V100s for example, BFloat16 can still be run but due to it running non-natively, it can result in significant slowdowns. +FSDP supports flexible mixed precision training allowing for arbitrary reduced +precision types (such as fp16 or bfloat16). Currently BFloat16 is only available +on Ampere GPUs, so you need to confirm native support before you use it. On +V100s for example, BFloat16 can still be run but due to it running non-natively, +it can result in significant slowdowns. To check if BFloat16 is natively supported, you can use the following : @@ -475,7 +528,9 @@ To check if BFloat16 is natively supported, you can use the following : and nccl.version() >= (2, 10) ) -One of the advantages of mixed percision in FSDP is providing granular control over different precision levels for parameters, gradients, and buffers as follows: +One of the advantages of mixed percision in FSDP is providing granular control +over different precision levels for parameters, gradients, and buffers as +follows: .. code-block:: python @@ -503,9 +558,15 @@ One of the advantages of mixed percision in FSDP is providing granular control o buffer_dtype=torch.float32, ) -Note that if a certain type (parameter, reduce, buffer) is not specified, they will not be casted at all. +Note that if a certain type (parameter, reduce, buffer) is not specified, they +will not be casted at all. -This flexibility allows users fine grained control, such as only setting gradient communication to happen in reduced precision, and all parameters / buffer computation to be done in full precision. This is potentially useful in cases where intra-node communication is the main bottleneck and parameters / buffers must be in full precision to avoid accuracy issues. This can be done with the following policy: +This flexibility allows users fine grained control, such as only setting +gradient communication to happen in reduced precision, and all parameters / +buffer computation to be done in full precision. This is potentially useful in +cases where intra-node communication is the main bottleneck and parameters / +buffers must be in full precision to avoid accuracy issues. This can be done +with the following policy: .. code-block:: bash @@ -521,12 +582,19 @@ In 2.4 we just add the relevant mixed precision policy to the FSDP wrapper: auto_wrap_policy=t5_auto_wrap_policy, mixed_precision=bfSixteen) -In our experiments, we have observed up to 4x speed up by using BFloat16 for training and memory reduction of approximately 30% in some experiments that can be used for batch size increases. +In our experiments, we have observed up to 4x speed up by using BFloat16 for +training and memory reduction of approximately 30% in some experiments that can +be used for batch size increases. Intializing FSDP Model on Device -------------------------------- -In 1.12, FSDP supports a `device_id` argument meant to initialize input CPU module on the device given by `device_id`. This is useful when the entire model does not fit on a single GPU, but fits in a host's CPU memory. When `device_id` is specified, FSDP will move the model to the specified device on a per-FSDP unit basis, avoiding GPU OOM issues while initializing several times faster than CPU-based initialization: +In 1.12, FSDP supports a `device_id` argument meant to initialize input CPU +module on the device given by `device_id`. This is useful when the entire model +does not fit on a single GPU, but fits in a host's CPU memory. When `device_id` +is specified, FSDP will move the model to the specified device on a per-FSDP +unit basis, avoiding GPU OOM issues while initializing several times faster than +CPU-based initialization: .. code-block:: python @@ -541,7 +609,12 @@ In 1.12, FSDP supports a `device_id` argument meant to initialize input CPU modu Sharding Strategy ----------------- -FSDP sharding strategy by default is set to fully shard the model parameters, gradients and optimizer states get sharded across all ranks. (also termed Zero3 sharding). In case you are interested to have Zero2 sharding strategy, where only optimizer states and gradients are sharded, FSDP support this feature by passing the Sharding strategy by using "ShardingStrategy.SHARD_GRAD_OP", instead of "ShardingStrategy.FULL_SHARD" to the FSDP initialization as follows: +FSDP sharding strategy by default is set to fully shard the model parameters, +gradients and optimizer states get sharded across all ranks. (also termed Zero3 +sharding). In case you are interested to have Zero2 sharding strategy, where +only optimizer states and gradients are sharded, FSDP support this feature by +passing the Sharding strategy by using "ShardingStrategy.SHARD_GRAD_OP", +instead of "ShardingStrategy.FULL_SHARD" to the FSDP initialization as follows: .. code-block:: python @@ -553,13 +626,22 @@ FSDP sharding strategy by default is set to fully shard the model parameters, gr device_id=torch.cuda.current_device(), sharding_strategy=ShardingStrategy.SHARD_GRAD_OP # FULL_SHARD) -This will reduce the communication overhead in FSDP, in this case, it holds full parameters after forward and through the backwards pass. +This will reduce the communication overhead in FSDP, in this case, it holds full +parameters after forward and through the backwards pass. -This saves an all_gather during backwards so there is less communication at the cost of a higher memory footprint. Note that full model params are freed at the end of backwards and all_gather will happen on the next forward pass. +This saves an all_gather during backwards so there is less communication at the +cost of a higher memory footprint. Note that full model params are freed at the +end of backwards and all_gather will happen on the next forward pass. Backward Prefetch ----------------- -The backward prefetch setting controls the timing of when the next FSDP unit's parameters should be requested. By setting it to `BACKWARD_PRE`, the next FSDP's unit params can begin to be requested and arrive sooner before the computation of the current unit starts. This overlaps the `all_gather` communication and gradient computation which can increase the training speed in exchange for slightly higher memory consumption. It can be utilized in the FSDP wrapper in 2.4 as follows: +The backward prefetch setting controls the timing of when the next FSDP unit's +parameters should be requested. By setting it to `BACKWARD_PRE`, the next +FSDP's unit params can begin to be requested and arrive sooner before the +computation of the current unit starts. This overlaps the `all_gather` +communication and gradient computation which can increase the training speed in +exchange for slightly higher memory consumption. It can be utilized in the FSDP +wrapper in 2.4 as follows: .. code-block:: python @@ -571,15 +653,27 @@ The backward prefetch setting controls the timing of when the next FSDP unit's p device_id=torch.cuda.current_device(), backward_prefetch = BackwardPrefetch.BACKWARD_PRE) -`backward_prefetch` has two modes, `BACKWARD_PRE` and `BACKWARD_POST`. `BACKWARD_POST` means that the next FSDP unit's params will not be requested until the current FSDP unit processing is complete, thus minimizing memory overhead. In some cases, using `BACKWARD_PRE` can increase model training speed up to 2-10%, with even higher speed improvements noted for larger models. +`backward_prefetch` has two modes, `BACKWARD_PRE` and `BACKWARD_POST`. +`BACKWARD_POST` means that the next FSDP unit's params will not be requested +until the current FSDP unit processing is complete, thus minimizing memory +overhead. In some cases, using `BACKWARD_PRE` can increase model training speed +up to 2-10%, with even higher speed improvements noted for larger models. Model Checkpoint Saving, by streaming to the Rank0 CPU ------------------------------------------------------ -To save model checkpoints using FULL_STATE_DICT saving which saves model in the same fashion as a local model, PyTorch 1.12 offers a few utilities to support the saving of larger models. +To save model checkpoints using FULL_STATE_DICT saving which saves model in the +same fashion as a local model, PyTorch 1.12 offers a few utilities to support +the saving of larger models. -First, a FullStateDictConfig can be specified, allowing the state_dict to be populated on rank 0 only and offloaded to the CPU. +First, a FullStateDictConfig can be specified, allowing the state_dict to be +populated on rank 0 only and offloaded to the CPU. -When using this configuration, FSDP will allgather model parameters, offloading them to the CPU one by one, only on rank 0. When the state_dict is finally saved, it will only be populated on rank 0 and contain CPU tensors. This avoids potential OOM for models that are larger than a single GPU memory and allows users to checkpoint models whose size is roughly the available CPU RAM on the user's machine. +When using this configuration, FSDP will allgather model parameters, offloading +them to the CPU one by one, only on rank 0. When the state_dict is finally +saved, it will only be populated on rank 0 and contain CPU tensors. This avoids +potential OOM for models that are larger than a single GPU memory and allows +users to checkpoint models whose size is roughly the available CPU RAM on the +user's machine. This feature can be run as follows: @@ -596,7 +690,15 @@ This feature can be run as follows: Summary ------- -In this tutorial, we have introduced many new features for FSDP available in Pytorch 1.12 and used HF T5 as the running example. -Using the proper wrapping policy especially for transformer models, along with mixed precision and backward prefetch should speed up your training runs. Also, features such as initializing the model on device, and checkpoint saving via streaming to CPU should help to avoid OOM error in dealing with large models. -We are actively working to add new features to FSDP for the next release. If you have feedback, feature requests, questions or are encountering issues using FSDP, please feel free to contact us by opening an issue at `PyTorch Github repository `__. +In this tutorial, we have introduced many new features for FSDP available in +Pytorch 1.12 and used HF T5 as the running example. Using the proper wrapping +policy especially for transformer models, along with mixed precision and +backward prefetch should speed up your training runs. Also, features such as +initializing the model on device, and checkpoint saving via streaming to CPU +should help to avoid OOM error in dealing with large models. + +We are actively working to add new features to FSDP for the next release. If +you have feedback, feature requests, questions or are encountering issues +using FSDP, please feel free to contact us by opening an issue in the +`PyTorch Github repository `__. From 8aec5c9e6d8e94ce0e7730bc9faaff66c0a5bb00 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Mon, 11 Jul 2022 08:14:21 -0700 Subject: [PATCH 2/2] Update --- intermediate_source/FSDP_adavnced_tutorial.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/intermediate_source/FSDP_adavnced_tutorial.rst b/intermediate_source/FSDP_adavnced_tutorial.rst index 5197da16f10..cce90e8787e 100644 --- a/intermediate_source/FSDP_adavnced_tutorial.rst +++ b/intermediate_source/FSDP_adavnced_tutorial.rst @@ -611,7 +611,7 @@ Sharding Strategy ----------------- FSDP sharding strategy by default is set to fully shard the model parameters, gradients and optimizer states get sharded across all ranks. (also termed Zero3 -sharding). In case you are interested to have Zero2 sharding strategy, where +sharding). In case you are interested to have the Zero2 sharding strategy, where only optimizer states and gradients are sharded, FSDP support this feature by passing the Sharding strategy by using "ShardingStrategy.SHARD_GRAD_OP", instead of "ShardingStrategy.FULL_SHARD" to the FSDP initialization as follows: @@ -624,7 +624,7 @@ instead of "ShardingStrategy.FULL_SHARD" to the FSDP initialization as follows: auto_wrap_policy=t5_auto_wrap_policy, mixed_precision=bfSixteen, device_id=torch.cuda.current_device(), - sharding_strategy=ShardingStrategy.SHARD_GRAD_OP # FULL_SHARD) + sharding_strategy=ShardingStrategy.SHARD_GRAD_OP # ZERO2) This will reduce the communication overhead in FSDP, in this case, it holds full parameters after forward and through the backwards pass.