27
27
28
28
import torch
29
29
from fbgemm_gpu .permute_pooled_embedding_modules import PermutePooledEmbeddings
30
+ from fbgemm_gpu .split_table_batched_embeddings_ops_training import (
31
+ DenseTableBatchedEmbeddingBagsCodegen ,
32
+ )
30
33
from tensordict import TensorDict
31
34
from torch import distributed as dist , nn , Tensor
32
35
from torch .autograd .profiler import record_function
50
53
)
51
54
from torchrec .distributed .sharding .cw_sharding import CwPooledEmbeddingSharding
52
55
from torchrec .distributed .sharding .dp_sharding import DpPooledEmbeddingSharding
56
+ from torchrec .distributed .sharding .dynamic_sharding_utils import (
57
+ shards_all_to_all ,
58
+ update_state_dict_post_resharding ,
59
+ )
53
60
from torchrec .distributed .sharding .grid_sharding import GridPooledEmbeddingSharding
54
61
from torchrec .distributed .sharding .rw_sharding import RwPooledEmbeddingSharding
55
62
from torchrec .distributed .sharding .tw_sharding import TwPooledEmbeddingSharding
@@ -635,14 +642,17 @@ def __init__(
635
642
self ._env = env
636
643
# output parameters as DTensor in state dict
637
644
self ._output_dtensor : bool = env .output_dtensor
638
-
639
- sharding_type_to_sharding_infos = create_sharding_infos_by_sharding (
640
- module ,
641
- table_name_to_parameter_sharding ,
642
- "embedding_bags." ,
643
- fused_params ,
645
+ self .sharding_type_to_sharding_infos : Dict [str , List [EmbeddingShardingInfo ]] = (
646
+ create_sharding_infos_by_sharding (
647
+ module ,
648
+ table_name_to_parameter_sharding ,
649
+ "embedding_bags." ,
650
+ fused_params ,
651
+ )
652
+ )
653
+ self ._sharding_types : List [str ] = list (
654
+ self .sharding_type_to_sharding_infos .keys ()
644
655
)
645
- self ._sharding_types : List [str ] = list (sharding_type_to_sharding_infos .keys ())
646
656
self ._embedding_shardings : List [
647
657
EmbeddingSharding [
648
658
EmbeddingShardingContext ,
@@ -658,7 +668,7 @@ def __init__(
658
668
permute_embeddings = True ,
659
669
qcomm_codecs_registry = self .qcomm_codecs_registry ,
660
670
)
661
- for embedding_configs in sharding_type_to_sharding_infos .values ()
671
+ for embedding_configs in self . sharding_type_to_sharding_infos .values ()
662
672
]
663
673
664
674
self ._is_weighted : bool = module .is_weighted ()
@@ -833,7 +843,7 @@ def _pre_load_state_dict_hook(
833
843
lookup = lookup .module
834
844
lookup .purge ()
835
845
836
- def _initialize_torch_state (self ) -> None : # noqa
846
+ def _initialize_torch_state (self , skip_registering : bool = False ) -> None : # noqa
837
847
"""
838
848
This provides consistency between this class and the EmbeddingBagCollection's
839
849
nn.Module API calls (state_dict, named_modules, etc)
@@ -1063,11 +1073,12 @@ def post_state_dict_hook(
1063
1073
destination_key = f"{ prefix } embedding_bags.{ table_name } .weight"
1064
1074
destination [destination_key ] = sharded_kvtensor
1065
1075
1066
- self .register_state_dict_pre_hook (self ._pre_state_dict_hook )
1067
- self ._register_state_dict_hook (post_state_dict_hook )
1068
- self ._register_load_state_dict_pre_hook (
1069
- self ._pre_load_state_dict_hook , with_module = True
1070
- )
1076
+ if not skip_registering :
1077
+ self .register_state_dict_pre_hook (self ._pre_state_dict_hook )
1078
+ self ._register_state_dict_hook (post_state_dict_hook )
1079
+ self ._register_load_state_dict_pre_hook (
1080
+ self ._pre_load_state_dict_hook , with_module = True
1081
+ )
1071
1082
self .reset_parameters ()
1072
1083
1073
1084
def reset_parameters (self ) -> None :
@@ -1164,6 +1175,40 @@ def _create_output_dist(self) -> None:
1164
1175
self ._uncombined_embedding_dims .extend (sharding .uncombined_embedding_dims ())
1165
1176
embedding_shard_metadata .extend (sharding .embedding_shard_metadata ())
1166
1177
self ._dim_per_key = torch .tensor (self ._embedding_dims , device = self ._device )
1178
+
1179
+ embedding_shard_offsets : List [int ] = [
1180
+ meta .shard_offsets [1 ] if meta is not None else 0
1181
+ for meta in embedding_shard_metadata
1182
+ ]
1183
+ embedding_name_order : Dict [str , int ] = {}
1184
+ for i , name in enumerate (self ._uncombined_embedding_names ):
1185
+ embedding_name_order .setdefault (name , i )
1186
+
1187
+ def sort_key (input : Tuple [int , str ]) -> Tuple [int , int ]:
1188
+ index , name = input
1189
+ return (embedding_name_order [name ], embedding_shard_offsets [index ])
1190
+
1191
+ permute_indices = [
1192
+ i
1193
+ for i , _ in sorted (
1194
+ enumerate (self ._uncombined_embedding_names ), key = sort_key
1195
+ )
1196
+ ]
1197
+ self ._permute_op : PermutePooledEmbeddings = PermutePooledEmbeddings (
1198
+ self ._uncombined_embedding_dims , permute_indices , self ._device
1199
+ )
1200
+
1201
+ def _update_output_dist (self ) -> None :
1202
+ embedding_shard_metadata : List [Optional [ShardMetadata ]] = []
1203
+ # TODO: Optimize to only go through embedding shardings with new ranks
1204
+ self ._output_dists : List [nn .Module ] = []
1205
+ self ._embedding_names : List [str ] = []
1206
+ for sharding in self ._embedding_shardings :
1207
+ # TODO: if sharding type of table completely changes, need to regenerate everything
1208
+ self ._embedding_names .extend (sharding .embedding_names ())
1209
+ self ._output_dists .append (sharding .create_output_dist (device = self ._device ))
1210
+ embedding_shard_metadata .extend (sharding .embedding_shard_metadata ())
1211
+
1167
1212
embedding_shard_offsets : List [int ] = [
1168
1213
meta .shard_offsets [1 ] if meta is not None else 0
1169
1214
for meta in embedding_shard_metadata
@@ -1396,13 +1441,117 @@ def compute_and_output_dist(
1396
1441
1397
1442
return awaitable
1398
1443
1444
+ def update_shards (
1445
+ self ,
1446
+ changed_sharding_params : Dict [str , ParameterSharding ], # NOTE: only delta
1447
+ env : ShardingEnv ,
1448
+ device : Optional [torch .device ],
1449
+ ) -> None :
1450
+ """
1451
+ Update shards for this module based on the changed_sharding_params. This will:
1452
+ 1. Move current lookup tensors to CPU
1453
+ 2. Purge lookups
1454
+ 3. Call shards_all_2_all containing collective to redistribute tensors
1455
+ 4. Update state_dict and other attributes to reflect new placements and shards
1456
+ 5. Create new lookups, and load in updated state_dict
1457
+
1458
+ Args:
1459
+ changed_sharding_params (Dict[str, ParameterSharding]): A dictionary mapping
1460
+ table names to their new parameter sharding configs. This should only
1461
+ contain shards/table names that need to be moved.
1462
+ env (ShardingEnv): The sharding environment for the module.
1463
+ device (Optional[torch.device]): The device to place the updated module on.
1464
+ """
1465
+
1466
+ if env .output_dtensor :
1467
+ raise RuntimeError ("We do not yet support DTensor for resharding yet" )
1468
+ return
1469
+
1470
+ current_state = self .state_dict ()
1471
+ # TODO: Save Optimizers
1472
+
1473
+ saved_weights = {}
1474
+ # TODO: Saving lookups tensors to CPU to eventually avoid recreating them completely again
1475
+ for i , lookup in enumerate (self ._lookups ):
1476
+ for attribute , tbe_module in lookup .named_modules ():
1477
+ if type (tbe_module ) is DenseTableBatchedEmbeddingBagsCodegen :
1478
+ saved_weights [str (i ) + "." + attribute ] = tbe_module .weights .cpu ()
1479
+ # Note: lookup.purge should delete tbe_module and weights
1480
+ # del tbe_module.weights
1481
+ # del tbe_module
1482
+ # pyre-ignore
1483
+ lookup .purge ()
1484
+
1485
+ # Deleting all lookups
1486
+ self ._lookups .clear ()
1487
+
1488
+ local_output_by_src_rank , local_output_tensor = shards_all_to_all (
1489
+ module = self ,
1490
+ device = device , # pyre-ignore
1491
+ changed_sharding_params = changed_sharding_params ,
1492
+ env = env ,
1493
+ )
1494
+
1495
+ current_state = update_state_dict_post_resharding (
1496
+ update_state_dict = current_state ,
1497
+ local_output_by_src_rank = local_output_by_src_rank ,
1498
+ local_output_tensor = local_output_tensor ,
1499
+ changed_sharding_params = changed_sharding_params ,
1500
+ curr_rank = dist .get_rank (),
1501
+ extend_shard_name_callback = self .extend_shard_name ,
1502
+ )
1503
+
1504
+ for name , param in changed_sharding_params .items ():
1505
+ self .module_sharding_plan [name ] = param
1506
+ # TODO: Support detecting old sharding type when sharding type is changing
1507
+ for sharding_info in self .sharding_type_to_sharding_infos [
1508
+ param .sharding_type
1509
+ ]:
1510
+ if sharding_info .embedding_config .name == name :
1511
+ sharding_info .param_sharding = param
1512
+
1513
+ self ._sharding_types : List [str ] = list (
1514
+ self .sharding_type_to_sharding_infos .keys ()
1515
+ )
1516
+ # TODO: Optimize to update only the changed embedding shardings
1517
+ self ._embedding_shardings : List [
1518
+ EmbeddingSharding [
1519
+ EmbeddingShardingContext ,
1520
+ KeyedJaggedTensor ,
1521
+ torch .Tensor ,
1522
+ torch .Tensor ,
1523
+ ]
1524
+ ] = [
1525
+ create_embedding_bag_sharding (
1526
+ embedding_configs ,
1527
+ env ,
1528
+ device ,
1529
+ permute_embeddings = True ,
1530
+ qcomm_codecs_registry = self .qcomm_codecs_registry ,
1531
+ )
1532
+ for embedding_configs in self .sharding_type_to_sharding_infos .values ()
1533
+ ]
1534
+
1535
+ self ._create_lookups ()
1536
+ self ._update_output_dist ()
1537
+
1538
+ if env .process_group and dist .get_backend (env .process_group ) != "fake" :
1539
+ self ._initialize_torch_state (skip_registering = True )
1540
+
1541
+ self .load_state_dict (current_state )
1542
+ return
1543
+
1399
1544
@property
1400
1545
def fused_optimizer (self ) -> KeyedOptimizer :
1401
1546
return self ._optim
1402
1547
1403
1548
def create_context (self ) -> EmbeddingBagCollectionContext :
1404
1549
return EmbeddingBagCollectionContext ()
1405
1550
1551
+ @staticmethod
1552
+ def extend_shard_name (shard_name : str ) -> str :
1553
+ return f"embedding_bags.{ shard_name } .weight"
1554
+
1406
1555
1407
1556
class EmbeddingBagCollectionSharder (BaseEmbeddingSharder [EmbeddingBagCollection ]):
1408
1557
"""
@@ -1435,6 +1584,33 @@ def shardable_parameters(
1435
1584
for name , param in module .embedding_bags .named_parameters ()
1436
1585
}
1437
1586
1587
+ def reshard (
1588
+ self ,
1589
+ sharded_module : ShardedEmbeddingBagCollection ,
1590
+ changed_shard_to_params : Dict [str , ParameterSharding ],
1591
+ env : ShardingEnv ,
1592
+ device : Optional [torch .device ] = None ,
1593
+ ) -> ShardedEmbeddingBagCollection :
1594
+ """
1595
+ Updates the sharded module in place based on the changed_shard_to_params
1596
+ which contains the new ParameterSharding with different shard placements.
1597
+
1598
+ Args:
1599
+ sharded_module (ShardedEmbeddingBagCollection): The module to update
1600
+ changed_shard_to_params (Dict[str, ParameterSharding]): A dictionary mapping
1601
+ table names to their new parameter sharding configs. This should only
1602
+ contain shards/table names that need to be moved
1603
+ env (ShardingEnv): The sharding environment
1604
+ device (Optional[torch.device]): The device to place the updated module on
1605
+
1606
+ Returns:
1607
+ ShardedEmbeddingBagCollection: The updated sharded module
1608
+ """
1609
+
1610
+ if len (changed_shard_to_params ) > 0 :
1611
+ sharded_module .update_shards (changed_shard_to_params , env , device )
1612
+ return sharded_module
1613
+
1438
1614
@property
1439
1615
def module_type (self ) -> Type [EmbeddingBagCollection ]:
1440
1616
return EmbeddingBagCollection
0 commit comments