Skip to content

Commit 4838bfa

Browse files
EddyLXJfacebook-github-bot
authored andcommitted
Free mem trigger with all2all for sync trigger eviction (meta-pytorch#3442)
Summary: Before KVZCH is using ID_COUNT and MEM_UTIL eviction trigger mode, both are very tricky and hard for model engineer to decide what num to use for the id count or mem util threshold. Besides that, the eviction start time is out of sync after some time in training, which can cause great qps drop during eviction. This diff is adding support for free memory trigger eviction. It will check how many free memory left every N batch in every rank and if free memory below the threshold, it will trigger eviction in all tbes of all ranks using all reduce. In this way, we can force the start time of eviction in all ranks. Reviewed By: emlin Differential Revision: D83896528
1 parent 3a6cf2e commit 4838bfa

File tree

2 files changed

+49
-11
lines changed

2 files changed

+49
-11
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -242,9 +242,9 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:
242242
)
243243
ssd_tbe_params["cache_sets"] = int(max_cache_sets)
244244

245-
if "kvzch_eviction_trigger_mode" in fused_params and config.is_using_virtual_table:
246-
ssd_tbe_params["kvzch_eviction_trigger_mode"] = fused_params.get(
247-
"kvzch_eviction_trigger_mode"
245+
if "kvzch_eviction_tbe_config" in fused_params and config.is_using_virtual_table:
246+
ssd_tbe_params["kvzch_eviction_tbe_config"] = fused_params.get(
247+
"kvzch_eviction_tbe_config"
248248
)
249249

250250
ssd_tbe_params["table_names"] = [table.name for table in config.embedding_tables]
@@ -337,11 +337,40 @@ def _populate_zero_collision_tbe_params(
337337
eviction_strategy = -1
338338
table_names = [table.name for table in config.embedding_tables]
339339
l2_cache_size = tbe_params["l2_cache_size"]
340-
if "kvzch_eviction_trigger_mode" in tbe_params:
341-
eviction_trigger_mode = tbe_params["kvzch_eviction_trigger_mode"]
342-
tbe_params.pop("kvzch_eviction_trigger_mode")
343-
else:
344-
eviction_trigger_mode = 2 # 2 means mem_util based eviction
340+
341+
# Eviction tbe config default values
342+
eviction_trigger_mode = 2 # 2 means mem_util based eviction
343+
eviction_free_mem_threshold_gb = (
344+
200 # Eviction free memory trigger threshold in GB
345+
)
346+
eviction_free_mem_check_interval_batch = (
347+
1000
348+
) # how many batchs to check free memory when trigger model is free_mem
349+
threshold_calculation_bucket_stride = 0.2
350+
threshold_calculation_bucket_num = 1000000 # 1M
351+
if "kvzch_eviction_tbe_config" in tbe_params:
352+
eviction_tbe_config = tbe_params["kvzch_eviction_tbe_config"]
353+
tbe_params.pop("kvzch_eviction_tbe_config")
354+
355+
if eviction_tbe_config.kvzch_eviction_trigger_mode is not None:
356+
eviction_trigger_mode = eviction_tbe_config.kvzch_eviction_trigger_mode
357+
if eviction_tbe_config.eviction_free_mem_threshold_gb is not None:
358+
eviction_free_mem_threshold_gb = (
359+
eviction_tbe_config.eviction_free_mem_threshold_gb
360+
)
361+
if eviction_tbe_config.eviction_free_mem_check_interval_batch is not None:
362+
eviction_free_mem_check_interval_batch = (
363+
eviction_tbe_config.eviction_free_mem_check_interval_batch
364+
)
365+
if eviction_tbe_config.threshold_calculation_bucket_stride is not None:
366+
threshold_calculation_bucket_stride = (
367+
eviction_tbe_config.threshold_calculation_bucket_stride
368+
)
369+
if eviction_tbe_config.threshold_calculation_bucket_num is not None:
370+
threshold_calculation_bucket_num = (
371+
eviction_tbe_config.threshold_calculation_bucket_num
372+
)
373+
345374
for i, table in enumerate(config.embedding_tables):
346375
policy_t = table.virtual_table_eviction_policy
347376
if policy_t is not None:
@@ -421,6 +450,10 @@ def _populate_zero_collision_tbe_params(
421450
training_id_keep_count=training_id_keep_count,
422451
l2_weight_thresholds=l2_weight_thresholds,
423452
meta_header_lens=meta_header_lens,
453+
eviction_free_mem_threshold_gb=eviction_free_mem_threshold_gb,
454+
eviction_free_mem_check_interval_batch=eviction_free_mem_check_interval_batch,
455+
threshold_calculation_bucket_stride=threshold_calculation_bucket_stride,
456+
threshold_calculation_bucket_num=threshold_calculation_bucket_num,
424457
)
425458
else:
426459
eviction_policy = EvictionPolicy(meta_header_lens=meta_header_lens)
@@ -1768,6 +1801,7 @@ def __init__(
17681801
feature_table_map=self._feature_table_map,
17691802
ssd_cache_location=embedding_location,
17701803
pooling_mode=PoolingMode.NONE,
1804+
pg=pg,
17711805
**ssd_tbe_params,
17721806
).to(device)
17731807

@@ -2000,6 +2034,7 @@ def __init__(
20002034
ssd_cache_location=embedding_location,
20012035
pooling_mode=PoolingMode.NONE,
20022036
backend_type=backend_type,
2037+
pg=pg,
20032038
**ssd_tbe_params,
20042039
).to(device)
20052040

@@ -2680,6 +2715,7 @@ def __init__(
26802715
feature_table_map=self._feature_table_map,
26812716
ssd_cache_location=embedding_location,
26822717
pooling_mode=self._pooling,
2718+
pg=pg,
26832719
**ssd_tbe_params,
26842720
).to(device)
26852721

@@ -2900,6 +2936,7 @@ def __init__(
29002936
ssd_cache_location=embedding_location,
29012937
pooling_mode=self._pooling,
29022938
backend_type=backend_type,
2939+
pg=pg,
29032940
**ssd_tbe_params,
29042941
).to(device)
29052942

torchrec/distributed/types.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
3131
BoundsCheckMode,
3232
CacheAlgorithm,
33+
KVZCHEvictionTBEConfig,
3334
MultiPassPrefetchConfig,
3435
)
3536

@@ -662,7 +663,7 @@ class KeyValueParams:
662663
lazy_bulk_init_enabled: bool: whether to enable lazy(async) bulk init for SSD TBE
663664
enable_raw_embedding_streaming: Optional[bool]: enable raw embedding streaming for SSD TBE
664665
res_store_shards: Optional[int] = None: the number of shards to store the raw embeddings
665-
kvzch_eviction_trigger_mode: Optional[int]: eviction trigger mode for KVZCH
666+
kvzch_eviction_tbe_config: Optional[KVZCHEvictionTBEConfig]: KVZCH eviction config for TBE
666667
667668
# Parameter Server (PS) Attributes
668669
ps_hosts (Optional[Tuple[Tuple[str, int]]]): List of PS host ip addresses
@@ -688,7 +689,7 @@ class KeyValueParams:
688689
None # enable raw embedding streaming for SSD TBE
689690
)
690691
res_store_shards: Optional[int] = None # shards to store the raw embeddings
691-
kvzch_eviction_trigger_mode: Optional[int] = None # eviction trigger mode for KVZCH
692+
kvzch_eviction_tbe_config: Optional[KVZCHEvictionTBEConfig] = None
692693

693694
# Parameter Server (PS) Attributes
694695
ps_hosts: Optional[Tuple[Tuple[str, int], ...]] = None
@@ -717,7 +718,7 @@ def __hash__(self) -> int:
717718
self.lazy_bulk_init_enabled,
718719
self.enable_raw_embedding_streaming,
719720
self.res_store_shards,
720-
self.kvzch_eviction_trigger_mode,
721+
self.kvzch_eviction_tbe_config,
721722
)
722723
)
723724

0 commit comments

Comments
 (0)