From f706b80512628d37b61033f9a0a62ce2a2b0527a Mon Sep 17 00:00:00 2001 From: Felicity Liao <11263993+aporialiao@users.noreply.github.com> Date: Fri, 6 Jun 2025 15:59:17 -0700 Subject: [PATCH] Enable proper optimizer state storing + Test between batches (#3053) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/3053 # Main Changes 1. Enable unit test with an adaptive optimizer `Adagrad` 1. Previously I tested the optimizer state with an optimizer `SGD` that is static throughout training so didn't actually test if we stored opt state, instead here I used the `Adagrad` which exposed the previous implementation did not properly store optimziers. 2. Properly store optimizer state in `update_optimizer_state` 2. Append optimizer tensors as inputs to the all2all call, then parse through the output tensors to store the right tensors. 2. Optimizer tensors that did not need to be sent to a new rank are persisted and resaved. 2. After new lookups are created, use `load_state_dict` to load in the saved optimizer state to the current optimizers. 3. Helpers & other small changes 3. Helper to compare optimizer tensors for unit tests 3. Update `DMP` reshard - optimizer saving to match the same fqn Reviewed By: aliafzal Differential Revision: D75565054 --- torchrec/distributed/embeddingbag.py | 49 +++++--- torchrec/distributed/model_parallel.py | 5 +- .../distributed/sharding/dynamic_sharding.py | 105 +++++++++++++++++- .../distributed/test_utils/test_sharding.py | 36 ++++++ .../tests/test_dynamic_sharding.py | 8 +- 5 files changed, 179 insertions(+), 24 deletions(-) diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index e3177b953..12cc92d3e 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -27,9 +27,6 @@ import torch from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings -from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( - DenseTableBatchedEmbeddingBagsCodegen, -) from tensordict import TensorDict from torch import distributed as dist, nn, Tensor from torch.autograd.profiler import record_function @@ -61,6 +58,7 @@ get_largest_dims_from_sharding_plan_updates, shards_all_to_all, update_module_sharding_plan, + update_optimizer_state_post_resharding, update_state_dict_post_resharding, ) from torchrec.distributed.sharding.grid_sharding import GridPooledEmbeddingSharding @@ -1535,7 +1533,7 @@ def update_shards( return current_state = self.state_dict() - # TODO: Save Optimizers + has_optimizer = len(self._optim._optims) > 0 # TODO: Saving lookups tensors to CPU to eventually avoid recreating them completely again # TODO: Ensure lookup tensors are actually being deleted @@ -1550,6 +1548,7 @@ def update_shards( max_dim_0, max_dim_1 = get_largest_dims_from_sharding_plan_updates( changed_sharding_params ) + old_optimizer_state = self._optim.state_dict() if has_optimizer else None local_shard_names_by_src_rank, local_output_tensor = shards_all_to_all( module=self, @@ -1560,16 +1559,7 @@ def update_shards( extend_shard_name=self.extend_shard_name, max_dim_0=max_dim_0, max_dim_1=max_dim_1, - ) - - current_state = update_state_dict_post_resharding( - state_dict=current_state, - ordered_shard_names_and_lengths=local_shard_names_by_src_rank, - output_tensor=local_output_tensor, - new_sharding_params=changed_sharding_params, - curr_rank=dist.get_rank(), - extend_shard_name=self.extend_shard_name, - max_dim_0=max_dim_0, + optimizer_state=old_optimizer_state, ) for name, param in changed_sharding_params.items(): @@ -1615,8 +1605,6 @@ def update_shards( if env.process_group and dist.get_backend(env.process_group) != "fake": self._initialize_torch_state(skip_registering=True) - self.load_state_dict(current_state) - # update optimizer optims = [] for lookup in self._lookups: @@ -1635,6 +1623,35 @@ def update_shards( self._optim: CombinedOptimizer = CombinedOptimizer(optims) + if has_optimizer: + split_index = len(local_output_tensor) // 2 + local_weight_tensors = local_output_tensor[:split_index] + local_optimizer_tensors = local_output_tensor[split_index:] + # Modifies new_opt_state in place and returns it + optimizer_state = update_optimizer_state_post_resharding( + old_opt_state=old_optimizer_state, # pyre-ignore + new_opt_state=copy.deepcopy(self._optim.state_dict()), + ordered_shard_names_and_lengths=local_shard_names_by_src_rank, + output_tensor=local_optimizer_tensors, + max_dim_0=max_dim_0, + ) + + self._optim.load_state_dict(optimizer_state) + else: + local_weight_tensors = local_output_tensor + + current_state = update_state_dict_post_resharding( + state_dict=current_state, + ordered_shard_names_and_lengths=local_shard_names_by_src_rank, + output_tensor=local_weight_tensors, + new_sharding_params=changed_sharding_params, + curr_rank=dist.get_rank(), + extend_shard_name=self.extend_shard_name, + max_dim_0=max_dim_0, + ) + + self.load_state_dict(current_state) + update_module_sharding_plan(self, changed_sharding_params) return diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py index c048a591b..6c7189c6f 100644 --- a/torchrec/distributed/model_parallel.py +++ b/torchrec/distributed/model_parallel.py @@ -687,7 +687,10 @@ def reshard( self.device, ) - self._optim: CombinedOptimizer = self._init_optim(self._dmp_wrapped_module) + # Need to use .module to maintain FQN consistency + self._optim: CombinedOptimizer = self._init_optim( + self._dmp_wrapped_module.module # pyre-ignore + ) self._plan.plan[sharded_module_fqn] = sharded_module.module_sharding_plan return sharded_module diff --git a/torchrec/distributed/sharding/dynamic_sharding.py b/torchrec/distributed/sharding/dynamic_sharding.py index dac6204d7..c517d7df7 100644 --- a/torchrec/distributed/sharding/dynamic_sharding.py +++ b/torchrec/distributed/sharding/dynamic_sharding.py @@ -8,7 +8,7 @@ # pyre-strict import copy -from typing import Any, Callable, Dict, List, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import torch.distributed as dist @@ -84,6 +84,7 @@ def shards_all_to_all( max_dim_0: int, max_dim_1: int, extend_shard_name: Callable[[str], str] = lambda x: x, + optimizer_state: Optional[Dict[str, Dict[str, Dict[str, ShardedTensor]]]] = None, ) -> Tuple[OrderedShardNamesWithSizes, torch.Tensor]: """ Performs an all-to-all communication to redistribute shards across ranks based on new sharding parameters. @@ -121,14 +122,18 @@ def shards_all_to_all( # Module sharding plan is used to get the source ranks for each shard assert hasattr(module, "module_sharding_plan") + has_optimizer = optimizer_state is not None + world_size = env.world_size rank = dist.get_rank() input_splits_per_rank = [[0] * world_size for _ in range(world_size)] output_splits_per_rank = [[0] * world_size for _ in range(world_size)] output_tensor_tensor_count = 0 + output_optimizer_tensor_count = 0 shard_names_to_lengths_by_src_rank = [[] for _ in range(world_size)] local_table_to_input_tensor_by_dst_rank = [[] for _ in range(world_size)] + local_table_to_opt_by_dst_rank = [[] for _ in range(world_size)] for shard_name, param in changed_sharding_params.items(): sharded_t = state_dict[extend_shard_name(shard_name)] assert param.ranks is not None @@ -142,24 +147,47 @@ def shards_all_to_all( # index needed to distinguish between multiple shards # within the same shardedTensor for each table for i in range(len(src_ranks)): + + # 1 to 1 mapping from src to dst dst_rank = dst_ranks[i] src_rank = src_ranks[i] shard_size = sharded_t.metadata().shards_metadata[i].shard_sizes input_splits_per_rank[src_rank][dst_rank] += max_dim_0 output_splits_per_rank[dst_rank][src_rank] += max_dim_0 + if has_optimizer: + input_splits_per_rank[src_rank][dst_rank] += max_dim_0 + output_splits_per_rank[dst_rank][src_rank] += max_dim_0 + + # If sending from current rank if src_rank == rank: + if has_optimizer: + # pyre-ignore + local_optimizer = optimizer_state["state"][ + extend_shard_name(shard_name) + ][tmp_momentum_extender(shard_name)].local_shards() + assert len(local_optimizer) == 1 + padded_local_optimizer = pad_tensor_to_max_dims( + local_optimizer[0].tensor, max_dim_0, max_dim_1 + ) + local_table_to_opt_by_dst_rank[dst_rank].append( + padded_local_optimizer + ) local_shards = sharded_t.local_shards() assert len(local_shards) == 1 cur_t = pad_tensor_to_max_dims( - sharded_t.local_shards()[0].tensor, max_dim_0, max_dim_1 + local_shards[0].tensor, max_dim_0, max_dim_1 ) local_table_to_input_tensor_by_dst_rank[dst_rank].append(cur_t) + + # If recieving from current rank if dst_rank == rank: shard_names_to_lengths_by_src_rank[src_rank].append( (shard_name, shard_size) ) output_tensor_tensor_count += max_dim_0 + if has_optimizer: + output_optimizer_tensor_count += max_dim_0 local_input_splits = input_splits_per_rank[rank] local_output_splits = output_splits_per_rank[rank] @@ -175,9 +203,23 @@ def shards_all_to_all( dim=0, ) + for sub_l in local_table_to_opt_by_dst_rank: + for shard_info in sub_l: + local_input_tensor = torch.cat( + ( + local_input_tensor, + shard_info, + ), + dim=0, + ) + max_embedding_size = max_dim_1 local_output_tensor = torch.empty( - [output_tensor_tensor_count, max_embedding_size], device=device + [ + output_tensor_tensor_count + output_optimizer_tensor_count, + max_embedding_size, + ], + device=device, ) assert sum(local_output_splits) == len(local_output_tensor) @@ -277,6 +319,50 @@ def update_state_dict_post_resharding( return state_dict +def update_optimizer_state_post_resharding( + old_opt_state: Dict[str, Dict[str, Dict[str, ShardedTensor]]], + new_opt_state: Dict[str, Dict[str, Dict[str, ShardedTensor]]], + ordered_shard_names_and_lengths: OrderedShardNamesWithSizes, + output_tensor: torch.Tensor, + max_dim_0: int, +) -> Dict[str, Dict[str, Dict[str, ShardedTensor]]]: + new_opt_state_state = new_opt_state["state"] + old_opt_state_state = old_opt_state["state"] + + # Remove padding and store tensors by shard name + slice_index = 0 + shard_name_to_local_output_tensor: Dict[str, torch.Tensor] = {} + for shard_name, shard_size in ordered_shard_names_and_lengths: + end_slice_index = slice_index + max_dim_0 + cur_t = output_tensor[slice_index:end_slice_index] + cur_t = pad_tensor_to_max_dims( + cur_t, shard_size[0], shard_size[1], remove_padding=True + ) + shard_name_to_local_output_tensor[shard_name] = cur_t + slice_index = end_slice_index + + for extended_shard_name, item in new_opt_state_state.items(): + if extended_shard_name in old_opt_state_state: + new_opt_state_state[extended_shard_name] = old_opt_state_state[ + extended_shard_name + ] + else: + shard_name = extract_shard_name(extended_shard_name) + momentum_name = tmp_momentum_extender(shard_name) + sharded_t = item[momentum_name] + assert len(sharded_t._local_shards) == 1 + # TODO: support multiple shards in CW sharding + sharded_t._local_shards = [ + Shard( + tensor=shard_name_to_local_output_tensor[shard_name], + metadata=shard.metadata, + ) + for shard in sharded_t._local_shards + ] + + return new_opt_state + + def update_module_sharding_plan( module: ShardedModule[Any, Any, Any, Any], # pyre-ignore changed_sharding_params: Dict[str, ParameterSharding], @@ -388,3 +474,16 @@ def output_sharding_plan_delta( if v.ranks != old_plan[k].ranks } ) + + +""" +Utils for Optimizer State accessing +""" + + +def tmp_momentum_extender(name: str) -> str: + return name + ".momentum1" + + +def extract_shard_name(name: str) -> str: + return name.split(".")[-2] diff --git a/torchrec/distributed/test_utils/test_sharding.py b/torchrec/distributed/test_utils/test_sharding.py index 3a8937e23..8793a2153 100644 --- a/torchrec/distributed/test_utils/test_sharding.py +++ b/torchrec/distributed/test_utils/test_sharding.py @@ -541,6 +541,8 @@ def dynamic_sharding_test( ) local_m1_dmp.reshard("sparse.ebc", new_module_sharding_plan_delta) + # Must recreate local_m1_opt, because current local_m1_opt is a copy of underlying fused_opt + local_m1_opt = CombinedOptimizer([local_m1_dmp.fused_optimizer, dense_m1_optim]) local_m1_pred = gen_full_pred_after_one_step( local_m1_dmp, local_m1_opt, local_input_1 @@ -954,7 +956,12 @@ def gen_full_pred_after_one_step( opt: torch.optim.Optimizer, input: ModelInput, skip_inference: bool = False, + skip_training: bool = False, ) -> torch.Tensor: + if skip_training: + model.train(False) + output = model(input) + return output # Run a single training step of the global model. opt.zero_grad() model.train(True) @@ -1120,3 +1127,32 @@ def generate_rank_placements( placement = sorted(random.sample(range(world_size), ranks_per_table)) placements.append(placement) return placements + + +def compare_opt_local_t( + opt_1: CombinedOptimizer, + opt_2: CombinedOptimizer, + table_id: int, + rtol: float = 1e-4, + atol: float = 1e-4, +) -> None: + """ + Helper function to compare the optimizer state of two models after one training step. + Useful for debugging sharding tests to see which model weights are different + """ + # TODO: update logic to be generic other embedding modules + t1 = ( + opt_1.state_dict()["state"][ + "sparse.ebc.embedding_bags.table_" + str(table_id) + ".weight" + ]["table_" + str(table_id) + ".momentum1"] + .local_shards()[0] + .tensor + ) + t2 = ( + opt_2.state_dict()["state"][ + "sparse.ebc.embedding_bags.table_" + str(table_id) + ".weight" + ]["table_" + str(table_id) + ".momentum1"] + .local_shards()[0] + .tensor + ) + torch.testing.assert_close(t1, t2, rtol=rtol, atol=atol) diff --git a/torchrec/distributed/tests/test_dynamic_sharding.py b/torchrec/distributed/tests/test_dynamic_sharding.py index db55a8fa2..c2a939140 100644 --- a/torchrec/distributed/tests/test_dynamic_sharding.py +++ b/torchrec/distributed/tests/test_dynamic_sharding.py @@ -8,8 +8,6 @@ # pyre-strict -import copy - import random import unittest @@ -21,7 +19,7 @@ from hypothesis import assume, given, settings, Verbosity -from torch import nn +from torch import nn, optim from torchrec import distributed as trec_dist, EmbeddingBagCollection, KeyedJaggedTensor from torchrec.distributed.embedding_types import EmbeddingComputeKernel @@ -530,9 +528,11 @@ class MultiRankDMPDynamicShardingTest(ModelParallelTestShared): apply_optimizer_in_backward_config=st.sampled_from( [ None, + { + "embedding_bags": (optim.Adagrad, {"lr": 0.04}), + }, { "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), - "embeddings": (torch.optim.SGD, {"lr": 0.2}), }, ] ),