Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] FSDP via SPMD #6379

Closed
alanwaketan opened this issue Jan 25, 2024 · 11 comments
Closed

[RFC] FSDP via SPMD #6379

alanwaketan opened this issue Jan 25, 2024 · 11 comments
Assignees
Labels

Comments

@alanwaketan
Copy link
Collaborator

alanwaketan commented Jan 25, 2024

FSDP via SPMD (FSDP v2)

Introduction

FSDP, fully sharded data parallel, is a well-known distributed training algorithm in the PyTorch world. SPMD is PyTorch/XLA’s API that allows users to annotate a single device PyTorch model and then let XLA’s GSPMD feature turn it into a distributed model. This design doc focuses on how to utilize SPMD to express FSDP, and make this new implementation performant and easy to use. Since PyTorch/XLA already has a native implementation of FSDP here, this new implementation is also referred to as FSDPv2.

Background

A lot of past SPMD training experiments as conducted in has demonstrated that FSDP, i.e, 1D sharding has better performance than 2D sharding as long as the model can fit into the training fleet.

To express FSDP using the vanilla SPMD API, one currently needs to accomplish the following 5 steps. Examples are taken from our HF Llama 2 fork.

1. Define mesh

# Place DCN on an independent axis in the mesh. Model parameters should be
# replicated along the DCN axis, and inputs and activations should have
# the batch dimension sharded along the combined DCN and data axes.
num_devices = xr.global_runtime_device_count()
model_axis = max(model_args.spmd_2d_sharding, 1)  # spmd_2d_sharding is set to 1 here.
dcn_axis = model_args.spmd_dcn_parallelism
data_axis = num_devices // model_axis // dcn_axis
ici_mesh_shape = (1, data_axis, model_axis)
dcn_mesh_shape = (dcn_axis, 1, 1)
spmd_mesh = xs.HybridMesh(ici_mesh_shape=ici_mesh_shape, dcn_mesh_shape=dcn_mesh_shape,axis_names=('dcn', 'data', 'model'))

2. Shard data loader

import torch_xla.experimental.xla_sharding as xs
import torch_xla.distributed.parallel_loader as pl

sharding_spec = xs.ShardingSpec(self.args.spmd_mesh, (('dcn', 'data'), None))
# TODO(jonbolin): Once integrated with Accelerate, we can use the Accelerate-prepared
# MpDeviceLoader instead of manually adding sharding and adding a dataset attribute.
loader = pl.MpDeviceLoader(dataloader, self.args.device, input_sharding=sharding_spec, loader_prefetch_size=self.args.train_batch_size, device_prefetch_size=4)

3. Shard weights

for name, param in model.named_parameters():
  if model_args.spmd_fsdp_sharding:
            print('> [FSDP] Sharding tensor', name, param.shape, param.dtype)
            # We don't care about layernorm's weights, and
            # LLaMA doesn't use biases.
            if len(param.shape) == 1:
                continue
            assert len(param.shape) == 2

            # Shard the largest dimension
            if param.shape[0] > param.shape[1]:
                partition_spec = ('data', None)
            else:
                partition_spec = (None, 'data')
            xs.mark_sharding(param, spmd_mesh, partition_spec)

4. Shard activations

# Apply 2D sharding:
# hidden_states (batch, length, hidden)
# mesh (data, None, model)
if self.spmd_debug:
  print('> Sharding hidden_states', hidden_states.shape, self.spmd_mesh.get_logical_mesh().shape)
xs.mark_sharding(hidden_states, self.spmd_mesh, (('dcn', 'data'), None, 'model'))
if self.spmd_debug:
  print(torch_xla._XLAC._get_xla_sharding_spec(hidden_states))

5. Apply backward optimization barrier

for i, block in enumerate(model.model.layers):
        # LLaMA-specific
        xs.apply_backward_optimization_barrier(model.model.layers[i])

Even though FSDP requires far less sharding annotations than 2D sharding which will require a lot more sharding on activations and additional sharding on the attention layers if for LLMs, it’s still complex. Here is the recap of what our native FSDP implementation, which we will refer to as FSDPv1 in the design, is being used in general:

from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy

auto_wrap_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls=transformer_cls_to_wrap)
self.model = model = FSDP(model, auto_wrap_policy=auto_wrap_policy, **fsdp_kwargs)

It’s far less code boilerplate to write, and thus a much better user experience! Here comes the problem statement: could we recreate the same user experience while keeping the same performance with vanilla SPMD FSDP annotations?

Goals

  • Easy to use UX
  • Competitive out of box performance
  • Seamless integration with HF and Lightning
  • Major way of using SPMD

Non-goals

  • Replace 2D sharding

Feature Requirements

Before talking about the design candidates, let’s detail what are the features and characteristics of this new system. This way, it can better help us navigate through the different design candidates.

P0: Shard on weights

This is the basic concept of FSDP where weights are sharded and distributed among the training fleet.

P0: Shard on activations

This is not needed in FSDPv1. However, if we omit this in the vanilla SPMD example in the Background section, we get much worse performance. The following performance benchmarks are taken with v4-8 and Llama 2 2B on 1K seq_len.

Shard activations

  • Xprof: Missing
  • Hardware FLOPS utilization: 65.0%
  • Peak memory allocation: 16392.05 MiB

Don’t shard activations

  • Xprof: Missing
  • Hardware FLOPS utilization: 50.3%
  • Peak memory allocation: 16673.46 MiB

It turns out that in the second case, the compiler decides to do some wild all-to-all and all-reduce in the attention layer.

Screenshot 2023-12-12 at 10 46 47 PM

Therefore, this feature is a must to instruct the compiler to follow the FSDP algorithm. Fortunately, we only need to shard either the input or the output hidden_states of the decoder layer in the case of LLM, and we don’t need to shard every activation.

P0: Backward optimization barrier

This is needed to prevent gigantic fusions on syncing the gradients. The only remaining question is whether it’s compatible with gradient checkpointing since both of them will overwrite the backward pass in some fashions. Theoretically speaking, it should be compatible and the application order shouldn’t matter.

P0: Manual wrapping

Most of the features here should be packaged together and can be applied separately to the root module and the children modules. Let’s take FSDPv1 as an example. Typically the wrapper will be applied to two modules:

  1. The root module.
  2. The decoder layer.

Even though FSDPv1 will be default to shard all the parameters including children’s in the wrapped module, yet the rebuilding, memory-freeing, gradient synchronizing logic only applies to wrapped module. If only the root module is wrapped, then all parameters will be built in full during the outermost forward and thus no memory-saving. If every child module is wrapped, the overhead will just be too much. That’s why usually only the above two types of modules are wrapped.

P1: MultiSlice support

The implementation should be flexible to support 1) data parallel over MultiSlice and 2) FSDP over MultiSlice.

P1: Defer Parameter initialization

This is needed when the total model size is larger than the host memory. In TPU v5e, the host memory is extremely limited and this feature becomes a must. Basically, what we need to do is to initialize the model layer by layer, and transfer the layer to the device immediately.

P1: Auto wrapping

This refers to the ability to apply the same set of rules, e.g., sharding/opt-barrier/etc, automatically to children modules from the root module.

P1: Distributed checkpointing support

Two use cases here: one is for exception handling during the training job, and the other one is for consolidating to be used for future inference. For different design candidates, this feature requirement might have different implications.** For example, a nn.Module wrapper approach will introduce additional naming prefixes in the state_dicts.**

P1: HuggingFace and Lightning integrations

FSDPv2 should design in mind to easily replace the current FSDPv1 integrations presented in HuggingFace and Lightning, and thus become the default distributed algorithm in those two high level frameworks for PyTorch/XLA.

P2: Mixed precision support

In FSDPv1, it offers manual mixed precision support where the weights are always in FP32 but compute can be performed in BF16. Mixed precision support is definitely needed but whether it’s supported via torch.amp or via this is under discussion.

P2: Gradient Averaging

In FSDPv1, it offers a nice way of averaging the gradients by world_size to avoid overflows during all_reduce. Unclear whether this is necessary for this design.

N.A.: Shard on attentions

This is needed in the case of 2D sharding as we need to pick num_attention_heads dim to shard on the model axis. Since we only shard on the bs dim, theoretically speaking we shouldn’t need it. Experiments also validate the theory.

N.A.: Replace nn.Linear

This is needed for 2D sharding as we don’t want PyTorch to collide the two dims of a tensor where both of them are sharded during a matmul operation. However, in FSDP, at most one dim of the tensor will be sharded, and therefore this is not needed. Here is the xprof that drops XLAPatchedLinear. No performance degradation is observed.

N.A.: Shard optimizer states

This is proven to be unnecessary during the 2D sharding exercise.

Design

In this section, two approaches will be discussed. Each of them will have a PoC implementation that includes all the P0 features to demonstrate the feasibility and pros & cons. Then one of them will be selected as the final design and then more P1 and P2 features will be added on top of.

As an nn.Module: SpmdFullyShardedDataParallel

This is a traditional approach like the FSDPv2. Here we have the following major components

init

It will take care of the following P0s:

  • Shard on weights
  • Apply backward optimization barrier

Below is the pseudo code:

 def __init__(self, module: nn.Module, mesh: spmd.Mesh, shard_output:Optional[Callable] = None):
    # Check the paramters
  ...

    super().__init__()

    self._orig_module = module
    self._mesh = mesh
    self._shard_output = shard_output  # will explain in the next section

    # shard the weights
    for param in module.parameters():
      spmd.mark_sharding(param, mesh, _prepare_spmd_partition_spec(param))

    # apply the backward optimization barrier
    spmd.xla_sharding.apply_backward_optimization_barrier(module)

forward

This is the most debatable part. The forward function is used to shard on activations, specifically the output of the original module. As discussed in the above section, this is required to maintain the high performance. However, the output of the forward function can be anything, and therefore it’s really hard to shard.

Here is the proposed solution. Conventionally, the output usually will be:

  1. A tensor
  2. A tuple of tensors

For the 1st case, we can safely shard it. For the 2nd case, we can shard on the 0th element and warn the user to provide an output sharding function if that element is not intended. For all the other cases, we will just raise a runtime error to ask users to provide an output sharding function. That’s what shard_output in the above section represents.

  def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
    output = self.module(*args, **kwargs)
    self._shard_output(output, self._mesh)
    return output

Code boilerplate

Some code boilerplate that is copied from the FSDPv1.

def module(self) -> nn.Module:
  """make model.module accessible, just like DDP."""
  ...

def __getattr__(self, name: str) -> Union[torch.Tensor, nn.Module]:
  """Forward missing attributes to wrapped module."""
  ...

def __getitem__(self, key: int) -> nn.Module:
  """Forward indexing calls in case the module is a nn.Sequential."""
  ...

Manual wrapping

Here it follows the same logic as FSDPv1 where only the root module and important backbone children modules should be wrapped, like the decoder layers in LLMs. This is very important as we don’t want to register the backward optimization barrier for every linear. Otherwise, the compiler will lose a ton of optimization opportunities.

(WIP) As a function: spmd_fully_sharded_data_parallel

This will be very similar to the distribute_module, and in fact probably can be a high level wrapper on top of it. Basically, it will take care of everything within the function.

Test Plan

  • Unit tests that examine all the basic functionalities.
  • Performance tests that demonstrate the parity of FSDPv2 against the vanilla SPMD implementation.

(WIP) PoCs Evaluation

SpmdFullyShardedDataParallel

Performance

  • Xprof: Missing
  • Hardware FLOPS utilization: 63.5% v.s. 65.0%.
  • Peak memory allocation: 16666.45 MiB v.s.16392.05 MiB

Project Management

Milestone 1:

  1. Wrap up initial design
  2. Provide PoC of SpmdFullyShardedDataParallel: 1. PyTorch/XLA PR, 2. HF integration.
  3. Provide PoC of spmd_fully_sharded_data_parallel.
  4. Evaluate the two PoCs and determine the right approach.

Milestone 2:

TBD, implement the rest of the feature set.

Open Questions

  • Distributed checkpointing compatibility, naming conventions

cc @JackCaoG @yeounoh @jonb377 @wconstab @wanchaol

@alanwaketan alanwaketan self-assigned this Jan 25, 2024
@YangFei1990
Copy link
Contributor

Awesome thanks @alanwaketan ! One thing I did not find from the RFC is about the gradient clipping, which in FSDPv1 it requires to collect the global gradient norm before clipping. Wondering how it will be handled with GSPMD.
Also it is mentioned that optimizer state sharding is not required in v2, since in v1 all parameters/grads/opt states are sharded, does v2 have the same behavior?

@alanwaketan
Copy link
Collaborator Author

Awesome thanks @alanwaketan ! One thing I did not find from the RFC is about the gradient clipping, which in FSDPv1 it requires to collect the global gradient norm before clipping. Wondering how it will be handled with GSPMD. Also it is mentioned that optimizer state sharding is not required in v2, since in v1 all parameters/grads/opt states are sharded, does v2 have the same behavior?

Optimizer states will be sharded automatically. We just don't need to specify the sharding of them. Please refer to https://pytorch.org/blog/high-performance-llama-2/.

In terms of gradient clipping, it's something under debut. I'm not quite sure whether we need it or not. Can you remind me what's the benefits of this function? I don't think we ever need that in TPU.

@baoleai
Copy link
Contributor

baoleai commented Jan 26, 2024

We have found that in order to effectively utilize SPMD-FSDP, it is necessary to configure cache_all_gather=False in https://github.com/openxla/xla/blob/main/xla/service/spmd/spmd_partitioner.h#L72. This adjustment is crucial to prevent the consolidation of all-gather operations for weights throughout the forward and backward phases. When set to cache_all_gather=True, the all-gather operation is triggered during the forward phase and its results are retained until consumed in the backward phase, which can lead to increased peak GPU memory.

@alanwaketan
Copy link
Collaborator Author

We have found that in order to effectively utilize SPMD-FSDP, it is necessary to configure cache_all_gather=False in openxla/xla@main/xla/service/spmd/spmd_partitioner.h#L72. This adjustment is crucial to prevent the consolidation of all-gather operations for weights throughout the forward and backward phases. When set to cache_all_gather=True, the all-gather operation is triggered during the forward phase and its results are retained until consumed in the backward phase, which can lead to increased peak GPU memory.

Can gradient checkpointing solve your issue?

@baoleai
Copy link
Contributor

baoleai commented Jan 26, 2024

We have found that in order to effectively utilize SPMD-FSDP, it is necessary to configure cache_all_gather=False in openxla/xla@main/xla/service/spmd/spmd_partitioner.h#L72. This adjustment is crucial to prevent the consolidation of all-gather operations for weights throughout the forward and backward phases. When set to cache_all_gather=True, the all-gather operation is triggered during the forward phase and its results are retained until consumed in the backward phase, which can lead to increased peak GPU memory.

Can gradient checkpointing solve your issue?

No, because the all-gather is for weights, not activations, it's unrelated to gradient checkpointing. The issue becomes more pronounced with larger models like llama2-70B; the increase in GPU memory is not very noticeable with smaller models.

Additionally, in the Python implementation of FSDP, the all-gather operations are generated separately during the forward and backward passes, and are not merged together.

@alanwaketan
Copy link
Collaborator Author

We have found that in order to effectively utilize SPMD-FSDP, it is necessary to configure cache_all_gather=False in openxla/xla@main/xla/service/spmd/spmd_partitioner.h#L72. This adjustment is crucial to prevent the consolidation of all-gather operations for weights throughout the forward and backward phases. When set to cache_all_gather=True, the all-gather operation is triggered during the forward phase and its results are retained until consumed in the backward phase, which can lead to increased peak GPU memory.

Can gradient checkpointing solve your issue?

No, because the all-gather is for weights, not activations, it's unrelated to gradient checkpointing. The issue becomes more pronounced with larger models like llama2-70B; the increase in GPU memory is not very noticeable with smaller models.

Additionally, in the Python implementation of FSDP, the all-gather operations are generated separately during the forward and backward passes, and are not merged together.

So you are trying to prevent that consolidation between the fwd + bwd during the ckpted backward? As there is no way the original fwd's all-gather will be consolidated to the ckpted backward.

@wconstab
Copy link
Collaborator

Not sure if you're following the per-param fsdp rewrite- see PRs labeled [FSDP2]. Some ux changes if you want to try to stay consistent.

@alanwaketan
Copy link
Collaborator Author

@wconstab It will be great to stay consistent. On the other hand, I wonder if there is any design docs?

@wconstab
Copy link
Collaborator

@awgu can you link

@awgu
Copy link

awgu commented Jan 29, 2024

A majority of the info is captured here: pytorch/pytorch#114299

@alanwaketan
Copy link
Collaborator Author

A majority of the info is captured here: pytorch/pytorch#114299

Thanks, will take a look!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

6 participants