|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | +# All rights reserved. |
| 4 | +# |
| 5 | +# This source code is licensed under the BSD-style license found in the |
| 6 | +# LICENSE file in the root directory of this source tree. |
| 7 | + |
| 8 | + |
| 9 | +# pyre-strict |
| 10 | +from bisect import bisect_left |
| 11 | +from typing import Dict, List, Optional |
| 12 | + |
| 13 | +import torch |
| 14 | +from torchrec.distributed.model_tracker.types import ( |
| 15 | + DeltaRows, |
| 16 | + EmbdUpdateMode, |
| 17 | + IndexedLookup, |
| 18 | +) |
| 19 | +from torchrec.distributed.utils import none_throws |
| 20 | + |
| 21 | + |
| 22 | +def _compute_unique_rows( |
| 23 | + ids: List[torch.Tensor], |
| 24 | + embeddings: Optional[List[torch.Tensor]], |
| 25 | + mode: EmbdUpdateMode, |
| 26 | +) -> DeltaRows: |
| 27 | + r""" |
| 28 | + To calculate unique ids and embeddings |
| 29 | + """ |
| 30 | + if mode == EmbdUpdateMode.NONE: |
| 31 | + assert ( |
| 32 | + embeddings is None |
| 33 | + ), f"{mode=} == EmbdUpdateMode.NONE but received embeddings" |
| 34 | + unique_ids = torch.cat(ids).unique(return_inverse=False) |
| 35 | + return DeltaRows(ids=unique_ids, embeddings=None) |
| 36 | + else: |
| 37 | + assert ( |
| 38 | + embeddings is not None |
| 39 | + ), f"{mode=} != EmbdUpdateMode.NONE but received no embeddings" |
| 40 | + |
| 41 | + cat_ids = torch.cat(ids) |
| 42 | + cat_embeddings = torch.cat(embeddings) |
| 43 | + |
| 44 | + if mode == EmbdUpdateMode.LAST: |
| 45 | + cat_ids = cat_ids.flip(dims=[0]) |
| 46 | + cat_embeddings = cat_embeddings.flip(dims=[0]) |
| 47 | + |
| 48 | + # Get unique ids and inverse mapping (each element's index in unique_ids). |
| 49 | + unique_ids, inverse = cat_ids.unique(sorted=False, return_inverse=True) |
| 50 | + |
| 51 | + # Create a tensor of original indices. This will be used to find first occurrences of ids. |
| 52 | + all_indices = torch.arange(cat_ids.size(0), device=cat_ids.device) |
| 53 | + |
| 54 | + # Initialize tensor for first occurrence indices (filled with a high value). |
| 55 | + first_occurrence = torch.full( |
| 56 | + (unique_ids.size(0),), |
| 57 | + cat_ids.size(0), |
| 58 | + dtype=torch.int64, |
| 59 | + device=cat_ids.device, |
| 60 | + ) |
| 61 | + |
| 62 | + # Scatter indices using inverse mapping and reduce with "amin" to get first or last (if reversed) occurrence per unique id. |
| 63 | + first_occurrence = first_occurrence.scatter_reduce( |
| 64 | + 0, inverse, all_indices, reduce="amin" |
| 65 | + ) |
| 66 | + |
| 67 | + # Use first occurrence indices to select corresponding embedding row. |
| 68 | + unique_embedings = cat_embeddings[first_occurrence] |
| 69 | + return DeltaRows(ids=unique_ids, embeddings=unique_embedings) |
| 70 | + |
| 71 | + |
| 72 | +class DeltaStore: |
| 73 | + """ |
| 74 | + DeltaStore is a helper class that stores and manages local delta (row) updates for embeddings/states across |
| 75 | + various batches during training, designed to be used with TorchRecs ModelDeltaTracker. |
| 76 | + It maintains a CUDA in-memory representation of requested ids and embeddings/states, |
| 77 | + providing a way to compact and get delta updates for each embedding table. |
| 78 | +
|
| 79 | + The class supports different embedding update modes (NONE, FIRST, LAST) to determine |
| 80 | + how to handle duplicate ids when compacting or retrieving embeddings. |
| 81 | +
|
| 82 | + """ |
| 83 | + |
| 84 | + def __init__(self, embdUpdateMode: EmbdUpdateMode = EmbdUpdateMode.NONE) -> None: |
| 85 | + self.embdUpdateMode = embdUpdateMode |
| 86 | + self.per_fqn_lookups: Dict[str, List[IndexedLookup]] = {} |
| 87 | + |
| 88 | + def append( |
| 89 | + self, |
| 90 | + batch_idx: int, |
| 91 | + table_fqn: str, |
| 92 | + ids: torch.Tensor, |
| 93 | + embeddings: Optional[torch.Tensor], |
| 94 | + ) -> None: |
| 95 | + table_fqn_lookup = self.per_fqn_lookups.get(table_fqn, []) |
| 96 | + table_fqn_lookup.append( |
| 97 | + IndexedLookup(batch_idx=batch_idx, ids=ids, embeddings=embeddings) |
| 98 | + ) |
| 99 | + self.per_fqn_lookups[table_fqn] = table_fqn_lookup |
| 100 | + |
| 101 | + def delete(self, up_to_idx: Optional[int] = None) -> None: |
| 102 | + """ |
| 103 | + Delete all idx from the store up to `up_to_idx` |
| 104 | + """ |
| 105 | + if up_to_idx is None: |
| 106 | + # If up_to_idx is None, delete all lookups |
| 107 | + self.per_fqn_lookups = {} |
| 108 | + else: |
| 109 | + # lookups are sorted by idx. |
| 110 | + up_to_idx = none_throws(up_to_idx) |
| 111 | + for table_fqn, lookups in self.per_fqn_lookups.items(): |
| 112 | + # remove all lookups up to up_to_idx |
| 113 | + self.per_fqn_lookups[table_fqn] = [ |
| 114 | + lookup for lookup in lookups if lookup.batch_idx >= up_to_idx |
| 115 | + ] |
| 116 | + |
| 117 | + def compact(self, start_idx: int, end_idx: int) -> None: |
| 118 | + r""" |
| 119 | + Compact (ids, embeddings) in batch index range from start_idx, curr_batch_idx. |
| 120 | + """ |
| 121 | + assert ( |
| 122 | + start_idx < end_idx |
| 123 | + ), f"start_idx {start_idx} must be smaller then end_idx, but got {end_idx}" |
| 124 | + |
| 125 | + new_per_fqn_lookups: Dict[str, List[IndexedLookup]] = {} |
| 126 | + for table_fqn, lookups in self.per_fqn_lookups.items(): |
| 127 | + indexices = [h.batch_idx for h in lookups] |
| 128 | + index_l = bisect_left(indexices, start_idx) |
| 129 | + index_r = bisect_left(indexices, end_idx) |
| 130 | + lookups_to_compact = lookups[index_l:index_r] |
| 131 | + if len(lookups_to_compact) <= 1: |
| 132 | + new_per_fqn_lookups[table_fqn] = lookups |
| 133 | + continue |
| 134 | + ids = [lookup.ids for lookup in lookups_to_compact] |
| 135 | + embeddings = ( |
| 136 | + [none_throws(lookup.embeddings) for lookup in lookups_to_compact] |
| 137 | + if self.embdUpdateMode != EmbdUpdateMode.NONE |
| 138 | + else None |
| 139 | + ) |
| 140 | + delta_rows = _compute_unique_rows( |
| 141 | + ids=ids, embeddings=embeddings, mode=self.embdUpdateMode |
| 142 | + ) |
| 143 | + new_per_fqn_lookups[table_fqn] = ( |
| 144 | + lookups[:index_l] |
| 145 | + + [ |
| 146 | + IndexedLookup( |
| 147 | + batch_idx=start_idx, |
| 148 | + ids=delta_rows.ids, |
| 149 | + embeddings=delta_rows.embeddings, |
| 150 | + ) |
| 151 | + ] |
| 152 | + + lookups[index_r:] |
| 153 | + ) |
| 154 | + self.per_fqn_lookups = new_per_fqn_lookups |
| 155 | + |
| 156 | + def get_delta(self, from_idx: int = 0) -> Dict[str, DeltaRows]: |
| 157 | + r""" |
| 158 | + Return all unique/delta ids per table from the Delta Store. |
| 159 | + """ |
| 160 | + |
| 161 | + delta_per_table_fqn: Dict[str, DeltaRows] = {} |
| 162 | + for table_fqn, lookups in self.per_fqn_lookups.items(): |
| 163 | + compact_ids = [ |
| 164 | + lookup.ids for lookup in lookups if lookup.batch_idx >= from_idx |
| 165 | + ] |
| 166 | + compact_embeddings = ( |
| 167 | + [ |
| 168 | + none_throws(lookup.embeddings) |
| 169 | + for lookup in lookups |
| 170 | + if lookup.batch_idx >= from_idx |
| 171 | + ] |
| 172 | + if self.embdUpdateMode != EmbdUpdateMode.NONE |
| 173 | + else None |
| 174 | + ) |
| 175 | + |
| 176 | + delta_per_table_fqn[table_fqn] = _compute_unique_rows( |
| 177 | + ids=compact_ids, embeddings=compact_embeddings, mode=self.embdUpdateMode |
| 178 | + ) |
| 179 | + return delta_per_table_fqn |
0 commit comments