@@ -242,9 +242,9 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:
242242 )
243243 ssd_tbe_params ["cache_sets" ] = int (max_cache_sets )
244244
245- if "kvzch_eviction_trigger_mode " in fused_params and config .is_using_virtual_table :
246- ssd_tbe_params ["kvzch_eviction_trigger_mode " ] = fused_params .get (
247- "kvzch_eviction_trigger_mode "
245+ if "kvzch_eviction_tbe_config " in fused_params and config .is_using_virtual_table :
246+ ssd_tbe_params ["kvzch_eviction_tbe_config " ] = fused_params .get (
247+ "kvzch_eviction_tbe_config "
248248 )
249249
250250 ssd_tbe_params ["table_names" ] = [table .name for table in config .embedding_tables ]
@@ -337,11 +337,40 @@ def _populate_zero_collision_tbe_params(
337337 eviction_strategy = - 1
338338 table_names = [table .name for table in config .embedding_tables ]
339339 l2_cache_size = tbe_params ["l2_cache_size" ]
340- if "kvzch_eviction_trigger_mode" in tbe_params :
341- eviction_trigger_mode = tbe_params ["kvzch_eviction_trigger_mode" ]
342- tbe_params .pop ("kvzch_eviction_trigger_mode" )
343- else :
344- eviction_trigger_mode = 2 # 2 means mem_util based eviction
340+
341+ # Eviction tbe config default values
342+ eviction_trigger_mode = 2 # 2 means mem_util based eviction
343+ eviction_free_mem_threshold_gb = (
344+ 200 # Eviction free memory trigger threshold in GB
345+ )
346+ eviction_free_mem_check_interval_batch = (
347+ 1000
348+ ) # how many batchs to check free memory when trigger model is free_mem
349+ threshold_calculation_bucket_stride = 0.2
350+ threshold_calculation_bucket_num = 1000000 # 1M
351+ if "kvzch_eviction_tbe_config" in tbe_params :
352+ eviction_tbe_config = tbe_params ["kvzch_eviction_tbe_config" ]
353+ tbe_params .pop ("kvzch_eviction_tbe_config" )
354+
355+ if eviction_tbe_config .kvzch_eviction_trigger_mode is not None :
356+ eviction_trigger_mode = eviction_tbe_config .kvzch_eviction_trigger_mode
357+ if eviction_tbe_config .eviction_free_mem_threshold_gb is not None :
358+ eviction_free_mem_threshold_gb = (
359+ eviction_tbe_config .eviction_free_mem_threshold_gb
360+ )
361+ if eviction_tbe_config .eviction_free_mem_check_interval_batch is not None :
362+ eviction_free_mem_check_interval_batch = (
363+ eviction_tbe_config .eviction_free_mem_check_interval_batch
364+ )
365+ if eviction_tbe_config .threshold_calculation_bucket_stride is not None :
366+ threshold_calculation_bucket_stride = (
367+ eviction_tbe_config .threshold_calculation_bucket_stride
368+ )
369+ if eviction_tbe_config .threshold_calculation_bucket_num is not None :
370+ threshold_calculation_bucket_num = (
371+ eviction_tbe_config .threshold_calculation_bucket_num
372+ )
373+
345374 for i , table in enumerate (config .embedding_tables ):
346375 policy_t = table .virtual_table_eviction_policy
347376 if policy_t is not None :
@@ -421,6 +450,10 @@ def _populate_zero_collision_tbe_params(
421450 training_id_keep_count = training_id_keep_count ,
422451 l2_weight_thresholds = l2_weight_thresholds ,
423452 meta_header_lens = meta_header_lens ,
453+ eviction_free_mem_threshold_gb = eviction_free_mem_threshold_gb ,
454+ eviction_free_mem_check_interval_batch = eviction_free_mem_check_interval_batch ,
455+ threshold_calculation_bucket_stride = threshold_calculation_bucket_stride ,
456+ threshold_calculation_bucket_num = threshold_calculation_bucket_num ,
424457 )
425458 else :
426459 eviction_policy = EvictionPolicy (meta_header_lens = meta_header_lens )
@@ -1768,6 +1801,7 @@ def __init__(
17681801 feature_table_map = self ._feature_table_map ,
17691802 ssd_cache_location = embedding_location ,
17701803 pooling_mode = PoolingMode .NONE ,
1804+ pg = pg ,
17711805 ** ssd_tbe_params ,
17721806 ).to (device )
17731807
@@ -2000,6 +2034,7 @@ def __init__(
20002034 ssd_cache_location = embedding_location ,
20012035 pooling_mode = PoolingMode .NONE ,
20022036 backend_type = backend_type ,
2037+ pg = pg ,
20032038 ** ssd_tbe_params ,
20042039 ).to (device )
20052040
@@ -2680,6 +2715,7 @@ def __init__(
26802715 feature_table_map = self ._feature_table_map ,
26812716 ssd_cache_location = embedding_location ,
26822717 pooling_mode = self ._pooling ,
2718+ pg = pg ,
26832719 ** ssd_tbe_params ,
26842720 ).to (device )
26852721
@@ -2900,6 +2936,7 @@ def __init__(
29002936 ssd_cache_location = embedding_location ,
29012937 pooling_mode = self ._pooling ,
29022938 backend_type = backend_type ,
2939+ pg = pg ,
29032940 ** ssd_tbe_params ,
29042941 ).to (device )
29052942
0 commit comments