Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Ollama streamed tool calls. Set finish_reason to tool_calls for all tool_calls responses #3469

Merged

Conversation

jackmpcollins
Copy link
Contributor

@jackmpcollins jackmpcollins commented May 6, 2024

Title

Fix Ollama streamed tool calls. Set finish_reason to tool_calls for all tool_calls responses

Relevant issues

Fixes #3333
Fixes #2209
Closes #2597

Type

🐛 Bug Fix

Changes

  • Parse tool calls for streamed ollama/ and ollama_chat/ responses
  • Set finish_reason to "tool_calls" for all ollama/ and ollama_chat/ responses
  • Use data instead of optional_params so that JSON response format is correctly detected in ollama completion.

For the streamed tool calls the entire response is joined before parsing, so only one chunk/delta is yielded. This makes it much easier to parse out the function name and arguments. In future it would be nice to update this to stream the tool calls the same way openai does.

Testing

Adapted the test cases from #3333

import litellm

messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}]
tools = [
    {
        "type": "function",
        "function": {
            "name": "get_current_weather",
            "description": "Get the current weather in a given location",
            "parameters": {
                "type": "object",
                "properties": {
                    "location": {
                        "type": "string",
                        "description": "The city and state, e.g. San Francisco, CA",
                    },
                    "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
                },
                "required": ["location"],
            },
        },
    }
]


print("\n---\n")
response = litellm.completion(model="ollama/llama2", messages=messages, tools=tools)
print(response)

print("\n---\n")
response = litellm.completion(model="ollama_chat/llama2", messages=messages, tools=tools)
print(response)

print("\n---\n")
aresponse = await litellm.acompletion(model="ollama/llama2", messages=messages, tools=tools)
print(aresponse)

print("\n---\n")
aresponse = await litellm.acompletion(model="ollama_chat/llama2", messages=messages, tools=tools)
print(aresponse)


print("\n---\n")
response = litellm.completion(model="ollama/llama2", messages=messages, tools=tools, stream=True)
for chunk in response:
    print(chunk)

print("\n---\n")
response = litellm.completion(model="ollama_chat/llama2", messages=messages, tools=tools, stream=True)
for chunk in response:
    print(chunk)

print("\n---\n")
aresponse = await litellm.acompletion(model="ollama/llama2", messages=messages, tools=tools, stream=True)
async for chunk in aresponse:
    print(chunk)

print("\n---\n")
aresponse = await litellm.acompletion(model="ollama_chat/llama2", messages=messages, tools=tools, stream=True)
async for chunk in aresponse:
    print(chunk)

which output


---

ModelResponse(id='chatcmpl-853e92fb-3da0-4852-a502-ce412381cd2a', choices=[Choices(finish_reason='tool_calls', index=0, message=Message(content=None, role='assistant', tool_calls=[ChatCompletionMessageToolCall(function=Function(arguments='{"location": "San Francisco, CA", "unit": "celsius"}', name='get_current_weather'), id='call_f65b5464-9101-4bcf-a115-8ad7a263f069', type='function')]))], created=1714961534, model='ollama/llama2', object='chat.completion', system_fingerprint=None, usage=Usage(prompt_tokens=150, completion_tokens=42, total_tokens=192))

---

ModelResponse(id='chatcmpl-f07ed597-34d3-43e8-962c-2bc90582ed4e', choices=[Choices(finish_reason='tool_calls', index=0, message=Message(content=None, role='assistant', tool_calls=[ChatCompletionMessageToolCall(function=Function(arguments='{"location": "San Francisco, CA"}', name='get_current_weather'), id='call_049f2949-c57e-44bf-a0bb-921fa652c56b', type='function')]))], created=1714961541, model='ollama/llama2', object='chat.completion', system_fingerprint=None, usage=Usage(prompt_tokens=314, completion_tokens=33, total_tokens=347))

---

ModelResponse(id='chatcmpl-bbcefe79-bb4e-4e1b-8037-b179dea6caf2', choices=[Choices(finish_reason='tool_calls', index=0, message=Message(content=None, role='assistant', tool_calls=[ChatCompletionMessageToolCall(function=Function(arguments='{"location": "San Francisco, CA"}', name='get_current_weather'), id='call_48665cb1-2a2d-46ec-a2d5-1e9061f4d5b2', type='function')]))], created=1714961547, model='ollama/llama2', object='chat.completion', system_fingerprint=None, usage=Usage(prompt_tokens=440, completion_tokens=33, total_tokens=473))

---

ModelResponse(id='chatcmpl-8b5ad4bd-c748-4b38-b3ea-993b6b43de84', choices=[Choices(finish_reason='tool_calls', index=0, message=Message(content=None, role='assistant', tool_calls=[ChatCompletionMessageToolCall(function=Function(arguments='{"location": "San Francisco, CA"}', name='get_current_weather'), id='call_b289a5ef-4fc7-4a21-b823-f84b8ce11f75', type='function')]))], created=1714961557, model='ollama_chat/llama2', object='chat.completion', system_fingerprint=None, usage=Usage(prompt_tokens=604, completion_tokens=24, total_tokens=628))

---

ModelResponse(id='chatcmpl-16f7e445-1172-4b6a-9216-cebe9c26264e', choices=[StreamingChoices(finish_reason='tool_calls', index=0, delta=Delta(content=None, role=None, function_call=None, tool_calls=[ChatCompletionDeltaToolCall(id='call_a0c3083f-e257-4125-9fa2-73f8d09421c1', function=Function(arguments='{"location": "San Francisco, CA"}', name='get_current_weather'), type='function', index=0)]), logprobs=None)], created=1714961563, model='llama2', object='chat.completion.chunk', system_fingerprint=None)

---

ModelResponse(id='chatcmpl-497bb5e8-e950-4d7f-985e-ab25f98a4d43', choices=[StreamingChoices(finish_reason='tool_calls', index=0, delta=Delta(content=None, role=None, function_call=None, tool_calls=[ChatCompletionDeltaToolCall(id='call_b65f37f4-7c9e-488f-a80e-04147e5e4d47', function=Function(arguments='{"location": "San Francisco, CA"}', name='get_current_weather'), type='function', index=0)]), logprobs=None)], created=1714961573, model='llama2', object='chat.completion.chunk', system_fingerprint=None)

---

ModelResponse(id='chatcmpl-8c36e05d-b5b1-4522-a0fd-8658ba20305a', choices=[StreamingChoices(finish_reason='tool_calls', index=0, delta=Delta(content=None, role=None, function_call=None, tool_calls=[ChatCompletionDeltaToolCall(id='call_69d13fbf-60f7-469c-9015-af4e4aa6b1c5', function=Function(arguments='{"location": "San Francisco, CA"}', name='get_current_weather'), type='function', index=0)]), logprobs=None)], created=1714961586, model='llama2', object='chat.completion.chunk', system_fingerprint=None)

---

ModelResponse(id='chatcmpl-22648686-8618-463a-b1e5-ba23c34ea275', choices=[StreamingChoices(finish_reason='tool_calls', index=0, delta=Delta(content=None, role=None, function_call=None, tool_calls=[ChatCompletionDeltaToolCall(id='call_85bf2a5c-5c18-4d24-98d5-89a82d122d0c', function=Function(arguments='{"location": "San Francisco, CA"}', name='get_current_weather'), type='function', index=0)]), logprobs=None)], created=1714961598, model='llama2', object='chat.completion.chunk', system_fingerprint=None)

Pre-Submission Checklist (optional but appreciated):

  • I have included relevant documentation updates (stored in /docs/my-website)

OS Tests (optional but appreciated):

  • Tested on Windows
  • Tested on MacOS
  • Tested on Linux

Copy link

vercel bot commented May 6, 2024

The latest updates on your projects. Learn more about Vercel for Git ↗︎

Name Status Preview Comments Updated (UTC)
litellm ✅ Ready (Inspect) Visit Preview 💬 Add feedback May 6, 2024 7:17am

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 tests added here for (ollama, ollama_chat) x (stream False, True) x (def, async def). All passing locally for me (mac, ollama running in background) - I'm not sure how your automated testing will handle this.

Command to run just these

poetry run pytest -vv test_completion.py::test_completion_ollama_function_call test_completion.py::test_completion_ollama_function_call_stream test_completion.py::test_acompletion_ollama_function_call test_completion.py::test_acompletion_ollama_function_call_stream

and output for me showing all 8 passing
Screenshot 2024-05-06 at 12 12 57 AM

@krrishdholakia krrishdholakia merged commit 5f119f2 into BerriAI:main May 6, 2024
2 checks passed
@jackmpcollins jackmpcollins deleted the fix-ollama-streamed-tool-calls branch May 6, 2024 17:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants