From 664ad30fb91c8dc814bc2a8e0b94feb339b90013 Mon Sep 17 00:00:00 2001 From: Zain Huda Date: Mon, 30 Sep 2024 11:29:20 -0700 Subject: [PATCH] grid based sharding for EBC (#2445) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2445 ## Introducing Grid based sharding for EBC This is a form of CW sharding and then TWRW sharding the respective CW shards. One of the key changes is how the metadata from sharding placements is constructed in grid sharding. We leverage the concept of `per_node` from TWRW and combine it with the permutations and concatenation required in CW. #### Comms Design Combining optimizations from both CW and TWRW to have highly performant code. Sparse feature distribution is done akin to the TWRW process where the input KJT is bucketized and permuted (according to the stagger indices) and AlltoAll to the correponding ranks that require their part of the input. Similarly for the embedding all to all, we first all reduce within the node to get the embedding lookup for the CW shard on that node. We're able to leverage the intra node comms for this stage. Then the reduce scatter, just like TWRW, where the embeddings are split to each rank to hold it's and its corresponding cross group ranks embeddings which are then shared in a AlltoAll call. Lastly, the `PermutePooledEmbeddingsSplit` callback is called to rearrange the embedding lookup appropriately (cat the CW lookups in the right order). #### Optimizer Sharding Fused optimizer sharding is also updated, we needed to fix how row wise optimizer states are constructed since optimizers for CW shards are row wise sharded. For grid sharding, this doesn't work since the row wise shards are repeated for each CW shard as well as we can encounter the uneven row wise case which is not possible in CW sharding. For grid shards the approach is to use a rolling offset from the previous shard which solves for both uneven row wise shards and the repeated CW shards. NOTE: Bypasses added in planner to pass CI, which are to be removed in forthcoming diff Reviewed By: dstaay-fb Differential Revision: D62594442 --- .../distributed/batched_embedding_kernel.py | 43 +- torchrec/distributed/embeddingbag.py | 8 + .../planner/tests/test_proposers.py | 3 +- .../distributed/sharding/grid_sharding.py | 480 ++++++++++++++++++ .../test_utils/test_model_parallel.py | 150 ++++++ .../distributed/test_utils/test_sharding.py | 1 + torchrec/distributed/types.py | 2 + 7 files changed, 674 insertions(+), 13 deletions(-) create mode 100644 torchrec/distributed/sharding/grid_sharding.py diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index 090544a48..0da8df7d8 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -43,7 +43,7 @@ ) from fbgemm_gpu.tbe.ssd import ASSOC, SSDTableBatchedEmbeddingBags from torch import nn -from torchrec.distributed.comm import get_local_rank +from torchrec.distributed.comm import get_local_rank, get_local_size from torchrec.distributed.composable.table_batched_embedding_slice import ( TableBatchedEmbeddingSlice, ) @@ -215,29 +215,33 @@ def get_optimizer_rowwise_shard_metadata_and_global_metadata( table_global_metadata: ShardedTensorMetadata, optimizer_state: torch.Tensor, sharding_dim: int, + is_grid_sharded: bool = False, ) -> Tuple[Dict[ShardMetadata, ShardMetadata], ShardedTensorMetadata]: - table_global_shards_metadata: List[ShardMetadata] = ( table_global_metadata.shards_metadata ) - # column-wise sharding - # sort the metadata based on column offset and - # we construct the momentum tensor in row-wise sharded way if sharding_dim == 1: + # column-wise sharding + # sort the metadata based on column offset and + # we construct the momentum tensor in row-wise sharded way table_global_shards_metadata = sorted( table_global_shards_metadata, key=lambda shard: shard.shard_offsets[1], ) table_shard_metadata_to_optimizer_shard_metadata = {} - + rolling_offset = 0 for idx, table_shard_metadata in enumerate(table_global_shards_metadata): offset = table_shard_metadata.shard_offsets[0] - # for column-wise sharding, we still create row-wise sharded metadata for optimizer - # manually create a row-wise offset - if sharding_dim == 1: + if is_grid_sharded: + # we use a rolling offset to calculate the current offset for shard to account for uneven row wise case for our shards + offset = rolling_offset + rolling_offset += table_shard_metadata.shard_sizes[0] + elif sharding_dim == 1: + # for column-wise sharding, we still create row-wise sharded metadata for optimizer + # manually create a row-wise offset offset = idx * table_shard_metadata.shard_sizes[0] table_shard_metadata_to_optimizer_shard_metadata[ @@ -255,14 +259,22 @@ def get_optimizer_rowwise_shard_metadata_and_global_metadata( ) len_rw_shards = ( len(table_shard_metadata_to_optimizer_shard_metadata) - if sharding_dim == 1 + if sharding_dim == 1 and not is_grid_sharded + else 1 + ) + # for grid sharding, the row dimension is replicated CW shard times + grid_shard_nodes = ( + len(table_global_shards_metadata) // get_local_size() + if is_grid_sharded else 1 ) rowwise_optimizer_st_metadata = ShardedTensorMetadata( shards_metadata=list( table_shard_metadata_to_optimizer_shard_metadata.values() ), - size=torch.Size([table_global_metadata.size[0] * len_rw_shards]), + size=torch.Size( + [table_global_metadata.size[0] * len_rw_shards * grid_shard_nodes] + ), tensor_properties=tensor_properties, ) @@ -324,7 +336,6 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata( all_optimizer_states = emb_module.get_optimizer_state() optimizer_states_keys_by_table: Dict[str, List[torch.Tensor]] = {} - for ( table_config, optimizer_states, @@ -408,6 +419,13 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata( 1 if table_config.local_cols != table_config.embedding_dim else 0 ) + is_grid_sharded: bool = ( + True + if table_config.local_cols != table_config.embedding_dim + and table_config.local_rows != table_config.num_embeddings + else False + ) + if all( opt_state is not None for opt_state in shard_params.optimizer_states ): @@ -431,6 +449,7 @@ def get_sharded_optim_state(momentum_idx: int) -> ShardedTensor: table_config.global_metadata, shard_params.optimizer_states[0][momentum_idx - 1], sharding_dim, + is_grid_sharded, ) else: ( diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 31331a4e8..2e0c50067 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -47,6 +47,7 @@ ) from torchrec.distributed.sharding.cw_sharding import CwPooledEmbeddingSharding from torchrec.distributed.sharding.dp_sharding import DpPooledEmbeddingSharding +from torchrec.distributed.sharding.grid_sharding import GridPooledEmbeddingSharding from torchrec.distributed.sharding.rw_sharding import RwPooledEmbeddingSharding from torchrec.distributed.sharding.tw_sharding import TwPooledEmbeddingSharding from torchrec.distributed.sharding.twcw_sharding import TwCwPooledEmbeddingSharding @@ -193,6 +194,13 @@ def create_embedding_bag_sharding( permute_embeddings=permute_embeddings, qcomm_codecs_registry=qcomm_codecs_registry, ) + elif sharding_type == ShardingType.GRID_SHARD.value: + return GridPooledEmbeddingSharding( + sharding_infos, + env, + device, + qcomm_codecs_registry=qcomm_codecs_registry, + ) else: raise ValueError(f"Sharding type not supported {sharding_type}") diff --git a/torchrec/distributed/planner/tests/test_proposers.py b/torchrec/distributed/planner/tests/test_proposers.py index c34e38efc..9d17b1ba9 100644 --- a/torchrec/distributed/planner/tests/test_proposers.py +++ b/torchrec/distributed/planner/tests/test_proposers.py @@ -351,7 +351,8 @@ def test_grid_search_three_table(self) -> None: So the total number of pruned options will be: (num_sharding_types - 1) * 3 + 1 = 16 """ - num_pruned_options = (len(ShardingType) - 1) * 3 + 1 + # NOTE - remove -2 from sharding type length once grid sharding in planner is added + num_pruned_options = (len(ShardingType) - 2) * 3 + 1 self.grid_search_proposer.load(search_space) for ( sharding_options diff --git a/torchrec/distributed/sharding/grid_sharding.py b/torchrec/distributed/sharding/grid_sharding.py new file mode 100644 index 000000000..2b57371d5 --- /dev/null +++ b/torchrec/distributed/sharding/grid_sharding.py @@ -0,0 +1,480 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, TypeVar, Union + +import torch +import torch.distributed as dist +from fbgemm_gpu.permute_pooled_embedding_modules_split import ( + PermutePooledEmbeddingsSplit, +) +from torchrec.distributed.comm import get_local_size, intra_and_cross_node_pg +from torchrec.distributed.dist_data import ( + PooledEmbeddingsAllToAll, + PooledEmbeddingsReduceScatter, + VariableBatchPooledEmbeddingsAllToAll, + VariableBatchPooledEmbeddingsReduceScatter, +) +from torchrec.distributed.embedding_lookup import GroupedPooledEmbeddingsLookup +from torchrec.distributed.embedding_sharding import ( + BaseEmbeddingDist, + BaseEmbeddingLookup, + BaseSparseFeaturesDist, + EmbeddingSharding, + EmbeddingShardingContext, + EmbeddingShardingInfo, + group_tables, +) +from torchrec.distributed.embedding_types import ( + BaseGroupedFeatureProcessor, + EmbeddingComputeKernel, + GroupedEmbeddingConfig, + ShardedEmbeddingTable, +) +from torchrec.distributed.sharding.twrw_sharding import TwRwSparseFeaturesDist +from torchrec.distributed.types import ( + Awaitable, + CommOp, + QuantizedCommCodecs, + ShardedTensorMetadata, + ShardingEnv, + ShardingType, + ShardMetadata, +) +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.streamable import Multistreamable + +C = TypeVar("C", bound=Multistreamable) +F = TypeVar("F", bound=Multistreamable) +T = TypeVar("T") +W = TypeVar("W") + + +class BaseGridEmbeddingSharding(EmbeddingSharding[C, F, T, W]): + """ + Base class for grid sharding. + """ + + def __init__( + self, + sharding_infos: List[EmbeddingShardingInfo], + env: ShardingEnv, + device: Optional[torch.device] = None, + need_pos: bool = False, + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + ) -> None: + super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) + self._env = env + self._pg: Optional[dist.ProcessGroup] = self._env.process_group + self._world_size: int = self._env.world_size + self._rank: int = self._env.rank + self._device = device + self._need_pos = need_pos + self._embedding_names: List[str] = [] + self._embedding_dims: List[int] = [] + self._embedding_order: List[int] = [] + + self._combined_embedding_names: List[str] = [] + self._combined_embedding_dims: List[int] = [] + intra_pg, cross_pg = intra_and_cross_node_pg( + device, backend=dist.get_backend(self._pg) + ) + self._intra_pg: Optional[dist.ProcessGroup] = intra_pg + self._cross_pg: Optional[dist.ProcessGroup] = cross_pg + self._local_size: int = ( + intra_pg.size() if intra_pg else get_local_size(self._world_size) + ) + + sharded_tables_per_rank = self._shard(sharding_infos) + self._grouped_embedding_configs_per_rank: List[List[GroupedEmbeddingConfig]] = ( + [] + ) + self._grouped_embedding_configs_per_node: List[List[GroupedEmbeddingConfig]] = ( + [] + ) + self._grouped_embedding_configs_per_rank = group_tables(sharded_tables_per_rank) + self._grouped_embedding_configs_per_node = [ + self._grouped_embedding_configs_per_rank[rank] + for rank in range(self._world_size) + if rank % self._local_size == 0 + ] + self._has_feature_processor: bool = False + for group_config in self._grouped_embedding_configs_per_rank[ + self._rank // self._local_size + ]: + if group_config.has_feature_processor: + self._has_feature_processor = True + + self._init_combined_embeddings() + + def _init_combined_embeddings(self) -> None: + """ + similar to CW sharding, but this time each CW shard is on a node and not rank + """ + embedding_names = [] + for grouped_embedding_configs in self._grouped_embedding_configs_per_node: + for grouped_config in grouped_embedding_configs: + embedding_names.extend(grouped_config.embedding_names()) + + embedding_dims = [] + for grouped_embedding_configs in self._grouped_embedding_configs_per_node: + for grouped_config in grouped_embedding_configs: + embedding_dims.extend(grouped_config.embedding_dims()) + + embedding_shard_metadata = self.embedding_shard_metadata() + + embedding_name_to_index_offset_tuples: Dict[str, List[Tuple[int, int]]] = {} + for i, (name, metadata) in enumerate( + zip(embedding_names, embedding_shard_metadata) + ): + if name not in embedding_name_to_index_offset_tuples: + embedding_name_to_index_offset_tuples[name] = [] + # find index of each of the offset by column (CW sharding so only col dim changes) + embedding_name_to_index_offset_tuples[name].append( + (i, metadata.shard_offsets[1] if metadata is not None else 0) + ) + + # sort the index offset tuples by offset and then grab the associated index + embedding_name_to_index: Dict[str, List[int]] = {} + for name, index_offset_tuples in embedding_name_to_index_offset_tuples.items(): + embedding_name_to_index[name] = [ + idx_off_tuple[0] + for idx_off_tuple in sorted( + index_offset_tuples, + key=lambda idx_off_tuple: idx_off_tuple[1], + ) + ] + + combined_embedding_names: List[str] = [] + seen_embedding_names: Set[str] = set() + + for name in embedding_names: + if name not in seen_embedding_names: + combined_embedding_names.append(name) + seen_embedding_names.add(name) + + combined_embedding_dims: List[int] = [] + + embedding_order: List[int] = [] + for name in combined_embedding_names: + combined_embedding_dims.append( + sum([embedding_dims[idx] for idx in embedding_name_to_index[name]]) + ) + embedding_order.extend(embedding_name_to_index[name]) + + self._embedding_names: List[str] = embedding_names + self._embedding_dims: List[int] = embedding_dims + self._embedding_order: List[int] = embedding_order + + self._combined_embedding_names: List[str] = combined_embedding_names + self._combined_embedding_dims: List[int] = combined_embedding_dims + + def _shard( + self, + sharding_infos: List[EmbeddingShardingInfo], + ) -> List[List[ShardedEmbeddingTable]]: + world_size = self._world_size + tables_per_rank: List[List[ShardedEmbeddingTable]] = [ + [] for i in range(world_size) + ] + for info in sharding_infos: + # pyre-fixme [16] + shards = info.param_sharding.sharding_spec.shards + + # construct the global sharded_tensor_metadata + global_metadata = ShardedTensorMetadata( + shards_metadata=shards, + size=torch.Size( + [ + info.embedding_config.num_embeddings, + info.embedding_config.embedding_dim, + ] + ), + ) + + # expectation is planner CW shards across a node, so each CW shard will have local_size num row shards + # pyre-fixme [6] + for i, rank in enumerate(info.param_sharding.ranks): + tables_per_rank[rank].append( + ShardedEmbeddingTable( + num_embeddings=info.embedding_config.num_embeddings, + embedding_dim=info.embedding_config.embedding_dim, + name=info.embedding_config.name, + embedding_names=info.embedding_config.embedding_names, + data_type=info.embedding_config.data_type, + feature_names=info.embedding_config.feature_names, + pooling=info.embedding_config.pooling, + is_weighted=info.embedding_config.is_weighted, + has_feature_processor=info.embedding_config.has_feature_processor, + # sharding by row and col + local_rows=shards[i].shard_sizes[0], + local_cols=shards[i].shard_sizes[1], + compute_kernel=EmbeddingComputeKernel( + info.param_sharding.compute_kernel + ), + local_metadata=shards[i], + global_metadata=global_metadata, + weight_init_max=info.embedding_config.weight_init_max, + weight_init_min=info.embedding_config.weight_init_min, + fused_params=info.fused_params, + ) + ) + + return tables_per_rank + + def embedding_dims(self) -> List[int]: + return self._combined_embedding_dims + + def embedding_names(self) -> List[str]: + return self._combined_embedding_names + + def embedding_names_per_rank(self) -> List[List[str]]: + raise NotImplementedError + + def embedding_shard_metadata(self) -> List[Optional[ShardMetadata]]: + embedding_shard_metadata = [] + for grouped_config in self._grouped_embedding_configs_per_node: + for config in grouped_config: + embedding_shard_metadata.extend(config.embedding_shard_metadata()) + return embedding_shard_metadata + + def feature_names(self) -> List[str]: + feature_names = [] + for grouped_config in self._grouped_embedding_configs_per_node: + for config in grouped_config: + feature_names.extend(config.feature_names()) + return feature_names + + def _get_feature_hash_sizes(self) -> List[int]: + feature_hash_sizes: List[int] = [] + for grouped_config in self._grouped_embedding_configs_per_node: + for config in grouped_config: + feature_hash_sizes.extend(config.feature_hash_sizes()) + return feature_hash_sizes + + def _dim_sum_per_node(self) -> List[int]: + dim_sum_per_node = [] + for grouped_embedding_configs in self._grouped_embedding_configs_per_node: + dim_sum = 0 + for grouped_config in grouped_embedding_configs: + dim_sum += grouped_config.dim_sum() + dim_sum_per_node.append(dim_sum) + return dim_sum_per_node + + def _emb_dim_per_node_per_feature(self) -> List[List[int]]: + emb_dim_per_node_per_feature = [] + for grouped_embedding_configs in self._grouped_embedding_configs_per_node: + emb_dim_per_feature = [] + for grouped_config in grouped_embedding_configs: + emb_dim_per_feature += grouped_config.embedding_dims() + emb_dim_per_node_per_feature.append(emb_dim_per_feature) + return emb_dim_per_node_per_feature + + def _features_per_rank( + self, group: List[List[GroupedEmbeddingConfig]] + ) -> List[int]: + features_per_rank = [] + for grouped_embedding_configs in group: + num_features = 0 + for grouped_config in grouped_embedding_configs: + num_features += grouped_config.num_features() + features_per_rank.append(num_features) + return features_per_rank + + +class GridPooledEmbeddingDist( + BaseEmbeddingDist[EmbeddingShardingContext, torch.Tensor, torch.Tensor] +): + def __init__( + self, + rank: int, + cross_pg: dist.ProcessGroup, + intra_pg: dist.ProcessGroup, + dim_sum_per_node: List[int], + emb_dim_per_node_per_feature: List[List[int]], + device: Optional[torch.device] = None, + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + callbacks: Optional[List[Callable[[torch.Tensor], torch.Tensor]]] = None, + ) -> None: + super().__init__() + self._rank = rank + self._intra_pg: dist.ProcessGroup = intra_pg + self._cross_pg: dist.ProcessGroup = cross_pg + self._dim_sum_per_node = dim_sum_per_node + self._emb_dim_per_node_per_feature = emb_dim_per_node_per_feature + self._device = device + self._intra_codecs: Optional[QuantizedCommCodecs] = ( + qcomm_codecs_registry.get( + CommOp.POOLED_EMBEDDINGS_REDUCE_SCATTER.name, None + ) + if qcomm_codecs_registry + else None + ) + self._cross_codecs: Optional[QuantizedCommCodecs] = ( + qcomm_codecs_registry.get(CommOp.POOLED_EMBEDDINGS_ALL_TO_ALL.name, None) + if qcomm_codecs_registry + else None + ) + self._intra_dist: Optional[ + Union[ + PooledEmbeddingsReduceScatter, + VariableBatchPooledEmbeddingsReduceScatter, + ] + ] = None + self._cross_dist: Optional[ + Union[ + PooledEmbeddingsAllToAll, + VariableBatchPooledEmbeddingsAllToAll, + ] + ] = None + self._callbacks = callbacks + + def forward( + self, + local_embs: torch.Tensor, + sharding_ctx: Optional[EmbeddingShardingContext] = None, + ) -> Awaitable[torch.Tensor]: + """ + Performs reduce-scatter pooled operation on pooled embeddings tensor followed by + AlltoAll pooled operation. + + Args: + local_embs (torch.Tensor): pooled embeddings tensor to distribute. + + Returns: + Awaitable[torch.Tensor]: awaitable of pooled embeddings tensor. + """ + if self._intra_dist is None or self._cross_dist is None: + self._create_output_dist_modules(sharding_ctx) + local_rank = self._rank % self._intra_pg.size() + if sharding_ctx is not None and len(set(sharding_ctx.batch_size_per_rank)) > 1: + # preprocess batch_size_per_rank + ( + batch_size_per_rank_by_cross_group, + batch_size_sum_by_cross_group, + ) = self._preprocess_batch_size_per_rank( + self._intra_pg.size(), + self._cross_pg.size(), + sharding_ctx.batch_size_per_rank, + ) + # Perform ReduceScatterV within one host + rs_result = cast(PooledEmbeddingsReduceScatter, self._intra_dist)( + local_embs, input_splits=batch_size_sum_by_cross_group + ).wait() + return cast(PooledEmbeddingsAllToAll, self._cross_dist)( + rs_result, + batch_size_per_rank=batch_size_per_rank_by_cross_group[local_rank], + ) + else: + return cast(PooledEmbeddingsAllToAll, self._cross_dist)( + cast(PooledEmbeddingsReduceScatter, self._intra_dist)(local_embs).wait() + ) + + def _preprocess_batch_size_per_rank( + self, local_size: int, nodes: int, batch_size_per_rank: List[int] + ) -> Tuple[List[List[int]], List[int]]: + """ + Reorders `batch_size_per_rank` so it's aligned with reordered features after + AlltoAll. + """ + batch_size_per_rank_by_cross_group: List[List[int]] = [] + batch_size_sum_by_cross_group: List[int] = [] + for local_rank in range(local_size): + batch_size_per_rank_: List[int] = [] + batch_size_sum = 0 + for node in range(nodes): + batch_size_per_rank_.append( + batch_size_per_rank[local_rank + node * local_size] + ) + batch_size_sum += batch_size_per_rank[local_rank + node * local_size] + batch_size_per_rank_by_cross_group.append(batch_size_per_rank_) + batch_size_sum_by_cross_group.append(batch_size_sum) + + return batch_size_per_rank_by_cross_group, batch_size_sum_by_cross_group + + def _create_output_dist_modules( + self, sharding_ctx: Optional[EmbeddingShardingContext] = None + ) -> None: + self._intra_dist = PooledEmbeddingsReduceScatter( + pg=self._intra_pg, + codecs=self._intra_codecs, + ) + self._cross_dist = PooledEmbeddingsAllToAll( + pg=self._cross_pg, + dim_sum_per_rank=self._dim_sum_per_node, + device=self._device, + codecs=self._cross_codecs, + callbacks=self._callbacks, + ) + + +class GridPooledEmbeddingSharding( + BaseGridEmbeddingSharding[ + EmbeddingShardingContext, KeyedJaggedTensor, torch.Tensor, torch.Tensor + ] +): + """ + Shards embedding bags table-wise then row-wise. + """ + + def create_input_dist( + self, device: Optional[torch.device] = None + ) -> BaseSparseFeaturesDist[KeyedJaggedTensor]: + features_per_rank = self._features_per_rank( + self._grouped_embedding_configs_per_rank + ) + feature_hash_sizes = self._get_feature_hash_sizes() + assert self._pg is not None + assert self._intra_pg is not None + return TwRwSparseFeaturesDist( + pg=self._pg, + local_size=self._intra_pg.size(), + features_per_rank=features_per_rank, + feature_hash_sizes=feature_hash_sizes, + device=device if device is not None else self._device, + has_feature_processor=self._has_feature_processor, + need_pos=self._need_pos, + ) + + def create_lookup( + self, + device: Optional[torch.device] = None, + fused_params: Optional[Dict[str, Any]] = None, + feature_processor: Optional[BaseGroupedFeatureProcessor] = None, + ) -> BaseEmbeddingLookup: + return GroupedPooledEmbeddingsLookup( + grouped_configs=self._grouped_embedding_configs_per_rank[self._rank], + pg=self._pg, + device=device if device is not None else self._device, + feature_processor=feature_processor, + sharding_type=ShardingType.TABLE_ROW_WISE, + ) + + def create_output_dist( + self, + device: Optional[torch.device] = None, + ) -> BaseEmbeddingDist[EmbeddingShardingContext, torch.Tensor, torch.Tensor]: + embedding_permute_op: Optional[PermutePooledEmbeddingsSplit] = None + callbacks: Optional[List[Callable[[torch.Tensor], torch.Tensor]]] = None + if self._embedding_order != list(range(len(self._embedding_order))): + assert len(self._embedding_order) == len(self._embedding_dims) + embedding_permute_op = PermutePooledEmbeddingsSplit( + self._embedding_dims, self._embedding_order, device=self._device + ) + callbacks = [embedding_permute_op] + return GridPooledEmbeddingDist( + rank=self._rank, + cross_pg=cast(dist.ProcessGroup, self._cross_pg), + intra_pg=cast(dist.ProcessGroup, self._intra_pg), + dim_sum_per_node=self._dim_sum_per_node(), + emb_dim_per_node_per_feature=self._emb_dim_per_node_per_feature(), + device=device if device is not None else self._device, + qcomm_codecs_registry=self.qcomm_codecs_registry, + callbacks=callbacks, + ) diff --git a/torchrec/distributed/test_utils/test_model_parallel.py b/torchrec/distributed/test_utils/test_model_parallel.py index dcc31a9a6..c9bc22ab3 100644 --- a/torchrec/distributed/test_utils/test_model_parallel.py +++ b/torchrec/distributed/test_utils/test_model_parallel.py @@ -662,3 +662,153 @@ def test_sharding_multiple_kernels(self, sharding_type: str) -> None: variable_batch_per_feature=True, has_weighted_tables=False, ) + + @unittest.skipIf( + torch.cuda.device_count() <= 3, + "Not enough GPUs, this test requires at least four GPUs", + ) + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + ], + ), + qcomms_config=st.sampled_from( + [ + None, + QCommsConfig( + forward_precision=CommType.FP16, backward_precision=CommType.BF16 + ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + None, + { + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), + "embeddings": (torch.optim.SGD, {"lr": 0.2}), + }, + ] + ), + pooling=st.sampled_from([PoolingType.SUM, PoolingType.MEAN]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) + def test_sharding_grid( + self, + sharder_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + pooling: PoolingType, + ) -> None: + self._test_sharding( + # pyre-ignore[6] + sharders=[ + create_test_sharder( + sharder_type, + ShardingType.GRID_SHARD.value, + kernel_type, + qcomms_config=qcomms_config, + device=self.device, + ), + ], + world_size=4, + local_size=2, + backend=self.backend, + qcomms_config=qcomms_config, + constraints={ + "table_0": ParameterConstraints(min_partition=8), + "table_1": ParameterConstraints(min_partition=12), + "table_2": ParameterConstraints(min_partition=16), + "table_3": ParameterConstraints(min_partition=20), + "table_4": ParameterConstraints(min_partition=8), + "table_5": ParameterConstraints(min_partition=12), + "weighted_table_0": ParameterConstraints(min_partition=8), + "weighted_table_1": ParameterConstraints(min_partition=12), + }, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + pooling=pooling, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 7, + "Not enough GPUs, this test requires at least eight GPUs", + ) + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + ], + ), + qcomms_config=st.sampled_from( + [ + None, + QCommsConfig( + forward_precision=CommType.FP16, backward_precision=CommType.BF16 + ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + None, + { + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), + "embeddings": (torch.optim.SGD, {"lr": 0.2}), + }, + ] + ), + pooling=st.sampled_from([PoolingType.SUM, PoolingType.MEAN]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) + def test_sharding_grid_8gpu( + self, + sharder_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + pooling: PoolingType, + ) -> None: + self._test_sharding( + # pyre-ignore[6] + sharders=[ + create_test_sharder( + sharder_type, + ShardingType.GRID_SHARD.value, + kernel_type, + qcomms_config=qcomms_config, + device=self.device, + ), + ], + world_size=8, + local_size=2, + backend=self.backend, + qcomms_config=qcomms_config, + constraints={ + "table_0": ParameterConstraints(min_partition=8), + "table_1": ParameterConstraints(min_partition=12), + "table_2": ParameterConstraints(min_partition=8), + "table_3": ParameterConstraints(min_partition=10), + "table_4": ParameterConstraints(min_partition=4), + "table_5": ParameterConstraints(min_partition=6), + "weighted_table_0": ParameterConstraints(min_partition=2), + "weighted_table_1": ParameterConstraints(min_partition=3), + }, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + pooling=pooling, + ) diff --git a/torchrec/distributed/test_utils/test_sharding.py b/torchrec/distributed/test_utils/test_sharding.py index a5edb56f7..02fafafeb 100644 --- a/torchrec/distributed/test_utils/test_sharding.py +++ b/torchrec/distributed/test_utils/test_sharding.py @@ -369,6 +369,7 @@ def sharding_single_rank_test( in { ShardingType.TABLE_ROW_WISE.value, ShardingType.TABLE_COLUMN_WISE.value, + ShardingType.GRID_SHARD.value, } and ctx.device.type != "cpu" ): diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index 19b48f5f8..a73deab0c 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -135,6 +135,8 @@ class ShardingType(Enum): TABLE_ROW_WISE = "table_row_wise" # Column-wise on the same node and table-wise across nodes TABLE_COLUMN_WISE = "table_column_wise" + # Grid sharding, where each rank gets a subset of columns and rows in a CW and TWRW style + GRID_SHARD = "grid_shard" class PipelineType(Enum):