@@ -1143,6 +1143,176 @@ def test_kv_connector_handles_preemption():
11431143 assert scheduler .kv_cache_manager .block_pool .get_num_free_blocks () == NUM_BLOCKS - 1
11441144
11451145
1146+ def _iterate_until_done (scheduler : Scheduler ):
1147+ while True :
1148+ scheduler_output = scheduler .schedule ()
1149+ if len (scheduler .running ) == 0 :
1150+ break
1151+ model_runner_output = make_output (scheduler )
1152+ scheduler .update_from_output (scheduler_output , model_runner_output )
1153+
1154+
1155+ @pytest .mark .parametrize (
1156+ "global_threshold,"
1157+ "request_num_tokens,"
1158+ "request_local_hit_blocks,"
1159+ "request_external_hit_blocks,"
1160+ "request_thresholds,"
1161+ "request_expected_scehduled" ,
1162+ [
1163+ (
1164+ 0.0 ,
1165+ [57 , 34 , 28 ],
1166+ [1 , 1 , 0 ],
1167+ [0 , 1 , 0 ],
1168+ # expected hit ratio: [0.281, 0.941, 0.0]
1169+ # calculated as (local + external) * BLOCK_SIZE / tokens
1170+ [None , 0.4 , 0.1 ],
1171+ [True , True , False ],
1172+ ),
1173+ (
1174+ 0.3 ,
1175+ [157 , 134 , 128 , 20 , 150 ],
1176+ [4 , 1 , 0 , 0 , 1 ],
1177+ [2 , 4 , 0 , 1 , 0 ],
1178+ # expected hit ratio: [0.611, 0.597, 0.0, 0.8, 0.106]
1179+ [0.8 , 0.4 , 0.1 , None , None ],
1180+ [False , True , False , True , False ],
1181+ ),
1182+ ],
1183+ )
1184+ def test_cache_hit_threshold (
1185+ # we validate global_threshold is used when request threshold is None
1186+ global_threshold : float ,
1187+ # number of tokens in each request
1188+ request_num_tokens : list [int ],
1189+ # number of blocks hit in local cache per request
1190+ request_local_hit_blocks : list [int ],
1191+ # number of blocks hit in external cache per request
1192+ request_external_hit_blocks : list [int ],
1193+ # optional cache_hit_threshold for each request
1194+ request_thresholds : list [Optional [float ]],
1195+ # bool per request indicating if it is expected to be scheduled
1196+ request_expected_scehduled : list [bool ],
1197+ ):
1198+ assert (
1199+ len (request_num_tokens )
1200+ == len (request_thresholds )
1201+ == len (request_local_hit_blocks )
1202+ == len (request_external_hit_blocks )
1203+ == len (request_expected_scehduled )
1204+ )
1205+
1206+ scheduler = create_scheduler (
1207+ enable_prefix_caching = True ,
1208+ global_cache_hit_threshold = global_threshold ,
1209+ use_kv_connector = True ,
1210+ )
1211+
1212+ _insert_to_local_cache (request_local_hit_blocks , scheduler )
1213+ _mock_external_cache_hit (request_external_hit_blocks , scheduler )
1214+
1215+ requests , scheduler_output = _create_and_schedule_requests (
1216+ request_num_tokens , request_thresholds , scheduler
1217+ )
1218+
1219+ # assert all requests expected to be scheduled are indeed scheduled
1220+ assert [
1221+ r .request_id
1222+ for r , expected in zip (requests , request_expected_scehduled )
1223+ if expected
1224+ ] == [s .req_id for s in scheduler_output .scheduled_new_reqs ]
1225+
1226+ # assert other requests are "finished" due to cache threshold
1227+ requests_expected_not_scheduled = [
1228+ r for r , expected in zip (requests , request_expected_scehduled ) if not expected
1229+ ]
1230+ assert all (
1231+ r .status == RequestStatus .FINISHED_CACHE_HIT_BELOW_THRESHOLD
1232+ for r in requests_expected_not_scheduled
1233+ )
1234+
1235+ _iterate_until_done (scheduler )
1236+ assert_scheduler_empty (scheduler )
1237+
1238+
1239+ def _create_and_schedule_requests (
1240+ request_num_tokens : list [int ],
1241+ request_thresholds : list [Optional [float ]],
1242+ scheduler : Scheduler ,
1243+ ):
1244+ num_requests = len (request_num_tokens )
1245+ requests = create_requests (
1246+ num_requests = num_requests ,
1247+ num_tokens = request_num_tokens ,
1248+ block_size = scheduler .cache_config .block_size ,
1249+ cache_hit_thresholds = request_thresholds ,
1250+ )
1251+
1252+ for request in requests :
1253+ scheduler .add_request (request )
1254+
1255+ scheduler_output = scheduler .schedule ()
1256+ model_runner_output = make_output (scheduler )
1257+ scheduler .update_from_output (scheduler_output , model_runner_output )
1258+ return requests , scheduler_output
1259+
1260+
1261+ def _mock_external_cache_hit (request_external_hit_blocks , scheduler : Scheduler ):
1262+ BLOCK_SIZE = scheduler .cache_config .block_size
1263+ scheduler .connector .get_num_new_matched_tokens = Mock (name = "method" )
1264+ scheduler .connector .get_num_new_matched_tokens .side_effect = [
1265+ (i * BLOCK_SIZE , False ) for i in request_external_hit_blocks
1266+ ]
1267+
1268+
1269+ def _insert_to_local_cache (request_local_hit_blocks , scheduler : Scheduler ):
1270+ """Schedule requests to fill in the local cache"""
1271+ BLOCK_SIZE = scheduler .cache_config .block_size
1272+ num_total_requests = len (request_local_hit_blocks )
1273+
1274+ requests_to_schedule = [
1275+ i for i , hit_blocks in enumerate (request_local_hit_blocks ) if hit_blocks > 0
1276+ ]
1277+
1278+ num_requests_to_schedule = len (requests_to_schedule )
1279+ if num_requests_to_schedule == 0 :
1280+ # nothing to do
1281+ return
1282+
1283+ # Mock no external Cache Hit for this cache-warmup phase
1284+ scheduler .connector .get_num_new_matched_tokens = Mock (name = "method" )
1285+ scheduler .connector .get_num_new_matched_tokens .return_value = (0 , False )
1286+
1287+ # set threshold to 0.0 to ensure all are scheduled
1288+ zero_thresholds : list [Optional [float ]] = [0.0 ] * num_total_requests
1289+
1290+ # Only requests with local hits should run and populate the cache
1291+ # We create all requests to make sure the correct tokens are cached
1292+ # (since the tokens are generated according to request id)
1293+ requests = create_requests (
1294+ num_requests = num_total_requests ,
1295+ num_tokens = [x * BLOCK_SIZE for x in request_local_hit_blocks ],
1296+ block_size = BLOCK_SIZE ,
1297+ cache_hit_thresholds = zero_thresholds ,
1298+ )
1299+
1300+ # Only schedule the request we want to run and populate the cache
1301+ for i in requests_to_schedule :
1302+ scheduler .add_request (requests [i ])
1303+
1304+ scheduler_output = scheduler .schedule ()
1305+
1306+ # verify all were indeed scheduled
1307+ assert len (scheduler_output .scheduled_new_reqs ) == num_requests_to_schedule
1308+
1309+ # iterate until all scheduled requests are done
1310+ model_runner_output = make_output (scheduler )
1311+ scheduler .update_from_output (scheduler_output , model_runner_output )
1312+ _iterate_until_done (scheduler )
1313+ assert_scheduler_empty (scheduler )
1314+
1315+
11461316def make_output (scheduler : Scheduler ):
11471317 return ModelRunnerOutput (
11481318 req_ids = [req .request_id for req in scheduler .running ],
@@ -1215,13 +1385,7 @@ def test_memory_leak():
12151385 model_runner_output = make_output (scheduler )
12161386 scheduler .update_from_output (scheduler_output , model_runner_output )
12171387
1218- # Iterate until done.
1219- while True :
1220- scheduler_output = scheduler .schedule ()
1221- if len (scheduler .running ) == 0 :
1222- break
1223- model_runner_output = make_output (scheduler )
1224- scheduler .update_from_output (scheduler_output , model_runner_output )
1388+ _iterate_until_done (scheduler )
12251389
12261390 # Confirm no memory leak.
12271391 assert_scheduler_empty (scheduler )
0 commit comments