diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 6fa024334..2f6c6b9ed 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -46,6 +46,11 @@ ShardedEmbeddingModule, ShardingType, ) + +from torchrec.distributed.feature_score_utils import ( + create_sharding_type_to_feature_score_mapping, + may_collect_feature_scores, +) from torchrec.distributed.fused_params import ( FUSED_PARAM_IS_SSD_TABLE, FUSED_PARAM_SSD_TABLE_LIST, @@ -90,7 +95,6 @@ from torchrec.modules.embedding_configs import ( EmbeddingConfig, EmbeddingTableConfig, - FeatureScoreBasedEvictionPolicy, PoolingType, ) from torchrec.modules.embedding_modules import ( @@ -463,12 +467,12 @@ def __init__( ] = { sharding_type: self.create_embedding_sharding( sharding_type=sharding_type, - sharding_infos=embedding_confings, + sharding_infos=embedding_configs, env=env, device=device, qcomm_codecs_registry=self.qcomm_codecs_registry, ) - for sharding_type, embedding_confings in sharding_type_to_sharding_infos.items() + for sharding_type, embedding_configs in sharding_type_to_sharding_infos.items() } self.enable_embedding_update: bool = any( @@ -490,16 +494,20 @@ def __init__( self._has_uninitialized_input_dist: bool = True logger.info(f"EC index dedup enabled: {self._use_index_dedup}.") - for config in self._embedding_configs: - virtual_table_eviction_policy = config.virtual_table_eviction_policy - if virtual_table_eviction_policy is not None and isinstance( - virtual_table_eviction_policy, FeatureScoreBasedEvictionPolicy - ): - self._enable_feature_score_weight_accumulation = True - break - + self._enable_feature_score_weight_accumulation: bool = False + self._enabled_feature_score_auto_collection: bool = False + self._sharding_type_feature_score_mapping: Dict[str, Dict[str, float]] = {} + ( + self._enable_feature_score_weight_accumulation, + self._enabled_feature_score_auto_collection, + self._sharding_type_feature_score_mapping, + ) = create_sharding_type_to_feature_score_mapping( + self._embedding_configs, sharding_type_to_sharding_infos + ) logger.info( - f"EC feature score weight accumulation enabled: {self._enable_feature_score_weight_accumulation}." + f"EC feature score weight accumulation enabled: {self._enable_feature_score_weight_accumulation}, " + f"auto collection enabled: {self._enabled_feature_score_auto_collection}, " + f"sharding type to feature score mapping: {self._sharding_type_feature_score_mapping}" ) # Get all fused optimizers and combine them. @@ -1361,22 +1369,22 @@ def _dedup_indices( source_weights.dtype == torch.float32 ), "Only float32 weights are supported for feature score eviction weights." - acc_weights = torch.ops.fbgemm.jagged_acc_weights_and_counts( - source_weights.view(-1), - reverse_indices, + # Accumulate weights using scatter_add + acc_weights = torch.zeros( unique_indices.numel(), + dtype=torch.float32, + device=source_weights.device, ) + # Use PyTorch's scatter_add to accumulate weights + acc_weights.scatter_add_(0, reverse_indices, source_weights) + dedup_features = KeyedJaggedTensor( keys=input_feature.keys(), lengths=lengths, offsets=offsets, values=unique_indices, - weights=( - acc_weights.view(torch.float64).view(-1) - if acc_weights is not None - else None - ), + weights=(acc_weights.view(-1) if acc_weights is not None else None), ) ctx.input_features.append(input_feature) @@ -1495,6 +1503,11 @@ def input_dist( self._features_order_tensor, ) features_by_shards = features.split(self._feature_splits) + features_by_shards = may_collect_feature_scores( + features_by_shards, + self._enabled_feature_score_auto_collection, + self._sharding_type_feature_score_mapping, + ) if self._use_index_dedup: features_by_shards = self._dedup_indices(ctx, features_by_shards) diff --git a/torchrec/distributed/feature_score_utils.py b/torchrec/distributed/feature_score_utils.py new file mode 100644 index 000000000..4442a6e78 --- /dev/null +++ b/torchrec/distributed/feature_score_utils.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import logging +from typing import Dict, List, Sequence, Tuple + +import torch + +from torch.autograd.profiler import record_function + +from torchrec.distributed.embedding_sharding import EmbeddingShardingInfo +from torchrec.distributed.embedding_types import ShardingType + +from torchrec.modules.embedding_configs import ( + EmbeddingConfig, + FeatureScoreBasedEvictionPolicy, +) +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + +logger: logging.Logger = logging.getLogger(__name__) + + +def create_sharding_type_to_feature_score_mapping( + embedding_configs: Sequence[EmbeddingConfig], + sharding_type_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]], +) -> Tuple[bool, bool, Dict[str, Dict[str, float]]]: + enable_feature_score_weight_accumulation = False + enabled_feature_score_auto_collection = False + + # Validation for virtual table configurations + virtual_tables = [ + config for config in embedding_configs if config.use_virtual_table + ] + if virtual_tables: + virtual_tables_with_eviction = [ + config + for config in virtual_tables + if config.virtual_table_eviction_policy is not None + ] + if virtual_tables_with_eviction: + # Check if any virtual table uses FeatureScoreBasedEvictionPolicy + tables_with_feature_score_policy = [ + config + for config in virtual_tables_with_eviction + if isinstance( + config.virtual_table_eviction_policy, + FeatureScoreBasedEvictionPolicy, + ) + ] + + # If any virtual table uses FeatureScoreBasedEvictionPolicy, + # then ALL virtual tables with eviction policies must use FeatureScoreBasedEvictionPolicy + if tables_with_feature_score_policy: + assert all( + isinstance( + config.virtual_table_eviction_policy, + FeatureScoreBasedEvictionPolicy, + ) + for config in virtual_tables_with_eviction + ), "If any virtual table uses FeatureScoreBasedEvictionPolicy, all virtual tables with eviction policies must use FeatureScoreBasedEvictionPolicy" + enable_feature_score_weight_accumulation = True + + # Check if any table has enable_auto_feature_score_collection=True + tables_with_auto_collection = [ + config + for config in tables_with_feature_score_policy + if config.virtual_table_eviction_policy is not None + and isinstance( + config.virtual_table_eviction_policy, + FeatureScoreBasedEvictionPolicy, + ) + and config.virtual_table_eviction_policy.enable_auto_feature_score_collection + ] + if tables_with_auto_collection: + # All virtual tables with FeatureScoreBasedEvictionPolicy must have enable_auto_feature_score_collection=True + assert all( + config.virtual_table_eviction_policy is not None + and isinstance( + config.virtual_table_eviction_policy, + FeatureScoreBasedEvictionPolicy, + ) + and config.virtual_table_eviction_policy.enable_auto_feature_score_collection + for config in tables_with_feature_score_policy + ), "If any virtual table has enable_auto_feature_score_collection=True, all virtual tables with FeatureScoreBasedEvictionPolicy must have enable_auto_feature_score_collection=True" + enabled_feature_score_auto_collection = True + + sharding_type_feature_score_mapping: Dict[str, Dict[str, float]] = {} + if enabled_feature_score_auto_collection: + for ( + sharding_type, + sharding_info, + ) in sharding_type_to_sharding_infos.items(): + feature_score_mapping: Dict[str, float] = {} + if sharding_type == ShardingType.DATA_PARALLEL.value: + sharding_type_feature_score_mapping[sharding_type] = ( + feature_score_mapping + ) + continue + for config in sharding_info: + vtep = config.embedding_config.virtual_table_eviction_policy + if vtep is not None and isinstance( + vtep, FeatureScoreBasedEvictionPolicy + ): + if vtep.eviction_ttl_mins > 0: + logger.info( + f"Virtual table eviction policy enabled for table {config.embedding_config.name} {sharding_type} with eviction TTL {vtep.eviction_ttl_mins} mins." + ) + feature_score_mapping.update( + dict.fromkeys(config.embedding_config.feature_names, 0.0) + ) + continue + for k in config.embedding_config.feature_names: + if ( + k + # pyre-ignore [16] + in config.embedding_config.virtual_table_eviction_policy.feature_score_mapping + ): + feature_score_mapping[k] = ( + config.embedding_config.virtual_table_eviction_policy.feature_score_mapping[ + k + ] + ) + else: + assert ( + # pyre-ignore [16] + config.embedding_config.virtual_table_eviction_policy.feature_score_default_value + is not None + ), f"Table {config.embedding_config.name} eviction policy feature_score_default_value is not set but feature {k} is not in feature_score_mapping." + feature_score_mapping[k] = ( + config.embedding_config.virtual_table_eviction_policy.feature_score_default_value + ) + sharding_type_feature_score_mapping[sharding_type] = feature_score_mapping + return ( + enable_feature_score_weight_accumulation, + enabled_feature_score_auto_collection, + sharding_type_feature_score_mapping, + ) + + +@torch.fx.wrap +def may_collect_feature_scores( + input_feature_splits: List[KeyedJaggedTensor], + enabled_feature_score_auto_collection: bool, + sharding_type_feature_score_mapping: Dict[str, Dict[str, float]], +) -> List[KeyedJaggedTensor]: + if not enabled_feature_score_auto_collection: + return input_feature_splits + with record_function("## collect_feature_score ##"): + for features, mapping in zip( + input_feature_splits, sharding_type_feature_score_mapping.values() + ): + assert ( + features.weights_or_none() is None + ), f"Auto feature collection: {features.keys()=} has non empty weights" + if ( + mapping is None or len(mapping) == 0 + ): # collection is disabled fir this sharding type + continue + feature_score_weights = [] + device = features.device() + for f in features.keys(): + # input dist includes multiple lookups input including both virtual table and non-virtual table features. + # We needs to attach weights for all features due to KJT weights requirements, so set 0.0 score for non virtual table features + score = mapping[f] if f in mapping else 0.0 + feature_score_weights.append( + torch.ones_like( + features[f].values(), + dtype=torch.float32, + device=device, + ) + * score + ) + features._weights = ( + torch.cat(feature_score_weights, dim=0) + if feature_score_weights + else None + ) + return input_feature_splits diff --git a/torchrec/distributed/tests/test_feature_score_utils.py b/torchrec/distributed/tests/test_feature_score_utils.py new file mode 100644 index 000000000..3916d3c3e --- /dev/null +++ b/torchrec/distributed/tests/test_feature_score_utils.py @@ -0,0 +1,476 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import Dict, List +from unittest.mock import Mock + +import torch +from torchrec.distributed.embedding_sharding import EmbeddingShardingInfo +from torchrec.distributed.embedding_types import ShardingType +from torchrec.distributed.feature_score_utils import ( + create_sharding_type_to_feature_score_mapping, + may_collect_feature_scores, +) +from torchrec.distributed.types import ParameterSharding +from torchrec.modules.embedding_configs import ( + EmbeddingBagConfig, + EmbeddingConfig, + EmbeddingTableConfig, + FeatureScoreBasedEvictionPolicy, +) +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +def _convert_to_table_config( + config: EmbeddingConfig | EmbeddingBagConfig, +) -> EmbeddingTableConfig: + """Convert EmbeddingConfig or EmbeddingBagConfig to EmbeddingTableConfig for sharding info""" + pooling = getattr(config, "pooling", None) + if pooling is None: + from torchrec.modules.embedding_configs import PoolingType + + pooling = PoolingType.SUM + + return EmbeddingTableConfig( + num_embeddings=config.num_embeddings, + embedding_dim=config.embedding_dim, + name=config.name, + data_type=config.data_type, + feature_names=config.feature_names, + pooling=pooling, + is_weighted=False, + has_feature_processor=False, + embedding_names=[config.name], + weight_init_max=config.weight_init_max, + weight_init_min=config.weight_init_min, + use_virtual_table=config.use_virtual_table, + virtual_table_eviction_policy=config.virtual_table_eviction_policy, + ) + + +class CreateShardingTypeToFeatureScoreMappingTest(unittest.TestCase): + def test_no_virtual_tables(self) -> None: + # Setup: create embedding configs without virtual tables + embedding_configs = [ + EmbeddingConfig( + name="table_0", + embedding_dim=64, + num_embeddings=100, + feature_names=["feature_0", "feature_1"], + ), + ] + sharding_type_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]] = {} + + # Execute: run create_sharding_type_to_feature_score_mapping + ( + enable_weight_acc, + enable_auto_collection, + mapping, + ) = create_sharding_type_to_feature_score_mapping( + embedding_configs, sharding_type_to_sharding_infos + ) + + # Assert: both flags should be False and mapping should be empty + self.assertFalse(enable_weight_acc) + self.assertFalse(enable_auto_collection) + self.assertEqual(mapping, {}) + + def test_virtual_table_without_eviction_policy(self) -> None: + # Setup: create virtual table without eviction policy + embedding_configs = [ + EmbeddingConfig( + name="table_0", + embedding_dim=64, + num_embeddings=100, + feature_names=["feature_0"], + use_virtual_table=True, + ), + ] + sharding_type_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]] = {} + + # Execute: run create_sharding_type_to_feature_score_mapping + ( + enable_weight_acc, + enable_auto_collection, + mapping, + ) = create_sharding_type_to_feature_score_mapping( + embedding_configs, sharding_type_to_sharding_infos + ) + + # Assert: both flags should be False and mapping should be empty + self.assertFalse(enable_weight_acc) + self.assertFalse(enable_auto_collection) + self.assertEqual(mapping, {}) + + def test_virtual_table_with_feature_score_policy_without_auto_collection( + self, + ) -> None: + # Setup: create virtual table with feature score policy but without auto collection + embedding_configs = [ + EmbeddingConfig( + name="table_0", + embedding_dim=64, + num_embeddings=100, + feature_names=["feature_0"], + use_virtual_table=True, + virtual_table_eviction_policy=FeatureScoreBasedEvictionPolicy( + feature_score_mapping={"feature_0": 1.0}, + enable_auto_feature_score_collection=False, + ), + ), + ] + sharding_type_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]] = {} + + # Execute: run create_sharding_type_to_feature_score_mapping + ( + enable_weight_acc, + enable_auto_collection, + mapping, + ) = create_sharding_type_to_feature_score_mapping( + embedding_configs, sharding_type_to_sharding_infos + ) + + # Assert: weight accumulation is enabled but auto collection is not + self.assertTrue(enable_weight_acc) + self.assertFalse(enable_auto_collection) + self.assertEqual(mapping, {}) + + def test_virtual_table_with_auto_collection_enabled(self) -> None: + # Setup: create virtual table with auto collection enabled + mock_embedding_config = EmbeddingConfig( + name="table_0", + embedding_dim=64, + num_embeddings=100, + feature_names=["feature_0", "feature_1"], + use_virtual_table=True, + virtual_table_eviction_policy=FeatureScoreBasedEvictionPolicy( + feature_score_mapping={"feature_0": 1.5, "feature_1": 2.0}, + enable_auto_feature_score_collection=True, + ), + ) + + mock_param = torch.nn.Parameter(torch.randn(100, 64)) + mock_param_sharding = Mock(spec=ParameterSharding) + + sharding_info = EmbeddingShardingInfo( + embedding_config=_convert_to_table_config(mock_embedding_config), + param_sharding=mock_param_sharding, + param=mock_param, + ) + + embedding_configs = [mock_embedding_config] + sharding_type_to_sharding_infos = { + ShardingType.TABLE_WISE.value: [sharding_info], + } + + # Execute: run create_sharding_type_to_feature_score_mapping + ( + enable_weight_acc, + enable_auto_collection, + mapping, + ) = create_sharding_type_to_feature_score_mapping( + embedding_configs, sharding_type_to_sharding_infos + ) + + # Assert: both flags are enabled and mapping contains feature scores + self.assertTrue(enable_weight_acc) + self.assertTrue(enable_auto_collection) + self.assertIn(ShardingType.TABLE_WISE.value, mapping) + self.assertEqual( + mapping[ShardingType.TABLE_WISE.value], + {"feature_0": 1.5, "feature_1": 2.0}, + ) + + def test_virtual_table_with_default_value(self) -> None: + # Setup: create virtual table with default value for missing features + mock_embedding_config = EmbeddingConfig( + name="table_0", + embedding_dim=64, + num_embeddings=100, + feature_names=["feature_0", "feature_1"], + use_virtual_table=True, + virtual_table_eviction_policy=FeatureScoreBasedEvictionPolicy( + feature_score_mapping={"feature_0": 1.5}, + feature_score_default_value=0.5, + enable_auto_feature_score_collection=True, + ), + ) + + mock_param = torch.nn.Parameter(torch.randn(100, 64)) + mock_param_sharding = Mock(spec=ParameterSharding) + + sharding_info = EmbeddingShardingInfo( + embedding_config=_convert_to_table_config(mock_embedding_config), + param_sharding=mock_param_sharding, + param=mock_param, + ) + + embedding_configs = [mock_embedding_config] + sharding_type_to_sharding_infos = { + ShardingType.TABLE_WISE.value: [sharding_info], + } + + # Execute: run create_sharding_type_to_feature_score_mapping + ( + enable_weight_acc, + enable_auto_collection, + mapping, + ) = create_sharding_type_to_feature_score_mapping( + embedding_configs, sharding_type_to_sharding_infos + ) + + # Assert: mapping contains explicit score for feature_0 and default for feature_1 + self.assertTrue(enable_weight_acc) + self.assertTrue(enable_auto_collection) + self.assertEqual( + mapping[ShardingType.TABLE_WISE.value], + {"feature_0": 1.5, "feature_1": 0.5}, + ) + + def test_data_parallel_sharding_type(self) -> None: + # Setup: create virtual table with data parallel sharding + mock_embedding_config = EmbeddingConfig( + name="table_0", + embedding_dim=64, + num_embeddings=100, + feature_names=["feature_0"], + use_virtual_table=True, + virtual_table_eviction_policy=FeatureScoreBasedEvictionPolicy( + feature_score_mapping={"feature_0": 1.0}, + enable_auto_feature_score_collection=True, + ), + ) + + embedding_configs = [mock_embedding_config] + sharding_type_to_sharding_infos = { + ShardingType.DATA_PARALLEL.value: [], + } + + # Execute: run create_sharding_type_to_feature_score_mapping + ( + enable_weight_acc, + enable_auto_collection, + mapping, + ) = create_sharding_type_to_feature_score_mapping( + embedding_configs, sharding_type_to_sharding_infos + ) + + # Assert: data parallel sharding has empty mapping + self.assertTrue(enable_weight_acc) + self.assertTrue(enable_auto_collection) + self.assertIn(ShardingType.DATA_PARALLEL.value, mapping) + self.assertEqual(mapping[ShardingType.DATA_PARALLEL.value], {}) + + def test_eviction_ttl_mins_positive(self) -> None: + # Setup: create virtual table with positive eviction_ttl_mins + mock_embedding_config = EmbeddingConfig( + name="table_0", + embedding_dim=64, + num_embeddings=100, + feature_names=["feature_0", "feature_1"], + use_virtual_table=True, + virtual_table_eviction_policy=FeatureScoreBasedEvictionPolicy( + feature_score_mapping={}, + eviction_ttl_mins=60, + enable_auto_feature_score_collection=True, + ), + ) + + mock_param = torch.nn.Parameter(torch.randn(100, 64)) + mock_param_sharding = Mock(spec=ParameterSharding) + + sharding_info = EmbeddingShardingInfo( + embedding_config=_convert_to_table_config(mock_embedding_config), + param_sharding=mock_param_sharding, + param=mock_param, + ) + + embedding_configs = [mock_embedding_config] + sharding_type_to_sharding_infos = { + ShardingType.TABLE_WISE.value: [sharding_info], + } + + # Execute: run create_sharding_type_to_feature_score_mapping + ( + enable_weight_acc, + enable_auto_collection, + mapping, + ) = create_sharding_type_to_feature_score_mapping( + embedding_configs, sharding_type_to_sharding_infos + ) + + # Assert: all features get 0.0 score when eviction_ttl_mins is positive + self.assertTrue(enable_weight_acc) + self.assertTrue(enable_auto_collection) + self.assertEqual( + mapping[ShardingType.TABLE_WISE.value], + {"feature_0": 0.0, "feature_1": 0.0}, + ) + + +class MayCollectFeatureScoresTest(unittest.TestCase): + def test_auto_collection_disabled(self) -> None: + # Setup: create input features without auto collection enabled + input_features = KeyedJaggedTensor( + keys=["feature_0"], + values=torch.tensor([1, 2, 3]), + lengths=torch.tensor([3]), + ) + input_feature_splits = [input_features] + enabled_feature_score_auto_collection = False + sharding_type_feature_score_mapping: Dict[str, Dict[str, float]] = {} + + # Execute: run may_collect_feature_scores + result = may_collect_feature_scores( + input_feature_splits, + enabled_feature_score_auto_collection, + sharding_type_feature_score_mapping, + ) + + # Assert: input should be returned unchanged + self.assertEqual(result, input_feature_splits) + self.assertIsNone(result[0].weights_or_none()) + + def test_auto_collection_with_empty_mapping(self) -> None: + # Setup: create input features with empty mapping + input_features = KeyedJaggedTensor( + keys=["feature_0"], + values=torch.tensor([1, 2, 3]), + lengths=torch.tensor([3]), + ) + input_feature_splits = [input_features] + enabled_feature_score_auto_collection = True + sharding_type_feature_score_mapping = {"table_wise": {}} + + # Execute: run may_collect_feature_scores + result = may_collect_feature_scores( + input_feature_splits, + enabled_feature_score_auto_collection, + sharding_type_feature_score_mapping, + ) + + # Assert: input should be returned without weights added + self.assertEqual(len(result), 1) + self.assertIsNone(result[0].weights_or_none()) + + def test_auto_collection_with_feature_scores(self) -> None: + # Setup: create input features with feature score mapping + input_features = KeyedJaggedTensor( + keys=["feature_0", "feature_1"], + values=torch.tensor([1, 2, 3, 4, 5]), + lengths=torch.tensor([2, 3]), + ) + input_feature_splits = [input_features] + enabled_feature_score_auto_collection = True + sharding_type_feature_score_mapping = { + "table_wise": {"feature_0": 1.5, "feature_1": 2.0} + } + + # Execute: run may_collect_feature_scores + result = may_collect_feature_scores( + input_feature_splits, + enabled_feature_score_auto_collection, + sharding_type_feature_score_mapping, + ) + + # Assert: weights should be attached with correct scores + self.assertEqual(len(result), 1) + weights = result[0].weights_or_none() + self.assertIsNotNone(weights) + self.assertEqual(weights.shape[0], 5) + self.assertTrue(torch.allclose(weights[:2], torch.tensor([1.5, 1.5]))) + self.assertTrue(torch.allclose(weights[2:], torch.tensor([2.0, 2.0, 2.0]))) + + def test_auto_collection_with_missing_feature_in_mapping(self) -> None: + # Setup: create input features with one feature not in mapping + input_features = KeyedJaggedTensor( + keys=["feature_0", "feature_1"], + values=torch.tensor([1, 2, 3]), + lengths=torch.tensor([1, 2]), + ) + input_feature_splits = [input_features] + enabled_feature_score_auto_collection = True + sharding_type_feature_score_mapping = {"table_wise": {"feature_0": 1.5}} + + # Execute: run may_collect_feature_scores + result = may_collect_feature_scores( + input_feature_splits, + enabled_feature_score_auto_collection, + sharding_type_feature_score_mapping, + ) + + # Assert: missing feature should get 0.0 score + self.assertEqual(len(result), 1) + weights = result[0].weights_or_none() + self.assertIsNotNone(weights) + self.assertEqual(weights.shape[0], 3) + self.assertEqual(weights[0].item(), 1.5) + self.assertEqual(weights[1].item(), 0.0) + self.assertEqual(weights[2].item(), 0.0) + + def test_auto_collection_with_multiple_feature_splits(self) -> None: + # Setup: create multiple input feature splits + input_features_1 = KeyedJaggedTensor( + keys=["feature_0"], + values=torch.tensor([1, 2]), + lengths=torch.tensor([2]), + ) + input_features_2 = KeyedJaggedTensor( + keys=["feature_1"], + values=torch.tensor([3, 4, 5]), + lengths=torch.tensor([3]), + ) + input_feature_splits = [input_features_1, input_features_2] + enabled_feature_score_auto_collection = True + sharding_type_feature_score_mapping = { + "sharding_1": {"feature_0": 1.0}, + "sharding_2": {"feature_1": 2.0}, + } + + # Execute: run may_collect_feature_scores + result = may_collect_feature_scores( + input_feature_splits, + enabled_feature_score_auto_collection, + sharding_type_feature_score_mapping, + ) + + # Assert: each split should have appropriate weights + self.assertEqual(len(result), 2) + weights_1 = result[0].weights_or_none() + weights_2 = result[1].weights_or_none() + self.assertIsNotNone(weights_1) + self.assertIsNotNone(weights_2) + self.assertTrue(torch.allclose(weights_1, torch.tensor([1.0, 1.0]))) + self.assertTrue(torch.allclose(weights_2, torch.tensor([2.0, 2.0, 2.0]))) + + def test_auto_collection_preserves_device(self) -> None: + # Setup: create input features on GPU if available + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + input_features = KeyedJaggedTensor( + keys=["feature_0"], + values=torch.tensor([1, 2, 3], device=device), + lengths=torch.tensor([3], device=device), + ) + input_feature_splits = [input_features] + enabled_feature_score_auto_collection = True + sharding_type_feature_score_mapping = {"table_wise": {"feature_0": 1.5}} + + # Execute: run may_collect_feature_scores + result = may_collect_feature_scores( + input_feature_splits, + enabled_feature_score_auto_collection, + sharding_type_feature_score_mapping, + ) + + # Assert: weights should be on the same device as input + self.assertEqual(len(result), 1) + weights = result[0].weights_or_none() + self.assertIsNotNone(weights) + self.assertEqual(weights.device, device) diff --git a/torchrec/distributed/tests/test_sequence_model_parallel.py b/torchrec/distributed/tests/test_sequence_model_parallel.py index d13d819c3..dbf91d446 100644 --- a/torchrec/distributed/tests/test_sequence_model_parallel.py +++ b/torchrec/distributed/tests/test_sequence_model_parallel.py @@ -15,6 +15,10 @@ import torch from fbgemm_gpu.split_embedding_configs import EmbOptimType from hypothesis import assume, given, settings, Verbosity +from torchrec.distributed.embedding import ( + EmbeddingCollectionContext, + ShardedEmbeddingCollection, +) from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.fbgemm_qcomm_codec import CommType, QCommsConfig from torchrec.distributed.planner import ParameterConstraints @@ -27,6 +31,8 @@ ) from torchrec.distributed.types import ShardingType from torchrec.modules.embedding_configs import EmbeddingConfig +from torchrec.modules.embedding_modules import EmbeddingCollection +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor from torchrec.test_utils import seed_and_log, skip_if_asan_class @@ -378,6 +384,144 @@ def _test_sharding( ) +class DedupIndicesWeightAccumulationTest(unittest.TestCase): + """ + Test suite for validating the _dedup_indices method weight accumulation logic. + This tests the correctness of the new scatter_add_along_first_dim implementation. + """ + + # to be deleted + def test_dedup_indices_weight_accumulation(self) -> None: + """ + Test the _dedup_indices method to ensure weight accumulation works correctly + with the new scatter_add_along_first_dim implementation. + """ + # Setup: Create a minimal ShardedEmbeddingCollection for testing + device = torch.device("cuda:0") + + # Create a mock ShardedEmbeddingCollection with minimal setup + class MockShardedEmbeddingCollection: + def __init__(self): + self._enable_feature_score_weight_accumulation = True + self._device = device + # Register required buffers for _dedup_indices + self._buffers = {} + + # Mock hash_size_cumsum_tensor_0 - cumulative sum of embedding table sizes + self._buffers["_hash_size_cumsum_tensor_0"] = torch.tensor( + [0, 10], dtype=torch.int64, device=device + ) + # Mock hash_size_offset_tensor_0 - offset for each feature + self._buffers["_hash_size_offset_tensor_0"] = torch.tensor( + [0], dtype=torch.int64, device=device + ) + + def get_buffer(self, name: str) -> torch.Tensor: + return self._buffers[name] + + def _dedup_indices( + self, + ctx: EmbeddingCollectionContext, + input_feature_splits: List[KeyedJaggedTensor], + ) -> List[KeyedJaggedTensor]: + # Copy the actual _dedup_indices logic for testing + features_by_shards = [] + for i, input_feature in enumerate(input_feature_splits): + hash_size_cumsum = self.get_buffer(f"_hash_size_cumsum_tensor_{i}") + hash_size_offset = self.get_buffer(f"_hash_size_offset_tensor_{i}") + ( + lengths, + offsets, + unique_indices, + reverse_indices, + ) = torch.ops.fbgemm.jagged_unique_indices( + hash_size_cumsum, + hash_size_offset, + input_feature.offsets().to(torch.int64), + input_feature.values().to(torch.int64), + ) + acc_weights = None + if ( + self._enable_feature_score_weight_accumulation + and input_feature.weights_or_none() is not None + ): + source_weights = input_feature.weights() + assert ( + source_weights.dtype == torch.float32 + ), "Only float32 weights are supported for feature score eviction weights." + + # Accumulate weights using scatter_add + acc_weights = torch.zeros( + unique_indices.numel(), + dtype=torch.float32, + device=source_weights.device, + ) + + # Use PyTorch's scatter_add to accumulate weights + acc_weights.scatter_add_(0, reverse_indices, source_weights) + + features_by_shards.append( + KeyedJaggedTensor( + keys=input_feature.keys(), + values=unique_indices, + weights=acc_weights, + lengths=lengths, + offsets=offsets, + ) + ) + return features_by_shards + + # Create mock ShardedEmbeddingCollection instance + sharded_ec = MockShardedEmbeddingCollection() + + # Create test input with duplicate indices and varying weights + values = torch.tensor( + [0, 1, 0, 2, 1, 0], dtype=torch.int64, device=device + ) # Indices with duplicates + weights = torch.tensor( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=torch.float32, device=device + ) # Corresponding weights + lengths = torch.tensor( + [6], dtype=torch.int64, device=device + ) # Single feature with 6 values + + kjt_input = KeyedJaggedTensor( + keys=["feature_0"], + values=values, + weights=weights, + lengths=lengths, + ) + + # Execute: Run _dedup_indices method + ctx = EmbeddingCollectionContext() + features_by_shards = sharded_ec._dedup_indices(ctx, [kjt_input]) + + # Assert: Validate accumulated weights and counts + dedup_feature = features_by_shards[0] + self.assertIsNotNone(dedup_feature.weights_or_none()) + + # Reconstruct accumulated weights tensor (weights are stored as flattened float64 view) + acc_weights = dedup_feature.weights().view(torch.float32).view(-1, 1) + + # Expected results based on duplicate indices: + # Index 0 appears 3 times with weights [1.0, 3.0, 6.0] -> sum = 10.0, count = 3 + # Index 1 appears 2 times with weights [2.0, 5.0] -> sum = 7.0, count = 2 + # Index 2 appears 1 time with weight [4.0] -> sum = 4.0, count = 1 + + unique_values = dedup_feature.values() + self.assertEqual(len(unique_values), 3) # Should have 3 unique indices + + # Find positions of each unique index (order may vary after deduplication) + idx_0_pos = (unique_values == 0).nonzero(as_tuple=True)[0][0] + idx_1_pos = (unique_values == 1).nonzero(as_tuple=True)[0][0] + idx_2_pos = (unique_values == 2).nonzero(as_tuple=True)[0][0] + + # Validate accumulated weights (column 0) and counts (column 1) + self.assertAlmostEqual(acc_weights[idx_0_pos, 0].item(), 10.0, places=5) + self.assertAlmostEqual(acc_weights[idx_1_pos, 0].item(), 7.0, places=5) + self.assertAlmostEqual(acc_weights[idx_2_pos, 0].item(), 4.0, places=5) + + @skip_if_asan_class class TDSequenceModelParallelTest(SequenceModelParallelTest): diff --git a/torchrec/modules/embedding_configs.py b/torchrec/modules/embedding_configs.py index f5d099db4..e8c01f70a 100644 --- a/torchrec/modules/embedding_configs.py +++ b/torchrec/modules/embedding_configs.py @@ -240,6 +240,9 @@ class FeatureScoreBasedEvictionPolicy(VirtualTableEvictionPolicy): None # 0 means no eviction ) inference_eviction_ttl_mins: Optional[int] = None # 0 means no eviction + feature_score_mapping: Optional[Dict[str, float]] = None # feature score mapping + feature_score_default_value: Optional[float] = None # default feature score value + enable_auto_feature_score_collection: bool = False def __post_init__(self) -> None: if self.inference_eviction_feature_score_threshold is None: @@ -248,6 +251,19 @@ def __post_init__(self) -> None: self.inference_eviction_ttl_mins = self.eviction_ttl_mins if self.max_inference_id_num_per_rank == 0: self.max_inference_id_num_per_rank = self.training_id_keep_count + if self.enable_auto_feature_score_collection: + if self.feature_score_mapping is None: + self.feature_score_mapping = {} + + +@dataclass +class FeatureScoreMapping: + """ + Feature score mapping for virtual table. + """ + + feature_score_mapping: Dict[str, float] = field(default_factory=dict) + eviction_enabled: bool = False @dataclass