|  | 
|  | 1 | +#!/usr/bin/env python3 | 
|  | 2 | +# Copyright (c) Meta Platforms, Inc. and affiliates. | 
|  | 3 | +# All rights reserved. | 
|  | 4 | +# | 
|  | 5 | +# This source code is licensed under the BSD-style license found in the | 
|  | 6 | +# LICENSE file in the root directory of this source tree. | 
|  | 7 | + | 
|  | 8 | +# pyre-strict | 
|  | 9 | +import logging | 
|  | 10 | +from typing import Dict, List, Sequence, Tuple | 
|  | 11 | + | 
|  | 12 | +import torch | 
|  | 13 | + | 
|  | 14 | +from torch.autograd.profiler import record_function | 
|  | 15 | + | 
|  | 16 | +from torchrec.distributed.embedding_sharding import EmbeddingShardingInfo | 
|  | 17 | +from torchrec.distributed.embedding_types import ShardingType | 
|  | 18 | + | 
|  | 19 | +from torchrec.modules.embedding_configs import ( | 
|  | 20 | +    EmbeddingConfig, | 
|  | 21 | +    FeatureScoreBasedEvictionPolicy, | 
|  | 22 | +) | 
|  | 23 | +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor | 
|  | 24 | + | 
|  | 25 | +logger: logging.Logger = logging.getLogger(__name__) | 
|  | 26 | + | 
|  | 27 | + | 
|  | 28 | +def create_sharding_type_to_feature_score_mapping( | 
|  | 29 | +    embedding_configs: Sequence[EmbeddingConfig], | 
|  | 30 | +    sharding_type_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]], | 
|  | 31 | +) -> Tuple[bool, bool, Dict[str, Dict[str, float]]]: | 
|  | 32 | +    enable_feature_score_weight_accumulation = False | 
|  | 33 | +    enabled_feature_score_auto_collection = False | 
|  | 34 | + | 
|  | 35 | +    # Validation for virtual table configurations | 
|  | 36 | +    virtual_tables = [ | 
|  | 37 | +        config for config in embedding_configs if config.use_virtual_table | 
|  | 38 | +    ] | 
|  | 39 | +    if virtual_tables: | 
|  | 40 | +        virtual_tables_with_eviction = [ | 
|  | 41 | +            config | 
|  | 42 | +            for config in virtual_tables | 
|  | 43 | +            if config.virtual_table_eviction_policy is not None | 
|  | 44 | +        ] | 
|  | 45 | +        if virtual_tables_with_eviction: | 
|  | 46 | +            # Check if any virtual table uses FeatureScoreBasedEvictionPolicy | 
|  | 47 | +            tables_with_feature_score_policy = [ | 
|  | 48 | +                config | 
|  | 49 | +                for config in virtual_tables_with_eviction | 
|  | 50 | +                if isinstance( | 
|  | 51 | +                    config.virtual_table_eviction_policy, | 
|  | 52 | +                    FeatureScoreBasedEvictionPolicy, | 
|  | 53 | +                ) | 
|  | 54 | +            ] | 
|  | 55 | + | 
|  | 56 | +            # If any virtual table uses FeatureScoreBasedEvictionPolicy, | 
|  | 57 | +            # then ALL virtual tables with eviction policies must use FeatureScoreBasedEvictionPolicy | 
|  | 58 | +            if tables_with_feature_score_policy: | 
|  | 59 | +                assert all( | 
|  | 60 | +                    isinstance( | 
|  | 61 | +                        config.virtual_table_eviction_policy, | 
|  | 62 | +                        FeatureScoreBasedEvictionPolicy, | 
|  | 63 | +                    ) | 
|  | 64 | +                    for config in virtual_tables_with_eviction | 
|  | 65 | +                ), "If any virtual table uses FeatureScoreBasedEvictionPolicy, all virtual tables with eviction policies must use FeatureScoreBasedEvictionPolicy" | 
|  | 66 | +                enable_feature_score_weight_accumulation = True | 
|  | 67 | + | 
|  | 68 | +                # Check if any table has enable_auto_feature_score_collection=True | 
|  | 69 | +                tables_with_auto_collection = [ | 
|  | 70 | +                    config | 
|  | 71 | +                    for config in tables_with_feature_score_policy | 
|  | 72 | +                    if config.virtual_table_eviction_policy is not None | 
|  | 73 | +                    and isinstance( | 
|  | 74 | +                        config.virtual_table_eviction_policy, | 
|  | 75 | +                        FeatureScoreBasedEvictionPolicy, | 
|  | 76 | +                    ) | 
|  | 77 | +                    and config.virtual_table_eviction_policy.enable_auto_feature_score_collection | 
|  | 78 | +                ] | 
|  | 79 | +                if tables_with_auto_collection: | 
|  | 80 | +                    # All virtual tables with FeatureScoreBasedEvictionPolicy must have enable_auto_feature_score_collection=True | 
|  | 81 | +                    assert all( | 
|  | 82 | +                        config.virtual_table_eviction_policy is not None | 
|  | 83 | +                        and isinstance( | 
|  | 84 | +                            config.virtual_table_eviction_policy, | 
|  | 85 | +                            FeatureScoreBasedEvictionPolicy, | 
|  | 86 | +                        ) | 
|  | 87 | +                        and config.virtual_table_eviction_policy.enable_auto_feature_score_collection | 
|  | 88 | +                        for config in tables_with_feature_score_policy | 
|  | 89 | +                    ), "If any virtual table has enable_auto_feature_score_collection=True, all virtual tables with FeatureScoreBasedEvictionPolicy must have enable_auto_feature_score_collection=True" | 
|  | 90 | +                    enabled_feature_score_auto_collection = True | 
|  | 91 | + | 
|  | 92 | +    sharding_type_feature_score_mapping: Dict[str, Dict[str, float]] = {} | 
|  | 93 | +    if enabled_feature_score_auto_collection: | 
|  | 94 | +        for ( | 
|  | 95 | +            sharding_type, | 
|  | 96 | +            sharding_info, | 
|  | 97 | +        ) in sharding_type_to_sharding_infos.items(): | 
|  | 98 | +            feature_score_mapping: Dict[str, float] = {} | 
|  | 99 | +            if sharding_type == ShardingType.DATA_PARALLEL.value: | 
|  | 100 | +                sharding_type_feature_score_mapping[sharding_type] = ( | 
|  | 101 | +                    feature_score_mapping | 
|  | 102 | +                ) | 
|  | 103 | +                continue | 
|  | 104 | +            for config in sharding_info: | 
|  | 105 | +                vtep = config.embedding_config.virtual_table_eviction_policy | 
|  | 106 | +                if vtep is not None and isinstance( | 
|  | 107 | +                    vtep, FeatureScoreBasedEvictionPolicy | 
|  | 108 | +                ): | 
|  | 109 | +                    if vtep.eviction_ttl_mins > 0: | 
|  | 110 | +                        logger.info( | 
|  | 111 | +                            f"Virtual table eviction policy enabled for table {config.embedding_config.name} {sharding_type} with eviction TTL {vtep.eviction_ttl_mins} mins." | 
|  | 112 | +                        ) | 
|  | 113 | +                        feature_score_mapping.update( | 
|  | 114 | +                            dict.fromkeys(config.embedding_config.feature_names, 0.0) | 
|  | 115 | +                        ) | 
|  | 116 | +                        continue | 
|  | 117 | +                    for k in config.embedding_config.feature_names: | 
|  | 118 | +                        if ( | 
|  | 119 | +                            k | 
|  | 120 | +                            # pyre-ignore [16] | 
|  | 121 | +                            in config.embedding_config.virtual_table_eviction_policy.feature_score_mapping | 
|  | 122 | +                        ): | 
|  | 123 | +                            feature_score_mapping[k] = ( | 
|  | 124 | +                                config.embedding_config.virtual_table_eviction_policy.feature_score_mapping[ | 
|  | 125 | +                                    k | 
|  | 126 | +                                ] | 
|  | 127 | +                            ) | 
|  | 128 | +                        else: | 
|  | 129 | +                            assert ( | 
|  | 130 | +                                # pyre-ignore [16] | 
|  | 131 | +                                config.embedding_config.virtual_table_eviction_policy.feature_score_default_value | 
|  | 132 | +                                is not None | 
|  | 133 | +                            ), f"Table {config.embedding_config.name} eviction policy feature_score_default_value is not set but feature {k} is not in feature_score_mapping." | 
|  | 134 | +                            feature_score_mapping[k] = ( | 
|  | 135 | +                                config.embedding_config.virtual_table_eviction_policy.feature_score_default_value | 
|  | 136 | +                            ) | 
|  | 137 | +            sharding_type_feature_score_mapping[sharding_type] = feature_score_mapping | 
|  | 138 | +    return ( | 
|  | 139 | +        enable_feature_score_weight_accumulation, | 
|  | 140 | +        enabled_feature_score_auto_collection, | 
|  | 141 | +        sharding_type_feature_score_mapping, | 
|  | 142 | +    ) | 
|  | 143 | + | 
|  | 144 | + | 
|  | 145 | +@torch.fx.wrap | 
|  | 146 | +def may_collect_feature_scores( | 
|  | 147 | +    input_feature_splits: List[KeyedJaggedTensor], | 
|  | 148 | +    enabled_feature_score_auto_collection: bool, | 
|  | 149 | +    sharding_type_feature_score_mapping: Dict[str, Dict[str, float]], | 
|  | 150 | +) -> List[KeyedJaggedTensor]: | 
|  | 151 | +    if not enabled_feature_score_auto_collection: | 
|  | 152 | +        return input_feature_splits | 
|  | 153 | +    with record_function("## collect_feature_score ##"): | 
|  | 154 | +        for features, mapping in zip( | 
|  | 155 | +            input_feature_splits, sharding_type_feature_score_mapping.values() | 
|  | 156 | +        ): | 
|  | 157 | +            assert ( | 
|  | 158 | +                features.weights_or_none() is None | 
|  | 159 | +            ), f"Auto feature collection: {features.keys()=} has non empty weights" | 
|  | 160 | +            if ( | 
|  | 161 | +                mapping is None or len(mapping) == 0 | 
|  | 162 | +            ):  # collection is disabled fir this sharding type | 
|  | 163 | +                continue | 
|  | 164 | +            feature_score_weights = [] | 
|  | 165 | +            device = features.device() | 
|  | 166 | +            for f in features.keys(): | 
|  | 167 | +                # input dist includes multiple lookups input including both virtual table and non-virtual table features. | 
|  | 168 | +                # We needs to attach weights for all features due to KJT weights requirements, so set 0.0 score for non virtual table features | 
|  | 169 | +                score = mapping[f] if f in mapping else 0.0 | 
|  | 170 | +                feature_score_weights.append( | 
|  | 171 | +                    torch.ones_like( | 
|  | 172 | +                        features[f].values(), | 
|  | 173 | +                        dtype=torch.float32, | 
|  | 174 | +                        device=device, | 
|  | 175 | +                    ) | 
|  | 176 | +                    * score | 
|  | 177 | +                ) | 
|  | 178 | +            features._weights = ( | 
|  | 179 | +                torch.cat(feature_score_weights, dim=0) | 
|  | 180 | +                if feature_score_weights | 
|  | 181 | +                else None | 
|  | 182 | +            ) | 
|  | 183 | +    return input_feature_splits | 
0 commit comments