From d0cdc85f1ad6fc6278364fc5c6a08607bbbae0a3 Mon Sep 17 00:00:00 2001 From: Felicity Liao Date: Wed, 2 Apr 2025 13:45:25 -0700 Subject: [PATCH] 1/n Dynamic Sharding API + Test for EBC, TW, ShardedTensor (#2852) Summary: Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs. What's added here: 1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection` 2. Util functions for dynamic sharding - these are used by the `update_shards` API: 1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight` 2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params` 3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from various: `world_sizes`, `num_tables`, `data_types`. 1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 Future work items (features not yet supported in this diff): * CW, RW, and many other sharding types * Optimizer saving * DTensor implementation Differential Revision: D69095169 --- torchrec/distributed/embeddingbag.py | 205 +++++++- .../distributed/sharding/dynamic_sharding.py | 190 ++++++++ .../tests/test_dynamic_sharding.py | 437 ++++++++++++++++++ 3 files changed, 818 insertions(+), 14 deletions(-) create mode 100644 torchrec/distributed/sharding/dynamic_sharding.py create mode 100644 torchrec/distributed/tests/test_dynamic_sharding.py diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 6a0192841..33ea870ae 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -27,6 +27,9 @@ import torch from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings +from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( + DenseTableBatchedEmbeddingBagsCodegen, +) from tensordict import TensorDict from torch import distributed as dist, nn, Tensor from torch.autograd.profiler import record_function @@ -50,6 +53,10 @@ ) from torchrec.distributed.sharding.cw_sharding import CwPooledEmbeddingSharding from torchrec.distributed.sharding.dp_sharding import DpPooledEmbeddingSharding +from torchrec.distributed.sharding.dynamic_sharding import ( + shards_all_to_all, + update_state_dict_post_resharding, +) from torchrec.distributed.sharding.grid_sharding import GridPooledEmbeddingSharding from torchrec.distributed.sharding.rw_sharding import RwPooledEmbeddingSharding from torchrec.distributed.sharding.tw_sharding import TwPooledEmbeddingSharding @@ -635,14 +642,17 @@ def __init__( self._env = env # output parameters as DTensor in state dict self._output_dtensor: bool = env.output_dtensor - - sharding_type_to_sharding_infos = create_sharding_infos_by_sharding( - module, - table_name_to_parameter_sharding, - "embedding_bags.", - fused_params, + self.sharding_type_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]] = ( + create_sharding_infos_by_sharding( + module, + table_name_to_parameter_sharding, + "embedding_bags.", + fused_params, + ) + ) + self._sharding_types: List[str] = list( + self.sharding_type_to_sharding_infos.keys() ) - self._sharding_types: List[str] = list(sharding_type_to_sharding_infos.keys()) self._embedding_shardings: List[ EmbeddingSharding[ EmbeddingShardingContext, @@ -658,7 +668,7 @@ def __init__( permute_embeddings=True, qcomm_codecs_registry=self.qcomm_codecs_registry, ) - for embedding_configs in sharding_type_to_sharding_infos.values() + for embedding_configs in self.sharding_type_to_sharding_infos.values() ] self._is_weighted: bool = module.is_weighted() @@ -833,7 +843,7 @@ def _pre_load_state_dict_hook( lookup = lookup.module lookup.purge() - def _initialize_torch_state(self) -> None: # noqa + def _initialize_torch_state(self, skip_registering: bool = False) -> None: # noqa """ This provides consistency between this class and the EmbeddingBagCollection's nn.Module API calls (state_dict, named_modules, etc) @@ -1063,11 +1073,12 @@ def post_state_dict_hook( destination_key = f"{prefix}embedding_bags.{table_name}.weight" destination[destination_key] = sharded_kvtensor - self.register_state_dict_pre_hook(self._pre_state_dict_hook) - self._register_state_dict_hook(post_state_dict_hook) - self._register_load_state_dict_pre_hook( - self._pre_load_state_dict_hook, with_module=True - ) + if not skip_registering: + self.register_state_dict_pre_hook(self._pre_state_dict_hook) + self._register_state_dict_hook(post_state_dict_hook) + self._register_load_state_dict_pre_hook( + self._pre_load_state_dict_hook, with_module=True + ) self.reset_parameters() def reset_parameters(self) -> None: @@ -1164,6 +1175,7 @@ def _create_output_dist(self) -> None: self._uncombined_embedding_dims.extend(sharding.uncombined_embedding_dims()) embedding_shard_metadata.extend(sharding.embedding_shard_metadata()) self._dim_per_key = torch.tensor(self._embedding_dims, device=self._device) + embedding_shard_offsets: List[int] = [ meta.shard_offsets[1] if meta is not None else 0 for meta in embedding_shard_metadata @@ -1179,6 +1191,38 @@ def _create_output_dist(self) -> None: embedding_shard_offsets[i], ), ) + + self._permute_op: PermutePooledEmbeddings = PermutePooledEmbeddings( + self._uncombined_embedding_dims, permute_indices, self._device + ) + + def _update_output_dist(self) -> None: + embedding_shard_metadata: List[Optional[ShardMetadata]] = [] + # TODO: Optimize to only go through embedding shardings with new ranks + self._output_dists: List[nn.Module] = [] + self._embedding_names: List[str] = [] + for sharding in self._embedding_shardings: + # TODO: if sharding type of table completely changes, need to regenerate everything + self._embedding_names.extend(sharding.embedding_names()) + self._output_dists.append(sharding.create_output_dist(device=self._device)) + embedding_shard_metadata.extend(sharding.embedding_shard_metadata()) + + embedding_shard_offsets: List[int] = [ + meta.shard_offsets[1] if meta is not None else 0 + for meta in embedding_shard_metadata + ] + embedding_name_order: Dict[str, int] = {} + for i, name in enumerate(self._uncombined_embedding_names): + embedding_name_order.setdefault(name, i) + + permute_indices = sorted( + range(len(self._uncombined_embedding_names)), + key=lambda i: ( + embedding_name_order[self._uncombined_embedding_names[i]], + embedding_shard_offsets[i], + ), + ) + self._permute_op: PermutePooledEmbeddings = PermutePooledEmbeddings( self._uncombined_embedding_dims, permute_indices, self._device ) @@ -1396,6 +1440,108 @@ def compute_and_output_dist( return awaitable + def update_shards( + self, + changed_sharding_params: Dict[str, ParameterSharding], # NOTE: only delta + env: ShardingEnv, + device: Optional[torch.device], + ) -> None: + """ + Update shards for this module based on the changed_sharding_params. This will: + 1. Move current lookup tensors to CPU + 2. Purge lookups + 3. Call shards_all_2_all containing collective to redistribute tensors + 4. Update state_dict and other attributes to reflect new placements and shards + 5. Create new lookups, and load in updated state_dict + + Args: + changed_sharding_params (Dict[str, ParameterSharding]): A dictionary mapping + table names to their new parameter sharding configs. This should only + contain shards/table names that need to be moved. + env (ShardingEnv): The sharding environment for the module. + device (Optional[torch.device]): The device to place the updated module on. + """ + + if env.output_dtensor: + raise RuntimeError("We do not yet support DTensor for resharding yet") + return + + current_state = self.state_dict() + # TODO: Save Optimizers + + saved_weights = {} + # TODO: Saving lookups tensors to CPU to eventually avoid recreating them completely again + for i, lookup in enumerate(self._lookups): + for attribute, tbe_module in lookup.named_modules(): + if type(tbe_module) is DenseTableBatchedEmbeddingBagsCodegen: + saved_weights[str(i) + "." + attribute] = tbe_module.weights.cpu() + # Note: lookup.purge should delete tbe_module and weights + # del tbe_module.weights + # del tbe_module + # pyre-ignore + lookup.purge() + + # Deleting all lookups + self._lookups.clear() + + local_shard_names_by_src_rank, local_output_tensor = shards_all_to_all( + module=self, + state_dict=current_state, + device=device, # pyre-ignore + changed_sharding_params=changed_sharding_params, + env=env, + extend_shard_name=self.extend_shard_name, + ) + + current_state = update_state_dict_post_resharding( + state_dict=current_state, + shard_names_by_src_rank=local_shard_names_by_src_rank, + output_tensor=local_output_tensor, + new_sharding_params=changed_sharding_params, + curr_rank=dist.get_rank(), + extend_shard_name=self.extend_shard_name, + ) + + for name, param in changed_sharding_params.items(): + self.module_sharding_plan[name] = param + # TODO: Support detecting old sharding type when sharding type is changing + for sharding_info in self.sharding_type_to_sharding_infos[ + param.sharding_type + ]: + if sharding_info.embedding_config.name == name: + sharding_info.param_sharding = param + + self._sharding_types: List[str] = list( + self.sharding_type_to_sharding_infos.keys() + ) + # TODO: Optimize to update only the changed embedding shardings + self._embedding_shardings: List[ + EmbeddingSharding[ + EmbeddingShardingContext, + KeyedJaggedTensor, + torch.Tensor, + torch.Tensor, + ] + ] = [ + create_embedding_bag_sharding( + embedding_configs, + env, + device, + permute_embeddings=True, + qcomm_codecs_registry=self.qcomm_codecs_registry, + ) + for embedding_configs in self.sharding_type_to_sharding_infos.values() + ] + + self._create_lookups() + self._update_output_dist() + + if env.process_group and dist.get_backend(env.process_group) != "fake": + self._initialize_torch_state(skip_registering=True) + + self.load_state_dict(current_state) + return + @property def fused_optimizer(self) -> KeyedOptimizer: return self._optim @@ -1403,6 +1549,10 @@ def fused_optimizer(self) -> KeyedOptimizer: def create_context(self) -> EmbeddingBagCollectionContext: return EmbeddingBagCollectionContext() + @staticmethod + def extend_shard_name(shard_name: str) -> str: + return f"embedding_bags.{shard_name}.weight" + class EmbeddingBagCollectionSharder(BaseEmbeddingSharder[EmbeddingBagCollection]): """ @@ -1435,6 +1585,33 @@ def shardable_parameters( for name, param in module.embedding_bags.named_parameters() } + def reshard( + self, + sharded_module: ShardedEmbeddingBagCollection, + changed_shard_to_params: Dict[str, ParameterSharding], + env: ShardingEnv, + device: Optional[torch.device] = None, + ) -> ShardedEmbeddingBagCollection: + """ + Updates the sharded module in place based on the changed_shard_to_params + which contains the new ParameterSharding with different shard placements. + + Args: + sharded_module (ShardedEmbeddingBagCollection): The module to update + changed_shard_to_params (Dict[str, ParameterSharding]): A dictionary mapping + table names to their new parameter sharding configs. This should only + contain shards/table names that need to be moved + env (ShardingEnv): The sharding environment + device (Optional[torch.device]): The device to place the updated module on + + Returns: + ShardedEmbeddingBagCollection: The updated sharded module + """ + + if len(changed_shard_to_params) > 0: + sharded_module.update_shards(changed_shard_to_params, env, device) + return sharded_module + @property def module_type(self) -> Type[EmbeddingBagCollection]: return EmbeddingBagCollection diff --git a/torchrec/distributed/sharding/dynamic_sharding.py b/torchrec/distributed/sharding/dynamic_sharding.py new file mode 100644 index 000000000..4e50c4f72 --- /dev/null +++ b/torchrec/distributed/sharding/dynamic_sharding.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, Callable, Dict, List, Tuple + +import torch +import torch.distributed as dist +from torch.distributed._shard.sharded_tensor import Shard +from torchrec.distributed.types import ( + ParameterSharding, + ShardedModule, + ShardedTensor, + ShardingEnv, +) + + +def shards_all_to_all( + module: ShardedModule[Any, Any, Any, Any], # pyre-ignore + state_dict: Dict[str, ShardedTensor], + device: torch.device, + changed_sharding_params: Dict[str, ParameterSharding], + env: ShardingEnv, + extend_shard_name: Callable[[str], str] = lambda x: x, +) -> Tuple[List[str], torch.Tensor]: + """ + Performs an all-to-all communication to redistribute shards across ranks based on new sharding parameters. + Assumes ranks are ordered in ParameterSharding.ranks. + + Args: + module (ShardedEmbeddingBagCollection): The module containing sharded tensors to be redistributed. + TODO: Update to support more modules + + state_dict (Dict[str, ShardedTensor]): The state dictionary containing the current sharded tensors. + + device (torch.device): The device on which the output tensors will be placed. + + changed_sharding_params (Dict[str, ParameterSharding]): A dictionary mapping shard names to their new sharding parameters. + + env (ShardingEnv): The sharding environment containing world size and other distributed information. + + extend_shard_name (Callable[[str], str], optional): A function to extend shard names to the full name in state_dict. + + Returns: + Tuple[List[str], torch.Tensor]: A tuple containing: + - A list of shard names that were sent from a specific rank to the current rank, ordered by rank, then shard order. + - The tensor containing all shards received by the current rank after the all-to-all operation. + """ + if env.output_dtensor: + raise RuntimeError("We do not yet support DTensor for resharding yet") + return + + # Module sharding plan is used to get the source ranks for each shard + assert hasattr(module, "module_sharding_plan") + + world_size = env.world_size + rank = dist.get_rank() + input_splits_per_rank = [[0] * world_size for _ in range(world_size)] + output_splits_per_rank = [[0] * world_size for _ in range(world_size)] + local_input_tensor = torch.empty([0], device=device) + local_output_tensor = torch.empty([0], device=device) + + shard_names_by_src_rank = [] + for shard_name, param in changed_sharding_params.items(): + sharded_t = state_dict[extend_shard_name(shard_name)] + assert param.ranks is not None + dst_ranks = param.ranks + state_dict[extend_shard_name(shard_name)] + # pyre-ignore + src_ranks = module.module_sharding_plan[shard_name].ranks + + # TODO: Implement changing rank sizes for beyond TW sharding + assert len(dst_ranks) == len(src_ranks) + + # index needed to distinguish between multiple shards + # within the same shardedTensor for each table + for i in range(len(src_ranks)): + dst_rank = dst_ranks[i] + src_rank = src_ranks[i] + + shard_size = sharded_t.metadata().shards_metadata[i].shard_sizes + shard_size_dim_0 = shard_size[0] + input_splits_per_rank[src_rank][dst_rank] += shard_size_dim_0 + output_splits_per_rank[dst_rank][src_rank] += shard_size_dim_0 + if src_rank == rank: + local_shards = sharded_t.local_shards() + assert len(local_shards) == 1 + local_input_tensor = torch.cat( + ( + local_input_tensor, + sharded_t.local_shards()[0].tensor, + ) + ) + if dst_rank == rank: + shard_names_by_src_rank.append(shard_name) + local_output_tensor = torch.cat( + (local_output_tensor, torch.empty(shard_size, device=device)) + ) + + local_input_splits = input_splits_per_rank[rank] + local_output_splits = output_splits_per_rank[rank] + + assert sum(local_output_splits) == len(local_output_tensor) + assert sum(local_input_splits) == len(local_input_tensor) + dist.all_to_all_single( + output=local_output_tensor, + input=local_input_tensor, + output_split_sizes=local_output_splits, + input_split_sizes=local_input_splits, + group=dist.group.WORLD, + ) + + return shard_names_by_src_rank, local_output_tensor + + +def update_state_dict_post_resharding( + state_dict: Dict[str, ShardedTensor], + shard_names_by_src_rank: List[str], + output_tensor: torch.Tensor, + new_sharding_params: Dict[str, ParameterSharding], + curr_rank: int, + extend_shard_name: Callable[[str], str] = lambda x: x, +) -> Dict[str, ShardedTensor]: + """ + Updates and returns the given state_dict with new placements and + local_shards based on the output tensor of the AllToAll collective. + + Args: + state_dict (Dict[str, Any]): The state dict to be updated with new shard placements and local shards. + + shard_names_by_src_rank (List[str]): A list of shard names that were sent from a specific rank to the + current rank, ordered by rank, then shard order. + + output_tensor (torch.Tensor): The tensor containing the output data from the AllToAll operation. + + new_sharding_params (Dict[str, ParameterSharding]): A dictionary mapping shard names to their new sharding parameters. + This should only contain shard names that were updated during the AllToAll operation. + + curr_rank (int): The current rank of the process in the distributed environment. + + extend_shard_name (Callable[[str], str], optional): A function to extend shard names to the full name in state_dict. + + Returns: + Dict[str, ShardedTensor]: The updated state dictionary with new shard placements and local shards. + """ + slice_index = 0 + shard_names_by_src_rank + + shard_name_to_local_output_tensor: Dict[str, torch.Tensor] = {} + + for shard_name in shard_names_by_src_rank: + shard_size = state_dict[extend_shard_name(shard_name)].size(0) + end_slice_index = slice_index + shard_size + shard_name_to_local_output_tensor[shard_name] = output_tensor[ + slice_index:end_slice_index + ] + slice_index = end_slice_index + + for shard_name, param in new_sharding_params.items(): + extended_name = extend_shard_name(shard_name) + # pyre-ignore + for i in range(len(param.ranks)): + # pyre-ignore + r = param.ranks[i] + sharded_t = state_dict[extended_name] + # Update placements + sharded_t.metadata().shards_metadata[i].placement = ( + torch.distributed._remote_device(f"rank:{r}/cuda:{r}") + ) + if r == curr_rank: + assert len(output_tensor) > 0 + # slice output tensor for correct size. + sharded_t._local_shards = [ + Shard( + tensor=shard_name_to_local_output_tensor[shard_name], + metadata=state_dict[extended_name] + .metadata() + .shards_metadata[i], + ) + ] + break + else: + sharded_t._local_shards = [] + + return state_dict diff --git a/torchrec/distributed/tests/test_dynamic_sharding.py b/torchrec/distributed/tests/test_dynamic_sharding.py new file mode 100644 index 000000000..ccc46cc94 --- /dev/null +++ b/torchrec/distributed/tests/test_dynamic_sharding.py @@ -0,0 +1,437 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + + +import copy + +import random +import unittest + +from typing import Any, Dict, List, Optional, Union + +import hypothesis.strategies as st + +import torch + +from hypothesis import given, settings, Verbosity +from torch import nn + +from torchrec import distributed as trec_dist, EmbeddingBagCollection, KeyedJaggedTensor +from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection + +from torchrec.distributed.sharding_plan import ( + construct_module_sharding_plan, + get_module_to_default_sharders, + table_wise, +) + +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + MultiProcessTestBase, +) +from torchrec.distributed.test_utils.test_sharding import copy_state_dict + +from torchrec.distributed.types import ( + EmbeddingModuleShardingPlan, + ParameterSharding, + ShardingEnv, + ShardingType, +) +from torchrec.modules.embedding_configs import data_type_to_dtype, EmbeddingBagConfig + +from torchrec.test_utils import skip_if_asan_class +from torchrec.types import DataType + + +# Utils: +def table_name(i: int) -> str: + return "table_" + str(i) + + +def feature_name(i: int) -> str: + return "feature_" + str(i) + + +def generate_input_by_world_size( + world_size: int, + num_tables: int, + num_embeddings: int = 4, + max_mul: int = 3, +) -> List[KeyedJaggedTensor]: + # TODO merge with new ModelInput generator in TestUtils + kjt_input_per_rank = [] + mul = random.randint(1, max_mul) + total_size = num_tables * mul + + for _ in range(world_size): + feature_names = [feature_name(i) for i in range(num_tables)] + lengths = [] + values = [] + counting_l = 0 + for i in range(total_size): + if i == total_size - 1: + lengths.append(total_size - counting_l) + break + next_l = random.randint(0, total_size - counting_l) + values.extend( + [random.randint(0, num_embeddings - 1) for _ in range(next_l)] + ) + lengths.append(next_l) + counting_l += next_l + + # for length in lengths: + + kjt_input_per_rank.append( + KeyedJaggedTensor.from_lengths_sync( + keys=feature_names, + values=torch.LongTensor(values), + lengths=torch.LongTensor(lengths), + ) + ) + + return kjt_input_per_rank + + +def generate_embedding_bag_config( + data_type: DataType, + num_tables: int = 3, + embedding_dim: int = 16, + num_embeddings: int = 4, +) -> List[EmbeddingBagConfig]: + embedding_bag_config = [] + for i in range(num_tables): + embedding_bag_config.append( + EmbeddingBagConfig( + name=table_name(i), + feature_names=[feature_name(i)], + embedding_dim=embedding_dim, + num_embeddings=num_embeddings, + data_type=data_type, + ), + ) + return embedding_bag_config + + +def create_test_initial_state_dict( + sharded_module_type: nn.Module, + num_tables: int, + data_type: DataType, + embedding_dim: int = 16, + num_embeddings: int = 4, +) -> Dict[str, torch.Tensor]: + """ + Helpful for debugging: + + initial_state_dict = { + "embedding_bags.table_0.weight": torch.tensor( + [ + [1] * 16, + [2] * 16, + [3] * 16, + [4] * 16, + ], + ), + "embedding_bags.table_1.weight": torch.tensor( + [ + [101] * 16, + [102] * 16, + [103] * 16, + [104] * 16, + ], + dtype=data_type_to_dtype(data_type), + ), + ... + } + """ + + initial_state_dict = {} + for i in range(num_tables): + # pyre-ignore + extended_name = sharded_module_type.extend_shard_name(table_name(i)) + initial_state_dict[extended_name] = torch.tensor( + [[j + (i * 100)] * embedding_dim for j in range(num_embeddings)], + dtype=data_type_to_dtype(data_type), + ) + + return initial_state_dict + + +def are_modules_identical( + module1: Union[EmbeddingBagCollection, ShardedEmbeddingBagCollection], + module2: Union[EmbeddingBagCollection, ShardedEmbeddingBagCollection], +) -> None: + # Check if both modules have the same type + assert type(module1) is type(module2) + + # Check if both modules have the same parameters + params1 = list(module1.named_parameters()) + params2 = list(module2.named_parameters()) + + assert len(params1) == len(params2) + + for param1, param2 in zip(params1, params2): + # Check parameter names + assert param1[0] == param2[0] + # Check parameter values + assert torch.allclose(param1[1], param2[1]) + + # Check if both modules have the same buffers + buffers1 = list(module1.named_buffers()) + buffers2 = list(module2.named_buffers()) + + assert len(buffers1) == len(buffers2) + + for buffer1, buffer2 in zip(buffers1, buffers2): + assert buffer1[0] == buffer2[0] # Check buffer names + assert torch.allclose(buffer1[1], buffer2[1]) # Check buffer values + + +def output_sharding_plan_delta( + old_plan: EmbeddingModuleShardingPlan, new_plan: EmbeddingModuleShardingPlan +) -> EmbeddingModuleShardingPlan: + assert len(old_plan) == len(new_plan) + return_plan = copy.deepcopy(new_plan) + for shard_name, old_param in old_plan.items(): + if shard_name not in return_plan: + raise ValueError(f"Shard {shard_name} not found in new plan") + new_param = return_plan[shard_name] + old_ranks = old_param.ranks + new_ranks = new_param.ranks + if old_ranks == new_ranks: + del return_plan[shard_name] + + return return_plan + + +def _test_ebc_resharding( + tables: List[EmbeddingBagConfig], + initial_state_dict: Dict[str, Any], + rank: int, + world_size: int, + kjt_input_per_rank: List[KeyedJaggedTensor], + backend: str, + module_sharding_plan: EmbeddingModuleShardingPlan, + new_module_sharding_plan: EmbeddingModuleShardingPlan, + local_size: Optional[int] = None, +) -> None: + """ + Distributed call to test resharding for ebc by creating 2 models with identical config and + states: + m1 sharded with new_module_sharding_plan + m2 sharded with module_sharding_plan, then resharded with new_module_sharding_plan + + Expects m1 and resharded m2 to be the same, and predictions outputted from the same KJT + inputs to be the same. + + TODO: modify to include other modules once dynamic sharding is built out. + """ + trec_dist.comm_ops.set_gradient_division(False) + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + kjt_input_per_rank = [kjt.to(ctx.device) for kjt in kjt_input_per_rank] + + initial_state_dict = { + fqn: tensor.to(ctx.device) for fqn, tensor in initial_state_dict.items() + } + m1 = EmbeddingBagCollection( + tables=tables, + device=ctx.device, + ) + + m2 = EmbeddingBagCollection( + tables=tables, + device=ctx.device, + ) + + # Load initial State - making sure models are identical + m1.load_state_dict(initial_state_dict) + copy_state_dict( + loc=m1.state_dict(), + glob=copy.deepcopy(initial_state_dict), + ) + + m2.load_state_dict(initial_state_dict) + copy_state_dict( + loc=m2.state_dict(), + glob=copy.deepcopy(initial_state_dict), + ) + + sharder = get_module_to_default_sharders()[type(m1)] + + # pyre-ignore + env = ShardingEnv.from_process_group(ctx.pg) + + sharded_m1 = sharder.shard( + module=m1, + params=new_module_sharding_plan, + env=env, + device=ctx.device, + ) + + sharded_m2 = sharder.shard( + module=m1, + params=module_sharding_plan, + env=env, + device=ctx.device, + ) + + new_module_sharding_plan_delta = output_sharding_plan_delta( + module_sharding_plan, new_module_sharding_plan + ) + + # pyre-ignore + resharded_m2 = sharder.reshard( + sharded_module=sharded_m2, + changed_shard_to_params=new_module_sharding_plan_delta, + env=env, + device=ctx.device, + ) + + are_modules_identical(sharded_m1, resharded_m2) + + feature_keys = [] + for table in tables: + feature_keys.extend(table.feature_names) + + # For current test model and inputs, the prediction should be the exact same + rtol = 0 + atol = 0 + + for _ in range(world_size): + # sharded model + # each rank gets a subbatch + sharded_m1_pred_kt_no_dict = sharded_m1(kjt_input_per_rank[ctx.rank]) + resharded_m2_pred_kt_no_dict = resharded_m2(kjt_input_per_rank[ctx.rank]) + + sharded_m1_pred_kt = sharded_m1_pred_kt_no_dict.to_dict() + resharded_m2_pred_kt = resharded_m2_pred_kt_no_dict.to_dict() + sharded_m1_pred = torch.stack( + [sharded_m1_pred_kt[feature] for feature in feature_keys] + ) + + resharded_m2_pred = torch.stack( + [resharded_m2_pred_kt[feature] for feature in feature_keys] + ) + # cast to CPU because when casting unsharded_model.to on the same module, there could some race conditions + # in normal author modelling code this won't be an issue because each rank would individually create + # their model. output from sharded_pred is correctly on the correct device. + + # Compare predictions of sharded vs unsharded models. + torch.testing.assert_close( + sharded_m1_pred.cpu(), resharded_m2_pred.cpu(), rtol=rtol, atol=atol + ) + + sharded_m1_pred.sum().backward() + resharded_m2_pred.sum().backward() + + +@skip_if_asan_class +class MultiRankDynamicShardingTest(MultiProcessTestBase): + def _run_ebc_resharding_test( + self, + per_param_sharding: Dict[str, ParameterSharding], + new_per_param_sharding: Dict[str, ParameterSharding], + num_tables: int, + world_size: int, + data_type: DataType, + embedding_dim: int = 16, + num_embeddings: int = 4, + ) -> None: + embedding_bag_config = generate_embedding_bag_config( + data_type, num_tables, embedding_dim, num_embeddings + ) + + module_sharding_plan = construct_module_sharding_plan( + EmbeddingBagCollection(tables=embedding_bag_config), + # pyre-ignore + per_param_sharding=per_param_sharding, + local_size=world_size, + world_size=world_size, + device_type="cuda" if torch.cuda.is_available() else "cpu", + ) + + new_module_sharding_plan = construct_module_sharding_plan( + EmbeddingBagCollection(tables=embedding_bag_config), + # pyre-ignore + per_param_sharding=new_per_param_sharding, + local_size=world_size, + world_size=world_size, + device_type="cuda" if torch.cuda.is_available() else "cpu", + ) + + # Row-wise not supported on gloo + if ( + not torch.cuda.is_available() + and new_module_sharding_plan["table_0"].sharding_type + == ShardingType.ROW_WISE.value + ): + return + + kjt_input_per_rank = generate_input_by_world_size( + world_size, num_tables, num_embeddings + ) + + # initial_state_dict filled with deterministic dummy values + initial_state_dict = create_test_initial_state_dict( + ShardedEmbeddingBagCollection, # pyre-ignore + num_tables, + data_type, + embedding_dim, + num_embeddings, + ) + + self._run_multi_process_test( + callable=_test_ebc_resharding, + world_size=world_size, + tables=embedding_bag_config, + initial_state_dict=initial_state_dict, + kjt_input_per_rank=kjt_input_per_rank, + backend="nccl" if torch.cuda.is_available() else "gloo", + module_sharding_plan=module_sharding_plan, + new_module_sharding_plan=new_module_sharding_plan, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + @given( # pyre-ignore + num_tables=st.sampled_from([2, 3, 4]), + data_type=st.sampled_from([DataType.FP32, DataType.FP16]), + world_size=st.sampled_from([2, 4]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) + def test_dynamic_sharding_ebc_tw( + self, + num_tables: int, + data_type: DataType, + world_size: int, + ) -> None: + # Tests EBC dynamic sharding implementation for TW + + # Cannot include old/new rank generation with hypothesis library due to depedency on world_size + old_ranks = [random.randint(0, world_size - 1) for _ in range(num_tables)] + new_ranks = [random.randint(0, world_size - 1) for _ in range(num_tables)] + + if new_ranks == old_ranks: + return + per_param_sharding = {} + new_per_param_sharding = {} + + # Construct parameter shardings + for i in range(num_tables): + per_param_sharding[table_name(i)] = table_wise(rank=old_ranks[i]) + new_per_param_sharding[table_name(i)] = table_wise(rank=new_ranks[i]) + + self._run_ebc_resharding_test( + per_param_sharding, + new_per_param_sharding, + num_tables, + world_size, + data_type, + )