diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index e3177b953..12cc92d3e 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -27,9 +27,6 @@ import torch from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings -from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( - DenseTableBatchedEmbeddingBagsCodegen, -) from tensordict import TensorDict from torch import distributed as dist, nn, Tensor from torch.autograd.profiler import record_function @@ -61,6 +58,7 @@ get_largest_dims_from_sharding_plan_updates, shards_all_to_all, update_module_sharding_plan, + update_optimizer_state_post_resharding, update_state_dict_post_resharding, ) from torchrec.distributed.sharding.grid_sharding import GridPooledEmbeddingSharding @@ -1535,7 +1533,7 @@ def update_shards( return current_state = self.state_dict() - # TODO: Save Optimizers + has_optimizer = len(self._optim._optims) > 0 # TODO: Saving lookups tensors to CPU to eventually avoid recreating them completely again # TODO: Ensure lookup tensors are actually being deleted @@ -1550,6 +1548,7 @@ def update_shards( max_dim_0, max_dim_1 = get_largest_dims_from_sharding_plan_updates( changed_sharding_params ) + old_optimizer_state = self._optim.state_dict() if has_optimizer else None local_shard_names_by_src_rank, local_output_tensor = shards_all_to_all( module=self, @@ -1560,16 +1559,7 @@ def update_shards( extend_shard_name=self.extend_shard_name, max_dim_0=max_dim_0, max_dim_1=max_dim_1, - ) - - current_state = update_state_dict_post_resharding( - state_dict=current_state, - ordered_shard_names_and_lengths=local_shard_names_by_src_rank, - output_tensor=local_output_tensor, - new_sharding_params=changed_sharding_params, - curr_rank=dist.get_rank(), - extend_shard_name=self.extend_shard_name, - max_dim_0=max_dim_0, + optimizer_state=old_optimizer_state, ) for name, param in changed_sharding_params.items(): @@ -1615,8 +1605,6 @@ def update_shards( if env.process_group and dist.get_backend(env.process_group) != "fake": self._initialize_torch_state(skip_registering=True) - self.load_state_dict(current_state) - # update optimizer optims = [] for lookup in self._lookups: @@ -1635,6 +1623,35 @@ def update_shards( self._optim: CombinedOptimizer = CombinedOptimizer(optims) + if has_optimizer: + split_index = len(local_output_tensor) // 2 + local_weight_tensors = local_output_tensor[:split_index] + local_optimizer_tensors = local_output_tensor[split_index:] + # Modifies new_opt_state in place and returns it + optimizer_state = update_optimizer_state_post_resharding( + old_opt_state=old_optimizer_state, # pyre-ignore + new_opt_state=copy.deepcopy(self._optim.state_dict()), + ordered_shard_names_and_lengths=local_shard_names_by_src_rank, + output_tensor=local_optimizer_tensors, + max_dim_0=max_dim_0, + ) + + self._optim.load_state_dict(optimizer_state) + else: + local_weight_tensors = local_output_tensor + + current_state = update_state_dict_post_resharding( + state_dict=current_state, + ordered_shard_names_and_lengths=local_shard_names_by_src_rank, + output_tensor=local_weight_tensors, + new_sharding_params=changed_sharding_params, + curr_rank=dist.get_rank(), + extend_shard_name=self.extend_shard_name, + max_dim_0=max_dim_0, + ) + + self.load_state_dict(current_state) + update_module_sharding_plan(self, changed_sharding_params) return diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py index c048a591b..6c7189c6f 100644 --- a/torchrec/distributed/model_parallel.py +++ b/torchrec/distributed/model_parallel.py @@ -687,7 +687,10 @@ def reshard( self.device, ) - self._optim: CombinedOptimizer = self._init_optim(self._dmp_wrapped_module) + # Need to use .module to maintain FQN consistency + self._optim: CombinedOptimizer = self._init_optim( + self._dmp_wrapped_module.module # pyre-ignore + ) self._plan.plan[sharded_module_fqn] = sharded_module.module_sharding_plan return sharded_module diff --git a/torchrec/distributed/sharding/dynamic_sharding.py b/torchrec/distributed/sharding/dynamic_sharding.py index dac6204d7..c517d7df7 100644 --- a/torchrec/distributed/sharding/dynamic_sharding.py +++ b/torchrec/distributed/sharding/dynamic_sharding.py @@ -8,7 +8,7 @@ # pyre-strict import copy -from typing import Any, Callable, Dict, List, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import torch.distributed as dist @@ -84,6 +84,7 @@ def shards_all_to_all( max_dim_0: int, max_dim_1: int, extend_shard_name: Callable[[str], str] = lambda x: x, + optimizer_state: Optional[Dict[str, Dict[str, Dict[str, ShardedTensor]]]] = None, ) -> Tuple[OrderedShardNamesWithSizes, torch.Tensor]: """ Performs an all-to-all communication to redistribute shards across ranks based on new sharding parameters. @@ -121,14 +122,18 @@ def shards_all_to_all( # Module sharding plan is used to get the source ranks for each shard assert hasattr(module, "module_sharding_plan") + has_optimizer = optimizer_state is not None + world_size = env.world_size rank = dist.get_rank() input_splits_per_rank = [[0] * world_size for _ in range(world_size)] output_splits_per_rank = [[0] * world_size for _ in range(world_size)] output_tensor_tensor_count = 0 + output_optimizer_tensor_count = 0 shard_names_to_lengths_by_src_rank = [[] for _ in range(world_size)] local_table_to_input_tensor_by_dst_rank = [[] for _ in range(world_size)] + local_table_to_opt_by_dst_rank = [[] for _ in range(world_size)] for shard_name, param in changed_sharding_params.items(): sharded_t = state_dict[extend_shard_name(shard_name)] assert param.ranks is not None @@ -142,24 +147,47 @@ def shards_all_to_all( # index needed to distinguish between multiple shards # within the same shardedTensor for each table for i in range(len(src_ranks)): + + # 1 to 1 mapping from src to dst dst_rank = dst_ranks[i] src_rank = src_ranks[i] shard_size = sharded_t.metadata().shards_metadata[i].shard_sizes input_splits_per_rank[src_rank][dst_rank] += max_dim_0 output_splits_per_rank[dst_rank][src_rank] += max_dim_0 + if has_optimizer: + input_splits_per_rank[src_rank][dst_rank] += max_dim_0 + output_splits_per_rank[dst_rank][src_rank] += max_dim_0 + + # If sending from current rank if src_rank == rank: + if has_optimizer: + # pyre-ignore + local_optimizer = optimizer_state["state"][ + extend_shard_name(shard_name) + ][tmp_momentum_extender(shard_name)].local_shards() + assert len(local_optimizer) == 1 + padded_local_optimizer = pad_tensor_to_max_dims( + local_optimizer[0].tensor, max_dim_0, max_dim_1 + ) + local_table_to_opt_by_dst_rank[dst_rank].append( + padded_local_optimizer + ) local_shards = sharded_t.local_shards() assert len(local_shards) == 1 cur_t = pad_tensor_to_max_dims( - sharded_t.local_shards()[0].tensor, max_dim_0, max_dim_1 + local_shards[0].tensor, max_dim_0, max_dim_1 ) local_table_to_input_tensor_by_dst_rank[dst_rank].append(cur_t) + + # If recieving from current rank if dst_rank == rank: shard_names_to_lengths_by_src_rank[src_rank].append( (shard_name, shard_size) ) output_tensor_tensor_count += max_dim_0 + if has_optimizer: + output_optimizer_tensor_count += max_dim_0 local_input_splits = input_splits_per_rank[rank] local_output_splits = output_splits_per_rank[rank] @@ -175,9 +203,23 @@ def shards_all_to_all( dim=0, ) + for sub_l in local_table_to_opt_by_dst_rank: + for shard_info in sub_l: + local_input_tensor = torch.cat( + ( + local_input_tensor, + shard_info, + ), + dim=0, + ) + max_embedding_size = max_dim_1 local_output_tensor = torch.empty( - [output_tensor_tensor_count, max_embedding_size], device=device + [ + output_tensor_tensor_count + output_optimizer_tensor_count, + max_embedding_size, + ], + device=device, ) assert sum(local_output_splits) == len(local_output_tensor) @@ -277,6 +319,50 @@ def update_state_dict_post_resharding( return state_dict +def update_optimizer_state_post_resharding( + old_opt_state: Dict[str, Dict[str, Dict[str, ShardedTensor]]], + new_opt_state: Dict[str, Dict[str, Dict[str, ShardedTensor]]], + ordered_shard_names_and_lengths: OrderedShardNamesWithSizes, + output_tensor: torch.Tensor, + max_dim_0: int, +) -> Dict[str, Dict[str, Dict[str, ShardedTensor]]]: + new_opt_state_state = new_opt_state["state"] + old_opt_state_state = old_opt_state["state"] + + # Remove padding and store tensors by shard name + slice_index = 0 + shard_name_to_local_output_tensor: Dict[str, torch.Tensor] = {} + for shard_name, shard_size in ordered_shard_names_and_lengths: + end_slice_index = slice_index + max_dim_0 + cur_t = output_tensor[slice_index:end_slice_index] + cur_t = pad_tensor_to_max_dims( + cur_t, shard_size[0], shard_size[1], remove_padding=True + ) + shard_name_to_local_output_tensor[shard_name] = cur_t + slice_index = end_slice_index + + for extended_shard_name, item in new_opt_state_state.items(): + if extended_shard_name in old_opt_state_state: + new_opt_state_state[extended_shard_name] = old_opt_state_state[ + extended_shard_name + ] + else: + shard_name = extract_shard_name(extended_shard_name) + momentum_name = tmp_momentum_extender(shard_name) + sharded_t = item[momentum_name] + assert len(sharded_t._local_shards) == 1 + # TODO: support multiple shards in CW sharding + sharded_t._local_shards = [ + Shard( + tensor=shard_name_to_local_output_tensor[shard_name], + metadata=shard.metadata, + ) + for shard in sharded_t._local_shards + ] + + return new_opt_state + + def update_module_sharding_plan( module: ShardedModule[Any, Any, Any, Any], # pyre-ignore changed_sharding_params: Dict[str, ParameterSharding], @@ -388,3 +474,16 @@ def output_sharding_plan_delta( if v.ranks != old_plan[k].ranks } ) + + +""" +Utils for Optimizer State accessing +""" + + +def tmp_momentum_extender(name: str) -> str: + return name + ".momentum1" + + +def extract_shard_name(name: str) -> str: + return name.split(".")[-2] diff --git a/torchrec/distributed/test_utils/test_sharding.py b/torchrec/distributed/test_utils/test_sharding.py index 3a8937e23..8793a2153 100644 --- a/torchrec/distributed/test_utils/test_sharding.py +++ b/torchrec/distributed/test_utils/test_sharding.py @@ -541,6 +541,8 @@ def dynamic_sharding_test( ) local_m1_dmp.reshard("sparse.ebc", new_module_sharding_plan_delta) + # Must recreate local_m1_opt, because current local_m1_opt is a copy of underlying fused_opt + local_m1_opt = CombinedOptimizer([local_m1_dmp.fused_optimizer, dense_m1_optim]) local_m1_pred = gen_full_pred_after_one_step( local_m1_dmp, local_m1_opt, local_input_1 @@ -954,7 +956,12 @@ def gen_full_pred_after_one_step( opt: torch.optim.Optimizer, input: ModelInput, skip_inference: bool = False, + skip_training: bool = False, ) -> torch.Tensor: + if skip_training: + model.train(False) + output = model(input) + return output # Run a single training step of the global model. opt.zero_grad() model.train(True) @@ -1120,3 +1127,32 @@ def generate_rank_placements( placement = sorted(random.sample(range(world_size), ranks_per_table)) placements.append(placement) return placements + + +def compare_opt_local_t( + opt_1: CombinedOptimizer, + opt_2: CombinedOptimizer, + table_id: int, + rtol: float = 1e-4, + atol: float = 1e-4, +) -> None: + """ + Helper function to compare the optimizer state of two models after one training step. + Useful for debugging sharding tests to see which model weights are different + """ + # TODO: update logic to be generic other embedding modules + t1 = ( + opt_1.state_dict()["state"][ + "sparse.ebc.embedding_bags.table_" + str(table_id) + ".weight" + ]["table_" + str(table_id) + ".momentum1"] + .local_shards()[0] + .tensor + ) + t2 = ( + opt_2.state_dict()["state"][ + "sparse.ebc.embedding_bags.table_" + str(table_id) + ".weight" + ]["table_" + str(table_id) + ".momentum1"] + .local_shards()[0] + .tensor + ) + torch.testing.assert_close(t1, t2, rtol=rtol, atol=atol) diff --git a/torchrec/distributed/tests/test_dynamic_sharding.py b/torchrec/distributed/tests/test_dynamic_sharding.py index db55a8fa2..c2a939140 100644 --- a/torchrec/distributed/tests/test_dynamic_sharding.py +++ b/torchrec/distributed/tests/test_dynamic_sharding.py @@ -8,8 +8,6 @@ # pyre-strict -import copy - import random import unittest @@ -21,7 +19,7 @@ from hypothesis import assume, given, settings, Verbosity -from torch import nn +from torch import nn, optim from torchrec import distributed as trec_dist, EmbeddingBagCollection, KeyedJaggedTensor from torchrec.distributed.embedding_types import EmbeddingComputeKernel @@ -530,9 +528,11 @@ class MultiRankDMPDynamicShardingTest(ModelParallelTestShared): apply_optimizer_in_backward_config=st.sampled_from( [ None, + { + "embedding_bags": (optim.Adagrad, {"lr": 0.04}), + }, { "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), - "embeddings": (torch.optim.SGD, {"lr": 0.2}), }, ] ),