From 71d718b38e892f88b3e04e70cc17a75ecad29828 Mon Sep 17 00:00:00 2001 From: Luyu Gao Date: Mon, 12 Jun 2023 18:23:10 +0000 Subject: [PATCH] Add a context manager for activation sharding. --- fairscale/nn/model_parallel/mappings.py | 44 ++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/fairscale/nn/model_parallel/mappings.py b/fairscale/nn/model_parallel/mappings.py index 78d0961c5..655130faf 100644 --- a/fairscale/nn/model_parallel/mappings.py +++ b/fairscale/nn/model_parallel/mappings.py @@ -22,8 +22,9 @@ from typing import Any import torch +from torch.autograd.graph import saved_tensors_hooks -from .initialize import get_model_parallel_group +from .initialize import get_model_parallel_group, get_model_parallel_world_size, get_model_parallel_rank from .utils import split_tensor_along_last_dim @@ -154,3 +155,44 @@ def scatter_to_model_parallel_region(input_: torch.Tensor) -> torch.Tensor: def gather_from_model_parallel_region(input_: torch.Tensor) -> torch.Tensor: return _GatherFromModelParallelRegion.apply(input_) + + +def _pack_over_mp(tensor): + mp_world_size = get_model_parallel_world_size() + if mp_world_size == 1: + return tensor # no-op for mp=1 + full_tensor_shape = list(tensor.shape) + shard = tensor.view(-1).chunk(mp_world_size, dim=0)[get_model_parallel_rank()] + shard = shard.detach().clone().contiguous() # clone to explicitly release memory of the full tensor + del tensor + return shard, full_tensor_shape + + +def _unpack_over_mp(sharded_tensor): + sharded_tensor, full_tensor_shape = sharded_tensor + mp_world_size = get_model_parallel_world_size() + if mp_world_size == 1: + return sharded_tensor # no-op for mp=1 + full_tensor = torch.empty( + *full_tensor_shape, + dtype=sharded_tensor.dtype, + device=sharded_tensor.device) + + torch.distributed.all_gather_into_tensor( + full_tensor.view(-1), sharded_tensor, group=get_model_parallel_group() + ) + + return full_tensor + + +class shard_over_mp_group(saved_tensors_hooks): + """Context manager for activatoin sharding. + + This context manager shard tensors saved by autograd over the + model parallel group in the forward pass and unshards them + in the backward pass. Useful to remove redundancy in the + long-living activation tensors. + """ + + def __init__(self): + super().__init__(_pack_over_mp, _unpack_over_mp) \ No newline at end of file