Skip to content

Commit 0a802c7

Browse files
aporialiaofacebook-github-bot
authored andcommitted
1/n Dynamic Sharding API + Test for EBC, TW, ShardedTensor (#2852)
Summary: Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs. What's added here: 1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection` 2. Util functions for dynamic sharding - these are used by the `update_shards` API: 1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight` 2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params` 3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from various: `world_sizes`, `num_tables`, `data_types`. 1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 Future work items (features not yet supported in this diff): * CW, RW, and many other sharding types * Optimizer saving * DTensor implementation Differential Revision: D69095169
1 parent eea1862 commit 0a802c7

File tree

3 files changed

+795
-14
lines changed

3 files changed

+795
-14
lines changed

torchrec/distributed/embeddingbag.py

Lines changed: 190 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727

2828
import torch
2929
from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
30+
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
31+
DenseTableBatchedEmbeddingBagsCodegen,
32+
)
3033
from tensordict import TensorDict
3134
from torch import distributed as dist, nn, Tensor
3235
from torch.autograd.profiler import record_function
@@ -50,6 +53,10 @@
5053
)
5154
from torchrec.distributed.sharding.cw_sharding import CwPooledEmbeddingSharding
5255
from torchrec.distributed.sharding.dp_sharding import DpPooledEmbeddingSharding
56+
from torchrec.distributed.sharding.dynamic_sharding_utils import (
57+
shards_all_to_all,
58+
update_state_dict_post_resharding,
59+
)
5360
from torchrec.distributed.sharding.grid_sharding import GridPooledEmbeddingSharding
5461
from torchrec.distributed.sharding.rw_sharding import RwPooledEmbeddingSharding
5562
from torchrec.distributed.sharding.tw_sharding import TwPooledEmbeddingSharding
@@ -635,14 +642,17 @@ def __init__(
635642
self._env = env
636643
# output parameters as DTensor in state dict
637644
self._output_dtensor: bool = env.output_dtensor
638-
639-
sharding_type_to_sharding_infos = create_sharding_infos_by_sharding(
640-
module,
641-
table_name_to_parameter_sharding,
642-
"embedding_bags.",
643-
fused_params,
645+
self.sharding_type_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]] = (
646+
create_sharding_infos_by_sharding(
647+
module,
648+
table_name_to_parameter_sharding,
649+
"embedding_bags.",
650+
fused_params,
651+
)
652+
)
653+
self._sharding_types: List[str] = list(
654+
self.sharding_type_to_sharding_infos.keys()
644655
)
645-
self._sharding_types: List[str] = list(sharding_type_to_sharding_infos.keys())
646656
self._embedding_shardings: List[
647657
EmbeddingSharding[
648658
EmbeddingShardingContext,
@@ -658,7 +668,7 @@ def __init__(
658668
permute_embeddings=True,
659669
qcomm_codecs_registry=self.qcomm_codecs_registry,
660670
)
661-
for embedding_configs in sharding_type_to_sharding_infos.values()
671+
for embedding_configs in self.sharding_type_to_sharding_infos.values()
662672
]
663673

664674
self._is_weighted: bool = module.is_weighted()
@@ -833,7 +843,7 @@ def _pre_load_state_dict_hook(
833843
lookup = lookup.module
834844
lookup.purge()
835845

836-
def _initialize_torch_state(self) -> None: # noqa
846+
def _initialize_torch_state(self, skip_registering: bool = False) -> None: # noqa
837847
"""
838848
This provides consistency between this class and the EmbeddingBagCollection's
839849
nn.Module API calls (state_dict, named_modules, etc)
@@ -1063,11 +1073,12 @@ def post_state_dict_hook(
10631073
destination_key = f"{prefix}embedding_bags.{table_name}.weight"
10641074
destination[destination_key] = sharded_kvtensor
10651075

1066-
self.register_state_dict_pre_hook(self._pre_state_dict_hook)
1067-
self._register_state_dict_hook(post_state_dict_hook)
1068-
self._register_load_state_dict_pre_hook(
1069-
self._pre_load_state_dict_hook, with_module=True
1070-
)
1076+
if not skip_registering:
1077+
self.register_state_dict_pre_hook(self._pre_state_dict_hook)
1078+
self._register_state_dict_hook(post_state_dict_hook)
1079+
self._register_load_state_dict_pre_hook(
1080+
self._pre_load_state_dict_hook, with_module=True
1081+
)
10711082
self.reset_parameters()
10721083

10731084
def reset_parameters(self) -> None:
@@ -1164,6 +1175,40 @@ def _create_output_dist(self) -> None:
11641175
self._uncombined_embedding_dims.extend(sharding.uncombined_embedding_dims())
11651176
embedding_shard_metadata.extend(sharding.embedding_shard_metadata())
11661177
self._dim_per_key = torch.tensor(self._embedding_dims, device=self._device)
1178+
1179+
embedding_shard_offsets: List[int] = [
1180+
meta.shard_offsets[1] if meta is not None else 0
1181+
for meta in embedding_shard_metadata
1182+
]
1183+
embedding_name_order: Dict[str, int] = {}
1184+
for i, name in enumerate(self._uncombined_embedding_names):
1185+
embedding_name_order.setdefault(name, i)
1186+
1187+
def sort_key(input: Tuple[int, str]) -> Tuple[int, int]:
1188+
index, name = input
1189+
return (embedding_name_order[name], embedding_shard_offsets[index])
1190+
1191+
permute_indices = [
1192+
i
1193+
for i, _ in sorted(
1194+
enumerate(self._uncombined_embedding_names), key=sort_key
1195+
)
1196+
]
1197+
self._permute_op: PermutePooledEmbeddings = PermutePooledEmbeddings(
1198+
self._uncombined_embedding_dims, permute_indices, self._device
1199+
)
1200+
1201+
def _update_output_dist(self) -> None:
1202+
embedding_shard_metadata: List[Optional[ShardMetadata]] = []
1203+
# TODO: Optimize to only go through embedding shardings with new ranks
1204+
self._output_dists: List[nn.Module] = []
1205+
self._embedding_names: List[str] = []
1206+
for sharding in self._embedding_shardings:
1207+
# TODO: if sharding type of table completely changes, need to regenerate everything
1208+
self._embedding_names.extend(sharding.embedding_names())
1209+
self._output_dists.append(sharding.create_output_dist(device=self._device))
1210+
embedding_shard_metadata.extend(sharding.embedding_shard_metadata())
1211+
11671212
embedding_shard_offsets: List[int] = [
11681213
meta.shard_offsets[1] if meta is not None else 0
11691214
for meta in embedding_shard_metadata
@@ -1396,13 +1441,117 @@ def compute_and_output_dist(
13961441

13971442
return awaitable
13981443

1444+
def update_shards(
1445+
self,
1446+
changed_sharding_params: Dict[str, ParameterSharding], # NOTE: only delta
1447+
env: ShardingEnv,
1448+
device: Optional[torch.device],
1449+
) -> None:
1450+
"""
1451+
Update shards for this module based on the changed_sharding_params. This will:
1452+
1. Move current lookup tensors to CPU
1453+
2. Purge lookups
1454+
3. Call shards_all_2_all containing collective to redistribute tensors
1455+
4. Update state_dict and other attributes to reflect new placements and shards
1456+
5. Create new lookups, and load in updated state_dict
1457+
1458+
Args:
1459+
changed_sharding_params (Dict[str, ParameterSharding]): A dictionary mapping
1460+
table names to their new parameter sharding configs. This should only
1461+
contain shards/table names that need to be moved.
1462+
env (ShardingEnv): The sharding environment for the module.
1463+
device (Optional[torch.device]): The device to place the updated module on.
1464+
"""
1465+
1466+
if env.output_dtensor:
1467+
raise RuntimeError("We do not yet support DTensor for resharding yet")
1468+
return
1469+
1470+
current_state = self.state_dict()
1471+
# TODO: Save Optimizers
1472+
1473+
saved_weights = {}
1474+
# TODO: Saving lookups tensors to CPU to eventually avoid recreating them completely again
1475+
for i, lookup in enumerate(self._lookups):
1476+
for attribute, tbe_module in lookup.named_modules():
1477+
if type(tbe_module) is DenseTableBatchedEmbeddingBagsCodegen:
1478+
saved_weights[str(i) + "." + attribute] = tbe_module.weights.cpu()
1479+
# Note: lookup.purge should delete tbe_module and weights
1480+
# del tbe_module.weights
1481+
# del tbe_module
1482+
# pyre-ignore
1483+
lookup.purge()
1484+
1485+
# Deleting all lookups
1486+
self._lookups.clear()
1487+
1488+
local_output_by_src_rank, local_output_tensor = shards_all_to_all(
1489+
module=self,
1490+
device=device, # pyre-ignore
1491+
changed_sharding_params=changed_sharding_params,
1492+
env=env,
1493+
)
1494+
1495+
current_state = update_state_dict_post_resharding(
1496+
update_state_dict=current_state,
1497+
local_output_by_src_rank=local_output_by_src_rank,
1498+
local_output_tensor=local_output_tensor,
1499+
changed_sharding_params=changed_sharding_params,
1500+
curr_rank=dist.get_rank(),
1501+
extend_shard_name_callback=self.extend_shard_name,
1502+
)
1503+
1504+
for name, param in changed_sharding_params.items():
1505+
self.module_sharding_plan[name] = param
1506+
# TODO: Support detecting old sharding type when sharding type is changing
1507+
for sharding_info in self.sharding_type_to_sharding_infos[
1508+
param.sharding_type
1509+
]:
1510+
if sharding_info.embedding_config.name == name:
1511+
sharding_info.param_sharding = param
1512+
1513+
self._sharding_types: List[str] = list(
1514+
self.sharding_type_to_sharding_infos.keys()
1515+
)
1516+
# TODO: Optimize to update only the changed embedding shardings
1517+
self._embedding_shardings: List[
1518+
EmbeddingSharding[
1519+
EmbeddingShardingContext,
1520+
KeyedJaggedTensor,
1521+
torch.Tensor,
1522+
torch.Tensor,
1523+
]
1524+
] = [
1525+
create_embedding_bag_sharding(
1526+
embedding_configs,
1527+
env,
1528+
device,
1529+
permute_embeddings=True,
1530+
qcomm_codecs_registry=self.qcomm_codecs_registry,
1531+
)
1532+
for embedding_configs in self.sharding_type_to_sharding_infos.values()
1533+
]
1534+
1535+
self._create_lookups()
1536+
self._update_output_dist()
1537+
1538+
if env.process_group and dist.get_backend(env.process_group) != "fake":
1539+
self._initialize_torch_state(skip_registering=True)
1540+
1541+
self.load_state_dict(current_state)
1542+
return
1543+
13991544
@property
14001545
def fused_optimizer(self) -> KeyedOptimizer:
14011546
return self._optim
14021547

14031548
def create_context(self) -> EmbeddingBagCollectionContext:
14041549
return EmbeddingBagCollectionContext()
14051550

1551+
@staticmethod
1552+
def extend_shard_name(shard_name: str) -> str:
1553+
return f"embedding_bags.{shard_name}.weight"
1554+
14061555

14071556
class EmbeddingBagCollectionSharder(BaseEmbeddingSharder[EmbeddingBagCollection]):
14081557
"""
@@ -1435,6 +1584,33 @@ def shardable_parameters(
14351584
for name, param in module.embedding_bags.named_parameters()
14361585
}
14371586

1587+
def reshard(
1588+
self,
1589+
sharded_module: ShardedEmbeddingBagCollection,
1590+
changed_shard_to_params: Dict[str, ParameterSharding],
1591+
env: ShardingEnv,
1592+
device: Optional[torch.device] = None,
1593+
) -> ShardedEmbeddingBagCollection:
1594+
"""
1595+
Updates the sharded module in place based on the changed_shard_to_params
1596+
which contains the new ParameterSharding with different shard placements.
1597+
1598+
Args:
1599+
sharded_module (ShardedEmbeddingBagCollection): The module to update
1600+
changed_shard_to_params (Dict[str, ParameterSharding]): A dictionary mapping
1601+
table names to their new parameter sharding configs. This should only
1602+
contain shards/table names that need to be moved
1603+
env (ShardingEnv): The sharding environment
1604+
device (Optional[torch.device]): The device to place the updated module on
1605+
1606+
Returns:
1607+
ShardedEmbeddingBagCollection: The updated sharded module
1608+
"""
1609+
1610+
if len(changed_shard_to_params) > 0:
1611+
sharded_module.update_shards(changed_shard_to_params, env, device)
1612+
return sharded_module
1613+
14381614
@property
14391615
def module_type(self) -> Type[EmbeddingBagCollection]:
14401616
return EmbeddingBagCollection

0 commit comments

Comments
 (0)