@@ -241,9 +241,9 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:
241241 )
242242 ssd_tbe_params ["cache_sets" ] = int (max_cache_sets )
243243
244- if "kvzch_eviction_trigger_mode " in fused_params and config .is_using_virtual_table :
245- ssd_tbe_params ["kvzch_eviction_trigger_mode " ] = fused_params .get (
246- "kvzch_eviction_trigger_mode "
244+ if "kvzch_eviction_tbe_config " in fused_params and config .is_using_virtual_table :
245+ ssd_tbe_params ["kvzch_eviction_tbe_config " ] = fused_params .get (
246+ "kvzch_eviction_tbe_config "
247247 )
248248
249249 ssd_tbe_params ["table_names" ] = [table .name for table in config .embedding_tables ]
@@ -336,11 +336,40 @@ def _populate_zero_collision_tbe_params(
336336 eviction_strategy = - 1
337337 table_names = [table .name for table in config .embedding_tables ]
338338 l2_cache_size = tbe_params ["l2_cache_size" ]
339- if "kvzch_eviction_trigger_mode" in tbe_params :
340- eviction_trigger_mode = tbe_params ["kvzch_eviction_trigger_mode" ]
341- tbe_params .pop ("kvzch_eviction_trigger_mode" )
342- else :
343- eviction_trigger_mode = 2 # 2 means mem_util based eviction
339+
340+ # Eviction tbe config default values
341+ eviction_trigger_mode = 2 # 2 means mem_util based eviction
342+ eviction_free_mem_threshold_gb = (
343+ 10 # Eviction free memory trigger threshold in GB
344+ )
345+ eviction_free_mem_check_interval_batch = (
346+ 1000 ,
347+ ) # how many batchs to check free memory when trigger model is free_mem
348+ threshold_calculation_bucket_stride = 0.2
349+ threshold_calculation_bucket_num = 1000000 # 1M
350+ if "kvzch_eviction_tbe_config" in tbe_params :
351+ eviction_tbe_config = tbe_params ["kvzch_eviction_tbe_config" ]
352+ tbe_params .pop ("kvzch_eviction_tbe_config" )
353+
354+ if eviction_tbe_config .kvzch_eviction_trigger_mode is not None :
355+ eviction_trigger_mode = eviction_tbe_config .kvzch_eviction_trigger_mode
356+ if eviction_tbe_config .eviction_free_mem_threshold_gb is not None :
357+ eviction_free_mem_threshold_gb = (
358+ eviction_tbe_config .eviction_free_mem_threshold_gb
359+ )
360+ if eviction_tbe_config .eviction_free_mem_check_interval_batch is not None :
361+ eviction_free_mem_check_interval_batch = (
362+ eviction_tbe_config .eviction_free_mem_check_interval_batch
363+ )
364+ if eviction_tbe_config .threshold_calculation_bucket_stride is not None :
365+ threshold_calculation_bucket_stride = (
366+ eviction_tbe_config .threshold_calculation_bucket_stride
367+ )
368+ if eviction_tbe_config .threshold_calculation_bucket_num is not None :
369+ threshold_calculation_bucket_num = (
370+ eviction_tbe_config .threshold_calculation_bucket_num
371+ )
372+
344373 for i , table in enumerate (config .embedding_tables ):
345374 policy_t = table .virtual_table_eviction_policy
346375 if policy_t is not None :
@@ -420,6 +449,10 @@ def _populate_zero_collision_tbe_params(
420449 training_id_keep_count = training_id_keep_count ,
421450 l2_weight_thresholds = l2_weight_thresholds ,
422451 meta_header_lens = meta_header_lens ,
452+ eviction_free_mem_threshold_gb = eviction_free_mem_threshold_gb ,
453+ eviction_free_mem_check_interval_batch = eviction_free_mem_check_interval_batch ,
454+ threshold_calculation_bucket_stride = threshold_calculation_bucket_stride ,
455+ threshold_calculation_bucket_num = threshold_calculation_bucket_num ,
423456 )
424457 else :
425458 eviction_policy = EvictionPolicy (meta_header_lens = meta_header_lens )
@@ -1760,6 +1793,7 @@ def __init__(
17601793 feature_table_map = self ._feature_table_map ,
17611794 ssd_cache_location = embedding_location ,
17621795 pooling_mode = PoolingMode .NONE ,
1796+ pg = pg ,
17631797 ** ssd_tbe_params ,
17641798 ).to (device )
17651799
@@ -1992,6 +2026,7 @@ def __init__(
19922026 ssd_cache_location = embedding_location ,
19932027 pooling_mode = PoolingMode .NONE ,
19942028 backend_type = backend_type ,
2029+ pg = pg ,
19952030 ** ssd_tbe_params ,
19962031 ).to (device )
19972032
@@ -2672,6 +2707,7 @@ def __init__(
26722707 feature_table_map = self ._feature_table_map ,
26732708 ssd_cache_location = embedding_location ,
26742709 pooling_mode = self ._pooling ,
2710+ pg = pg ,
26752711 ** ssd_tbe_params ,
26762712 ).to (device )
26772713
@@ -2892,6 +2928,7 @@ def __init__(
28922928 ssd_cache_location = embedding_location ,
28932929 pooling_mode = self ._pooling ,
28942930 backend_type = backend_type ,
2931+ pg = pg ,
28952932 ** ssd_tbe_params ,
28962933 ).to (device )
28972934
0 commit comments