Skip to content

Commit e193fa6

Browse files
EddyLXJfacebook-github-bot
authored andcommitted
Free mem trigger with all2all for sync trigger eviction
Summary: X-link: facebookresearch/FBGEMM#2067 X-link: meta-pytorch/torchrec#3442 Before KVZCH is using ID_COUNT and MEM_UTIL eviction trigger mode, both are very tricky and hard for model engineer to decide what num to use for the id count or mem util threshold. Besides that, the eviction start time is out of sync after some time in training, which can cause great qps drop during eviction. This diff is adding support for free memory trigger eviction. It will check how many free memory left every N batch in every rank and if free memory below the threshold, it will trigger eviction in all tbes of all ranks using all reduce. In this way, we can force the start time of eviction in all ranks. Reviewed By: emlin Differential Revision: D83896528
1 parent 70a3f3d commit e193fa6

File tree

4 files changed

+172
-26
lines changed

4 files changed

+172
-26
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,19 +86,19 @@ class EvictionPolicy(NamedTuple):
8686
None # feature_score_counter_decay_rates for each table if eviction strategy is feature score
8787
)
8888
training_id_eviction_trigger_count: Optional[list[int]] = (
89-
None # training_id_eviction_trigger_count for each table
89+
None # Number of training IDs that, when exceeded, will trigger eviction for each table.
9090
)
9191
training_id_keep_count: Optional[list[int]] = (
92-
None # training_id_keep_count for each table
92+
None # Target number of training IDs to retain in each table after eviction.
9393
)
9494
l2_weight_thresholds: Optional[list[float]] = (
9595
None # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
9696
)
9797
threshold_calculation_bucket_stride: Optional[float] = (
98-
0.2 # threshold_calculation_bucket_stride if eviction strategy is feature score
98+
0.2 # The width of each feature score bucket used for threshold calculation in feature score-based eviction.
9999
)
100100
threshold_calculation_bucket_num: Optional[int] = (
101-
1000000 # 1M, threshold_calculation_bucket_num if eviction strategy is feature score
101+
1000000 # 1M, Total number of feature score buckets used for threshold calculation in feature score-based eviction.
102102
)
103103
interval_for_insufficient_eviction_s: int = (
104104
# wait at least # seconds before trigger next round of eviction, if last finished eviction is insufficient
@@ -114,10 +114,16 @@ class EvictionPolicy(NamedTuple):
114114
24 * 3600 # 1 day, interval for feature statistics decay
115115
)
116116
meta_header_lens: Optional[list[int]] = None # metaheader length for each table
117+
eviction_free_mem_threshold_gb: Optional[int] = (
118+
None # Minimum free memory (in GB) required before triggering eviction when using free_mem trigger mode.
119+
)
120+
eviction_free_mem_check_interval_batch: Optional[int] = (
121+
None # Number of batches between checks for free memory threshold when using free_mem trigger mode.
122+
)
117123

118124
def validate(self) -> None:
119-
assert self.eviction_trigger_mode in [0, 1, 2, 3, 4], (
120-
"eviction_trigger_mode must be 0, 1, 2, 3 or 4 "
125+
assert self.eviction_trigger_mode in [0, 1, 2, 3, 4, 5], (
126+
"eviction_trigger_mode must be 0, 1, 2, 3, 4, 5"
121127
f"actual {self.eviction_trigger_mode}"
122128
)
123129
if self.eviction_trigger_mode == 0:
@@ -143,6 +149,13 @@ def validate(self) -> None:
143149
assert (
144150
self.training_id_eviction_trigger_count is not None
145151
), "training_id_eviction_trigger_count must be set if eviction_trigger_mode is 4"
152+
elif self.eviction_trigger_mode == 5:
153+
assert (
154+
self.eviction_free_mem_threshold_gb is not None
155+
), "eviction_free_mem_threshold_gb must be set if eviction_trigger_mode is 5"
156+
assert (
157+
self.eviction_free_mem_check_interval_batch is not None
158+
), "eviction_free_mem_check_interval_batch must be set if eviction_trigger_mode is 5"
146159

147160
if self.eviction_strategy == 0:
148161
assert self.ttls_in_mins is not None, (
@@ -240,6 +253,19 @@ def validate(self) -> None:
240253
), "backend_return_whole_row can only be enabled when enable_optimizer_offloading is enabled"
241254

242255

256+
class KVZCHEvictionTBEConfig(NamedTuple):
257+
# Eviction trigger model for kvzch table: 0: disabled, 1: iteration, 2: mem_util, 3: manual, 4: id count, 5: free_mem
258+
kvzch_eviction_trigger_mode: Optional[int] = None
259+
# Minimum free memory (in GB) required before triggering eviction when using free_mem trigger mode.
260+
eviction_free_mem_threshold_gb: Optional[int] = None
261+
# Number of batches between checks for free memory threshold when using free_mem trigger mode.
262+
eviction_free_mem_check_interval_batch: Optional[int] = None
263+
# The width of each feature score bucket used for threshold calculation in feature score-based eviction.
264+
threshold_calculation_bucket_stride: Optional[float] = None
265+
# Total number of feature score buckets used for threshold calculation in feature score-based eviction.
266+
threshold_calculation_bucket_num: Optional[int] = None
267+
268+
243269
class BackendType(enum.IntEnum):
244270
SSD = 0
245271
DRAM = 1

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 128 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
import time
1919
from functools import cached_property
2020
from math import floor, log2
21-
from typing import Any, Callable, Optional, Union
21+
from typing import Any, Callable, ClassVar, Optional, Union
2222
import torch # usort:skip
23+
import weakref
2324

2425
# @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers
2526
import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers
@@ -34,6 +35,7 @@
3435
BoundsCheckMode,
3536
CacheAlgorithm,
3637
EmbeddingLocation,
38+
EvictionPolicy,
3739
get_bounds_check_version_for_platform,
3840
KVZCHParams,
3941
PoolingMode,
@@ -54,6 +56,8 @@
5456
from torch import distributed as dist, nn, Tensor # usort:skip
5557
from dataclasses import dataclass
5658

59+
import psutil
60+
5761
from torch.autograd.profiler import record_function
5862

5963
from ..cache import get_unique_indices_v2
@@ -100,6 +104,9 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
100104
_local_instance_index: int = -1
101105
res_params: RESParams
102106
table_names: list[str]
107+
_all_tbe_instances: ClassVar[weakref.WeakSet] = weakref.WeakSet()
108+
_first_instance_ref: ClassVar[weakref.ref] = None
109+
_eviction_triggered: ClassVar[bool] = False
103110

104111
def __init__(
105112
self,
@@ -179,6 +186,7 @@ def __init__(
179186
table_names: Optional[list[str]] = None,
180187
use_rowwise_bias_correction: bool = False, # For Adam use
181188
optimizer_state_dtypes: dict[str, SparseType] = {}, # noqa: B006
189+
pg: Optional[dist.ProcessGroup] = None,
182190
) -> None:
183191
super(SSDTableBatchedEmbeddingBags, self).__init__()
184192

@@ -567,6 +575,10 @@ def __init__(
567575
# loading checkpoint flag, set by checkpoint loader, and cleared after weight is applied to backend
568576
self.load_state_dict: bool = False
569577

578+
SSDTableBatchedEmbeddingBags._all_tbe_instances.add(self)
579+
if SSDTableBatchedEmbeddingBags._first_instance_ref is None:
580+
SSDTableBatchedEmbeddingBags._first_instance_ref = weakref.ref(self)
581+
570582
# create tbe unique id using rank index | local tbe idx
571583
if tbe_unique_id == -1:
572584
SSDTableBatchedEmbeddingBags._local_instance_index += 1
@@ -584,6 +596,7 @@ def __init__(
584596
self.tbe_unique_id = tbe_unique_id
585597
self.l2_cache_size = l2_cache_size
586598
logging.info(f"tbe_unique_id: {tbe_unique_id}")
599+
self.enable_free_mem_trigger_eviction: bool = False
587600
if self.backend_type == BackendType.SSD:
588601
logging.info(
589602
f"Logging SSD offloading setup, tbe_unique_id:{tbe_unique_id}, l2_cache_size:{l2_cache_size}GB, "
@@ -688,25 +701,31 @@ def __init__(
688701
if self.kv_zch_params.eviction_policy.eviction_mem_threshold_gb
689702
else self.l2_cache_size
690703
)
704+
kv_zch_params = self.kv_zch_params
705+
eviction_policy = self.kv_zch_params.eviction_policy
706+
if eviction_policy.eviction_trigger_mode == 5:
707+
# If trigger mode is free_mem(5), populate config
708+
self.set_free_mem_eviction_trigger_config(eviction_policy)
709+
691710
# Please refer to https://fburl.com/gdoc/nuupjwqq for the following eviction parameters.
692711
eviction_config = torch.classes.fbgemm.FeatureEvictConfig(
693-
self.kv_zch_params.eviction_policy.eviction_trigger_mode, # eviction is disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual, 4: id count
694-
self.kv_zch_params.eviction_policy.eviction_strategy, # evict_trigger_strategy: 0: timestamp, 1: counter, 2: counter + timestamp, 3: feature l2 norm, 4: timestamp threshold 5: feature score
695-
self.kv_zch_params.eviction_policy.eviction_step_intervals, # trigger_step_interval if trigger mode is iteration
712+
eviction_policy.eviction_trigger_mode, # eviction is disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual, 4: id count
713+
eviction_policy.eviction_strategy, # evict_trigger_strategy: 0: timestamp, 1: counter, 2: counter + timestamp, 3: feature l2 norm, 4: timestamp threshold 5: feature score
714+
eviction_policy.eviction_step_intervals, # trigger_step_interval if trigger mode is iteration
696715
eviction_mem_threshold_gb, # mem_util_threshold_in_GB if trigger mode is mem_util
697-
self.kv_zch_params.eviction_policy.ttls_in_mins, # ttls_in_mins for each table if eviction strategy is timestamp
698-
self.kv_zch_params.eviction_policy.counter_thresholds, # counter_thresholds for each table if eviction strategy is counter
699-
self.kv_zch_params.eviction_policy.counter_decay_rates, # counter_decay_rates for each table if eviction strategy is counter
700-
self.kv_zch_params.eviction_policy.feature_score_counter_decay_rates, # feature_score_counter_decay_rates for each table if eviction strategy is feature score
701-
self.kv_zch_params.eviction_policy.training_id_eviction_trigger_count, # training_id_eviction_trigger_count for each table
702-
self.kv_zch_params.eviction_policy.training_id_keep_count, # training_id_keep_count for each table
703-
self.kv_zch_params.eviction_policy.l2_weight_thresholds, # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
716+
eviction_policy.ttls_in_mins, # ttls_in_mins for each table if eviction strategy is timestamp
717+
eviction_policy.counter_thresholds, # counter_thresholds for each table if eviction strategy is counter
718+
eviction_policy.counter_decay_rates, # counter_decay_rates for each table if eviction strategy is counter
719+
eviction_policy.feature_score_counter_decay_rates, # feature_score_counter_decay_rates for each table if eviction strategy is feature score
720+
eviction_policy.training_id_eviction_trigger_count, # training_id_eviction_trigger_count for each table
721+
eviction_policy.training_id_keep_count, # training_id_keep_count for each table
722+
eviction_policy.l2_weight_thresholds, # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
704723
table_dims.tolist() if table_dims is not None else None,
705-
self.kv_zch_params.eviction_policy.threshold_calculation_bucket_stride, # threshold_calculation_bucket_stride if eviction strategy is feature score
706-
self.kv_zch_params.eviction_policy.threshold_calculation_bucket_num, # threshold_calculation_bucket_num if eviction strategy is feature score
707-
self.kv_zch_params.eviction_policy.interval_for_insufficient_eviction_s,
708-
self.kv_zch_params.eviction_policy.interval_for_sufficient_eviction_s,
709-
self.kv_zch_params.eviction_policy.interval_for_feature_statistics_decay_s,
724+
eviction_policy.threshold_calculation_bucket_stride, # threshold_calculation_bucket_stride if eviction strategy is feature score
725+
eviction_policy.threshold_calculation_bucket_num, # threshold_calculation_bucket_num if eviction strategy is feature score
726+
eviction_policy.interval_for_insufficient_eviction_s,
727+
eviction_policy.interval_for_sufficient_eviction_s,
728+
eviction_policy.interval_for_feature_statistics_decay_s,
710729
)
711730
self._ssd_db = torch.classes.fbgemm.DramKVEmbeddingCacheWrapper(
712731
self.cache_row_dim,
@@ -1065,6 +1084,8 @@ def __init__(
10651084

10661085
self.bounds_check_version: int = get_bounds_check_version_for_platform()
10671086

1087+
self._pg = pg
1088+
10681089
@cached_property
10691090
def cache_row_dim(self) -> int:
10701091
"""
@@ -2042,6 +2063,9 @@ def _prefetch( # noqa C901
20422063
if dist.get_rank() == 0:
20432064
self._report_kv_backend_stats()
20442065

2066+
# May trigger eviction if free mem trigger mode enabled before get cuda
2067+
self.may_trigger_eviction()
2068+
20452069
# Fetch data from SSD
20462070
if linear_cache_indices.numel() > 0:
20472071
self.record_function_via_dummy_profile(
@@ -4650,3 +4674,91 @@ def direct_write_embedding(
46504674
)
46514675

46524676
# Return control to the main stream without waiting for the backend operation to complete
4677+
4678+
def get_free_cpu_memory_gb(self) -> float:
4679+
mem = psutil.virtual_memory()
4680+
return mem.available / (1024**3)
4681+
4682+
@classmethod
4683+
def trigger_evict_in_all_tbes(cls) -> None:
4684+
for tbe in cls._all_tbe_instances:
4685+
tbe.ssd_db.trigger_feature_evict()
4686+
4687+
@classmethod
4688+
def tbe_has_ongoing_eviction(cls) -> bool:
4689+
for tbe in cls._all_tbe_instances:
4690+
if tbe.ssd_db.is_evicting():
4691+
return True
4692+
return False
4693+
4694+
def set_free_mem_eviction_trigger_config(
4695+
self, eviction_policy: EvictionPolicy
4696+
) -> None:
4697+
self.enable_free_mem_trigger_eviction = True
4698+
self.eviction_trigger_mode: int = eviction_policy.eviction_trigger_mode
4699+
assert (
4700+
eviction_policy.eviction_free_mem_check_interval_batch is not None
4701+
), "eviction_free_mem_check_interval_batch is unexpected none for free_mem eviction trigger mode"
4702+
self.eviction_free_mem_check_interval_batch: int = (
4703+
eviction_policy.eviction_free_mem_check_interval_batch
4704+
)
4705+
assert (
4706+
eviction_policy.eviction_free_mem_threshold_gb is not None
4707+
), "eviction_policy.eviction_free_mem_threshold_gb is unexpected none for free_mem eviction trigger mode"
4708+
self.eviction_free_mem_threshold_gb: int = (
4709+
eviction_policy.eviction_free_mem_threshold_gb
4710+
)
4711+
logging.info(
4712+
f"[FREE_MEM Eviction] eviction config, trigger model: FREE_MEM, {self.eviction_free_mem_check_interval_batch=}, {self.eviction_free_mem_threshold_gb=}"
4713+
)
4714+
4715+
def may_trigger_eviction(self) -> None:
4716+
def is_first_tbe() -> bool:
4717+
first = SSDTableBatchedEmbeddingBags._first_instance_ref
4718+
return first is not None and first() is self
4719+
4720+
# We assume that the eviction time is less than free mem check interval time
4721+
# So every time we reach this check, all evictions in all tbes should be finished.
4722+
# We only need to check the first tbe because all tbes share the same free mem,
4723+
# once the first tbe detect need to trigger eviction, it will call trigger func
4724+
# in all tbes from _all_tbe_instances
4725+
if (
4726+
self.enable_free_mem_trigger_eviction
4727+
and self.step % self.eviction_free_mem_check_interval_batch == 0
4728+
and self.training
4729+
and is_first_tbe()
4730+
):
4731+
if not SSDTableBatchedEmbeddingBags.tbe_has_ongoing_eviction():
4732+
SSDTableBatchedEmbeddingBags._eviction_triggered = False
4733+
4734+
free_cpu_mem_gb = self.get_free_cpu_memory_gb()
4735+
local_evict_trigger = int(
4736+
free_cpu_mem_gb < self.eviction_free_mem_threshold_gb
4737+
)
4738+
tensor_flag = torch.tensor(
4739+
local_evict_trigger,
4740+
device=self.current_device,
4741+
dtype=torch.int,
4742+
)
4743+
world_size = dist.get_world_size(self._pg)
4744+
if world_size > 1:
4745+
dist.all_reduce(tensor_flag, op=dist.ReduceOp.SUM, group=self._pg)
4746+
global_evict_trigger = tensor_flag.item()
4747+
else:
4748+
global_evict_trigger = local_evict_trigger
4749+
if (
4750+
global_evict_trigger >= 1
4751+
and SSDTableBatchedEmbeddingBags._eviction_triggered
4752+
):
4753+
logging.info(
4754+
f"[FREE_MEM Eviction] {global_evict_trigger} ranks triggered eviction, but SSDTableBatchedEmbeddingBags._eviction_triggered is true"
4755+
)
4756+
if (
4757+
global_evict_trigger >= 1
4758+
and not SSDTableBatchedEmbeddingBags._eviction_triggered
4759+
):
4760+
SSDTableBatchedEmbeddingBags._eviction_triggered = True
4761+
SSDTableBatchedEmbeddingBags.trigger_evict_in_all_tbes()
4762+
logging.info(
4763+
f"[FREE_MEM Eviction] Evict all at batch {self.step}, {free_cpu_mem_gb} GB free CPU memory, {global_evict_trigger} ranks triggered eviction"
4764+
)

fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1214,6 +1214,11 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
12141214
}
12151215
break;
12161216
}
1217+
case EvictTriggerMode::FREE_MEM: {
1218+
// For free mem eviction, all conditions checked in frontend, no check
1219+
// option in backend
1220+
return;
1221+
}
12171222
default:
12181223
break;
12191224
}

fbgemm_gpu/src/dram_kv_embedding_cache/feature_evict.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ enum class EvictTriggerMode {
3434
ITERATION, // Trigger based on iteration steps
3535
MEM_UTIL, // Trigger based on memory usage
3636
MANUAL, // Manually triggered by upstream
37-
ID_COUNT // Trigger based on id count
37+
ID_COUNT, // Trigger based on id count
38+
FREE_MEM, // Trigger based on free memory
3839
};
3940
inline std::string to_string(EvictTriggerMode mode) {
4041
switch (mode) {
@@ -48,6 +49,8 @@ inline std::string to_string(EvictTriggerMode mode) {
4849
return "MANUAL";
4950
case EvictTriggerMode::ID_COUNT:
5051
return "ID_COUNT";
52+
case EvictTriggerMode::FREE_MEM:
53+
return "FREE_MEM";
5154
}
5255
}
5356

@@ -184,6 +187,9 @@ struct FeatureEvictConfig : public torch::jit::CustomClassHolder {
184187
eviction_trigger_stats_log += "]";
185188
break;
186189
}
190+
case EvictTriggerMode::FREE_MEM: {
191+
break;
192+
}
187193
default:
188194
throw std::runtime_error("Unknown evict trigger mode");
189195
}
@@ -202,16 +208,13 @@ struct FeatureEvictConfig : public torch::jit::CustomClassHolder {
202208

203209
case EvictTriggerStrategy::BY_FEATURE_SCORE: {
204210
CHECK(feature_score_counter_decay_rates_.has_value());
205-
CHECK(training_id_eviction_trigger_count_.has_value());
206211
CHECK(training_id_keep_count_.has_value());
207212
CHECK(threshold_calculation_bucket_stride_.has_value());
208213
CHECK(threshold_calculation_bucket_num_.has_value());
209214
CHECK(ttls_in_mins_.has_value());
210215
LOG(INFO) << "eviction config, trigger mode:"
211216
<< to_string(trigger_mode_) << eviction_trigger_stats_log
212217
<< ", strategy: " << to_string(trigger_strategy_)
213-
<< ", training_id_eviction_trigger_count: "
214-
<< training_id_eviction_trigger_count_.value()
215218
<< ", training_id_keep_count:"
216219
<< training_id_keep_count_.value()
217220
<< ", ttls_in_mins: " << ttls_in_mins_.value()

0 commit comments

Comments
 (0)