diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index 090544a48..0da8df7d8 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -43,7 +43,7 @@ ) from fbgemm_gpu.tbe.ssd import ASSOC, SSDTableBatchedEmbeddingBags from torch import nn -from torchrec.distributed.comm import get_local_rank +from torchrec.distributed.comm import get_local_rank, get_local_size from torchrec.distributed.composable.table_batched_embedding_slice import ( TableBatchedEmbeddingSlice, ) @@ -215,29 +215,33 @@ def get_optimizer_rowwise_shard_metadata_and_global_metadata( table_global_metadata: ShardedTensorMetadata, optimizer_state: torch.Tensor, sharding_dim: int, + is_grid_sharded: bool = False, ) -> Tuple[Dict[ShardMetadata, ShardMetadata], ShardedTensorMetadata]: - table_global_shards_metadata: List[ShardMetadata] = ( table_global_metadata.shards_metadata ) - # column-wise sharding - # sort the metadata based on column offset and - # we construct the momentum tensor in row-wise sharded way if sharding_dim == 1: + # column-wise sharding + # sort the metadata based on column offset and + # we construct the momentum tensor in row-wise sharded way table_global_shards_metadata = sorted( table_global_shards_metadata, key=lambda shard: shard.shard_offsets[1], ) table_shard_metadata_to_optimizer_shard_metadata = {} - + rolling_offset = 0 for idx, table_shard_metadata in enumerate(table_global_shards_metadata): offset = table_shard_metadata.shard_offsets[0] - # for column-wise sharding, we still create row-wise sharded metadata for optimizer - # manually create a row-wise offset - if sharding_dim == 1: + if is_grid_sharded: + # we use a rolling offset to calculate the current offset for shard to account for uneven row wise case for our shards + offset = rolling_offset + rolling_offset += table_shard_metadata.shard_sizes[0] + elif sharding_dim == 1: + # for column-wise sharding, we still create row-wise sharded metadata for optimizer + # manually create a row-wise offset offset = idx * table_shard_metadata.shard_sizes[0] table_shard_metadata_to_optimizer_shard_metadata[ @@ -255,14 +259,22 @@ def get_optimizer_rowwise_shard_metadata_and_global_metadata( ) len_rw_shards = ( len(table_shard_metadata_to_optimizer_shard_metadata) - if sharding_dim == 1 + if sharding_dim == 1 and not is_grid_sharded + else 1 + ) + # for grid sharding, the row dimension is replicated CW shard times + grid_shard_nodes = ( + len(table_global_shards_metadata) // get_local_size() + if is_grid_sharded else 1 ) rowwise_optimizer_st_metadata = ShardedTensorMetadata( shards_metadata=list( table_shard_metadata_to_optimizer_shard_metadata.values() ), - size=torch.Size([table_global_metadata.size[0] * len_rw_shards]), + size=torch.Size( + [table_global_metadata.size[0] * len_rw_shards * grid_shard_nodes] + ), tensor_properties=tensor_properties, ) @@ -324,7 +336,6 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata( all_optimizer_states = emb_module.get_optimizer_state() optimizer_states_keys_by_table: Dict[str, List[torch.Tensor]] = {} - for ( table_config, optimizer_states, @@ -408,6 +419,13 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata( 1 if table_config.local_cols != table_config.embedding_dim else 0 ) + is_grid_sharded: bool = ( + True + if table_config.local_cols != table_config.embedding_dim + and table_config.local_rows != table_config.num_embeddings + else False + ) + if all( opt_state is not None for opt_state in shard_params.optimizer_states ): @@ -431,6 +449,7 @@ def get_sharded_optim_state(momentum_idx: int) -> ShardedTensor: table_config.global_metadata, shard_params.optimizer_states[0][momentum_idx - 1], sharding_dim, + is_grid_sharded, ) else: ( diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 31331a4e8..2e0c50067 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -47,6 +47,7 @@ ) from torchrec.distributed.sharding.cw_sharding import CwPooledEmbeddingSharding from torchrec.distributed.sharding.dp_sharding import DpPooledEmbeddingSharding +from torchrec.distributed.sharding.grid_sharding import GridPooledEmbeddingSharding from torchrec.distributed.sharding.rw_sharding import RwPooledEmbeddingSharding from torchrec.distributed.sharding.tw_sharding import TwPooledEmbeddingSharding from torchrec.distributed.sharding.twcw_sharding import TwCwPooledEmbeddingSharding @@ -193,6 +194,13 @@ def create_embedding_bag_sharding( permute_embeddings=permute_embeddings, qcomm_codecs_registry=qcomm_codecs_registry, ) + elif sharding_type == ShardingType.GRID_SHARD.value: + return GridPooledEmbeddingSharding( + sharding_infos, + env, + device, + qcomm_codecs_registry=qcomm_codecs_registry, + ) else: raise ValueError(f"Sharding type not supported {sharding_type}") diff --git a/torchrec/distributed/planner/tests/test_proposers.py b/torchrec/distributed/planner/tests/test_proposers.py index c34e38efc..9d17b1ba9 100644 --- a/torchrec/distributed/planner/tests/test_proposers.py +++ b/torchrec/distributed/planner/tests/test_proposers.py @@ -351,7 +351,8 @@ def test_grid_search_three_table(self) -> None: So the total number of pruned options will be: (num_sharding_types - 1) * 3 + 1 = 16 """ - num_pruned_options = (len(ShardingType) - 1) * 3 + 1 + # NOTE - remove -2 from sharding type length once grid sharding in planner is added + num_pruned_options = (len(ShardingType) - 2) * 3 + 1 self.grid_search_proposer.load(search_space) for ( sharding_options diff --git a/torchrec/distributed/sharding/grid_sharding.py b/torchrec/distributed/sharding/grid_sharding.py new file mode 100644 index 000000000..2b57371d5 --- /dev/null +++ b/torchrec/distributed/sharding/grid_sharding.py @@ -0,0 +1,480 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, TypeVar, Union + +import torch +import torch.distributed as dist +from fbgemm_gpu.permute_pooled_embedding_modules_split import ( + PermutePooledEmbeddingsSplit, +) +from torchrec.distributed.comm import get_local_size, intra_and_cross_node_pg +from torchrec.distributed.dist_data import ( + PooledEmbeddingsAllToAll, + PooledEmbeddingsReduceScatter, + VariableBatchPooledEmbeddingsAllToAll, + VariableBatchPooledEmbeddingsReduceScatter, +) +from torchrec.distributed.embedding_lookup import GroupedPooledEmbeddingsLookup +from torchrec.distributed.embedding_sharding import ( + BaseEmbeddingDist, + BaseEmbeddingLookup, + BaseSparseFeaturesDist, + EmbeddingSharding, + EmbeddingShardingContext, + EmbeddingShardingInfo, + group_tables, +) +from torchrec.distributed.embedding_types import ( + BaseGroupedFeatureProcessor, + EmbeddingComputeKernel, + GroupedEmbeddingConfig, + ShardedEmbeddingTable, +) +from torchrec.distributed.sharding.twrw_sharding import TwRwSparseFeaturesDist +from torchrec.distributed.types import ( + Awaitable, + CommOp, + QuantizedCommCodecs, + ShardedTensorMetadata, + ShardingEnv, + ShardingType, + ShardMetadata, +) +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.streamable import Multistreamable + +C = TypeVar("C", bound=Multistreamable) +F = TypeVar("F", bound=Multistreamable) +T = TypeVar("T") +W = TypeVar("W") + + +class BaseGridEmbeddingSharding(EmbeddingSharding[C, F, T, W]): + """ + Base class for grid sharding. + """ + + def __init__( + self, + sharding_infos: List[EmbeddingShardingInfo], + env: ShardingEnv, + device: Optional[torch.device] = None, + need_pos: bool = False, + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + ) -> None: + super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) + self._env = env + self._pg: Optional[dist.ProcessGroup] = self._env.process_group + self._world_size: int = self._env.world_size + self._rank: int = self._env.rank + self._device = device + self._need_pos = need_pos + self._embedding_names: List[str] = [] + self._embedding_dims: List[int] = [] + self._embedding_order: List[int] = [] + + self._combined_embedding_names: List[str] = [] + self._combined_embedding_dims: List[int] = [] + intra_pg, cross_pg = intra_and_cross_node_pg( + device, backend=dist.get_backend(self._pg) + ) + self._intra_pg: Optional[dist.ProcessGroup] = intra_pg + self._cross_pg: Optional[dist.ProcessGroup] = cross_pg + self._local_size: int = ( + intra_pg.size() if intra_pg else get_local_size(self._world_size) + ) + + sharded_tables_per_rank = self._shard(sharding_infos) + self._grouped_embedding_configs_per_rank: List[List[GroupedEmbeddingConfig]] = ( + [] + ) + self._grouped_embedding_configs_per_node: List[List[GroupedEmbeddingConfig]] = ( + [] + ) + self._grouped_embedding_configs_per_rank = group_tables(sharded_tables_per_rank) + self._grouped_embedding_configs_per_node = [ + self._grouped_embedding_configs_per_rank[rank] + for rank in range(self._world_size) + if rank % self._local_size == 0 + ] + self._has_feature_processor: bool = False + for group_config in self._grouped_embedding_configs_per_rank[ + self._rank // self._local_size + ]: + if group_config.has_feature_processor: + self._has_feature_processor = True + + self._init_combined_embeddings() + + def _init_combined_embeddings(self) -> None: + """ + similar to CW sharding, but this time each CW shard is on a node and not rank + """ + embedding_names = [] + for grouped_embedding_configs in self._grouped_embedding_configs_per_node: + for grouped_config in grouped_embedding_configs: + embedding_names.extend(grouped_config.embedding_names()) + + embedding_dims = [] + for grouped_embedding_configs in self._grouped_embedding_configs_per_node: + for grouped_config in grouped_embedding_configs: + embedding_dims.extend(grouped_config.embedding_dims()) + + embedding_shard_metadata = self.embedding_shard_metadata() + + embedding_name_to_index_offset_tuples: Dict[str, List[Tuple[int, int]]] = {} + for i, (name, metadata) in enumerate( + zip(embedding_names, embedding_shard_metadata) + ): + if name not in embedding_name_to_index_offset_tuples: + embedding_name_to_index_offset_tuples[name] = [] + # find index of each of the offset by column (CW sharding so only col dim changes) + embedding_name_to_index_offset_tuples[name].append( + (i, metadata.shard_offsets[1] if metadata is not None else 0) + ) + + # sort the index offset tuples by offset and then grab the associated index + embedding_name_to_index: Dict[str, List[int]] = {} + for name, index_offset_tuples in embedding_name_to_index_offset_tuples.items(): + embedding_name_to_index[name] = [ + idx_off_tuple[0] + for idx_off_tuple in sorted( + index_offset_tuples, + key=lambda idx_off_tuple: idx_off_tuple[1], + ) + ] + + combined_embedding_names: List[str] = [] + seen_embedding_names: Set[str] = set() + + for name in embedding_names: + if name not in seen_embedding_names: + combined_embedding_names.append(name) + seen_embedding_names.add(name) + + combined_embedding_dims: List[int] = [] + + embedding_order: List[int] = [] + for name in combined_embedding_names: + combined_embedding_dims.append( + sum([embedding_dims[idx] for idx in embedding_name_to_index[name]]) + ) + embedding_order.extend(embedding_name_to_index[name]) + + self._embedding_names: List[str] = embedding_names + self._embedding_dims: List[int] = embedding_dims + self._embedding_order: List[int] = embedding_order + + self._combined_embedding_names: List[str] = combined_embedding_names + self._combined_embedding_dims: List[int] = combined_embedding_dims + + def _shard( + self, + sharding_infos: List[EmbeddingShardingInfo], + ) -> List[List[ShardedEmbeddingTable]]: + world_size = self._world_size + tables_per_rank: List[List[ShardedEmbeddingTable]] = [ + [] for i in range(world_size) + ] + for info in sharding_infos: + # pyre-fixme [16] + shards = info.param_sharding.sharding_spec.shards + + # construct the global sharded_tensor_metadata + global_metadata = ShardedTensorMetadata( + shards_metadata=shards, + size=torch.Size( + [ + info.embedding_config.num_embeddings, + info.embedding_config.embedding_dim, + ] + ), + ) + + # expectation is planner CW shards across a node, so each CW shard will have local_size num row shards + # pyre-fixme [6] + for i, rank in enumerate(info.param_sharding.ranks): + tables_per_rank[rank].append( + ShardedEmbeddingTable( + num_embeddings=info.embedding_config.num_embeddings, + embedding_dim=info.embedding_config.embedding_dim, + name=info.embedding_config.name, + embedding_names=info.embedding_config.embedding_names, + data_type=info.embedding_config.data_type, + feature_names=info.embedding_config.feature_names, + pooling=info.embedding_config.pooling, + is_weighted=info.embedding_config.is_weighted, + has_feature_processor=info.embedding_config.has_feature_processor, + # sharding by row and col + local_rows=shards[i].shard_sizes[0], + local_cols=shards[i].shard_sizes[1], + compute_kernel=EmbeddingComputeKernel( + info.param_sharding.compute_kernel + ), + local_metadata=shards[i], + global_metadata=global_metadata, + weight_init_max=info.embedding_config.weight_init_max, + weight_init_min=info.embedding_config.weight_init_min, + fused_params=info.fused_params, + ) + ) + + return tables_per_rank + + def embedding_dims(self) -> List[int]: + return self._combined_embedding_dims + + def embedding_names(self) -> List[str]: + return self._combined_embedding_names + + def embedding_names_per_rank(self) -> List[List[str]]: + raise NotImplementedError + + def embedding_shard_metadata(self) -> List[Optional[ShardMetadata]]: + embedding_shard_metadata = [] + for grouped_config in self._grouped_embedding_configs_per_node: + for config in grouped_config: + embedding_shard_metadata.extend(config.embedding_shard_metadata()) + return embedding_shard_metadata + + def feature_names(self) -> List[str]: + feature_names = [] + for grouped_config in self._grouped_embedding_configs_per_node: + for config in grouped_config: + feature_names.extend(config.feature_names()) + return feature_names + + def _get_feature_hash_sizes(self) -> List[int]: + feature_hash_sizes: List[int] = [] + for grouped_config in self._grouped_embedding_configs_per_node: + for config in grouped_config: + feature_hash_sizes.extend(config.feature_hash_sizes()) + return feature_hash_sizes + + def _dim_sum_per_node(self) -> List[int]: + dim_sum_per_node = [] + for grouped_embedding_configs in self._grouped_embedding_configs_per_node: + dim_sum = 0 + for grouped_config in grouped_embedding_configs: + dim_sum += grouped_config.dim_sum() + dim_sum_per_node.append(dim_sum) + return dim_sum_per_node + + def _emb_dim_per_node_per_feature(self) -> List[List[int]]: + emb_dim_per_node_per_feature = [] + for grouped_embedding_configs in self._grouped_embedding_configs_per_node: + emb_dim_per_feature = [] + for grouped_config in grouped_embedding_configs: + emb_dim_per_feature += grouped_config.embedding_dims() + emb_dim_per_node_per_feature.append(emb_dim_per_feature) + return emb_dim_per_node_per_feature + + def _features_per_rank( + self, group: List[List[GroupedEmbeddingConfig]] + ) -> List[int]: + features_per_rank = [] + for grouped_embedding_configs in group: + num_features = 0 + for grouped_config in grouped_embedding_configs: + num_features += grouped_config.num_features() + features_per_rank.append(num_features) + return features_per_rank + + +class GridPooledEmbeddingDist( + BaseEmbeddingDist[EmbeddingShardingContext, torch.Tensor, torch.Tensor] +): + def __init__( + self, + rank: int, + cross_pg: dist.ProcessGroup, + intra_pg: dist.ProcessGroup, + dim_sum_per_node: List[int], + emb_dim_per_node_per_feature: List[List[int]], + device: Optional[torch.device] = None, + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + callbacks: Optional[List[Callable[[torch.Tensor], torch.Tensor]]] = None, + ) -> None: + super().__init__() + self._rank = rank + self._intra_pg: dist.ProcessGroup = intra_pg + self._cross_pg: dist.ProcessGroup = cross_pg + self._dim_sum_per_node = dim_sum_per_node + self._emb_dim_per_node_per_feature = emb_dim_per_node_per_feature + self._device = device + self._intra_codecs: Optional[QuantizedCommCodecs] = ( + qcomm_codecs_registry.get( + CommOp.POOLED_EMBEDDINGS_REDUCE_SCATTER.name, None + ) + if qcomm_codecs_registry + else None + ) + self._cross_codecs: Optional[QuantizedCommCodecs] = ( + qcomm_codecs_registry.get(CommOp.POOLED_EMBEDDINGS_ALL_TO_ALL.name, None) + if qcomm_codecs_registry + else None + ) + self._intra_dist: Optional[ + Union[ + PooledEmbeddingsReduceScatter, + VariableBatchPooledEmbeddingsReduceScatter, + ] + ] = None + self._cross_dist: Optional[ + Union[ + PooledEmbeddingsAllToAll, + VariableBatchPooledEmbeddingsAllToAll, + ] + ] = None + self._callbacks = callbacks + + def forward( + self, + local_embs: torch.Tensor, + sharding_ctx: Optional[EmbeddingShardingContext] = None, + ) -> Awaitable[torch.Tensor]: + """ + Performs reduce-scatter pooled operation on pooled embeddings tensor followed by + AlltoAll pooled operation. + + Args: + local_embs (torch.Tensor): pooled embeddings tensor to distribute. + + Returns: + Awaitable[torch.Tensor]: awaitable of pooled embeddings tensor. + """ + if self._intra_dist is None or self._cross_dist is None: + self._create_output_dist_modules(sharding_ctx) + local_rank = self._rank % self._intra_pg.size() + if sharding_ctx is not None and len(set(sharding_ctx.batch_size_per_rank)) > 1: + # preprocess batch_size_per_rank + ( + batch_size_per_rank_by_cross_group, + batch_size_sum_by_cross_group, + ) = self._preprocess_batch_size_per_rank( + self._intra_pg.size(), + self._cross_pg.size(), + sharding_ctx.batch_size_per_rank, + ) + # Perform ReduceScatterV within one host + rs_result = cast(PooledEmbeddingsReduceScatter, self._intra_dist)( + local_embs, input_splits=batch_size_sum_by_cross_group + ).wait() + return cast(PooledEmbeddingsAllToAll, self._cross_dist)( + rs_result, + batch_size_per_rank=batch_size_per_rank_by_cross_group[local_rank], + ) + else: + return cast(PooledEmbeddingsAllToAll, self._cross_dist)( + cast(PooledEmbeddingsReduceScatter, self._intra_dist)(local_embs).wait() + ) + + def _preprocess_batch_size_per_rank( + self, local_size: int, nodes: int, batch_size_per_rank: List[int] + ) -> Tuple[List[List[int]], List[int]]: + """ + Reorders `batch_size_per_rank` so it's aligned with reordered features after + AlltoAll. + """ + batch_size_per_rank_by_cross_group: List[List[int]] = [] + batch_size_sum_by_cross_group: List[int] = [] + for local_rank in range(local_size): + batch_size_per_rank_: List[int] = [] + batch_size_sum = 0 + for node in range(nodes): + batch_size_per_rank_.append( + batch_size_per_rank[local_rank + node * local_size] + ) + batch_size_sum += batch_size_per_rank[local_rank + node * local_size] + batch_size_per_rank_by_cross_group.append(batch_size_per_rank_) + batch_size_sum_by_cross_group.append(batch_size_sum) + + return batch_size_per_rank_by_cross_group, batch_size_sum_by_cross_group + + def _create_output_dist_modules( + self, sharding_ctx: Optional[EmbeddingShardingContext] = None + ) -> None: + self._intra_dist = PooledEmbeddingsReduceScatter( + pg=self._intra_pg, + codecs=self._intra_codecs, + ) + self._cross_dist = PooledEmbeddingsAllToAll( + pg=self._cross_pg, + dim_sum_per_rank=self._dim_sum_per_node, + device=self._device, + codecs=self._cross_codecs, + callbacks=self._callbacks, + ) + + +class GridPooledEmbeddingSharding( + BaseGridEmbeddingSharding[ + EmbeddingShardingContext, KeyedJaggedTensor, torch.Tensor, torch.Tensor + ] +): + """ + Shards embedding bags table-wise then row-wise. + """ + + def create_input_dist( + self, device: Optional[torch.device] = None + ) -> BaseSparseFeaturesDist[KeyedJaggedTensor]: + features_per_rank = self._features_per_rank( + self._grouped_embedding_configs_per_rank + ) + feature_hash_sizes = self._get_feature_hash_sizes() + assert self._pg is not None + assert self._intra_pg is not None + return TwRwSparseFeaturesDist( + pg=self._pg, + local_size=self._intra_pg.size(), + features_per_rank=features_per_rank, + feature_hash_sizes=feature_hash_sizes, + device=device if device is not None else self._device, + has_feature_processor=self._has_feature_processor, + need_pos=self._need_pos, + ) + + def create_lookup( + self, + device: Optional[torch.device] = None, + fused_params: Optional[Dict[str, Any]] = None, + feature_processor: Optional[BaseGroupedFeatureProcessor] = None, + ) -> BaseEmbeddingLookup: + return GroupedPooledEmbeddingsLookup( + grouped_configs=self._grouped_embedding_configs_per_rank[self._rank], + pg=self._pg, + device=device if device is not None else self._device, + feature_processor=feature_processor, + sharding_type=ShardingType.TABLE_ROW_WISE, + ) + + def create_output_dist( + self, + device: Optional[torch.device] = None, + ) -> BaseEmbeddingDist[EmbeddingShardingContext, torch.Tensor, torch.Tensor]: + embedding_permute_op: Optional[PermutePooledEmbeddingsSplit] = None + callbacks: Optional[List[Callable[[torch.Tensor], torch.Tensor]]] = None + if self._embedding_order != list(range(len(self._embedding_order))): + assert len(self._embedding_order) == len(self._embedding_dims) + embedding_permute_op = PermutePooledEmbeddingsSplit( + self._embedding_dims, self._embedding_order, device=self._device + ) + callbacks = [embedding_permute_op] + return GridPooledEmbeddingDist( + rank=self._rank, + cross_pg=cast(dist.ProcessGroup, self._cross_pg), + intra_pg=cast(dist.ProcessGroup, self._intra_pg), + dim_sum_per_node=self._dim_sum_per_node(), + emb_dim_per_node_per_feature=self._emb_dim_per_node_per_feature(), + device=device if device is not None else self._device, + qcomm_codecs_registry=self.qcomm_codecs_registry, + callbacks=callbacks, + ) diff --git a/torchrec/distributed/test_utils/test_model_parallel.py b/torchrec/distributed/test_utils/test_model_parallel.py index dcc31a9a6..c9bc22ab3 100644 --- a/torchrec/distributed/test_utils/test_model_parallel.py +++ b/torchrec/distributed/test_utils/test_model_parallel.py @@ -662,3 +662,153 @@ def test_sharding_multiple_kernels(self, sharding_type: str) -> None: variable_batch_per_feature=True, has_weighted_tables=False, ) + + @unittest.skipIf( + torch.cuda.device_count() <= 3, + "Not enough GPUs, this test requires at least four GPUs", + ) + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + ], + ), + qcomms_config=st.sampled_from( + [ + None, + QCommsConfig( + forward_precision=CommType.FP16, backward_precision=CommType.BF16 + ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + None, + { + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), + "embeddings": (torch.optim.SGD, {"lr": 0.2}), + }, + ] + ), + pooling=st.sampled_from([PoolingType.SUM, PoolingType.MEAN]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) + def test_sharding_grid( + self, + sharder_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + pooling: PoolingType, + ) -> None: + self._test_sharding( + # pyre-ignore[6] + sharders=[ + create_test_sharder( + sharder_type, + ShardingType.GRID_SHARD.value, + kernel_type, + qcomms_config=qcomms_config, + device=self.device, + ), + ], + world_size=4, + local_size=2, + backend=self.backend, + qcomms_config=qcomms_config, + constraints={ + "table_0": ParameterConstraints(min_partition=8), + "table_1": ParameterConstraints(min_partition=12), + "table_2": ParameterConstraints(min_partition=16), + "table_3": ParameterConstraints(min_partition=20), + "table_4": ParameterConstraints(min_partition=8), + "table_5": ParameterConstraints(min_partition=12), + "weighted_table_0": ParameterConstraints(min_partition=8), + "weighted_table_1": ParameterConstraints(min_partition=12), + }, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + pooling=pooling, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 7, + "Not enough GPUs, this test requires at least eight GPUs", + ) + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + ], + ), + qcomms_config=st.sampled_from( + [ + None, + QCommsConfig( + forward_precision=CommType.FP16, backward_precision=CommType.BF16 + ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + None, + { + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), + "embeddings": (torch.optim.SGD, {"lr": 0.2}), + }, + ] + ), + pooling=st.sampled_from([PoolingType.SUM, PoolingType.MEAN]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) + def test_sharding_grid_8gpu( + self, + sharder_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + pooling: PoolingType, + ) -> None: + self._test_sharding( + # pyre-ignore[6] + sharders=[ + create_test_sharder( + sharder_type, + ShardingType.GRID_SHARD.value, + kernel_type, + qcomms_config=qcomms_config, + device=self.device, + ), + ], + world_size=8, + local_size=2, + backend=self.backend, + qcomms_config=qcomms_config, + constraints={ + "table_0": ParameterConstraints(min_partition=8), + "table_1": ParameterConstraints(min_partition=12), + "table_2": ParameterConstraints(min_partition=8), + "table_3": ParameterConstraints(min_partition=10), + "table_4": ParameterConstraints(min_partition=4), + "table_5": ParameterConstraints(min_partition=6), + "weighted_table_0": ParameterConstraints(min_partition=2), + "weighted_table_1": ParameterConstraints(min_partition=3), + }, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + pooling=pooling, + ) diff --git a/torchrec/distributed/test_utils/test_sharding.py b/torchrec/distributed/test_utils/test_sharding.py index a5edb56f7..02fafafeb 100644 --- a/torchrec/distributed/test_utils/test_sharding.py +++ b/torchrec/distributed/test_utils/test_sharding.py @@ -369,6 +369,7 @@ def sharding_single_rank_test( in { ShardingType.TABLE_ROW_WISE.value, ShardingType.TABLE_COLUMN_WISE.value, + ShardingType.GRID_SHARD.value, } and ctx.device.type != "cpu" ): diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index 19b48f5f8..a73deab0c 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -135,6 +135,8 @@ class ShardingType(Enum): TABLE_ROW_WISE = "table_row_wise" # Column-wise on the same node and table-wise across nodes TABLE_COLUMN_WISE = "table_column_wise" + # Grid sharding, where each rank gets a subset of columns and rows in a CW and TWRW style + GRID_SHARD = "grid_shard" class PipelineType(Enum):