diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index cbcd0d78b..95913a65a 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -456,16 +456,23 @@ def _get_sharded_local_buckets_for_zero_collision( for table in embedding_tables: total_num_buckets = none_throws(table.total_num_buckets) - assert ( - total_num_buckets % world_size == 0 - ), f"total_num_buckets={total_num_buckets} must be divisible by world_size={world_size}" assert ( table.total_num_buckets and table.num_embeddings % table.total_num_buckets == 0 ), f"Table size '{table.num_embeddings}' must be divisible by num_buckets '{table.total_num_buckets}'" - bucket_offset_start = total_num_buckets // world_size * local_rank + extra_local_buckets = int(local_rank < (total_num_buckets % world_size)) + extra_bucket_padding = ( + (total_num_buckets % world_size) + if local_rank >= (total_num_buckets % world_size) + else 0 + ) + bucket_offset_start = ( + total_num_buckets // world_size + extra_local_buckets + ) * local_rank + extra_bucket_padding bucket_offset_end = min( - total_num_buckets, total_num_buckets // world_size * (local_rank + 1) + total_num_buckets, + (total_num_buckets // world_size + extra_local_buckets) * (local_rank + 1) + + extra_bucket_padding, ) bucket_size = ( table.num_embeddings + total_num_buckets - 1 diff --git a/torchrec/distributed/embedding_kernel.py b/torchrec/distributed/embedding_kernel.py index 21776d697..e444f59c8 100644 --- a/torchrec/distributed/embedding_kernel.py +++ b/torchrec/distributed/embedding_kernel.py @@ -99,9 +99,13 @@ def create_virtual_table_global_metadata( # Otherwise it will only set correct size on current rank and # virtual PMT will trigger recalc for the correct global size/offset. # NOTE this currently only works for row-wise sharding + my_rank_shard_size = metadata.shards_metadata[my_rank].shard_sizes[0] for rank, shard_metadata in enumerate(metadata.shards_metadata): if use_param_size_as_rows: # respect the param size and treat it as rows - curr_rank_rows = param.size()[0] # pyre-ignore[16] + # The param size only has the information for my_rank. In order to + # correctly calculate the size for other ranks, we need to use the current + # rank's shard size compared to the shard size of my_rank. + curr_rank_rows = (param.size()[0] * metadata.shards_metadata[rank].shard_sizes[0]) // my_rank_shard_size # pyre-ignore[16] else: curr_rank_rows = ( weight_count_per_rank[rank] if weight_count_per_rank is not None else 1 diff --git a/torchrec/distributed/planner/enumerators.py b/torchrec/distributed/planner/enumerators.py index 1e3abbfcb..202be6b71 100644 --- a/torchrec/distributed/planner/enumerators.py +++ b/torchrec/distributed/planner/enumerators.py @@ -38,6 +38,10 @@ ShardingType, ) from torchrec.modules.embedding_configs import DataType +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, + EmbeddingCollection, +) from torchrec.modules.embedding_tower import EmbeddingTower, EmbeddingTowerCollection @@ -178,7 +182,7 @@ def enumerate( # skip for other device groups if device_group and device_group != self._compute_device: continue - + num_buckets = self._get_num_buckets(name, child_module) sharding_options_per_table: List[ShardingOption] = [] for sharding_type in self._filter_sharding_types( @@ -200,6 +204,7 @@ def enumerate( sharding_type=sharding_type, col_wise_shard_dim=col_wise_shard_dim, device_memory_sizes=self._device_memory_sizes, + num_buckets=num_buckets, ) except ZeroDivisionError as e: # Re-raise with additional context about the table and module @@ -264,6 +269,33 @@ def enumerate( self._last_stored_search_space = copy.deepcopy(sharding_options) return sharding_options + def _get_num_buckets(self, parameter: str, module: nn.Module) -> Optional[int]: + """ + Get the number of buckets for each embedding table. + + Args: + parameter (str): name of the embedding table. + module (nn.Module): module to be sharded. + + Returns: + Optional[int]: Number of buckets for the table, or None if module is not EmbeddingBagCollection or table not found. + """ + # If module is not of type EmbeddingBagCollection, return None + if isinstance(module, EmbeddingBagCollection): + embedding_configs = module.embedding_bag_configs() + elif isinstance(module, EmbeddingCollection): + embedding_configs = module.embedding_configs() + else: + return None + + # Find the embedding config for the table with the same name as parameter input + for config in embedding_configs: + if config.name == parameter and config.use_virtual_table: + return config.total_num_buckets + + # If table with matching name not found, return None + return None + @property def last_stored_search_space(self) -> Optional[List[ShardingOption]]: # NOTE: This is the last search space stored by enumerate(...), do not use diff --git a/torchrec/distributed/planner/tests/test_enumerators.py b/torchrec/distributed/planner/tests/test_enumerators.py index 5adead69a..39a39d9f0 100644 --- a/torchrec/distributed/planner/tests/test_enumerators.py +++ b/torchrec/distributed/planner/tests/test_enumerators.py @@ -18,7 +18,10 @@ EmbeddingTowerSharder, ) from torchrec.distributed.embedding_types import EmbeddingComputeKernel -from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder +from torchrec.distributed.embeddingbag import ( + EmbeddingBagCollection, + EmbeddingBagCollectionSharder, +) from torchrec.distributed.mc_embeddingbag import ( ManagedCollisionEmbeddingBagCollectionSharder, ) @@ -45,6 +48,13 @@ [[17, 80], [17, 80], [17, 80], [17, 80], [17, 80], [17, 80], [17, 80], [11, 80]], ] +EXPECTED_RW_SHARD_SIZES_WITH_BUCKETS = [ + [[20, 20], [20, 20], [10, 20], [10, 20], [10, 20], [10, 20], [10, 20], [10, 20]], + [[22, 40], [22, 40], [11, 40], [11, 40], [11, 40], [11, 40], [11, 40], [11, 40]], + [[24, 60], [24, 60], [12, 60], [12, 60], [12, 60], [12, 60], [12, 60], [12, 60]], + [[26, 80], [26, 80], [13, 80], [13, 80], [13, 80], [13, 80], [13, 80], [13, 80]], +] + EXPECTED_RW_SHARD_OFFSETS = [ [[0, 0], [13, 0], [26, 0], [39, 0], [52, 0], [65, 0], [78, 0], [91, 0]], [[0, 0], [14, 0], [28, 0], [42, 0], [56, 0], [70, 0], [84, 0], [98, 0]], @@ -52,6 +62,13 @@ [[0, 0], [17, 0], [34, 0], [51, 0], [68, 0], [85, 0], [102, 0], [119, 0]], ] +EXPECTED_RW_SHARD_OFFSETS_WITH_BUCKETS = [ + [[0, 0], [20, 0], [40, 0], [50, 0], [60, 0], [70, 0], [80, 0], [90, 0]], + [[0, 0], [22, 0], [44, 0], [55, 0], [66, 0], [77, 0], [88, 0], [99, 0]], + [[0, 0], [24, 0], [48, 0], [60, 0], [72, 0], [84, 0], [96, 0], [108, 0]], + [[0, 0], [26, 0], [52, 0], [65, 0], [78, 0], [91, 0], [104, 0], [117, 0]], +] + def get_expected_cache_aux_size(rows: int) -> int: # 0.2 is the hardcoded cache load factor assumed in this test @@ -101,6 +118,48 @@ def get_expected_cache_aux_size(rows: int) -> int: ], ] +EXPECTED_VIRTUAL_TABLE_RW_SHARD_STORAGE_WITH_BUCKETS = [ + [ + Storage(hbm=165888, ddr=0), + Storage(hbm=165888, ddr=0), + Storage(hbm=165888, ddr=0), + Storage(hbm=165888, ddr=0), + Storage(hbm=165888, ddr=0), + Storage(hbm=165888, ddr=0), + Storage(hbm=165888, ddr=0), + Storage(hbm=165888, ddr=0), + ], + [ + Storage(hbm=1001472, ddr=0), + Storage(hbm=1001472, ddr=0), + Storage(hbm=1001472, ddr=0), + Storage(hbm=1001472, ddr=0), + Storage(hbm=1001472, ddr=0), + Storage(hbm=1001472, ddr=0), + Storage(hbm=1001472, ddr=0), + Storage(hbm=1001472, ddr=0), + ], + [ + Storage(hbm=1003520, ddr=0), + Storage(hbm=1003520, ddr=0), + Storage(hbm=1003520, ddr=0), + Storage(hbm=1003520, ddr=0), + Storage(hbm=1003520, ddr=0), + Storage(hbm=1003520, ddr=0), + Storage(hbm=1003520, ddr=0), + Storage(hbm=1003520, ddr=0), + ], + [ + Storage(hbm=2648064, ddr=0), + Storage(hbm=2648064, ddr=0), + Storage(hbm=2648064, ddr=0), + Storage(hbm=2648064, ddr=0), + Storage(hbm=2648064, ddr=0), + Storage(hbm=2648064, ddr=0), + Storage(hbm=2648064, ddr=0), + Storage(hbm=2648064, ddr=0), + ], +] EXPECTED_UVM_CACHING_RW_SHARD_STORAGE = [ [ @@ -145,6 +204,48 @@ def get_expected_cache_aux_size(rows: int) -> int: ], ] +EXPECTED_UVM_CACHING_RW_SHARD_STORAGE_WITH_BUCKETS = [ + [ + Storage(hbm=166352, ddr=1600), + Storage(hbm=166352, ddr=1600), + Storage(hbm=166120, ddr=800), + Storage(hbm=166120, ddr=800), + Storage(hbm=166120, ddr=800), + Storage(hbm=166120, ddr=800), + Storage(hbm=166120, ddr=800), + Storage(hbm=166120, ddr=800), + ], + [ + Storage(hbm=1002335, ddr=3520), + Storage(hbm=1002335, ddr=3520), + Storage(hbm=1001904, ddr=1760), + Storage(hbm=1001904, ddr=1760), + Storage(hbm=1001904, ddr=1760), + Storage(hbm=1001904, ddr=1760), + Storage(hbm=1001904, ddr=1760), + Storage(hbm=1001904, ddr=1760), + ], + [ + Storage(hbm=1004845, ddr=5760), + Storage(hbm=1004845, ddr=5760), + Storage(hbm=1004183, ddr=2880), + Storage(hbm=1004183, ddr=2880), + Storage(hbm=1004183, ddr=2880), + Storage(hbm=1004183, ddr=2880), + Storage(hbm=1004183, ddr=2880), + Storage(hbm=1004183, ddr=2880), + ], + [ + Storage(hbm=2649916, ddr=8320), + Storage(hbm=2649916, ddr=8320), + Storage(hbm=2648990, ddr=4160), + Storage(hbm=2648990, ddr=4160), + Storage(hbm=2648990, ddr=4160), + Storage(hbm=2648990, ddr=4160), + Storage(hbm=2648990, ddr=4160), + Storage(hbm=2648990, ddr=4160), + ], +] EXPECTED_TWRW_SHARD_SIZES = [ [[25, 20], [25, 20], [25, 20], [25, 20]], @@ -248,6 +349,16 @@ def compute_kernels( return [EmbeddingComputeKernel.FUSED.value] +class VirtualTableRWSharder(EmbeddingBagCollectionSharder): + def sharding_types(self, compute_device_type: str) -> List[str]: + return [ShardingType.ROW_WISE.value] + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + return [EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value] + + class UVMCachingRWSharder(EmbeddingBagCollectionSharder): def sharding_types(self, compute_device_type: str) -> List[str]: return [ShardingType.ROW_WISE.value] @@ -357,6 +468,27 @@ def setUp(self) -> None: min_partition=40, pooling_factors=[2, 1, 3, 7] ), } + self._virtual_table_constraints = { + "table_0": ParameterConstraints( + min_partition=20, + compute_kernels=[EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value], + ), + "table_1": ParameterConstraints( + min_partition=20, + pooling_factors=[1, 3, 5], + compute_kernels=[EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value], + ), + "table_2": ParameterConstraints( + min_partition=20, + pooling_factors=[8, 2], + compute_kernels=[EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value], + ), + "table_3": ParameterConstraints( + min_partition=40, + pooling_factors=[2, 1, 3, 7], + compute_kernels=[EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value], + ), + } self.num_tables = 4 tables = [ EmbeddingBagConfig( @@ -367,6 +499,17 @@ def setUp(self) -> None: ) for i in range(self.num_tables) ] + tables_with_buckets = [ + EmbeddingBagConfig( + num_embeddings=100 + i * 10, + embedding_dim=20 + i * 20, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + total_num_buckets=10, + use_virtual_table=True, + ) + for i in range(self.num_tables) + ] weighted_tables = [ EmbeddingBagConfig( num_embeddings=(i + 1) * 10, @@ -377,6 +520,9 @@ def setUp(self) -> None: for i in range(4) ] self.model = TestSparseNN(tables=tables, weighted_tables=[]) + self.model_with_buckets = EmbeddingBagCollection( + tables=tables_with_buckets, + ) self.enumerator = EmbeddingEnumerator( topology=Topology( world_size=self.world_size, @@ -386,6 +532,15 @@ def setUp(self) -> None: batch_size=self.batch_size, constraints=self.constraints, ) + self.virtual_table_enumerator = EmbeddingEnumerator( + topology=Topology( + world_size=self.world_size, + compute_device=self.compute_device, + local_world_size=self.local_world_size, + ), + batch_size=self.batch_size, + constraints=self._virtual_table_constraints, + ) self.tower_model = TestTowerSparseNN( tables=tables, weighted_tables=weighted_tables ) @@ -514,6 +669,26 @@ def test_rw_sharding(self) -> None: EXPECTED_RW_SHARD_STORAGE[i], ) + def test_virtual_table_rw_sharding_with_buckets(self) -> None: + sharding_options = self.virtual_table_enumerator.enumerate( + self.model_with_buckets, + [cast(ModuleSharder[torch.nn.Module], VirtualTableRWSharder())], + ) + for i, sharding_option in enumerate(sharding_options): + self.assertEqual(sharding_option.sharding_type, ShardingType.ROW_WISE.value) + self.assertEqual( + [shard.size for shard in sharding_option.shards], + EXPECTED_RW_SHARD_SIZES_WITH_BUCKETS[i], + ) + self.assertEqual( + [shard.offset for shard in sharding_option.shards], + EXPECTED_RW_SHARD_OFFSETS_WITH_BUCKETS[i], + ) + self.assertEqual( + [shard.storage for shard in sharding_option.shards], + EXPECTED_VIRTUAL_TABLE_RW_SHARD_STORAGE_WITH_BUCKETS[i], + ) + def test_uvm_caching_rw_sharding(self) -> None: sharding_options = self.enumerator.enumerate( self.model, @@ -535,6 +710,26 @@ def test_uvm_caching_rw_sharding(self) -> None: EXPECTED_UVM_CACHING_RW_SHARD_STORAGE[i], ) + def test_uvm_caching_rw_sharding_with_buckets(self) -> None: + sharding_options = self.enumerator.enumerate( + self.model_with_buckets, + [cast(ModuleSharder[torch.nn.Module], UVMCachingRWSharder())], + ) + for i, sharding_option in enumerate(sharding_options): + self.assertEqual(sharding_option.sharding_type, ShardingType.ROW_WISE.value) + self.assertEqual( + [shard.size for shard in sharding_option.shards], + EXPECTED_RW_SHARD_SIZES_WITH_BUCKETS[i], + ) + self.assertEqual( + [shard.offset for shard in sharding_option.shards], + EXPECTED_RW_SHARD_OFFSETS_WITH_BUCKETS[i], + ) + self.assertEqual( + [shard.storage for shard in sharding_option.shards], + EXPECTED_UVM_CACHING_RW_SHARD_STORAGE_WITH_BUCKETS[i], + ) + def test_twrw_sharding(self) -> None: sharding_options = self.enumerator.enumerate( self.model, [cast(ModuleSharder[torch.nn.Module], TWRWSharder())] diff --git a/torchrec/distributed/sharding/rw_sequence_sharding.py b/torchrec/distributed/sharding/rw_sequence_sharding.py index ebc76b976..9b27de2bd 100644 --- a/torchrec/distributed/sharding/rw_sequence_sharding.py +++ b/torchrec/distributed/sharding/rw_sequence_sharding.py @@ -130,6 +130,9 @@ def create_input_dist( ) -> BaseSparseFeaturesDist[KeyedJaggedTensor]: num_features = self._get_num_features() feature_hash_sizes = self._get_feature_hash_sizes() + virtual_table_feature_num_buckets = ( + self._get_virtual_table_feature_num_buckets() + ) return RwSparseFeaturesDist( # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got # `Optional[ProcessGroup]`. @@ -140,6 +143,7 @@ def create_input_dist( is_sequence=True, has_feature_processor=self._has_feature_processor, need_pos=False, + virtual_table_feature_num_buckets=virtual_table_feature_num_buckets, ) def create_lookup( diff --git a/torchrec/distributed/sharding/rw_sharding.py b/torchrec/distributed/sharding/rw_sharding.py index d310127c0..136052137 100644 --- a/torchrec/distributed/sharding/rw_sharding.py +++ b/torchrec/distributed/sharding/rw_sharding.py @@ -300,6 +300,20 @@ def _get_writable_feature_hash_sizes(self) -> List[int]: feature_hash_sizes.extend(group_config.feature_hash_sizes()) return feature_hash_sizes + def _get_virtual_table_feature_num_buckets(self) -> List[int]: + feature_num_buckets: List[int] = [] + for group_config in self._grouped_embedding_configs: + for embedding_table in group_config.embedding_tables: + if ( + embedding_table.total_num_buckets is not None + and embedding_table.use_virtual_table + ): + feature_num_buckets.extend( + [embedding_table.total_num_buckets] + * embedding_table.num_features() + ) + return feature_num_buckets + class RwSparseFeaturesDist(BaseSparseFeaturesDist[KeyedJaggedTensor]): """ @@ -331,6 +345,7 @@ def __init__( has_feature_processor: bool = False, need_pos: bool = False, keep_original_indices: bool = False, + virtual_table_feature_num_buckets: Optional[List[int]] = None, ) -> None: super().__init__() self._world_size: int = pg.size() @@ -340,11 +355,33 @@ def __init__( for i, hash_size in enumerate(feature_hash_sizes): block_divisor = self._world_size + # Using different num_bucket lists for MPZCH and KVZCH allows us to process them with + # different code paths for now, enabling uneven sharding for KVZCH only. MPZCH can have + # uneven sharding enabled for it as well in the future but that will require additional testing if feature_total_num_buckets is not None: - assert feature_total_num_buckets[i] % self._world_size == 0 + # MPZCH sharding + assert ( + feature_total_num_buckets[i] % self._world_size == 0 + ), f"Number of buckets: {feature_total_num_buckets[i]} should be divisible by world size: {self._world_size} for MPZCH" + block_divisor = feature_total_num_buckets[i] + elif virtual_table_feature_num_buckets is not None and len( + virtual_table_feature_num_buckets + ): + # KVZCH uneven sharding + assert ( + virtual_table_feature_num_buckets[i] >= self._world_size + ), f"Number of buckets: {virtual_table_feature_num_buckets[i]} for a table cannot be less than the word_size: {self._world_size}" + + block_divisor = virtual_table_feature_num_buckets[i] feature_block_sizes.append((hash_size + block_divisor - 1) // block_divisor) + self.kvzch_bucketize_row_pos: Optional[List[torch._tensor.Tensor]] = ( + self._get_bucketize_row_pos( + virtual_table_feature_num_buckets, feature_block_sizes + ) + ) + self.register_buffer( "_feature_block_sizes_tensor", torch.tensor( @@ -378,6 +415,37 @@ def __init__( self.unbucketize_permute_tensor: Optional[torch.Tensor] = None self._keep_original_indices = keep_original_indices + def _get_bucketize_row_pos( + self, + feature_num_buckets: Optional[List[int]], + feature_block_sizes: List[int], + ) -> Optional[List[torch.Tensor]]: + # Creates the bucketize row positions for uneven sharding with buckets. If the number of buckets + # is greater than the world size, and world_size % num_buckets != 0, the buckets count will not be + # the same on each rank. Bucketize_row_pos object lays out the distribution of buckets in this scenario. + # For eg. + # Bucketize_row_pos + # [ + # Tensor([0, 4, 8, 12, 15, 18, 21]), Feature 1 has 4 buckets on ranks 0, 1, 2. 3 buckets on ranks 3, 4, 5 + # Tensor([0, 2, 4, 6, 7, 8, 9]), Feature 2 has 2 buckets on ranks 0, 1, 2. 3 buckets on ranks 3, 4, 5 + # ] + if feature_num_buckets is None or len(feature_num_buckets) == 0: + return None + bucketize_row_pos = [[0] for _ in range(len(feature_num_buckets))] + bucketize_row_pos_tensors = [] + for feature in range(len(feature_num_buckets)): + for rank in range(self._world_size): + bucketize_row_pos[feature].append( + bucketize_row_pos[feature][-1] + + ( + (feature_num_buckets[feature] // self._world_size) + + int(rank < feature_num_buckets[feature] % self._world_size) + ) + * feature_block_sizes[feature] + ) + bucketize_row_pos_tensors.append(torch.tensor(bucketize_row_pos[feature])) + return bucketize_row_pos_tensors + def forward( self, sparse_features: KeyedJaggedTensor, @@ -413,6 +481,7 @@ def forward( else self._need_pos ), keep_original_indices=self._keep_original_indices, + block_bucketize_row_pos=self.kvzch_bucketize_row_pos, ) return self._dist(bucketized_features) @@ -558,6 +627,9 @@ def create_input_dist( ) -> BaseSparseFeaturesDist[KeyedJaggedTensor]: num_features = self._get_num_features() feature_hash_sizes = self._get_feature_hash_sizes() + virtual_table_feature_num_buckets = ( + self._get_virtual_table_feature_num_buckets() + ) return RwSparseFeaturesDist( # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got # `Optional[ProcessGroup]`. @@ -567,6 +639,7 @@ def create_input_dist( device=device if device is not None else self._device, is_sequence=False, has_feature_processor=self._has_feature_processor, + virtual_table_feature_num_buckets=virtual_table_feature_num_buckets, need_pos=self._need_pos, ) diff --git a/torchrec/distributed/sharding_plan.py b/torchrec/distributed/sharding_plan.py index 81e4fad8e..d70750d2e 100644 --- a/torchrec/distributed/sharding_plan.py +++ b/torchrec/distributed/sharding_plan.py @@ -96,6 +96,7 @@ def calculate_shard_sizes_and_offsets( sharding_type: str, col_wise_shard_dim: Optional[int] = None, device_memory_sizes: Optional[List[int]] = None, + num_buckets: Optional[int] = None, ) -> Tuple[List[List[int]], List[List[int]]]: """ Calculates sizes and offsets for tensor sharded according to provided sharding type. @@ -122,10 +123,12 @@ def calculate_shard_sizes_and_offsets( return [[rows, columns]], [[0, 0]] elif sharding_type == ShardingType.ROW_WISE.value: return ( - _calculate_rw_shard_sizes_and_offsets(rows, world_size, columns) + _calculate_rw_shard_sizes_and_offsets( + rows, world_size, columns, num_buckets + ) if not device_memory_sizes else _calculate_uneven_rw_shard_sizes_and_offsets( - rows, world_size, columns, device_memory_sizes + rows, world_size, columns, device_memory_sizes, num_buckets ) ) elif sharding_type == ShardingType.TABLE_ROW_WISE.value: @@ -170,7 +173,7 @@ def _calculate_grid_shard_sizes_and_offsets( def _calculate_rw_shard_sizes_and_offsets( - hash_size: int, num_devices: int, columns: int + hash_size: int, num_devices: int, columns: int, num_buckets: Optional[int] = None ) -> Tuple[List[List[int]], List[List[int]]]: """ Sets prefix of shard_sizes to be `math.ceil(hash_size/num_devices)`. @@ -183,21 +186,43 @@ def _calculate_rw_shard_sizes_and_offsets( Also consider the example of hash_size = 5, num_devices = 4. The expected rows per rank is [2,2,1,0]. - """ - block_size: int = math.ceil(hash_size / num_devices) - last_rank: int = hash_size // block_size - last_block_size: int = hash_size - block_size * last_rank + If num_buckets is specified, the sharding methodology changes to adapt to ZCH. + So, if hash_size = 10, num_devices = 4, num_buckets = 5, each bucket will have 2 rows. + After distributing the buckets evenly across ranks we will have the row distribution as + [4, 2, 2, 2] + """ shard_sizes: List[List[int]] = [] - - for rank in range(num_devices): - if rank < last_rank: - local_row: int = block_size - elif rank == last_rank: - local_row: int = last_block_size - else: - local_row: int = 0 - shard_sizes.append([local_row, columns]) + if num_buckets: + # number of buckets being specified means zch is enabled + assert ( + hash_size % num_buckets == 0 + ), "hash_size must be divisible by num_buckets" + bucket_size: int = hash_size // num_buckets + # number of buckets per rank + shard_buckets = math.floor(num_buckets / num_devices) + # number of ranks with an extra bucket + extra_bucket_shards = num_buckets % num_devices + for rank in range(num_devices): + if rank < extra_bucket_shards: + shard_size = bucket_size * (shard_buckets + 1) + else: + shard_size = bucket_size * shard_buckets + shard_sizes.append([shard_size, columns]) + else: + block_size: int = math.ceil(hash_size / num_devices) + last_rank: int = hash_size // block_size + last_block_size: int = hash_size - block_size * last_rank + shard_sizes: List[List[int]] = [] + + for rank in range(num_devices): + if rank < last_rank: + local_row: int = block_size + elif rank == last_rank: + local_row: int = last_block_size + else: + local_row: int = 0 + shard_sizes.append([local_row, columns]) shard_offsets = [[0, 0]] for i in range(num_devices - 1): @@ -207,7 +232,11 @@ def _calculate_rw_shard_sizes_and_offsets( def _calculate_uneven_rw_shard_sizes_and_offsets( - hash_size: int, num_devices: int, columns: int, device_memory_sizes: List[int] + hash_size: int, + num_devices: int, + columns: int, + device_memory_sizes: List[int], + num_buckets: Optional[int] = None, ) -> Tuple[List[List[int]], List[List[int]]]: assert num_devices == len(device_memory_sizes), "must provide all the memory size" total_size = sum(device_memory_sizes) @@ -215,10 +244,20 @@ def _calculate_uneven_rw_shard_sizes_and_offsets( last_rank = num_devices - 1 processed_total_rows = 0 - + if num_buckets is None: + num_buckets = hash_size + bucket_size = 1 + else: + assert ( + hash_size % num_buckets == 0 + ), "hash_size must be divisible by num_buckets" + bucket_size = hash_size // num_buckets for rank in range(num_devices): if rank < last_rank: - local_row: int = int(hash_size * (device_memory_sizes[rank] / total_size)) + local_row: int = ( + int(num_buckets * (device_memory_sizes[rank] / total_size)) + * bucket_size + ) processed_total_rows += local_row elif rank == last_rank: local_row: int = hash_size - processed_total_rows diff --git a/torchrec/distributed/tests/test_sharding_plan.py b/torchrec/distributed/tests/test_sharding_plan.py index 02f64e859..6c585b423 100644 --- a/torchrec/distributed/tests/test_sharding_plan.py +++ b/torchrec/distributed/tests/test_sharding_plan.py @@ -19,6 +19,8 @@ QuantManagedCollisionEmbeddingCollectionSharder, ) from torchrec.distributed.sharding_plan import ( + _calculate_rw_shard_sizes_and_offsets, + _calculate_uneven_rw_shard_sizes_and_offsets, column_wise, construct_module_sharding_plan, data_parallel, @@ -1237,3 +1239,147 @@ def test_module_to_default_sharders(self) -> None: default_sharder_map[QuantManagedCollisionEmbeddingCollection], QuantManagedCollisionEmbeddingCollectionSharder, ) + + +class RowWiseShardingTest(unittest.TestCase): + def test_non_zch_rw_sharding(self) -> None: + """Test the original row-wise sharding logic (without num_buckets)""" + # Test case 1: hash_size = 10, num_devices = 4 + shard_sizes, shard_offsets = _calculate_rw_shard_sizes_and_offsets( + hash_size=10, num_devices=4, columns=8 + ) + + # Expected: [3,3,3,1] rows per rank + expected_sizes = [[3, 8], [3, 8], [3, 8], [1, 8]] + expected_offsets = [[0, 0], [3, 0], [6, 0], [9, 0]] + + self.assertEqual(shard_sizes, expected_sizes) + self.assertEqual(shard_offsets, expected_offsets) + + # Test case 2: hash_size = 5, num_devices = 4 + shard_sizes, shard_offsets = _calculate_rw_shard_sizes_and_offsets( + hash_size=5, num_devices=4, columns=16 + ) + + # Expected: [2,2,1,0] rows per rank + expected_sizes = [[2, 16], [2, 16], [1, 16], [0, 16]] + expected_offsets = [[0, 0], [2, 0], [4, 0], [5, 0]] + + self.assertEqual(shard_sizes, expected_sizes) + self.assertEqual(shard_offsets, expected_offsets) + + def test_zch_rw_sharding(self) -> None: + """Test the new row-wise sharding logic with num_buckets (ZCH)""" + # Test case 1: hash_size = 10, num_devices = 4, num_buckets = 5 + # Each bucket has 2 rows, buckets distributed as [2,1,1,1] + # So rows are distributed as [4,2,2,2] + shard_sizes, shard_offsets = _calculate_rw_shard_sizes_and_offsets( + hash_size=10, num_devices=4, columns=8, num_buckets=5 + ) + + expected_sizes = [[4, 8], [2, 8], [2, 8], [2, 8]] + expected_offsets = [[0, 0], [4, 0], [6, 0], [8, 0]] + + self.assertEqual(shard_sizes, expected_sizes) + self.assertEqual(shard_offsets, expected_offsets) + + # Test case 2: hash_size = 100, num_devices = 4, num_buckets = 10 + # Each bucket has 10 rows, buckets distributed as [3,3,2,2] + # So rows are distributed as [30,30,20,20] + shard_sizes, shard_offsets = _calculate_rw_shard_sizes_and_offsets( + hash_size=100, num_devices=4, columns=16, num_buckets=10 + ) + + expected_sizes = [[30, 16], [30, 16], [20, 16], [20, 16]] + expected_offsets = [[0, 0], [30, 0], [60, 0], [80, 0]] + + self.assertEqual(shard_sizes, expected_sizes) + self.assertEqual(shard_offsets, expected_offsets) + + # Test case 3: hash_size = 18, num_devices = 3, num_buckets = 6 + # Each bucket has 3 rows (18 // 6 = 3), buckets distributed as [2,2,2] + # So rows are distributed as [6,6,6] + shard_sizes, shard_offsets = _calculate_rw_shard_sizes_and_offsets( + hash_size=18, num_devices=3, columns=32, num_buckets=6 + ) + + expected_sizes = [[6, 32], [6, 32], [6, 32]] + expected_offsets = [[0, 0], [6, 0], [12, 0]] + + self.assertEqual(shard_sizes, expected_sizes) + self.assertEqual(shard_offsets, expected_offsets) + + def test_uneven_rw_sharding_with_buckets(self) -> None: + """Test uneven row-wise sharding with num_buckets""" + # Test with device memory sizes [2, 1, 1] + device_memory_sizes = [2, 1, 1] + + # hash_size = 40, num_buckets = 8, bucket_size = 5 + # With memory ratio 2:1:1, buckets should be distributed as [4,2,2] + # So rows are distributed as [20,10,10] + shard_sizes, shard_offsets = _calculate_uneven_rw_shard_sizes_and_offsets( + hash_size=40, + num_devices=3, + columns=64, + device_memory_sizes=device_memory_sizes, + num_buckets=8, + ) + + expected_sizes = [[20, 64], [10, 64], [10, 64]] + expected_offsets = [[0, 0], [20, 0], [30, 0]] + + self.assertEqual(shard_sizes, expected_sizes) + self.assertEqual(shard_offsets, expected_offsets) + + # Test without num_buckets (should use hash_size as num_buckets) + # With memory ratio 2:1:1, rows should be distributed as [20,10,10] + shard_sizes, shard_offsets = _calculate_uneven_rw_shard_sizes_and_offsets( + hash_size=40, + num_devices=3, + columns=64, + device_memory_sizes=device_memory_sizes, + ) + + expected_sizes = [[20, 64], [10, 64], [10, 64]] + expected_offsets = [[0, 0], [20, 0], [30, 0]] + + self.assertEqual(shard_sizes, expected_sizes) + self.assertEqual(shard_offsets, expected_offsets) + + def test_rw_sharding_hash_size_not_divisible_by_num_buckets(self) -> None: + """Test that _calculate_rw_shard_sizes_and_offsets raises an assertion error when hash_size is not divisible by num_buckets""" + # Test case: hash_size = 10, num_buckets = 3 (not divisible) + with self.assertRaises(AssertionError): + _calculate_rw_shard_sizes_and_offsets( + hash_size=10, num_devices=4, columns=8, num_buckets=3 + ) + + # Test case: hash_size = 100, num_buckets = 7 (not divisible) + with self.assertRaises(AssertionError): + _calculate_rw_shard_sizes_and_offsets( + hash_size=100, num_devices=4, columns=16, num_buckets=7 + ) + + def test_uneven_rw_sharding_hash_size_not_divisible_by_num_buckets(self) -> None: + """Test that _calculate_uneven_rw_shard_sizes_and_offsets raises an assertion error when hash_size is not divisible by num_buckets""" + device_memory_sizes = [2, 1, 1] + + # Test case: hash_size = 10, num_buckets = 3 (not divisible) + with self.assertRaises(AssertionError): + _calculate_uneven_rw_shard_sizes_and_offsets( + hash_size=10, + num_devices=3, + columns=64, + device_memory_sizes=device_memory_sizes, + num_buckets=3, + ) + + # Test case: hash_size = 100, num_buckets = 7 (not divisible) + with self.assertRaises(AssertionError): + _calculate_uneven_rw_shard_sizes_and_offsets( + hash_size=100, + num_devices=3, + columns=64, + device_memory_sizes=device_memory_sizes, + num_buckets=7, + )