From 30ea6d3fea4c7ae6a66e823825ec2ebfe60d4341 Mon Sep 17 00:00:00 2001 From: Felicity Liao Date: Tue, 12 Aug 2025 13:03:00 -0700 Subject: [PATCH] Refactor create new sharding plan Differential Revision: D80113954 --- .../distributed/test_utils/test_sharding.py | 160 +++++++++++------- 1 file changed, 99 insertions(+), 61 deletions(-) diff --git a/torchrec/distributed/test_utils/test_sharding.py b/torchrec/distributed/test_utils/test_sharding.py index ff72f9fa2..dd9ad821d 100644 --- a/torchrec/distributed/test_utils/test_sharding.py +++ b/torchrec/distributed/test_utils/test_sharding.py @@ -492,69 +492,16 @@ def dynamic_sharding_test( assert ctx.pg is not None - num_tables = len(tables) - - ranks_per_tables = [1 for _ in range(num_tables)] - new_ranks = generate_rank_placements( - world_size, num_tables, ranks_per_tables, random_seed - ) - - ranks_per_tables_for_CW = [] - for table in tables: - - # CW sharding - valid_candidates = [ - i for i in range(1, world_size + 1) if table.embedding_dim % i == 0 - ] - ranks_per_tables_for_CW.append(random.choice(valid_candidates)) - - new_ranks_cw = generate_rank_placements( - world_size, num_tables, ranks_per_tables_for_CW, random_seed - ) - - new_per_param_sharding = {} - - assert len(sharders) == 1 - # pyre-ignore - kernel_type = sharders[0]._kernel_type - # Construct parameter shardings - for i in range(num_tables): - table_name = tables[i].name - table_constraint = constraints[table_name] # pyre-ignore - assert hasattr(table_constraint, "sharding_types") - assert ( - len(table_constraint.sharding_types) == 1 - ), "Dynamic Sharding currently only supports 1 sharding type per table" - sharding_type = ShardingType(table_constraint.sharding_types[0]) - sharding_type_constructor = get_sharding_constructor_from_type( - sharding_type - ) - - if sharding_type == ShardingType.TABLE_WISE: - new_per_param_sharding[table_name] = sharding_type_constructor( - rank=new_ranks[i][0], compute_kernel=kernel_type - ) - elif sharding_type == ShardingType.COLUMN_WISE: - new_per_param_sharding[table_name] = sharding_type_constructor( - ranks=new_ranks_cw[i], compute_kernel=kernel_type - ) - else: - raise NotImplementedError( - f"Dynamic Sharding currently does not support {sharding_type}" - ) - - new_module_sharding_plan = construct_module_sharding_plan( - local_m2.sparse.ebc, - sharder=sharders[0], - per_param_sharding=new_per_param_sharding, - local_size=world_size, + plan_1 = create_alternative_sharding_plan( + tables=tables, world_size=world_size, - device_type="cuda" if torch.cuda.is_available() else "cpu", + random_seed=random_seed, + sharders=sharders, + constraints=constraints, + local_model=local_m2, + original_plan=plan, ) - plan_1 = copy.deepcopy(plan) - plan_1.plan["sparse.ebc"] = new_module_sharding_plan - local_m1_dmp = DistributedModelParallel( local_m1, env=ShardingEnv.from_process_group(ctx.pg), # pyre-ignore @@ -591,7 +538,7 @@ def dynamic_sharding_test( ) new_module_sharding_plan_delta = output_sharding_plan_delta( - plan.plan["sparse.ebc"], new_module_sharding_plan # pyre-ignore + plan.plan["sparse.ebc"], plan_1.plan["sparse.ebc"] # pyre-ignore ) dense_m1_optim = KeyedOptimizerWrapper( @@ -1304,6 +1251,97 @@ def generate_rank_placements( return placements +def create_alternative_sharding_plan( + tables: List[EmbeddingTableConfig], + world_size: int, + random_seed: int, + sharders: List[ModuleSharder[nn.Module]], + constraints: Optional[Dict[str, ParameterConstraints]], + local_model: nn.Module, + original_plan: ShardingPlan, +) -> ShardingPlan: + """ + Creates an alternative sharding plan for dynamic sharding tests. + + Args: + tables: List of embedding table configurations + world_size: Number of processes in the distributed group + random_seed: Random seed for reproducible rank placement generation + sharders: List of module sharders + constraints: Parameter constraints for sharding + local_model: Local model to create sharding plan for + original_plan: Original sharding plan to copy and modify + + Returns: + Modified sharding plan with alternative parameter sharding + """ + if constraints is None: + raise ValueError("constraints parameter is required for dynamic sharding") + + num_tables = len(tables) + + ranks_per_tables = [1 for _ in range(num_tables)] + new_ranks = generate_rank_placements( + world_size, num_tables, ranks_per_tables, random_seed + ) + + ranks_per_tables_for_CW = [] + for table in tables: + + # CW sharding + valid_candidates = [ + i for i in range(1, world_size + 1) if table.embedding_dim % i == 0 + ] + ranks_per_tables_for_CW.append(random.choice(valid_candidates)) + + new_ranks_cw = generate_rank_placements( + world_size, num_tables, ranks_per_tables_for_CW, random_seed + ) + + new_per_param_sharding = {} + + assert len(sharders) == 1 + kernel_type = sharders[0]._kernel_type # pyre-ignore + # Construct parameter shardings + for i in range(num_tables): + table_name = tables[i].name + table_constraint = constraints[table_name] + assert hasattr(table_constraint, "sharding_types") + assert table_constraint.sharding_types is not None + assert ( + len(table_constraint.sharding_types) == 1 + ), "Dynamic Sharding currently only supports 1 sharding type per table" + sharding_type = ShardingType(table_constraint.sharding_types[0]) # pyre-ignore + sharding_type_constructor = get_sharding_constructor_from_type(sharding_type) + + if sharding_type == ShardingType.TABLE_WISE: + new_per_param_sharding[table_name] = sharding_type_constructor( + rank=new_ranks[i][0], compute_kernel=kernel_type + ) + elif sharding_type == ShardingType.COLUMN_WISE: + new_per_param_sharding[table_name] = sharding_type_constructor( + ranks=new_ranks_cw[i], compute_kernel=kernel_type + ) + else: + raise NotImplementedError( + f"Dynamic Sharding currently does not support {sharding_type}" + ) + + new_module_sharding_plan = construct_module_sharding_plan( + local_model.sparse.ebc, # pyre-ignore + sharder=sharders[0], + per_param_sharding=new_per_param_sharding, + local_size=world_size, + world_size=world_size, + device_type="cuda" if torch.cuda.is_available() else "cpu", + ) + + plan_1 = copy.deepcopy(original_plan) + plan_1.plan["sparse.ebc"] = new_module_sharding_plan + + return plan_1 + + def compare_opt_local_t( opt_1: CombinedOptimizer, opt_2: CombinedOptimizer,