2727
2828import torch
2929from fbgemm_gpu .permute_pooled_embedding_modules import PermutePooledEmbeddings
30+ from fbgemm_gpu .split_table_batched_embeddings_ops_training import (
31+ DenseTableBatchedEmbeddingBagsCodegen ,
32+ )
3033from tensordict import TensorDict
3134from torch import distributed as dist , nn , Tensor
3235from torch .autograd .profiler import record_function
5053)
5154from torchrec .distributed .sharding .cw_sharding import CwPooledEmbeddingSharding
5255from torchrec .distributed .sharding .dp_sharding import DpPooledEmbeddingSharding
56+ from torchrec .distributed .sharding .dynamic_sharding_utils import (
57+ shards_all_to_all ,
58+ update_state_dict_post_resharding ,
59+ )
5360from torchrec .distributed .sharding .grid_sharding import GridPooledEmbeddingSharding
5461from torchrec .distributed .sharding .rw_sharding import RwPooledEmbeddingSharding
5562from torchrec .distributed .sharding .tw_sharding import TwPooledEmbeddingSharding
@@ -635,14 +642,17 @@ def __init__(
635642 self ._env = env
636643 # output parameters as DTensor in state dict
637644 self ._output_dtensor : bool = env .output_dtensor
638-
639- sharding_type_to_sharding_infos = create_sharding_infos_by_sharding (
640- module ,
641- table_name_to_parameter_sharding ,
642- "embedding_bags." ,
643- fused_params ,
645+ self .sharding_type_to_sharding_infos : Dict [str , List [EmbeddingShardingInfo ]] = (
646+ create_sharding_infos_by_sharding (
647+ module ,
648+ table_name_to_parameter_sharding ,
649+ "embedding_bags." ,
650+ fused_params ,
651+ )
652+ )
653+ self ._sharding_types : List [str ] = list (
654+ self .sharding_type_to_sharding_infos .keys ()
644655 )
645- self ._sharding_types : List [str ] = list (sharding_type_to_sharding_infos .keys ())
646656 self ._embedding_shardings : List [
647657 EmbeddingSharding [
648658 EmbeddingShardingContext ,
@@ -658,7 +668,7 @@ def __init__(
658668 permute_embeddings = True ,
659669 qcomm_codecs_registry = self .qcomm_codecs_registry ,
660670 )
661- for embedding_configs in sharding_type_to_sharding_infos .values ()
671+ for embedding_configs in self . sharding_type_to_sharding_infos .values ()
662672 ]
663673
664674 self ._is_weighted : bool = module .is_weighted ()
@@ -833,7 +843,7 @@ def _pre_load_state_dict_hook(
833843 lookup = lookup .module
834844 lookup .purge ()
835845
836- def _initialize_torch_state (self ) -> None : # noqa
846+ def _initialize_torch_state (self , skip_registering : bool = False ) -> None : # noqa
837847 """
838848 This provides consistency between this class and the EmbeddingBagCollection's
839849 nn.Module API calls (state_dict, named_modules, etc)
@@ -1063,11 +1073,12 @@ def post_state_dict_hook(
10631073 destination_key = f"{ prefix } embedding_bags.{ table_name } .weight"
10641074 destination [destination_key ] = sharded_kvtensor
10651075
1066- self .register_state_dict_pre_hook (self ._pre_state_dict_hook )
1067- self ._register_state_dict_hook (post_state_dict_hook )
1068- self ._register_load_state_dict_pre_hook (
1069- self ._pre_load_state_dict_hook , with_module = True
1070- )
1076+ if not skip_registering :
1077+ self .register_state_dict_pre_hook (self ._pre_state_dict_hook )
1078+ self ._register_state_dict_hook (post_state_dict_hook )
1079+ self ._register_load_state_dict_pre_hook (
1080+ self ._pre_load_state_dict_hook , with_module = True
1081+ )
10711082 self .reset_parameters ()
10721083
10731084 def reset_parameters (self ) -> None :
@@ -1164,6 +1175,40 @@ def _create_output_dist(self) -> None:
11641175 self ._uncombined_embedding_dims .extend (sharding .uncombined_embedding_dims ())
11651176 embedding_shard_metadata .extend (sharding .embedding_shard_metadata ())
11661177 self ._dim_per_key = torch .tensor (self ._embedding_dims , device = self ._device )
1178+
1179+ embedding_shard_offsets : List [int ] = [
1180+ meta .shard_offsets [1 ] if meta is not None else 0
1181+ for meta in embedding_shard_metadata
1182+ ]
1183+ embedding_name_order : Dict [str , int ] = {}
1184+ for i , name in enumerate (self ._uncombined_embedding_names ):
1185+ embedding_name_order .setdefault (name , i )
1186+
1187+ def sort_key (input : Tuple [int , str ]) -> Tuple [int , int ]:
1188+ index , name = input
1189+ return (embedding_name_order [name ], embedding_shard_offsets [index ])
1190+
1191+ permute_indices = [
1192+ i
1193+ for i , _ in sorted (
1194+ enumerate (self ._uncombined_embedding_names ), key = sort_key
1195+ )
1196+ ]
1197+ self ._permute_op : PermutePooledEmbeddings = PermutePooledEmbeddings (
1198+ self ._uncombined_embedding_dims , permute_indices , self ._device
1199+ )
1200+
1201+ def _update_output_dist (self ) -> None :
1202+ embedding_shard_metadata : List [Optional [ShardMetadata ]] = []
1203+ # TODO: Optimize to only go through embedding shardings with new ranks
1204+ self ._output_dists : List [nn .Module ] = []
1205+ self ._embedding_names : List [str ] = []
1206+ for sharding in self ._embedding_shardings :
1207+ # TODO: if sharding type of table completely changes, need to regenerate everything
1208+ self ._embedding_names .extend (sharding .embedding_names ())
1209+ self ._output_dists .append (sharding .create_output_dist (device = self ._device ))
1210+ embedding_shard_metadata .extend (sharding .embedding_shard_metadata ())
1211+
11671212 embedding_shard_offsets : List [int ] = [
11681213 meta .shard_offsets [1 ] if meta is not None else 0
11691214 for meta in embedding_shard_metadata
@@ -1399,6 +1444,105 @@ def compute_and_output_dist(
13991444
14001445 return awaitable
14011446
1447+ def update_shards (
1448+ self ,
1449+ changed_sharding_params : Dict [str , ParameterSharding ], # NOTE: only delta
1450+ env : ShardingEnv ,
1451+ device : Optional [torch .device ],
1452+ ) -> None :
1453+ """
1454+ Update shards for this module based on the changed_sharding_params. This will:
1455+ 1. Move current lookup tensors to CPU
1456+ 2. Purge lookups
1457+ 3. Call shards_all_2_all containing collective to redistribute tensors
1458+ 4. Update state_dict and other attributes to reflect new placements and shards
1459+ 5. Create new lookups, and load in updated state_dict
1460+
1461+ Args:
1462+ changed_sharding_params (Dict[str, ParameterSharding]): A dictionary mapping
1463+ table names to their new parameter sharding configs. This should only
1464+ contain shards/table names that need to be moved.
1465+ env (ShardingEnv): The sharding environment for the module.
1466+ device (Optional[torch.device]): The device to place the updated module on.
1467+ """
1468+
1469+ if env .output_dtensor :
1470+ raise RuntimeError ("We do not yet support DTensor for resharding yet" )
1471+ return
1472+
1473+ current_state = self .state_dict ()
1474+ # TODO: Save Optimizers
1475+
1476+ saved_weights = {}
1477+ # TODO: Saving lookups tensors to CPU to eventually avoid recreating them completely again
1478+ for i , lookup in enumerate (self ._lookups ):
1479+ for attribute , tbe_module in lookup .named_modules ():
1480+ if type (tbe_module ) is DenseTableBatchedEmbeddingBagsCodegen :
1481+ saved_weights [str (i ) + "." + attribute ] = tbe_module .weights .cpu ()
1482+ # Note: lookup.purge should delete tbe_module and weights
1483+ # del tbe_module.weights
1484+ # del tbe_module
1485+ # pyre-ignore
1486+ lookup .purge ()
1487+
1488+ # Deleting all lookups
1489+ self ._lookups .clear ()
1490+
1491+ local_output_by_src_rank , local_output_tensor = shards_all_to_all (
1492+ module = self ,
1493+ device = device , # pyre-ignore
1494+ changed_sharding_params = changed_sharding_params ,
1495+ env = env ,
1496+ )
1497+
1498+ current_state = update_state_dict_post_resharding (
1499+ update_state_dict = current_state ,
1500+ local_output_by_src_rank = local_output_by_src_rank ,
1501+ local_output_tensor = local_output_tensor ,
1502+ changed_sharding_params = changed_sharding_params ,
1503+ curr_rank = dist .get_rank (),
1504+ )
1505+
1506+ for name , param in changed_sharding_params .items ():
1507+ self .module_sharding_plan [name ] = param
1508+ # TODO: Support detecting old sharding type when sharding type is changing
1509+ for sharding_info in self .sharding_type_to_sharding_infos [
1510+ param .sharding_type
1511+ ]:
1512+ if sharding_info .embedding_config .name == name :
1513+ sharding_info .param_sharding = param
1514+
1515+ self ._sharding_types : List [str ] = list (
1516+ self .sharding_type_to_sharding_infos .keys ()
1517+ )
1518+ # TODO: Optimize to update only the changed embedding shardings
1519+ self ._embedding_shardings : List [
1520+ EmbeddingSharding [
1521+ EmbeddingShardingContext ,
1522+ KeyedJaggedTensor ,
1523+ torch .Tensor ,
1524+ torch .Tensor ,
1525+ ]
1526+ ] = [
1527+ create_embedding_bag_sharding (
1528+ embedding_configs ,
1529+ env ,
1530+ device ,
1531+ permute_embeddings = True ,
1532+ qcomm_codecs_registry = self .qcomm_codecs_registry ,
1533+ )
1534+ for embedding_configs in self .sharding_type_to_sharding_infos .values ()
1535+ ]
1536+
1537+ self ._create_lookups ()
1538+ self ._update_output_dist ()
1539+
1540+ if env .process_group and dist .get_backend (env .process_group ) != "fake" :
1541+ self ._initialize_torch_state (skip_registering = True )
1542+
1543+ self .load_state_dict (current_state )
1544+ return
1545+
14021546 @property
14031547 def fused_optimizer (self ) -> KeyedOptimizer :
14041548 return self ._optim
@@ -1438,6 +1582,33 @@ def shardable_parameters(
14381582 for name , param in module .embedding_bags .named_parameters ()
14391583 }
14401584
1585+ def reshard (
1586+ self ,
1587+ sharded_module : ShardedEmbeddingBagCollection ,
1588+ changed_shard_to_params : Dict [str , ParameterSharding ],
1589+ env : ShardingEnv ,
1590+ device : Optional [torch .device ] = None ,
1591+ ) -> ShardedEmbeddingBagCollection :
1592+ """
1593+ Updates the sharded module in place based on the changed_shard_to_params
1594+ which contains the new ParameterSharding with different shard placements.
1595+
1596+ Args:
1597+ sharded_module (ShardedEmbeddingBagCollection): The module to update
1598+ changed_shard_to_params (Dict[str, ParameterSharding]): A dictionary mapping
1599+ table names to their new parameter sharding configs. This should only
1600+ contain shards/table names that need to be moved
1601+ env (ShardingEnv): The sharding environment
1602+ device (Optional[torch.device]): The device to place the updated module on
1603+
1604+ Returns:
1605+ ShardedEmbeddingBagCollection: The updated sharded module
1606+ """
1607+
1608+ if len (changed_shard_to_params ) > 0 :
1609+ sharded_module .update_shards (changed_shard_to_params , env , device )
1610+ return sharded_module
1611+
14411612 @property
14421613 def module_type (self ) -> Type [EmbeddingBagCollection ]:
14431614 return EmbeddingBagCollection
0 commit comments