From c1254b432802312c56d8f8b43259f7eb406bd9e8 Mon Sep 17 00:00:00 2001 From: Emma Lin Date: Tue, 4 Nov 2025 19:58:08 -0800 Subject: [PATCH] enable feature score auto collection in EBC (#3475) Summary: X-link: https://github.com/pytorch/FBGEMM/pull/5031 X-link: https://github.com/facebookresearch/FBGEMM/pull/2044 Enable feature score auto collection for EBC in the similar way of EC. The configuration has no difference in embedding table config: virtual_table_eviction_policy=FeatureScoreBasedEvictionPolicy( training_id_eviction_trigger_count=260_000_000, # 260M training_id_keep_count=160_000_000, # 160M enable_auto_feature_score_collection=True, feature_score_mapping={ "sparse_public_original_content_creator": 1.0, }, feature_score_default_value=0.5, ), Reviewed By: EddyLXJ Differential Revision: D85017179 --- .../distributed/batched_embedding_kernel.py | 40 ++- torchrec/distributed/embedding_lookup.py | 39 ++- torchrec/distributed/embeddingbag.py | 27 ++ torchrec/distributed/feature_score_utils.py | 4 +- .../tests/test_feature_score_utils.py | 231 ++++++++++++++++++ .../tests/test_sequence_model_parallel.py | 6 +- 6 files changed, 339 insertions(+), 8 deletions(-) diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index 6e5daaaef..51b7bf346 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -326,6 +326,8 @@ def _populate_zero_collision_tbe_params( meta_header_lens[i] = table.virtual_table_eviction_policy.get_meta_header_len() if not isinstance(table.virtual_table_eviction_policy, NoEvictionPolicy): enabled = True + + fs_eviction_enabled: bool = False if enabled: counter_thresholds = [0] * len(config.embedding_tables) ttls_in_mins = [0] * len(config.embedding_tables) @@ -384,6 +386,7 @@ def _populate_zero_collision_tbe_params( raise ValueError( f"Do not support multiple eviction strategy in one tbe {eviction_strategy} and 5 for tables {table_names}" ) + fs_eviction_enabled = True elif isinstance(policy_t, TimestampBasedEvictionPolicy): training_id_eviction_trigger_count[i] = ( policy_t.training_id_eviction_trigger_count @@ -459,6 +462,7 @@ def _populate_zero_collision_tbe_params( backend_return_whole_row=(backend_type == BackendType.DRAM), eviction_policy=eviction_policy, embedding_cache_mode=embedding_cache_mode_, + feature_score_collection_enabled=fs_eviction_enabled, ) @@ -2901,6 +2905,7 @@ def __init__( _populate_zero_collision_tbe_params( ssd_tbe_params, self._bucket_spec, config, backend_type ) + self._kv_zch_params: KVZCHParams = ssd_tbe_params["kv_zch_params"] compute_kernel = config.embedding_tables[0].compute_kernel embedding_location = compute_kernel_to_embedding_location(compute_kernel) @@ -3185,7 +3190,40 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: self._split_weights_res = None self._optim.set_sharded_embedding_weight_ids(sharded_embedding_weight_ids=None) - return super().forward(features) + weights = features.weights_or_none() + per_sample_weights = None + score_weights = None + if weights is not None and weights.dtype == torch.float64: + fp32_weights = weights.view(torch.float32) + per_sample_weights = fp32_weights[:, 0] + score_weights = fp32_weights[:, 1] + elif weights is not None and weights.dtype == torch.float32: + if self._kv_zch_params.feature_score_collection_enabled: + score_weights = weights.view(-1) + else: + per_sample_weights = weights.view(-1) + if features.variable_stride_per_key() and isinstance( + self.emb_module, + ( + SplitTableBatchedEmbeddingBagsCodegen, + DenseTableBatchedEmbeddingBagsCodegen, + SSDTableBatchedEmbeddingBags, + ), + ): + return self.emb_module( + indices=features.values().long(), + offsets=features.offsets().long(), + weights=score_weights, + per_sample_weights=per_sample_weights, + batch_size_per_feature_per_rank=features.stride_per_key_per_rank(), + ) + else: + return self.emb_module( + indices=features.values().long(), + offsets=features.offsets().long(), + weights=score_weights, + per_sample_weights=per_sample_weights, + ) class BatchedFusedEmbeddingBag( diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index 9f3ce69c7..53e419cc1 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -66,6 +66,7 @@ QuantBatchedEmbeddingBag, ) from torchrec.distributed.types import rank_device, ShardedTensor, ShardingType +from torchrec.modules.embedding_configs import FeatureScoreBasedEvictionPolicy from torchrec.sparse.jagged_tensor import KeyedJaggedTensor logger: logging.Logger = logging.getLogger(__name__) @@ -515,6 +516,23 @@ def __init__( ) -> None: super().__init__() self._emb_modules: nn.ModuleList = nn.ModuleList() + self._feature_score_auto_collections: List[bool] = [] + for config in grouped_configs: + collection = False + for table in config.embedding_tables: + if table.use_virtual_table and isinstance( + table.virtual_table_eviction_policy, FeatureScoreBasedEvictionPolicy + ): + if ( + table.virtual_table_eviction_policy.enable_auto_feature_score_collection + ): + collection = True + self._feature_score_auto_collections.append(collection) + + logger.info( + f"GroupedPooledEmbeddingsLookup: {self._feature_score_auto_collections=}" + ) + for config in grouped_configs: self._emb_modules.append( self._create_embedding_kernel(config, device, pg, sharding_type) @@ -692,8 +710,11 @@ def forward( features_by_group = sparse_features.split( self._feature_splits, ) - for config, emb_op, features in zip( - self.grouped_configs, self._emb_modules, features_by_group + for config, emb_op, features, fs_auto_collection in zip( + self.grouped_configs, + self._emb_modules, + features_by_group, + self._feature_score_auto_collections, ): if ( config.has_feature_processor @@ -703,9 +724,19 @@ def forward( features = self._feature_processor(features) if config.is_weighted: - features._weights = CommOpGradientScaling.apply( + feature_weights = CommOpGradientScaling.apply( features._weights, self._scale_gradient_factor - ) + ).float() + + if fs_auto_collection and features.weights_or_none() is not None: + score_weights = features.weights().float() + assert ( + feature_weights.numel() == score_weights.numel() + ), f"feature_weights.numel() {feature_weights.numel()} != score_weights.numel() {score_weights.numel()}" + cat_weights = torch.cat( + [feature_weights, score_weights], dim=1 + ).view(torch.float64) + features._weights = cat_weights lookup = emb_op(features) embeddings.append(lookup) diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index fd6117884..00d211357 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -51,6 +51,10 @@ KJTList, ShardedEmbeddingModule, ) +from torchrec.distributed.feature_score_utils import ( + create_sharding_type_to_feature_score_mapping, + may_collect_feature_scores, +) from torchrec.distributed.fused_params import ( FUSED_PARAM_IS_SSD_TABLE, FUSED_PARAM_SSD_TABLE_LIST, @@ -565,6 +569,24 @@ def __init__( # forward pass flow control self._has_uninitialized_input_dist: bool = True self._has_features_permute: bool = True + + self._enable_feature_score_weight_accumulation: bool = False + self._enabled_feature_score_auto_collection: bool = False + self._sharding_type_feature_score_mapping: Dict[str, Dict[str, float]] = {} + ( + self._enable_feature_score_weight_accumulation, + self._enabled_feature_score_auto_collection, + self._sharding_type_feature_score_mapping, + ) = create_sharding_type_to_feature_score_mapping( + self._embedding_bag_configs, self.sharding_type_to_sharding_infos + ) + + logger.info( + f"EBC feature score weight accumulation enabled: {self._enable_feature_score_weight_accumulation}, " + f"auto collection enabled: {self._enabled_feature_score_auto_collection}, " + f"sharding type to feature score mapping: {self._sharding_type_feature_score_mapping}" + ) + # Get all fused optimizers and combine them. optims = [] for lookup in self._lookups: @@ -1565,6 +1587,11 @@ def input_dist( features_by_shards = features.split( self._feature_splits, ) + features_by_shards = may_collect_feature_scores( + features_by_shards, + self._enabled_feature_score_auto_collection, + self._sharding_type_feature_score_mapping, + ) awaitables = [] for input_dist, features_by_shard, sharding_type in zip( self._input_dists, diff --git a/torchrec/distributed/feature_score_utils.py b/torchrec/distributed/feature_score_utils.py index 4442a6e78..1c6c5ad53 100644 --- a/torchrec/distributed/feature_score_utils.py +++ b/torchrec/distributed/feature_score_utils.py @@ -17,7 +17,7 @@ from torchrec.distributed.embedding_types import ShardingType from torchrec.modules.embedding_configs import ( - EmbeddingConfig, + BaseEmbeddingConfig, FeatureScoreBasedEvictionPolicy, ) from torchrec.sparse.jagged_tensor import KeyedJaggedTensor @@ -26,7 +26,7 @@ def create_sharding_type_to_feature_score_mapping( - embedding_configs: Sequence[EmbeddingConfig], + embedding_configs: Sequence[BaseEmbeddingConfig], sharding_type_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]], ) -> Tuple[bool, bool, Dict[str, Dict[str, float]]]: enable_feature_score_weight_accumulation = False diff --git a/torchrec/distributed/tests/test_feature_score_utils.py b/torchrec/distributed/tests/test_feature_score_utils.py index 3916d3c3e..0798a5fb5 100644 --- a/torchrec/distributed/tests/test_feature_score_utils.py +++ b/torchrec/distributed/tests/test_feature_score_utils.py @@ -474,3 +474,234 @@ def test_auto_collection_preserves_device(self) -> None: weights = result[0].weights_or_none() self.assertIsNotNone(weights) self.assertEqual(weights.device, device) + + +class EmbeddingBagConfigSupportTest(unittest.TestCase): + def test_embedding_bag_config_with_auto_collection_enabled(self) -> None: + # Setup: create EmbeddingBagConfig with auto collection enabled + mock_embedding_bag_config = EmbeddingBagConfig( + name="table_0", + embedding_dim=64, + num_embeddings=100, + feature_names=["feature_0", "feature_1"], + use_virtual_table=True, + virtual_table_eviction_policy=FeatureScoreBasedEvictionPolicy( + feature_score_mapping={"feature_0": 1.5, "feature_1": 2.0}, + enable_auto_feature_score_collection=True, + ), + ) + + mock_param = torch.nn.Parameter(torch.randn(100, 64)) + mock_param_sharding = Mock(spec=ParameterSharding) + + sharding_info = EmbeddingShardingInfo( + embedding_config=_convert_to_table_config(mock_embedding_bag_config), + param_sharding=mock_param_sharding, + param=mock_param, + ) + + embedding_configs = [mock_embedding_bag_config] + sharding_type_to_sharding_infos = { + ShardingType.TABLE_WISE.value: [sharding_info], + } + + # Execute: run create_sharding_type_to_feature_score_mapping + ( + enable_weight_acc, + enable_auto_collection, + mapping, + ) = create_sharding_type_to_feature_score_mapping( + embedding_configs, sharding_type_to_sharding_infos + ) + + # Assert: both flags are enabled and mapping contains feature scores + self.assertTrue(enable_weight_acc) + self.assertTrue(enable_auto_collection) + self.assertIn(ShardingType.TABLE_WISE.value, mapping) + self.assertEqual( + mapping[ShardingType.TABLE_WISE.value], + {"feature_0": 1.5, "feature_1": 2.0}, + ) + + def test_embedding_bag_config_with_default_value(self) -> None: + # Setup: create EmbeddingBagConfig with default value for missing features + mock_embedding_bag_config = EmbeddingBagConfig( + name="table_0", + embedding_dim=64, + num_embeddings=100, + feature_names=["feature_0", "feature_1", "feature_2"], + use_virtual_table=True, + virtual_table_eviction_policy=FeatureScoreBasedEvictionPolicy( + feature_score_mapping={"feature_0": 1.5, "feature_1": 2.0}, + feature_score_default_value=0.5, + enable_auto_feature_score_collection=True, + ), + ) + + mock_param = torch.nn.Parameter(torch.randn(100, 64)) + mock_param_sharding = Mock(spec=ParameterSharding) + + sharding_info = EmbeddingShardingInfo( + embedding_config=_convert_to_table_config(mock_embedding_bag_config), + param_sharding=mock_param_sharding, + param=mock_param, + ) + + embedding_configs = [mock_embedding_bag_config] + sharding_type_to_sharding_infos = { + ShardingType.TABLE_WISE.value: [sharding_info], + } + + # Execute: run create_sharding_type_to_feature_score_mapping + ( + enable_weight_acc, + enable_auto_collection, + mapping, + ) = create_sharding_type_to_feature_score_mapping( + embedding_configs, sharding_type_to_sharding_infos + ) + + # Assert: mapping contains explicit scores and default for feature_2 + self.assertTrue(enable_weight_acc) + self.assertTrue(enable_auto_collection) + self.assertEqual( + mapping[ShardingType.TABLE_WISE.value], + {"feature_0": 1.5, "feature_1": 2.0, "feature_2": 0.5}, + ) + + def test_mixed_embedding_config_and_bag_config(self) -> None: + # Setup: create both EmbeddingConfig and EmbeddingBagConfig with auto collection + embedding_config = EmbeddingConfig( + name="table_0", + embedding_dim=64, + num_embeddings=100, + feature_names=["feature_0"], + use_virtual_table=True, + virtual_table_eviction_policy=FeatureScoreBasedEvictionPolicy( + feature_score_mapping={"feature_0": 1.0}, + enable_auto_feature_score_collection=True, + ), + ) + + embedding_bag_config = EmbeddingBagConfig( + name="table_1", + embedding_dim=32, + num_embeddings=50, + feature_names=["feature_1"], + use_virtual_table=True, + virtual_table_eviction_policy=FeatureScoreBasedEvictionPolicy( + feature_score_mapping={"feature_1": 2.0}, + enable_auto_feature_score_collection=True, + ), + ) + + mock_param_0 = torch.nn.Parameter(torch.randn(100, 64)) + mock_param_1 = torch.nn.Parameter(torch.randn(50, 32)) + mock_param_sharding = Mock(spec=ParameterSharding) + + sharding_info_0 = EmbeddingShardingInfo( + embedding_config=_convert_to_table_config(embedding_config), + param_sharding=mock_param_sharding, + param=mock_param_0, + ) + + sharding_info_1 = EmbeddingShardingInfo( + embedding_config=_convert_to_table_config(embedding_bag_config), + param_sharding=mock_param_sharding, + param=mock_param_1, + ) + + embedding_configs = [embedding_config, embedding_bag_config] + sharding_type_to_sharding_infos = { + ShardingType.TABLE_WISE.value: [sharding_info_0, sharding_info_1], + } + + # Execute: run create_sharding_type_to_feature_score_mapping + ( + enable_weight_acc, + enable_auto_collection, + mapping, + ) = create_sharding_type_to_feature_score_mapping( + embedding_configs, sharding_type_to_sharding_infos + ) + + # Assert: mapping contains scores from both config types + self.assertTrue(enable_weight_acc) + self.assertTrue(enable_auto_collection) + self.assertIn(ShardingType.TABLE_WISE.value, mapping) + self.assertEqual( + mapping[ShardingType.TABLE_WISE.value], + {"feature_0": 1.0, "feature_1": 2.0}, + ) + + def test_embedding_bag_config_without_virtual_table(self) -> None: + # Setup: create EmbeddingBagConfig without virtual table + embedding_bag_configs = [ + EmbeddingBagConfig( + name="table_0", + embedding_dim=64, + num_embeddings=100, + feature_names=["feature_0"], + ), + ] + sharding_type_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]] = {} + + # Execute: run create_sharding_type_to_feature_score_mapping + ( + enable_weight_acc, + enable_auto_collection, + mapping, + ) = create_sharding_type_to_feature_score_mapping( + embedding_bag_configs, sharding_type_to_sharding_infos + ) + + # Assert: both flags should be False and mapping should be empty + self.assertFalse(enable_weight_acc) + self.assertFalse(enable_auto_collection) + self.assertEqual(mapping, {}) + + def test_embedding_bag_config_with_eviction_ttl_mins(self) -> None: + # Setup: create EmbeddingBagConfig with positive eviction_ttl_mins + mock_embedding_bag_config = EmbeddingBagConfig( + name="table_0", + embedding_dim=64, + num_embeddings=100, + feature_names=["feature_0", "feature_1"], + use_virtual_table=True, + virtual_table_eviction_policy=FeatureScoreBasedEvictionPolicy( + feature_score_mapping={}, + eviction_ttl_mins=60, + enable_auto_feature_score_collection=True, + ), + ) + + mock_param = torch.nn.Parameter(torch.randn(100, 64)) + mock_param_sharding = Mock(spec=ParameterSharding) + + sharding_info = EmbeddingShardingInfo( + embedding_config=_convert_to_table_config(mock_embedding_bag_config), + param_sharding=mock_param_sharding, + param=mock_param, + ) + + embedding_configs = [mock_embedding_bag_config] + sharding_type_to_sharding_infos = { + ShardingType.TABLE_WISE.value: [sharding_info], + } + + # Execute: run create_sharding_type_to_feature_score_mapping + ( + enable_weight_acc, + enable_auto_collection, + mapping, + ) = create_sharding_type_to_feature_score_mapping( + embedding_configs, sharding_type_to_sharding_infos + ) + + # Assert: all features get 0.0 score when eviction_ttl_mins is positive + self.assertTrue(enable_weight_acc) + self.assertTrue(enable_auto_collection) + self.assertEqual( + mapping[ShardingType.TABLE_WISE.value], + {"feature_0": 0.0, "feature_1": 0.0}, + ) diff --git a/torchrec/distributed/tests/test_sequence_model_parallel.py b/torchrec/distributed/tests/test_sequence_model_parallel.py index dbf91d446..488356846 100644 --- a/torchrec/distributed/tests/test_sequence_model_parallel.py +++ b/torchrec/distributed/tests/test_sequence_model_parallel.py @@ -390,7 +390,11 @@ class DedupIndicesWeightAccumulationTest(unittest.TestCase): This tests the correctness of the new scatter_add_along_first_dim implementation. """ - # to be deleted + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) def test_dedup_indices_weight_accumulation(self) -> None: """ Test the _dedup_indices method to ensure weight accumulation works correctly