Skip to content

Commit 2c4025d

Browse files
emlinfacebook-github-bot
authored andcommitted
add auto feature score collection to EC (#3474)
Summary: X-link: pytorch/FBGEMM#5030 X-link: facebookresearch/FBGEMM#2043 Enable feature score auto collection in ShardedEmbeddingCollection based on static feature to score mapping. If user needs custom score for specific id, they can disable auto collection and then change model code explicitly to collect score for each id. Here is the sample eviction policy config in embedding_table config to enable auto score collection: virtual_table_eviction_policy=FeatureScoreBasedEvictionPolicy( training_id_eviction_trigger_count=260_000_000, # 260M training_id_keep_count=160_000_000, # 160M enable_auto_feature_score_collection=True, feature_score_mapping={ "sparse_public_original_content_creator": 1.0, }, feature_score_default_value=0.5, ), Additionally the counter collected previously during EC dedup is not used by kvzch backend, so this diff removed that counter and allow KJT to transfer a single float32 weight tensor to backend. This allows feature score collection for EBC since there could have another float weight for EBC pooling already. Reviewed By: EddyLXJ Differential Revision: D83945722
1 parent acdecd7 commit 2c4025d

File tree

5 files changed

+852
-20
lines changed

5 files changed

+852
-20
lines changed

torchrec/distributed/embedding.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@
4646
ShardedEmbeddingModule,
4747
ShardingType,
4848
)
49+
50+
from torchrec.distributed.feature_score_utils import (
51+
create_sharding_type_to_feature_score_mapping,
52+
may_collect_feature_scores,
53+
)
4954
from torchrec.distributed.fused_params import (
5055
FUSED_PARAM_IS_SSD_TABLE,
5156
FUSED_PARAM_SSD_TABLE_LIST,
@@ -90,7 +95,6 @@
9095
from torchrec.modules.embedding_configs import (
9196
EmbeddingConfig,
9297
EmbeddingTableConfig,
93-
FeatureScoreBasedEvictionPolicy,
9498
PoolingType,
9599
)
96100
from torchrec.modules.embedding_modules import (
@@ -463,12 +467,12 @@ def __init__(
463467
] = {
464468
sharding_type: self.create_embedding_sharding(
465469
sharding_type=sharding_type,
466-
sharding_infos=embedding_confings,
470+
sharding_infos=embedding_configs,
467471
env=env,
468472
device=device,
469473
qcomm_codecs_registry=self.qcomm_codecs_registry,
470474
)
471-
for sharding_type, embedding_confings in sharding_type_to_sharding_infos.items()
475+
for sharding_type, embedding_configs in sharding_type_to_sharding_infos.items()
472476
}
473477

474478
self.enable_embedding_update: bool = any(
@@ -490,16 +494,20 @@ def __init__(
490494
self._has_uninitialized_input_dist: bool = True
491495
logger.info(f"EC index dedup enabled: {self._use_index_dedup}.")
492496

493-
for config in self._embedding_configs:
494-
virtual_table_eviction_policy = config.virtual_table_eviction_policy
495-
if virtual_table_eviction_policy is not None and isinstance(
496-
virtual_table_eviction_policy, FeatureScoreBasedEvictionPolicy
497-
):
498-
self._enable_feature_score_weight_accumulation = True
499-
break
500-
497+
self._enable_feature_score_weight_accumulation: bool = False
498+
self._enabled_feature_score_auto_collection: bool = False
499+
self._sharding_type_feature_score_mapping: Dict[str, Dict[str, float]] = {}
500+
(
501+
self._enable_feature_score_weight_accumulation,
502+
self._enabled_feature_score_auto_collection,
503+
self._sharding_type_feature_score_mapping,
504+
) = create_sharding_type_to_feature_score_mapping(
505+
self._embedding_configs, sharding_type_to_sharding_infos
506+
)
501507
logger.info(
502-
f"EC feature score weight accumulation enabled: {self._enable_feature_score_weight_accumulation}."
508+
f"EC feature score weight accumulation enabled: {self._enable_feature_score_weight_accumulation}, "
509+
f"auto collection enabled: {self._enabled_feature_score_auto_collection}, "
510+
f"sharding type to feature score mapping: {self._sharding_type_feature_score_mapping}"
503511
)
504512

505513
# Get all fused optimizers and combine them.
@@ -1361,22 +1369,22 @@ def _dedup_indices(
13611369
source_weights.dtype == torch.float32
13621370
), "Only float32 weights are supported for feature score eviction weights."
13631371

1364-
acc_weights = torch.ops.fbgemm.jagged_acc_weights_and_counts(
1365-
source_weights.view(-1),
1366-
reverse_indices,
1372+
# Accumulate weights using scatter_add
1373+
acc_weights = torch.zeros(
13671374
unique_indices.numel(),
1375+
dtype=torch.float32,
1376+
device=source_weights.device,
13681377
)
13691378

1379+
# Use PyTorch's scatter_add to accumulate weights
1380+
acc_weights.scatter_add_(0, reverse_indices, source_weights)
1381+
13701382
dedup_features = KeyedJaggedTensor(
13711383
keys=input_feature.keys(),
13721384
lengths=lengths,
13731385
offsets=offsets,
13741386
values=unique_indices,
1375-
weights=(
1376-
acc_weights.view(torch.float64).view(-1)
1377-
if acc_weights is not None
1378-
else None
1379-
),
1387+
weights=(acc_weights.view(-1) if acc_weights is not None else None),
13801388
)
13811389

13821390
ctx.input_features.append(input_feature)
@@ -1495,6 +1503,11 @@ def input_dist(
14951503
self._features_order_tensor,
14961504
)
14971505
features_by_shards = features.split(self._feature_splits)
1506+
features_by_shards = may_collect_feature_scores(
1507+
features_by_shards,
1508+
self._enabled_feature_score_auto_collection,
1509+
self._sharding_type_feature_score_mapping,
1510+
)
14981511
if self._use_index_dedup:
14991512
features_by_shards = self._dedup_indices(ctx, features_by_shards)
15001513

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
import logging
10+
from typing import Dict, List, Sequence, Tuple
11+
12+
import torch
13+
14+
from torch.autograd.profiler import record_function
15+
16+
from torchrec.distributed.embedding_sharding import EmbeddingShardingInfo
17+
from torchrec.distributed.embedding_types import ShardingType
18+
19+
from torchrec.modules.embedding_configs import (
20+
EmbeddingConfig,
21+
FeatureScoreBasedEvictionPolicy,
22+
)
23+
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
24+
25+
logger: logging.Logger = logging.getLogger(__name__)
26+
27+
28+
def create_sharding_type_to_feature_score_mapping(
29+
embedding_configs: Sequence[EmbeddingConfig],
30+
sharding_type_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]],
31+
) -> Tuple[bool, bool, Dict[str, Dict[str, float]]]:
32+
enable_feature_score_weight_accumulation = False
33+
enabled_feature_score_auto_collection = False
34+
35+
# Validation for virtual table configurations
36+
virtual_tables = [
37+
config for config in embedding_configs if config.use_virtual_table
38+
]
39+
if virtual_tables:
40+
virtual_tables_with_eviction = [
41+
config
42+
for config in virtual_tables
43+
if config.virtual_table_eviction_policy is not None
44+
]
45+
if virtual_tables_with_eviction:
46+
# Check if any virtual table uses FeatureScoreBasedEvictionPolicy
47+
tables_with_feature_score_policy = [
48+
config
49+
for config in virtual_tables_with_eviction
50+
if isinstance(
51+
config.virtual_table_eviction_policy,
52+
FeatureScoreBasedEvictionPolicy,
53+
)
54+
]
55+
56+
# If any virtual table uses FeatureScoreBasedEvictionPolicy,
57+
# then ALL virtual tables with eviction policies must use FeatureScoreBasedEvictionPolicy
58+
if tables_with_feature_score_policy:
59+
assert all(
60+
isinstance(
61+
config.virtual_table_eviction_policy,
62+
FeatureScoreBasedEvictionPolicy,
63+
)
64+
for config in virtual_tables_with_eviction
65+
), "If any virtual table uses FeatureScoreBasedEvictionPolicy, all virtual tables with eviction policies must use FeatureScoreBasedEvictionPolicy"
66+
enable_feature_score_weight_accumulation = True
67+
68+
# Check if any table has enable_auto_feature_score_collection=True
69+
tables_with_auto_collection = [
70+
config
71+
for config in tables_with_feature_score_policy
72+
if config.virtual_table_eviction_policy is not None
73+
and isinstance(
74+
config.virtual_table_eviction_policy,
75+
FeatureScoreBasedEvictionPolicy,
76+
)
77+
and config.virtual_table_eviction_policy.enable_auto_feature_score_collection
78+
]
79+
if tables_with_auto_collection:
80+
# All virtual tables with FeatureScoreBasedEvictionPolicy must have enable_auto_feature_score_collection=True
81+
assert all(
82+
config.virtual_table_eviction_policy is not None
83+
and isinstance(
84+
config.virtual_table_eviction_policy,
85+
FeatureScoreBasedEvictionPolicy,
86+
)
87+
and config.virtual_table_eviction_policy.enable_auto_feature_score_collection
88+
for config in tables_with_feature_score_policy
89+
), "If any virtual table has enable_auto_feature_score_collection=True, all virtual tables with FeatureScoreBasedEvictionPolicy must have enable_auto_feature_score_collection=True"
90+
enabled_feature_score_auto_collection = True
91+
92+
sharding_type_feature_score_mapping: Dict[str, Dict[str, float]] = {}
93+
if enabled_feature_score_auto_collection:
94+
for (
95+
sharding_type,
96+
sharding_info,
97+
) in sharding_type_to_sharding_infos.items():
98+
feature_score_mapping: Dict[str, float] = {}
99+
if sharding_type == ShardingType.DATA_PARALLEL.value:
100+
sharding_type_feature_score_mapping[sharding_type] = (
101+
feature_score_mapping
102+
)
103+
continue
104+
for config in sharding_info:
105+
vtep = config.embedding_config.virtual_table_eviction_policy
106+
if vtep is not None and isinstance(
107+
vtep, FeatureScoreBasedEvictionPolicy
108+
):
109+
if vtep.eviction_ttl_mins > 0:
110+
logger.info(
111+
f"Virtual table eviction policy enabled for table {config.embedding_config.name} {sharding_type} with eviction TTL {vtep.eviction_ttl_mins} mins."
112+
)
113+
feature_score_mapping.update(
114+
dict.fromkeys(config.embedding_config.feature_names, 0.0)
115+
)
116+
continue
117+
for k in config.embedding_config.feature_names:
118+
if (
119+
k
120+
# pyre-ignore [16]
121+
in config.embedding_config.virtual_table_eviction_policy.feature_score_mapping
122+
):
123+
feature_score_mapping[k] = (
124+
config.embedding_config.virtual_table_eviction_policy.feature_score_mapping[
125+
k
126+
]
127+
)
128+
else:
129+
assert (
130+
# pyre-ignore [16]
131+
config.embedding_config.virtual_table_eviction_policy.feature_score_default_value
132+
is not None
133+
), f"Table {config.embedding_config.name} eviction policy feature_score_default_value is not set but feature {k} is not in feature_score_mapping."
134+
feature_score_mapping[k] = (
135+
config.embedding_config.virtual_table_eviction_policy.feature_score_default_value
136+
)
137+
sharding_type_feature_score_mapping[sharding_type] = feature_score_mapping
138+
return (
139+
enable_feature_score_weight_accumulation,
140+
enabled_feature_score_auto_collection,
141+
sharding_type_feature_score_mapping,
142+
)
143+
144+
145+
@torch.fx.wrap
146+
def may_collect_feature_scores(
147+
input_feature_splits: List[KeyedJaggedTensor],
148+
enabled_feature_score_auto_collection: bool,
149+
sharding_type_feature_score_mapping: Dict[str, Dict[str, float]],
150+
) -> List[KeyedJaggedTensor]:
151+
if not enabled_feature_score_auto_collection:
152+
return input_feature_splits
153+
with record_function("## collect_feature_score ##"):
154+
for features, mapping in zip(
155+
input_feature_splits, sharding_type_feature_score_mapping.values()
156+
):
157+
assert (
158+
features.weights_or_none() is None
159+
), f"Auto feature collection: {features.keys()=} has non empty weights"
160+
if (
161+
mapping is None or len(mapping) == 0
162+
): # collection is disabled fir this sharding type
163+
continue
164+
feature_score_weights = []
165+
device = features.device()
166+
for f in features.keys():
167+
# input dist includes multiple lookups input including both virtual table and non-virtual table features.
168+
# We needs to attach weights for all features due to KJT weights requirements, so set 0.0 score for non virtual table features
169+
score = mapping[f] if f in mapping else 0.0
170+
feature_score_weights.append(
171+
torch.ones_like(
172+
features[f].values(),
173+
dtype=torch.float32,
174+
device=device,
175+
)
176+
* score
177+
)
178+
features._weights = (
179+
torch.cat(feature_score_weights, dim=0)
180+
if feature_score_weights
181+
else None
182+
)
183+
return input_feature_splits

0 commit comments

Comments
 (0)