Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

integrate tool calls #213

Merged
merged 5 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/tools/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# tools

This example demonstrates how to utilize tool calls with an asynchronous Ollama client and the chat endpoint.
87 changes: 87 additions & 0 deletions examples/tools/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import json
import ollama
import asyncio


# Simulates an API call to get flight times
# In a real application, this would fetch data from a live database or API
def get_flight_times(departure: str, arrival: str) -> str:
flights = {
'NYC-LAX': {'departure': '08:00 AM', 'arrival': '11:30 AM', 'duration': '5h 30m'},
'LAX-NYC': {'departure': '02:00 PM', 'arrival': '10:30 PM', 'duration': '5h 30m'},
'LHR-JFK': {'departure': '10:00 AM', 'arrival': '01:00 PM', 'duration': '8h 00m'},
'JFK-LHR': {'departure': '09:00 PM', 'arrival': '09:00 AM', 'duration': '7h 00m'},
'CDG-DXB': {'departure': '11:00 AM', 'arrival': '08:00 PM', 'duration': '6h 00m'},
'DXB-CDG': {'departure': '03:00 AM', 'arrival': '07:30 AM', 'duration': '7h 30m'},
}

key = f'{departure}-{arrival}'.upper()
return json.dumps(flights.get(key, {'error': 'Flight not found'}))


async def run(model: str):
client = ollama.AsyncClient()
# Initialize conversation with a user query
messages = [{'role': 'user', 'content': 'What is the flight time from New York (NYC) to Los Angeles (LAX)?'}]

# First API call: Send the query and function description to the model
response = await client.chat(
model=model,
messages=messages,
tools=[
{
'type': 'function',
'function': {
'name': 'get_flight_times',
'description': 'Get the flight times between two cities',
'parameters': {
'type': 'object',
'properties': {
'departure': {
'type': 'string',
'description': 'The departure city (airport code)',
},
'arrival': {
'type': 'string',
'description': 'The arrival city (airport code)',
},
},
'required': ['departure', 'arrival'],
},
},
},
],
)

# Add the model's response to the conversation history
messages.append(response['message'])

# Check if the model decided to use the provided function
if not response['message'].get('tool_calls'):
print("The model didn't use the function. Its response was:")
print(response['message']['content'])
return

# Process function calls made by the model
if response['message'].get('tool_calls'):
available_functions = {
'get_flight_times': get_flight_times,
}
for tool in response['message']['tool_calls']:
function_to_call = available_functions[tool['function']['name']]
function_response = function_to_call(tool['function']['arguments']['departure'], tool['function']['arguments']['arrival'])
# Add function response to the conversation
messages.append(
{
'role': 'tool',
'content': function_response,
}
)

# Second API call: Get final response from the model
final_response = await client.chat(model=model, messages=messages)
print(final_response['message']['content'])


# Run the async function
asyncio.run(run('mistral'))
22 changes: 9 additions & 13 deletions ollama/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
except metadata.PackageNotFoundError:
__version__ = '0.0.0'

from ollama._types import Message, Options, RequestError, ResponseError
from ollama._types import Message, Options, RequestError, ResponseError, Tool


class BaseClient:
Expand Down Expand Up @@ -180,6 +180,7 @@ def chat(
self,
model: str = '',
messages: Optional[Sequence[Message]] = None,
tools: Optional[Sequence[Tool]] = None,
stream: Literal[False] = False,
format: Literal['', 'json'] = '',
options: Optional[Options] = None,
Expand All @@ -191,6 +192,7 @@ def chat(
self,
model: str = '',
messages: Optional[Sequence[Message]] = None,
tools: Optional[Sequence[Tool]] = None,
stream: Literal[True] = True,
format: Literal['', 'json'] = '',
options: Optional[Options] = None,
Expand All @@ -201,6 +203,7 @@ def chat(
self,
model: str = '',
messages: Optional[Sequence[Message]] = None,
tools: Optional[Sequence[Tool]] = None,
stream: bool = False,
format: Literal['', 'json'] = '',
options: Optional[Options] = None,
Expand All @@ -222,12 +225,6 @@ def chat(
messages = deepcopy(messages)

for message in messages or []:
if not isinstance(message, dict):
raise TypeError('messages must be a list of Message or dict-like objects')
if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']:
raise RequestError('messages must contain a role and it must be one of "system", "user", or "assistant"')
if 'content' not in message:
raise RequestError('messages must contain content')
if images := message.get('images'):
message['images'] = [_encode_image(image) for image in images]

Expand All @@ -237,6 +234,7 @@ def chat(
json={
'model': model,
'messages': messages,
'tools': tools or [],
'stream': stream,
'format': format,
'options': options or {},
Expand Down Expand Up @@ -574,6 +572,7 @@ async def chat(
self,
model: str = '',
messages: Optional[Sequence[Message]] = None,
tools: Optional[Sequence[Tool]] = None,
stream: Literal[False] = False,
format: Literal['', 'json'] = '',
options: Optional[Options] = None,
Expand All @@ -585,6 +584,7 @@ async def chat(
self,
model: str = '',
messages: Optional[Sequence[Message]] = None,
tools: Optional[Sequence[Tool]] = None,
stream: Literal[True] = True,
format: Literal['', 'json'] = '',
options: Optional[Options] = None,
Expand All @@ -595,6 +595,7 @@ async def chat(
self,
model: str = '',
messages: Optional[Sequence[Message]] = None,
tools: Optional[Sequence[Tool]] = None,
stream: bool = False,
format: Literal['', 'json'] = '',
options: Optional[Options] = None,
Expand All @@ -615,12 +616,6 @@ async def chat(
messages = deepcopy(messages)

for message in messages or []:
if not isinstance(message, dict):
raise TypeError('messages must be a list of strings')
if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']:
raise RequestError('messages must contain a role and it must be one of "system", "user", or "assistant"')
if 'content' not in message:
raise RequestError('messages must contain content')
if images := message.get('images'):
message['images'] = [_encode_image(image) for image in images]

Expand All @@ -630,6 +625,7 @@ async def chat(
json={
'model': model,
'messages': messages,
'tools': tools or [],
'stream': stream,
'format': format,
'options': options or {},
Expand Down
51 changes: 50 additions & 1 deletion ollama/_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any, TypedDict, Sequence, Literal
from typing import Any, TypedDict, Sequence, Literal, Mapping

import sys

Expand Down Expand Up @@ -53,6 +53,27 @@ class GenerateResponse(BaseGenerateResponse):
'Tokenized history up to the point of the response.'


class ToolCallFunction(TypedDict):
"""
Tool call function.
"""

name: str
'Name of the function.'

args: NotRequired[Mapping[str, Any]]
'Arguments of the function.'


class ToolCall(TypedDict):
"""
Model tool calls.
"""

function: ToolCallFunction
'Function to be called.'


class Message(TypedDict):
"""
Chat message.
Expand All @@ -76,6 +97,34 @@ class Message(TypedDict):
Valid image formats depend on the model. See the model card for more information.
"""

tool_calls: NotRequired[Sequence[ToolCall]]
"""
Tools calls to be made by the model.
"""


class Property(TypedDict):
type: str
description: str
enum: NotRequired[Sequence[str]] # `enum` is optional and can be a list of strings


class Parameters(TypedDict):
type: str
required: Sequence[str]
properties: Mapping[str, Property]


class ToolFunction(TypedDict):
name: str
description: str
parameters: Parameters


class Tool(TypedDict):
type: str
function: ToolFunction


class ChatResponse(BaseGenerateResponse):
"""
Expand Down
6 changes: 6 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def test_client_chat(httpserver: HTTPServer):
json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'tools': [],
'stream': False,
'format': '',
'options': {},
Expand Down Expand Up @@ -73,6 +74,7 @@ def generate():
json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'tools': [],
'stream': True,
'format': '',
'options': {},
Expand Down Expand Up @@ -102,6 +104,7 @@ def test_client_chat_images(httpserver: HTTPServer):
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
},
],
'tools': [],
'stream': False,
'format': '',
'options': {},
Expand Down Expand Up @@ -522,6 +525,7 @@ async def test_async_client_chat(httpserver: HTTPServer):
json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'tools': [],
'stream': False,
'format': '',
'options': {},
Expand Down Expand Up @@ -560,6 +564,7 @@ def generate():
json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'tools': [],
'stream': True,
'format': '',
'options': {},
Expand Down Expand Up @@ -590,6 +595,7 @@ async def test_async_client_chat_images(httpserver: HTTPServer):
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
},
],
'tools': [],
'stream': False,
'format': '',
'options': {},
Expand Down