Skip to content

Commit c087238

Browse files
author
Kfir Wolfson
committed
add threshold to request logger and fix some calls to encode
Signed-off-by: Kfir Wolfson <kfirw@pliops.com>
1 parent 521b173 commit c087238

File tree

5 files changed

+38
-22
lines changed

5 files changed

+38
-22
lines changed

vllm/entrypoints/logger.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def log_inputs(
2626
prompt_embeds: Optional[torch.Tensor],
2727
params: Optional[Union[SamplingParams, PoolingParams, BeamSearchParams]],
2828
lora_request: Optional[LoRARequest],
29+
cache_hit_threshold: Optional[float],
2930
) -> None:
3031
max_log_len = self.max_log_len
3132
if max_log_len is not None:
@@ -46,6 +47,7 @@ def log_inputs(
4647
prompt_token_ids,
4748
prompt_embeds.shape if prompt_embeds is not None else None,
4849
lora_request,
50+
cache_hit_threshold,
4951
)
5052

5153
def log_outputs(

vllm/entrypoints/openai/serving_chat.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,11 +324,14 @@ async def create_chat_completion(
324324
self.default_sampling_params,
325325
)
326326

327+
cache_hit_threshold = request.cache_hit_threshold
328+
327329
self._log_inputs(
328330
request_id,
329331
request_prompts[i],
330332
params=sampling_params,
331333
lora_request=lora_request,
334+
cache_hit_threshold=cache_hit_threshold,
332335
)
333336

334337
trace_headers = (
@@ -352,6 +355,7 @@ async def create_chat_completion(
352355
lora_request=lora_request,
353356
trace_headers=trace_headers,
354357
priority=request.priority,
358+
cache_hit_threshold=request.cache_hit_threshold,
355359
)
356360

357361
generator = self.engine_client.generate(

vllm/entrypoints/openai/serving_completion.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,12 +182,14 @@ async def create_completion(
182182
)
183183

184184
request_id_item = f"{request_id}-{i}"
185+
cache_hit_threshold = request.cache_hit_threshold
185186

186187
self._log_inputs(
187188
request_id_item,
188189
engine_prompt,
189190
params=sampling_params,
190191
lora_request=lora_request,
192+
cache_hit_threshold=cache_hit_threshold,
191193
)
192194

193195
trace_headers = (
@@ -215,6 +217,7 @@ async def create_completion(
215217
lora_request=lora_request,
216218
trace_headers=trace_headers,
217219
priority=request.priority,
220+
cache_hit_threshold=cache_hit_threshold,
218221
)
219222

220223
generator = self.engine_client.generate(
@@ -224,10 +227,7 @@ async def create_completion(
224227
lora_request=lora_request,
225228
trace_headers=trace_headers,
226229
priority=request.priority,
227-
prompt_text=prompt_text,
228-
tokenization_kwargs=tokenization_kwargs,
229-
cache_hit_threshold=request.cache_hit_threshold
230-
)
230+
cache_hit_threshold=request.cache_hit_threshold)
231231

232232
generators.append(generator)
233233
except ValueError as e:

vllm/entrypoints/openai/serving_embedding.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,12 +205,15 @@ async def _process_chunked_request(
205205
prompt=chunk_text, prompt_token_ids=chunk_tokens
206206
)
207207

208+
cache_hit_threshold = getattr(ctx.request, "cache_hit_threshold",
209+
None)
208210
# Log the chunk
209211
self._log_inputs(
210212
chunk_request_id,
211213
chunk_request_prompt,
212214
params=pooling_params,
213215
lora_request=ctx.lora_request,
216+
cache_hit_threshold=cache_hit_threshold,
214217
)
215218

216219
# Create generator for this chunk and wrap it to return indices
@@ -221,6 +224,7 @@ async def _process_chunked_request(
221224
lora_request=ctx.lora_request,
222225
trace_headers=trace_headers,
223226
priority=getattr(ctx.request, "priority", 0),
227+
cache_hit_threshold=cache_hit_threshold,
224228
)
225229

226230
generators.append(original_generator)
@@ -320,12 +324,14 @@ async def _create_single_prompt_generator(
320324
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
321325
"""Create a generator for a single prompt using standard processing."""
322326
request_id_item = f"{ctx.request_id}-{prompt_index}"
327+
cache_hit_threshold = getattr(ctx.request, "cache_hit_threshold", None)
323328

324329
self._log_inputs(
325330
request_id_item,
326331
engine_prompt,
327332
params=pooling_params,
328333
lora_request=ctx.lora_request,
334+
cache_hit_threshold=cache_hit_threshold,
329335
)
330336

331337
# Return the original generator without wrapping
@@ -336,7 +342,7 @@ async def _create_single_prompt_generator(
336342
lora_request=ctx.lora_request,
337343
trace_headers=trace_headers,
338344
priority=getattr(ctx.request, "priority", 0),
339-
)
345+
cache_hit_threshold=cache_hit_threshold)
340346

341347
@override
342348
async def _prepare_generators(

vllm/entrypoints/openai/serving_engine.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -420,13 +420,13 @@ async def _prepare_generators(
420420

421421
for i, engine_prompt in enumerate(ctx.engine_prompts):
422422
request_id_item = f"{ctx.request_id}-{i}"
423-
424-
self._log_inputs(
425-
request_id_item,
426-
engine_prompt,
427-
params=pooling_params,
428-
lora_request=ctx.lora_request,
429-
)
423+
cache_hit_threshold = getattr(ctx.request,
424+
"cache_hit_threshold", None)
425+
self._log_inputs(request_id_item,
426+
engine_prompt,
427+
params=pooling_params,
428+
lora_request=ctx.lora_request,
429+
cache_hit_threshold=cache_hit_threshold)
430430

431431
generator = self.engine_client.encode(
432432
engine_prompt,
@@ -435,7 +435,7 @@ async def _prepare_generators(
435435
lora_request=ctx.lora_request,
436436
trace_headers=trace_headers,
437437
priority=getattr(ctx.request, "priority", 0),
438-
)
438+
cache_hit_threshold=cache_hit_threshold)
439439

440440
generators.append(generator)
441441

@@ -935,6 +935,7 @@ async def _process_inputs(
935935
lora_request: Optional[LoRARequest],
936936
trace_headers: Optional[Mapping[str, str]],
937937
priority: int,
938+
cache_hit_threshold: Optional[float] = None,
938939
) -> tuple[EngineCoreRequest, dict[str, Any]]:
939940
"""Use the Processor to process inputs for AsyncLLM."""
940941
tokenization_kwargs: dict[str, Any] = {}
@@ -951,7 +952,7 @@ async def _process_inputs(
951952
tokenization_kwargs=tokenization_kwargs,
952953
trace_headers=trace_headers,
953954
priority=priority,
954-
)
955+
cache_hit_threshold=cache_hit_threshold)
955956
return engine_request, tokenization_kwargs
956957

957958
async def _generate_with_builtin_tools(
@@ -968,11 +969,13 @@ async def _generate_with_builtin_tools(
968969
prompt_text, _, _ = self._get_prompt_components(request_prompt)
969970
orig_priority = priority
970971
while True:
972+
cache_hit_threshold = kwargs.get("cache_hit_threshold")
971973
self._log_inputs(
972974
request_id,
973975
request_prompt,
974976
params=sampling_params,
975977
lora_request=lora_request,
978+
cache_hit_threshold=cache_hit_threshold,
976979
)
977980
trace_headers = kwargs.get("trace_headers")
978981
engine_request, tokenization_kwargs = await self._process_inputs(
@@ -982,6 +985,7 @@ async def _generate_with_builtin_tools(
982985
lora_request=lora_request,
983986
trace_headers=trace_headers,
984987
priority=priority,
988+
cache_hit_threshold=cache_hit_threshold,
985989
)
986990

987991
generator = self.engine_client.generate(
@@ -1036,20 +1040,20 @@ def _log_inputs(
10361040
inputs: Union[RequestPrompt, PromptType],
10371041
params: Optional[Union[SamplingParams, PoolingParams, BeamSearchParams]],
10381042
lora_request: Optional[LoRARequest],
1043+
cache_hit_threshold: Optional[float] = None,
10391044
) -> None:
10401045
if self.request_logger is None:
10411046
return
10421047

10431048
prompt, prompt_token_ids, prompt_embeds = self._get_prompt_components(inputs)
10441049

1045-
self.request_logger.log_inputs(
1046-
request_id,
1047-
prompt,
1048-
prompt_token_ids,
1049-
prompt_embeds,
1050-
params=params,
1051-
lora_request=lora_request,
1052-
)
1050+
self.request_logger.log_inputs(request_id,
1051+
prompt,
1052+
prompt_token_ids,
1053+
prompt_embeds,
1054+
params=params,
1055+
lora_request=lora_request,
1056+
cache_hit_threshold=cache_hit_threshold)
10531057

10541058
async def _get_trace_headers(
10551059
self,

0 commit comments

Comments
 (0)