Skip to content

Commit 0480a9b

Browse files
aliafzalfacebook-github-bot
authored andcommitted
Delta Store
Summary: # Summary Introducing DeltaStore class which efficiently manages embedding table updates with the following features: * Tracks embedding table updates by table FQN with batch indexing * Supports multiple embedding update modes (NONE, FIRST, LAST) * Provides compaction functionality for calculating unique * Allows retrieval of unique/delta IDs per table with optional embedding values ## How lookups are preserved and fetched? In DeltaStore, lookups are preserved in the `per_fqn_lookups` dictionary, which maps table FQNs to lists of `IndexedLookup` objects. Each `IndexedLookup` contains: 1. `idx`: The batch index 2. `ids`: Tensor of embedding IDs 3. `embeddings`: Optional tensor of embedding values Lookups are added via the `append` method and can be: * Deleted with the `delete` method (up to a specific index) * Compacted with the `compact` method (merges lookups within a range) * Retrieved as unique/delta rows with the `get_delta` method ## This diffs: 1. delta_store.py includes all main logic to preserve, fetch, compact and delete 2. types.py includes required datatypes and enums 3. test_delta_store.py Includes test cases for compute, delete and compact methods Reviewed By: TroyGarden Differential Revision: D71130002
1 parent 558f476 commit 0480a9b

File tree

3 files changed

+1095
-0
lines changed

3 files changed

+1095
-0
lines changed
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
9+
# pyre-strict
10+
from bisect import bisect_left
11+
from typing import Dict, List, Optional
12+
13+
import torch
14+
from torchrec.distributed.model_tracker.types import (
15+
DeltaRows,
16+
EmbdUpdateMode,
17+
IndexedLookup,
18+
)
19+
from torchrec.distributed.utils import none_throws
20+
21+
22+
def _compute_unique_rows(
23+
ids: List[torch.Tensor],
24+
embeddings: Optional[List[torch.Tensor]],
25+
mode: EmbdUpdateMode,
26+
) -> DeltaRows:
27+
r"""
28+
To calculate unique ids and embeddings
29+
"""
30+
if mode == EmbdUpdateMode.NONE:
31+
assert (
32+
embeddings is None
33+
), f"{mode=} == EmbdUpdateMode.NONE but received embeddings"
34+
unique_ids = torch.cat(ids).unique(return_inverse=False)
35+
return DeltaRows(ids=unique_ids, embeddings=None)
36+
else:
37+
assert (
38+
embeddings is not None
39+
), f"{mode=} != EmbdUpdateMode.NONE but received no embeddings"
40+
41+
cat_ids = torch.cat(ids)
42+
cat_embeddings = torch.cat(embeddings)
43+
44+
if mode == EmbdUpdateMode.LAST:
45+
cat_ids = cat_ids.flip(dims=[0])
46+
cat_embeddings = cat_embeddings.flip(dims=[0])
47+
48+
# Get unique ids and inverse mapping (each element's index in unique_ids).
49+
unique_ids, inverse = cat_ids.unique(sorted=False, return_inverse=True)
50+
51+
# Create a tensor of original indices. This will be used to find first occurrences of ids.
52+
all_indices = torch.arange(cat_ids.size(0), device=cat_ids.device)
53+
54+
# Initialize tensor for first occurrence indices (filled with a high value).
55+
first_occurrence = torch.full(
56+
(unique_ids.size(0),),
57+
cat_ids.size(0),
58+
dtype=torch.int64,
59+
device=cat_ids.device,
60+
)
61+
62+
# Scatter indices using inverse mapping and reduce with "amin" to get first or last (if reversed) occurrence per unique id.
63+
first_occurrence = first_occurrence.scatter_reduce(
64+
0, inverse, all_indices, reduce="amin"
65+
)
66+
67+
# Use first occurrence indices to select corresponding embedding row.
68+
unique_embedings = cat_embeddings[first_occurrence]
69+
return DeltaRows(ids=unique_ids, embeddings=unique_embedings)
70+
71+
72+
class DeltaStore:
73+
"""
74+
DeltaStore is a helper class that stores and manages local delta (row) updates for embeddings/states across
75+
various batches during training, designed to be used with TorchRecs ModelDeltaTracker.
76+
It maintains a CUDA in-memory representation of requested ids and embeddings/states,
77+
providing a way to compact and get delta updates for each embedding table.
78+
79+
The class supports different embedding update modes (NONE, FIRST, LAST) to determine
80+
how to handle duplicate ids when compacting or retrieving embeddings.
81+
82+
"""
83+
84+
def __init__(self, embdUpdateMode: EmbdUpdateMode = EmbdUpdateMode.NONE) -> None:
85+
self.embdUpdateMode = embdUpdateMode
86+
self.per_fqn_lookups: Dict[str, List[IndexedLookup]] = {}
87+
88+
def append(
89+
self,
90+
batch_idx: int,
91+
table_fqn: str,
92+
ids: torch.Tensor,
93+
embeddings: Optional[torch.Tensor],
94+
) -> None:
95+
table_fqn_lookup = self.per_fqn_lookups.get(table_fqn, [])
96+
table_fqn_lookup.append(
97+
IndexedLookup(batch_idx=batch_idx, ids=ids, embeddings=embeddings)
98+
)
99+
self.per_fqn_lookups[table_fqn] = table_fqn_lookup
100+
101+
def delete(self, up_to_idx: Optional[int] = None) -> None:
102+
"""
103+
Delete all idx from the store up to `up_to_idx`
104+
"""
105+
if up_to_idx is None:
106+
# If up_to_idx is None, delete all lookups
107+
self.per_fqn_lookups = {}
108+
else:
109+
# lookups are sorted by idx.
110+
up_to_idx = none_throws(up_to_idx)
111+
for table_fqn, lookups in self.per_fqn_lookups.items():
112+
# remove all lookups up to up_to_idx
113+
self.per_fqn_lookups[table_fqn] = [
114+
lookup for lookup in lookups if lookup.batch_idx >= up_to_idx
115+
]
116+
117+
def compact(self, start_idx: int, end_idx: int) -> None:
118+
r"""
119+
Compact (ids, embeddings) in batch index range from start_idx, curr_batch_idx.
120+
"""
121+
assert (
122+
start_idx < end_idx
123+
), f"start_idx {start_idx} must be smaller then end_idx, but got {end_idx}"
124+
125+
new_per_fqn_lookups: Dict[str, List[IndexedLookup]] = {}
126+
for table_fqn, lookups in self.per_fqn_lookups.items():
127+
indexices = [h.batch_idx for h in lookups]
128+
index_l = bisect_left(indexices, start_idx)
129+
index_r = bisect_left(indexices, end_idx)
130+
lookups_to_compact = lookups[index_l:index_r]
131+
if len(lookups_to_compact) <= 1:
132+
new_per_fqn_lookups[table_fqn] = lookups
133+
continue
134+
ids = [lookup.ids for lookup in lookups_to_compact]
135+
embeddings = (
136+
[none_throws(lookup.embeddings) for lookup in lookups_to_compact]
137+
if self.embdUpdateMode != EmbdUpdateMode.NONE
138+
else None
139+
)
140+
delta_rows = _compute_unique_rows(
141+
ids=ids, embeddings=embeddings, mode=self.embdUpdateMode
142+
)
143+
new_per_fqn_lookups[table_fqn] = (
144+
lookups[:index_l]
145+
+ [
146+
IndexedLookup(
147+
batch_idx=start_idx,
148+
ids=delta_rows.ids,
149+
embeddings=delta_rows.embeddings,
150+
)
151+
]
152+
+ lookups[index_r:]
153+
)
154+
self.per_fqn_lookups = new_per_fqn_lookups
155+
156+
def get_delta(self, from_idx: int = 0) -> Dict[str, DeltaRows]:
157+
r"""
158+
Return all unique/delta ids per table from the Delta Store.
159+
"""
160+
161+
delta_per_table_fqn: Dict[str, DeltaRows] = {}
162+
for table_fqn, lookups in self.per_fqn_lookups.items():
163+
compact_ids = [
164+
lookup.ids for lookup in lookups if lookup.batch_idx >= from_idx
165+
]
166+
compact_embeddings = (
167+
[
168+
none_throws(lookup.embeddings)
169+
for lookup in lookups
170+
if lookup.batch_idx >= from_idx
171+
]
172+
if self.embdUpdateMode != EmbdUpdateMode.NONE
173+
else None
174+
)
175+
176+
delta_per_table_fqn[table_fqn] = _compute_unique_rows(
177+
ids=compact_ids, embeddings=compact_embeddings, mode=self.embdUpdateMode
178+
)
179+
return delta_per_table_fqn

0 commit comments

Comments
 (0)