Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 69 additions & 1 deletion corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,61 @@
)


def _mask_embeddings_by_frequency(
cache: Optional[Cache],
storage: Storage,
unique_keys: torch.Tensor,
unique_embs: torch.Tensor,
frequency_threshold: int,
mask_dims: int,
) -> None:
"""
Mask low-frequency embeddings by setting specific dimensions to zero.

This function queries scores from cache and storage, then masks embeddings
whose scores are below the frequency threshold.

Args:
cache: Optional cache table (can be None if caching is disabled)
storage: Storage table (always present)
unique_keys: Keys to query scores for
unique_embs: Embeddings to mask (modified in-place)
frequency_threshold: Minimum score threshold
mask_dims: Number of dimensions to mask (from the end)
"""
batch = unique_keys.size(0)
if batch == 0:
return
assert hasattr(
storage, "query_scores"
), "If you want to use frequency masking, storage must implement the query_scores method"
# Query scores from cache and storage
if cache is not None:
# 1. Query cache first
cache_scores = cache.query_scores(unique_keys)
cache_founds = cache_scores > 0

# 2. Query storage for cache misses
if (~cache_founds).any():
missing_keys = unique_keys[~cache_founds]
storage_scores = storage.query_scores(missing_keys)
cache_scores[~cache_founds] = storage_scores

scores = cache_scores
else:
# Without cache: query from storage only
scores = storage.query_scores(unique_keys)

# Apply masking
low_freq_mask = scores < frequency_threshold
if low_freq_mask.any():
unique_embs[low_freq_mask, -mask_dims:] = 0.0
for i in range(unique_embs.size(0)):
print(
f"Row {i}: score = {scores[i].item()}, last {mask_dims} dims = {unique_embs[i, -mask_dims:].tolist()}"
)


# TODO: BatchedDynamicEmbeddingFunction is more concrete.
class DynamicEmbeddingBagFunction(torch.autograd.Function):
@staticmethod
Expand Down Expand Up @@ -348,6 +403,8 @@ def forward(
input_dist_dedup: bool = False,
training: bool = True,
frequency_counters: Optional[torch.Tensor] = None,
frequency_threshold: int = 0,
mask_dims: int = 0,
*args,
):
table_num = len(storages)
Expand Down Expand Up @@ -426,6 +483,17 @@ def forward(
lfu_accumulated_frequency_per_table,
)

# Apply frequency-based masking if enabled
if is_lfu_enabled and mask_dims > 0 and frequency_threshold > 0:
_mask_embeddings_by_frequency(
caches[i] if caching else None,
storages[i],
unique_indices_per_table,
unique_embs_per_table,
frequency_threshold,
mask_dims,
)

if training or caching:
output_embs = torch.empty(
indices.shape[0], emb_dim, dtype=output_dtype, device=indices.device
Expand Down Expand Up @@ -501,4 +569,4 @@ def backward(ctx, grads):
optimizer,
)

return (None,) * 14
return (None,) * 16
5 changes: 4 additions & 1 deletion corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,8 @@ def __init__(
self._table_names = table_names
self.bounds_check_mode_int: int = bounds_check_mode.value
self._create_score()

self.frequency_threshold = table_option.frequency_threshold
self.mask_dims = table_option.mask_dims
if device is not None:
self.device_id = int(str(device)[-1])
else:
Expand Down Expand Up @@ -984,6 +985,8 @@ def forward(
self.use_index_dedup,
self.training,
per_sample_weights, # Pass frequency counters as weights
self.frequency_threshold,
self.mask_dims,
self._empty_tensor,
)
for cache in self._caches:
Expand Down
3 changes: 3 additions & 0 deletions corelib/dynamicemb/dynamicemb/dynamicemb_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,9 @@ class DynamicEmbTableOptions(_ContextOptions):
global_hbm_for_values: int = 0 # in bytes
external_storage: Storage = None
index_type: Optional[torch.dtype] = None
# Frequency-based masking parameters
frequency_threshold: int = 0 # frequency threshold for masking (0 = disabled)
mask_dims: int = 0 # number of dimensions to mask (0 = disabled)

def __post_init__(self):
assert (
Expand Down
24 changes: 24 additions & 0 deletions corelib/dynamicemb/dynamicemb/key_value_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
erase,
export_batch,
export_batch_matched,
find,
find_pointers,
find_pointers_with_scores,
insert_and_evict,
Expand Down Expand Up @@ -347,6 +348,29 @@ def create_scores(
else:
return None

def query_scores(self, unique_keys: torch.Tensor) -> torch.Tensor:
"""Query scores for given keys from the table.

Returns:
scores: Tensor of scores, with 0 for keys not found in table
"""

batch = unique_keys.size(0)
device = unique_keys.device

scores = torch.empty(batch, device=device, dtype=torch.long)
values = torch.empty(
batch, self.value_dim(), device=device, dtype=self.embedding_dtype()
)
founds = torch.empty(batch, device=device, dtype=torch.bool)

find(self.table, batch, unique_keys, values, founds, score=scores)

# for not found keys, set score to 0
scores[~founds] = 0

return scores

def insert(
self,
unique_keys: torch.Tensor,
Expand Down
4 changes: 3 additions & 1 deletion corelib/dynamicemb/example/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,9 +486,11 @@ def get_planner(device, eb_configs, batch_size, optimizer_type, training, cachin
initializer_args=DynamicEmbInitializerArgs(
mode=DynamicEmbInitializerMode.NORMAL
),
score_strategy=DynamicEmbScoreStrategy.STEP,
score_strategy=DynamicEmbScoreStrategy.LFU,
caching=caching,
training=training,
frequency_threshold=10,
mask_dims=10,
),
)

Expand Down