|
34 | 34 | from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
35 | 35 | log_tracing_disabled_warning)
|
36 | 36 | from vllm.utils import random_uuid
|
37 |
| -from llama_tools import preprocess_input, postprocess_output |
| 37 | +from rubra_tools import preprocess_input, postprocess_output |
38 | 38 |
|
39 | 39 | logger = init_logger(__name__)
|
40 | 40 |
|
@@ -211,21 +211,18 @@ async def create_chat_completion(
|
211 | 211 | try:
|
212 | 212 | conversation: List[ConversationMessage] = []
|
213 | 213 | image_futures: List[Awaitable[ImagePixelData]] = []
|
214 |
| - print("==================create chat completion====================") |
215 | 214 |
|
216 | 215 | for msg in request.messages:
|
217 | 216 | chat_parsed_result = self._parse_chat_message_content(msg)
|
218 | 217 |
|
219 | 218 | conversation.extend(chat_parsed_result.messages)
|
220 | 219 | image_futures.extend(chat_parsed_result.image_futures)
|
221 |
| - |
222 |
| - |
| 220 | + |
223 | 221 | if request.tools:
|
224 | 222 | raw_msgs = request.messages
|
225 | 223 | tools = [t.model_dump() for t in request.tools]
|
226 | 224 | raw_msgs = preprocess_input(msgs=raw_msgs, tools=tools)
|
227 | 225 | conversation = raw_msgs
|
228 |
| - |
229 | 226 | prompt = self.tokenizer.apply_chat_template(
|
230 | 227 | conversation=conversation,
|
231 | 228 | tokenize=False,
|
@@ -385,82 +382,106 @@ async def chat_completion_stream_generator(
|
385 | 382 | yield f"data: {data}\n\n"
|
386 | 383 | first_iteration = False
|
387 | 384 |
|
| 385 | + is_function_call = False |
| 386 | + checked_function_call = False |
388 | 387 | for output in res.outputs:
|
389 | 388 | i = output.index
|
390 | 389 |
|
391 | 390 | if finish_reason_sent[i]:
|
392 | 391 | continue
|
393 | 392 |
|
394 |
| - delta_token_ids = output.token_ids[previous_num_tokens[i]:] |
395 |
| - out_logprobs = output.logprobs[ |
396 |
| - previous_num_tokens[i]:] if output.logprobs else None |
397 |
| - |
398 |
| - if request.logprobs and request.top_logprobs is not None: |
399 |
| - assert out_logprobs is not None, ( |
400 |
| - "Did not output logprobs") |
401 |
| - logprobs = self._create_chat_logprobs( |
402 |
| - token_ids=delta_token_ids, |
403 |
| - top_logprobs=out_logprobs, |
404 |
| - num_output_top_logprobs=request.top_logprobs, |
405 |
| - ) |
406 |
| - else: |
407 |
| - logprobs = None |
408 |
| - |
409 |
| - delta_text = output.text[len(previous_texts[i]):] |
410 |
| - previous_texts[i] = output.text |
411 |
| - previous_num_tokens[i] = len(output.token_ids) |
412 |
| - |
413 |
| - if request.tool_choice and type( |
414 |
| - request.tool_choice |
415 |
| - ) is ChatCompletionNamedToolChoiceParam: |
416 |
| - delta_message = DeltaMessage(tool_calls=[ |
417 |
| - ToolCall(function=FunctionCall( |
418 |
| - name=request.tool_choice.function.name, |
419 |
| - arguments=delta_text)) |
420 |
| - ]) |
421 |
| - else: |
422 |
| - delta_message = DeltaMessage(content=delta_text) |
423 |
| - |
424 |
| - if output.finish_reason is None: |
425 |
| - # Send token-by-token response for each request.n |
426 |
| - |
427 |
| - choice_data = ChatCompletionResponseStreamChoice( |
428 |
| - index=i, |
429 |
| - delta=delta_message, |
430 |
| - logprobs=logprobs, |
431 |
| - finish_reason=None) |
432 |
| - chunk = ChatCompletionStreamResponse( |
433 |
| - id=request_id, |
434 |
| - object=chunk_object_type, |
435 |
| - created=created_time, |
436 |
| - choices=[choice_data], |
437 |
| - model=model_name) |
438 |
| - if (request.stream_options |
439 |
| - and request.stream_options.include_usage): |
440 |
| - chunk.usage = None |
441 |
| - data = chunk.model_dump_json(exclude_unset=True) |
442 |
| - yield f"data: {data}\n\n" |
443 |
| - else: |
444 |
| - # Send the finish response for each request.n only once |
445 |
| - prompt_tokens = len(res.prompt_token_ids) |
446 |
| - choice_data = ChatCompletionResponseStreamChoice( |
447 |
| - index=i, |
448 |
| - delta=delta_message, |
449 |
| - logprobs=logprobs, |
450 |
| - finish_reason=output.finish_reason, |
451 |
| - stop_reason=output.stop_reason) |
452 |
| - chunk = ChatCompletionStreamResponse( |
453 |
| - id=request_id, |
454 |
| - object=chunk_object_type, |
455 |
| - created=created_time, |
456 |
| - choices=[choice_data], |
457 |
| - model=model_name) |
458 |
| - if (request.stream_options |
459 |
| - and request.stream_options.include_usage): |
460 |
| - chunk.usage = None |
461 |
| - data = chunk.model_dump_json(exclude_unset=True) |
462 |
| - yield f"data: {data}\n\n" |
463 |
| - finish_reason_sent[i] = True |
| 393 | + if (not checked_function_call and len(output.text)>= 15): |
| 394 | + if "starttoolcall" in output.text: |
| 395 | + is_function_call = True |
| 396 | + checked_function_call = True |
| 397 | + |
| 398 | + if (checked_function_call and not is_function_call) or output.finish_reason is not None: |
| 399 | + |
| 400 | + delta_token_ids = output.token_ids[previous_num_tokens[i]:] |
| 401 | + out_logprobs = output.logprobs[ |
| 402 | + previous_num_tokens[i]:] if output.logprobs else None |
| 403 | + |
| 404 | + if request.logprobs and request.top_logprobs is not None: |
| 405 | + assert out_logprobs is not None, ( |
| 406 | + "Did not output logprobs") |
| 407 | + logprobs = self._create_chat_logprobs( |
| 408 | + token_ids=delta_token_ids, |
| 409 | + top_logprobs=out_logprobs, |
| 410 | + num_output_top_logprobs=request.top_logprobs, |
| 411 | + ) |
| 412 | + else: |
| 413 | + logprobs = None |
| 414 | + |
| 415 | + delta_text = output.text[len(previous_texts[i]):] |
| 416 | + previous_texts[i] = output.text |
| 417 | + previous_num_tokens[i] = len(output.token_ids) |
| 418 | + |
| 419 | + if request.tool_choice and type( |
| 420 | + request.tool_choice |
| 421 | + ) is ChatCompletionNamedToolChoiceParam: |
| 422 | + delta_message = DeltaMessage(tool_calls=[ |
| 423 | + ToolCall(function=FunctionCall( |
| 424 | + name=request.tool_choice.function.name, |
| 425 | + arguments=delta_text)) |
| 426 | + ]) |
| 427 | + else: |
| 428 | + content = delta_text |
| 429 | + function_output = postprocess_output(output_str=content) |
| 430 | + tool_calls = [] |
| 431 | + if function_output: |
| 432 | + try: |
| 433 | + for fc in function_output: |
| 434 | + function = FunctionCall(name=fc["function"]["name"], arguments=fc["function"]["arguments"]) |
| 435 | + call = ToolCall(function=function) |
| 436 | + tool_calls.append(call) |
| 437 | + content = "" |
| 438 | + except Exception as e: |
| 439 | + content = str(function_output) |
| 440 | + print(f"Error extract functions from output: {e}") |
| 441 | + delta_message = DeltaMessage(content=content, tool_calls=tool_calls) |
| 442 | + |
| 443 | + if output.finish_reason is None: |
| 444 | + # Send token-by-token response for each request.n |
| 445 | + |
| 446 | + choice_data = ChatCompletionResponseStreamChoice( |
| 447 | + index=i, |
| 448 | + delta=delta_message, |
| 449 | + logprobs=logprobs, |
| 450 | + finish_reason=None) |
| 451 | + |
| 452 | + chunk = ChatCompletionStreamResponse( |
| 453 | + id=request_id, |
| 454 | + object=chunk_object_type, |
| 455 | + created=created_time, |
| 456 | + choices=[choice_data], |
| 457 | + model=model_name) |
| 458 | + |
| 459 | + if (request.stream_options |
| 460 | + and request.stream_options.include_usage): |
| 461 | + chunk.usage = None |
| 462 | + data = chunk.model_dump_json(exclude_unset=True) |
| 463 | + yield f"data: {data}\n\n" |
| 464 | + else: |
| 465 | + # Send the finish response for each request.n only once |
| 466 | + prompt_tokens = len(res.prompt_token_ids) |
| 467 | + choice_data = ChatCompletionResponseStreamChoice( |
| 468 | + index=i, |
| 469 | + delta=delta_message, |
| 470 | + logprobs=logprobs, |
| 471 | + finish_reason=output.finish_reason, |
| 472 | + stop_reason=output.stop_reason) |
| 473 | + chunk = ChatCompletionStreamResponse( |
| 474 | + id=request_id, |
| 475 | + object=chunk_object_type, |
| 476 | + created=created_time, |
| 477 | + choices=[choice_data], |
| 478 | + model=model_name) |
| 479 | + if (request.stream_options |
| 480 | + and request.stream_options.include_usage): |
| 481 | + chunk.usage = None |
| 482 | + data = chunk.model_dump_json(exclude_unset=False) |
| 483 | + yield f"data: {data}\n\n" |
| 484 | + finish_reason_sent[i] = True |
464 | 485 |
|
465 | 486 | if (request.stream_options
|
466 | 487 | and request.stream_options.include_usage):
|
@@ -542,7 +563,7 @@ async def chat_completion_full_generator(
|
542 | 563 | function_output = postprocess_output(output_str=content)
|
543 | 564 | tool_calls = []
|
544 | 565 | if function_output:
|
545 |
| - print(f"Parsed function output: {function_output}\n\n") |
| 566 | + |
546 | 567 | try:
|
547 | 568 | for fc in function_output:
|
548 | 569 | function = FunctionCall(name=fc["function"]["name"], arguments=fc["function"]["arguments"])
|
|
0 commit comments