Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 31 additions & 12 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
)
from fbgemm_gpu.tbe.ssd import ASSOC, SSDTableBatchedEmbeddingBags
from torch import nn
from torchrec.distributed.comm import get_local_rank
from torchrec.distributed.comm import get_local_rank, get_local_size
from torchrec.distributed.composable.table_batched_embedding_slice import (
TableBatchedEmbeddingSlice,
)
Expand Down Expand Up @@ -215,29 +215,33 @@ def get_optimizer_rowwise_shard_metadata_and_global_metadata(
table_global_metadata: ShardedTensorMetadata,
optimizer_state: torch.Tensor,
sharding_dim: int,
is_grid_sharded: bool = False,
) -> Tuple[Dict[ShardMetadata, ShardMetadata], ShardedTensorMetadata]:

table_global_shards_metadata: List[ShardMetadata] = (
table_global_metadata.shards_metadata
)

# column-wise sharding
# sort the metadata based on column offset and
# we construct the momentum tensor in row-wise sharded way
if sharding_dim == 1:
# column-wise sharding
# sort the metadata based on column offset and
# we construct the momentum tensor in row-wise sharded way
table_global_shards_metadata = sorted(
table_global_shards_metadata,
key=lambda shard: shard.shard_offsets[1],
)

table_shard_metadata_to_optimizer_shard_metadata = {}

rolling_offset = 0
for idx, table_shard_metadata in enumerate(table_global_shards_metadata):
offset = table_shard_metadata.shard_offsets[0]
# for column-wise sharding, we still create row-wise sharded metadata for optimizer
# manually create a row-wise offset

if sharding_dim == 1:
if is_grid_sharded:
# we use a rolling offset to calculate the current offset for shard to account for uneven row wise case for our shards
offset = rolling_offset
rolling_offset += table_shard_metadata.shard_sizes[0]
elif sharding_dim == 1:
# for column-wise sharding, we still create row-wise sharded metadata for optimizer
# manually create a row-wise offset
offset = idx * table_shard_metadata.shard_sizes[0]

table_shard_metadata_to_optimizer_shard_metadata[
Expand All @@ -255,14 +259,22 @@ def get_optimizer_rowwise_shard_metadata_and_global_metadata(
)
len_rw_shards = (
len(table_shard_metadata_to_optimizer_shard_metadata)
if sharding_dim == 1
if sharding_dim == 1 and not is_grid_sharded
else 1
)
# for grid sharding, the row dimension is replicated CW shard times
grid_shard_nodes = (
len(table_global_shards_metadata) // get_local_size()
if is_grid_sharded
else 1
)
rowwise_optimizer_st_metadata = ShardedTensorMetadata(
shards_metadata=list(
table_shard_metadata_to_optimizer_shard_metadata.values()
),
size=torch.Size([table_global_metadata.size[0] * len_rw_shards]),
size=torch.Size(
[table_global_metadata.size[0] * len_rw_shards * grid_shard_nodes]
),
tensor_properties=tensor_properties,
)

Expand Down Expand Up @@ -324,7 +336,6 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(

all_optimizer_states = emb_module.get_optimizer_state()
optimizer_states_keys_by_table: Dict[str, List[torch.Tensor]] = {}

for (
table_config,
optimizer_states,
Expand Down Expand Up @@ -408,6 +419,13 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
1 if table_config.local_cols != table_config.embedding_dim else 0
)

is_grid_sharded: bool = (
True
if table_config.local_cols != table_config.embedding_dim
and table_config.local_rows != table_config.num_embeddings
else False
)

if all(
opt_state is not None for opt_state in shard_params.optimizer_states
):
Expand All @@ -431,6 +449,7 @@ def get_sharded_optim_state(momentum_idx: int) -> ShardedTensor:
table_config.global_metadata,
shard_params.optimizer_states[0][momentum_idx - 1],
sharding_dim,
is_grid_sharded,
)
else:
(
Expand Down
8 changes: 8 additions & 0 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
)
from torchrec.distributed.sharding.cw_sharding import CwPooledEmbeddingSharding
from torchrec.distributed.sharding.dp_sharding import DpPooledEmbeddingSharding
from torchrec.distributed.sharding.grid_sharding import GridPooledEmbeddingSharding
from torchrec.distributed.sharding.rw_sharding import RwPooledEmbeddingSharding
from torchrec.distributed.sharding.tw_sharding import TwPooledEmbeddingSharding
from torchrec.distributed.sharding.twcw_sharding import TwCwPooledEmbeddingSharding
Expand Down Expand Up @@ -193,6 +194,13 @@ def create_embedding_bag_sharding(
permute_embeddings=permute_embeddings,
qcomm_codecs_registry=qcomm_codecs_registry,
)
elif sharding_type == ShardingType.GRID_SHARD.value:
return GridPooledEmbeddingSharding(
sharding_infos,
env,
device,
qcomm_codecs_registry=qcomm_codecs_registry,
)
else:
raise ValueError(f"Sharding type not supported {sharding_type}")

Expand Down
3 changes: 2 additions & 1 deletion torchrec/distributed/planner/tests/test_proposers.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,8 @@ def test_grid_search_three_table(self) -> None:
So the total number of pruned options will be:
(num_sharding_types - 1) * 3 + 1 = 16
"""
num_pruned_options = (len(ShardingType) - 1) * 3 + 1
# NOTE - remove -2 from sharding type length once grid sharding in planner is added
num_pruned_options = (len(ShardingType) - 2) * 3 + 1
self.grid_search_proposer.load(search_space)
for (
sharding_options
Expand Down
Loading
Loading