|
34 | 34 | from vllm.tracing import (contains_trace_headers, extract_trace_headers, |
35 | 35 | log_tracing_disabled_warning) |
36 | 36 | from vllm.utils import random_uuid |
37 | | -from llama_tools import preprocess_input, postprocess_output |
| 37 | +from rubra_tools import preprocess_input, postprocess_output |
38 | 38 |
|
39 | 39 | logger = init_logger(__name__) |
40 | 40 |
|
@@ -211,21 +211,18 @@ async def create_chat_completion( |
211 | 211 | try: |
212 | 212 | conversation: List[ConversationMessage] = [] |
213 | 213 | image_futures: List[Awaitable[ImagePixelData]] = [] |
214 | | - print("==================create chat completion====================") |
215 | 214 |
|
216 | 215 | for msg in request.messages: |
217 | 216 | chat_parsed_result = self._parse_chat_message_content(msg) |
218 | 217 |
|
219 | 218 | conversation.extend(chat_parsed_result.messages) |
220 | 219 | image_futures.extend(chat_parsed_result.image_futures) |
221 | | - |
222 | | - |
| 220 | + |
223 | 221 | if request.tools: |
224 | 222 | raw_msgs = request.messages |
225 | 223 | tools = [t.model_dump() for t in request.tools] |
226 | 224 | raw_msgs = preprocess_input(msgs=raw_msgs, tools=tools) |
227 | 225 | conversation = raw_msgs |
228 | | - |
229 | 226 | prompt = self.tokenizer.apply_chat_template( |
230 | 227 | conversation=conversation, |
231 | 228 | tokenize=False, |
@@ -385,82 +382,106 @@ async def chat_completion_stream_generator( |
385 | 382 | yield f"data: {data}\n\n" |
386 | 383 | first_iteration = False |
387 | 384 |
|
| 385 | + is_function_call = False |
| 386 | + checked_function_call = False |
388 | 387 | for output in res.outputs: |
389 | 388 | i = output.index |
390 | 389 |
|
391 | 390 | if finish_reason_sent[i]: |
392 | 391 | continue |
393 | 392 |
|
394 | | - delta_token_ids = output.token_ids[previous_num_tokens[i]:] |
395 | | - out_logprobs = output.logprobs[ |
396 | | - previous_num_tokens[i]:] if output.logprobs else None |
397 | | - |
398 | | - if request.logprobs and request.top_logprobs is not None: |
399 | | - assert out_logprobs is not None, ( |
400 | | - "Did not output logprobs") |
401 | | - logprobs = self._create_chat_logprobs( |
402 | | - token_ids=delta_token_ids, |
403 | | - top_logprobs=out_logprobs, |
404 | | - num_output_top_logprobs=request.top_logprobs, |
405 | | - ) |
406 | | - else: |
407 | | - logprobs = None |
408 | | - |
409 | | - delta_text = output.text[len(previous_texts[i]):] |
410 | | - previous_texts[i] = output.text |
411 | | - previous_num_tokens[i] = len(output.token_ids) |
412 | | - |
413 | | - if request.tool_choice and type( |
414 | | - request.tool_choice |
415 | | - ) is ChatCompletionNamedToolChoiceParam: |
416 | | - delta_message = DeltaMessage(tool_calls=[ |
417 | | - ToolCall(function=FunctionCall( |
418 | | - name=request.tool_choice.function.name, |
419 | | - arguments=delta_text)) |
420 | | - ]) |
421 | | - else: |
422 | | - delta_message = DeltaMessage(content=delta_text) |
423 | | - |
424 | | - if output.finish_reason is None: |
425 | | - # Send token-by-token response for each request.n |
426 | | - |
427 | | - choice_data = ChatCompletionResponseStreamChoice( |
428 | | - index=i, |
429 | | - delta=delta_message, |
430 | | - logprobs=logprobs, |
431 | | - finish_reason=None) |
432 | | - chunk = ChatCompletionStreamResponse( |
433 | | - id=request_id, |
434 | | - object=chunk_object_type, |
435 | | - created=created_time, |
436 | | - choices=[choice_data], |
437 | | - model=model_name) |
438 | | - if (request.stream_options |
439 | | - and request.stream_options.include_usage): |
440 | | - chunk.usage = None |
441 | | - data = chunk.model_dump_json(exclude_unset=True) |
442 | | - yield f"data: {data}\n\n" |
443 | | - else: |
444 | | - # Send the finish response for each request.n only once |
445 | | - prompt_tokens = len(res.prompt_token_ids) |
446 | | - choice_data = ChatCompletionResponseStreamChoice( |
447 | | - index=i, |
448 | | - delta=delta_message, |
449 | | - logprobs=logprobs, |
450 | | - finish_reason=output.finish_reason, |
451 | | - stop_reason=output.stop_reason) |
452 | | - chunk = ChatCompletionStreamResponse( |
453 | | - id=request_id, |
454 | | - object=chunk_object_type, |
455 | | - created=created_time, |
456 | | - choices=[choice_data], |
457 | | - model=model_name) |
458 | | - if (request.stream_options |
459 | | - and request.stream_options.include_usage): |
460 | | - chunk.usage = None |
461 | | - data = chunk.model_dump_json(exclude_unset=True) |
462 | | - yield f"data: {data}\n\n" |
463 | | - finish_reason_sent[i] = True |
| 393 | + if (not checked_function_call and len(output.text)>= 15): |
| 394 | + if "starttoolcall" in output.text: |
| 395 | + is_function_call = True |
| 396 | + checked_function_call = True |
| 397 | + |
| 398 | + if (checked_function_call and not is_function_call) or output.finish_reason is not None: |
| 399 | + |
| 400 | + delta_token_ids = output.token_ids[previous_num_tokens[i]:] |
| 401 | + out_logprobs = output.logprobs[ |
| 402 | + previous_num_tokens[i]:] if output.logprobs else None |
| 403 | + |
| 404 | + if request.logprobs and request.top_logprobs is not None: |
| 405 | + assert out_logprobs is not None, ( |
| 406 | + "Did not output logprobs") |
| 407 | + logprobs = self._create_chat_logprobs( |
| 408 | + token_ids=delta_token_ids, |
| 409 | + top_logprobs=out_logprobs, |
| 410 | + num_output_top_logprobs=request.top_logprobs, |
| 411 | + ) |
| 412 | + else: |
| 413 | + logprobs = None |
| 414 | + |
| 415 | + delta_text = output.text[len(previous_texts[i]):] |
| 416 | + previous_texts[i] = output.text |
| 417 | + previous_num_tokens[i] = len(output.token_ids) |
| 418 | + |
| 419 | + if request.tool_choice and type( |
| 420 | + request.tool_choice |
| 421 | + ) is ChatCompletionNamedToolChoiceParam: |
| 422 | + delta_message = DeltaMessage(tool_calls=[ |
| 423 | + ToolCall(function=FunctionCall( |
| 424 | + name=request.tool_choice.function.name, |
| 425 | + arguments=delta_text)) |
| 426 | + ]) |
| 427 | + else: |
| 428 | + content = delta_text |
| 429 | + function_output = postprocess_output(output_str=content) |
| 430 | + tool_calls = [] |
| 431 | + if function_output: |
| 432 | + try: |
| 433 | + for fc in function_output: |
| 434 | + function = FunctionCall(name=fc["function"]["name"], arguments=fc["function"]["arguments"]) |
| 435 | + call = ToolCall(function=function) |
| 436 | + tool_calls.append(call) |
| 437 | + content = "" |
| 438 | + except Exception as e: |
| 439 | + content = str(function_output) |
| 440 | + print(f"Error extract functions from output: {e}") |
| 441 | + delta_message = DeltaMessage(content=content, tool_calls=tool_calls) |
| 442 | + |
| 443 | + if output.finish_reason is None: |
| 444 | + # Send token-by-token response for each request.n |
| 445 | + |
| 446 | + choice_data = ChatCompletionResponseStreamChoice( |
| 447 | + index=i, |
| 448 | + delta=delta_message, |
| 449 | + logprobs=logprobs, |
| 450 | + finish_reason=None) |
| 451 | + |
| 452 | + chunk = ChatCompletionStreamResponse( |
| 453 | + id=request_id, |
| 454 | + object=chunk_object_type, |
| 455 | + created=created_time, |
| 456 | + choices=[choice_data], |
| 457 | + model=model_name) |
| 458 | + |
| 459 | + if (request.stream_options |
| 460 | + and request.stream_options.include_usage): |
| 461 | + chunk.usage = None |
| 462 | + data = chunk.model_dump_json(exclude_unset=True) |
| 463 | + yield f"data: {data}\n\n" |
| 464 | + else: |
| 465 | + # Send the finish response for each request.n only once |
| 466 | + prompt_tokens = len(res.prompt_token_ids) |
| 467 | + choice_data = ChatCompletionResponseStreamChoice( |
| 468 | + index=i, |
| 469 | + delta=delta_message, |
| 470 | + logprobs=logprobs, |
| 471 | + finish_reason=output.finish_reason, |
| 472 | + stop_reason=output.stop_reason) |
| 473 | + chunk = ChatCompletionStreamResponse( |
| 474 | + id=request_id, |
| 475 | + object=chunk_object_type, |
| 476 | + created=created_time, |
| 477 | + choices=[choice_data], |
| 478 | + model=model_name) |
| 479 | + if (request.stream_options |
| 480 | + and request.stream_options.include_usage): |
| 481 | + chunk.usage = None |
| 482 | + data = chunk.model_dump_json(exclude_unset=False) |
| 483 | + yield f"data: {data}\n\n" |
| 484 | + finish_reason_sent[i] = True |
464 | 485 |
|
465 | 486 | if (request.stream_options |
466 | 487 | and request.stream_options.include_usage): |
@@ -542,7 +563,7 @@ async def chat_completion_full_generator( |
542 | 563 | function_output = postprocess_output(output_str=content) |
543 | 564 | tool_calls = [] |
544 | 565 | if function_output: |
545 | | - print(f"Parsed function output: {function_output}\n\n") |
| 566 | + |
546 | 567 | try: |
547 | 568 | for fc in function_output: |
548 | 569 | function = FunctionCall(name=fc["function"]["name"], arguments=fc["function"]["arguments"]) |
|
0 commit comments