|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | +# All rights reserved. |
| 4 | +# |
| 5 | +# This source code is licensed under the BSD-style license found in the |
| 6 | +# LICENSE file in the root directory of this source tree. |
| 7 | + |
| 8 | +# pyre-strict |
| 9 | +import logging |
| 10 | +from typing import Dict, List, Sequence, Tuple |
| 11 | + |
| 12 | +import torch |
| 13 | + |
| 14 | +from torch.autograd.profiler import record_function |
| 15 | + |
| 16 | +from torchrec.distributed.embedding_sharding import EmbeddingShardingInfo |
| 17 | +from torchrec.distributed.embedding_types import ShardingType |
| 18 | + |
| 19 | +from torchrec.modules.embedding_configs import ( |
| 20 | + EmbeddingConfig, |
| 21 | + FeatureScoreBasedEvictionPolicy, |
| 22 | +) |
| 23 | +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor |
| 24 | + |
| 25 | +logger: logging.Logger = logging.getLogger(__name__) |
| 26 | + |
| 27 | + |
| 28 | +def create_sharding_type_to_feature_score_mapping( |
| 29 | + embedding_configs: Sequence[EmbeddingConfig], |
| 30 | + sharding_type_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]], |
| 31 | +) -> Tuple[bool, bool, Dict[str, Dict[str, float]]]: |
| 32 | + enable_feature_score_weight_accumulation = False |
| 33 | + enabled_feature_score_auto_collection = False |
| 34 | + |
| 35 | + # Validation for virtual table configurations |
| 36 | + virtual_tables = [ |
| 37 | + config for config in embedding_configs if config.use_virtual_table |
| 38 | + ] |
| 39 | + if virtual_tables: |
| 40 | + virtual_tables_with_eviction = [ |
| 41 | + config |
| 42 | + for config in virtual_tables |
| 43 | + if config.virtual_table_eviction_policy is not None |
| 44 | + ] |
| 45 | + if virtual_tables_with_eviction: |
| 46 | + # Check if any virtual table uses FeatureScoreBasedEvictionPolicy |
| 47 | + tables_with_feature_score_policy = [ |
| 48 | + config |
| 49 | + for config in virtual_tables_with_eviction |
| 50 | + if isinstance( |
| 51 | + config.virtual_table_eviction_policy, |
| 52 | + FeatureScoreBasedEvictionPolicy, |
| 53 | + ) |
| 54 | + ] |
| 55 | + |
| 56 | + # If any virtual table uses FeatureScoreBasedEvictionPolicy, |
| 57 | + # then ALL virtual tables with eviction policies must use FeatureScoreBasedEvictionPolicy |
| 58 | + if tables_with_feature_score_policy: |
| 59 | + assert all( |
| 60 | + isinstance( |
| 61 | + config.virtual_table_eviction_policy, |
| 62 | + FeatureScoreBasedEvictionPolicy, |
| 63 | + ) |
| 64 | + for config in virtual_tables_with_eviction |
| 65 | + ), "If any virtual table uses FeatureScoreBasedEvictionPolicy, all virtual tables with eviction policies must use FeatureScoreBasedEvictionPolicy" |
| 66 | + enable_feature_score_weight_accumulation = True |
| 67 | + |
| 68 | + # Check if any table has enable_auto_feature_score_collection=True |
| 69 | + tables_with_auto_collection = [ |
| 70 | + config |
| 71 | + for config in tables_with_feature_score_policy |
| 72 | + if config.virtual_table_eviction_policy is not None |
| 73 | + and isinstance( |
| 74 | + config.virtual_table_eviction_policy, |
| 75 | + FeatureScoreBasedEvictionPolicy, |
| 76 | + ) |
| 77 | + and config.virtual_table_eviction_policy.enable_auto_feature_score_collection |
| 78 | + ] |
| 79 | + if tables_with_auto_collection: |
| 80 | + # All virtual tables with FeatureScoreBasedEvictionPolicy must have enable_auto_feature_score_collection=True |
| 81 | + assert all( |
| 82 | + config.virtual_table_eviction_policy is not None |
| 83 | + and isinstance( |
| 84 | + config.virtual_table_eviction_policy, |
| 85 | + FeatureScoreBasedEvictionPolicy, |
| 86 | + ) |
| 87 | + and config.virtual_table_eviction_policy.enable_auto_feature_score_collection |
| 88 | + for config in tables_with_feature_score_policy |
| 89 | + ), "If any virtual table has enable_auto_feature_score_collection=True, all virtual tables with FeatureScoreBasedEvictionPolicy must have enable_auto_feature_score_collection=True" |
| 90 | + enabled_feature_score_auto_collection = True |
| 91 | + |
| 92 | + sharding_type_feature_score_mapping: Dict[str, Dict[str, float]] = {} |
| 93 | + if enabled_feature_score_auto_collection: |
| 94 | + for ( |
| 95 | + sharding_type, |
| 96 | + sharding_info, |
| 97 | + ) in sharding_type_to_sharding_infos.items(): |
| 98 | + feature_score_mapping: Dict[str, float] = {} |
| 99 | + if sharding_type == ShardingType.DATA_PARALLEL.value: |
| 100 | + sharding_type_feature_score_mapping[sharding_type] = ( |
| 101 | + feature_score_mapping |
| 102 | + ) |
| 103 | + continue |
| 104 | + for config in sharding_info: |
| 105 | + vtep = config.embedding_config.virtual_table_eviction_policy |
| 106 | + if vtep is not None and isinstance( |
| 107 | + vtep, FeatureScoreBasedEvictionPolicy |
| 108 | + ): |
| 109 | + if vtep.eviction_ttl_mins > 0: |
| 110 | + logger.info( |
| 111 | + f"Virtual table eviction policy enabled for table {config.embedding_config.name} {sharding_type} with eviction TTL {vtep.eviction_ttl_mins} mins." |
| 112 | + ) |
| 113 | + feature_score_mapping.update( |
| 114 | + dict.fromkeys(config.embedding_config.feature_names, 0.0) |
| 115 | + ) |
| 116 | + continue |
| 117 | + for k in config.embedding_config.feature_names: |
| 118 | + if ( |
| 119 | + k |
| 120 | + # pyre-ignore [16] |
| 121 | + in config.embedding_config.virtual_table_eviction_policy.feature_score_mapping |
| 122 | + ): |
| 123 | + feature_score_mapping[k] = ( |
| 124 | + config.embedding_config.virtual_table_eviction_policy.feature_score_mapping[ |
| 125 | + k |
| 126 | + ] |
| 127 | + ) |
| 128 | + else: |
| 129 | + assert ( |
| 130 | + # pyre-ignore [16] |
| 131 | + config.embedding_config.virtual_table_eviction_policy.feature_score_default_value |
| 132 | + is not None |
| 133 | + ), f"Table {config.embedding_config.name} eviction policy feature_score_default_value is not set but feature {k} is not in feature_score_mapping." |
| 134 | + feature_score_mapping[k] = ( |
| 135 | + config.embedding_config.virtual_table_eviction_policy.feature_score_default_value |
| 136 | + ) |
| 137 | + sharding_type_feature_score_mapping[sharding_type] = feature_score_mapping |
| 138 | + return ( |
| 139 | + enable_feature_score_weight_accumulation, |
| 140 | + enabled_feature_score_auto_collection, |
| 141 | + sharding_type_feature_score_mapping, |
| 142 | + ) |
| 143 | + |
| 144 | + |
| 145 | +@torch.fx.wrap |
| 146 | +def may_collect_feature_scores( |
| 147 | + input_feature_splits: List[KeyedJaggedTensor], |
| 148 | + enabled_feature_score_auto_collection: bool, |
| 149 | + sharding_type_feature_score_mapping: Dict[str, Dict[str, float]], |
| 150 | +) -> List[KeyedJaggedTensor]: |
| 151 | + if not enabled_feature_score_auto_collection: |
| 152 | + return input_feature_splits |
| 153 | + with record_function("## collect_feature_score ##"): |
| 154 | + for features, mapping in zip( |
| 155 | + input_feature_splits, sharding_type_feature_score_mapping.values() |
| 156 | + ): |
| 157 | + assert ( |
| 158 | + features.weights_or_none() is None |
| 159 | + ), f"Auto feature collection: {features.keys()=} has non empty weights" |
| 160 | + if ( |
| 161 | + mapping is None or len(mapping) == 0 |
| 162 | + ): # collection is disabled fir this sharding type |
| 163 | + continue |
| 164 | + feature_score_weights = [] |
| 165 | + device = features.device() |
| 166 | + for f in features.keys(): |
| 167 | + # input dist includes multiple lookups input including both virtual table and non-virtual table features. |
| 168 | + # We needs to attach weights for all features due to KJT weights requirements, so set 0.0 score for non virtual table features |
| 169 | + score = mapping[f] if f in mapping else 0.0 |
| 170 | + feature_score_weights.append( |
| 171 | + torch.ones_like( |
| 172 | + features[f].values(), |
| 173 | + dtype=torch.float32, |
| 174 | + device=device, |
| 175 | + ) |
| 176 | + * score |
| 177 | + ) |
| 178 | + features._weights = ( |
| 179 | + torch.cat(feature_score_weights, dim=0) |
| 180 | + if feature_score_weights |
| 181 | + else None |
| 182 | + ) |
| 183 | + return input_feature_splits |
0 commit comments