Skip to content

Commit b0dffd3

Browse files
emlinmeta-codesync[bot]
authored andcommitted
add auto feature score collection to EC (#5030)
Summary: Pull Request resolved: #5030 X-link: meta-pytorch/torchrec#3474 X-link: https://github.com/facebookresearch/FBGEMM/pull/2043 Enable feature score auto collection in ShardedEmbeddingCollection based on static feature to score mapping. If user needs custom score for specific id, they can disable auto collection and then change model code explicitly to collect score for each id. Here is the sample eviction policy config in embedding_table config to enable auto score collection: virtual_table_eviction_policy=FeatureScoreBasedEvictionPolicy( training_id_eviction_trigger_count=260_000_000, # 260M training_id_keep_count=160_000_000, # 160M enable_auto_feature_score_collection=True, feature_score_mapping={ "sparse_public_original_content_creator": 1.0, }, feature_score_default_value=0.5, ), Additionally the counter collected previously during EC dedup is not used by kvzch backend, so this diff removed that counter and allow KJT to transfer a single float32 weight tensor to backend. This allows feature score collection for EBC since there could have another float weight for EBC pooling already. Reviewed By: RachelZheng, EddyLXJ Differential Revision: D83945722 fbshipit-source-id: 2dc71f6601de055b982f62ca3d73cdbe5fba2dce
1 parent 9d38b8d commit b0dffd3

File tree

2 files changed

+2
-4
lines changed

2 files changed

+2
-4
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2089,7 +2089,7 @@ def _prefetch( # noqa C901
20892089
torch.tensor(
20902090
[weights.shape[0]], device="cpu", dtype=torch.long
20912091
),
2092-
weights.cpu().view(torch.float32).view(-1, 2),
2092+
weights.cpu(),
20932093
)
20942094

20952095
# Generate row addresses (pointing to either L1 or the current

fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -770,7 +770,6 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
770770
CHECK_EQ(indices.size(0), engege_rates.size(0));
771771
auto indices_data_ptr = indices.data_ptr<index_t>();
772772
auto engage_rate_ptr = engege_rates.data_ptr<float>();
773-
int64_t stride = 2;
774773
{
775774
auto before_write_lock_ts =
776775
facebook::WallClockUtil::NowInUsecFast();
@@ -785,8 +784,7 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
785784
index_iter++) {
786785
const auto& id_index = *index_iter;
787786
auto id = int64_t(indices_data_ptr[id_index]);
788-
float engege_rate =
789-
float(engage_rate_ptr[id_index * stride + 0]);
787+
float engege_rate = float(engage_rate_ptr[id_index]);
790788
// use mempool
791789
weight_type* block = nullptr;
792790
auto before_lookup_cache_ts =

0 commit comments

Comments
 (0)