Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ollama support #162

Merged
merged 26 commits into from
Dec 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions docs/api/models/ollama.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# `pydantic_ai.models.ollama`

## Setup

For details on how to set up authentication with this model, see [model configuration for Ollama](../../install.md#ollama).

## Example usage

With `ollama` installed, you can run the server with the model you want to use:

```bash title="terminal-run-ollama"
ollama run llama3.2
```
(this will pull the `llama3.2` model if you don't already have it downloaded)

Then run your code, here's a minimal example:

```py title="ollama_example.py"
from pydantic import BaseModel

from pydantic_ai import Agent


class CityLocation(BaseModel):
city: str
country: str


agent = Agent('ollama:llama3.2', result_type=CityLocation)

result = agent.run_sync('Where the olympics held in 2012?')
print(result.data)
#> city='London' country='United Kingdom'
print(result.cost())
#> Cost(request_tokens=56, response_tokens=8, total_tokens=64, details=None)
```

See [`OllamaModel`][pydantic_ai.models.ollama.OllamaModel] for more information

::: pydantic_ai.models.ollama
6 changes: 6 additions & 0 deletions docs/install.md
Original file line number Diff line number Diff line change
Expand Up @@ -323,3 +323,9 @@ model = GroqModel('llama-3.1-70b-versatile', api_key='your-api-key')
agent = Agent(model)
...
```

### Ollama

To use [Ollama](https://ollama.com/), you must first download the Ollama client, and then download a model.

You must also ensure the Ollama server is running when trying to make requests to it. For more information, please see the [Ollama documentation](https://github.com/ollama/ollama/tree/main/docs)
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ nav:
- api/exceptions.md
- api/models/base.md
- api/models/openai.md
- api/models/ollama.md
- api/models/gemini.md
- api/models/vertexai.md
- api/models/groq.md
Expand Down
23 changes: 22 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,23 @@
'gemini-1.5-pro',
'vertexai:gemini-1.5-flash',
'vertexai:gemini-1.5-pro',
'ollama:codellama',
'ollama:gemma',
'ollama:gemma2',
'ollama:llama3',
'ollama:llama3.1',
'ollama:llama3.2',
'ollama:llama3.2-vision',
'ollama:llama3.3',
'ollama:mistral',
'ollama:mistral-nemo',
'ollama:mixtral',
'ollama:phi3',
'ollama:qwq',
'ollama:qwen',
'ollama:qwen2',
'ollama:qwen2.5',
'ollama:starcoder2',
'test',
]
"""Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
Expand Down Expand Up @@ -239,7 +256,7 @@ def infer_model(model: Model | KnownModelName) -> Model:
elif model.startswith('openai:'):
from .openai import OpenAIModel

return OpenAIModel(model[7:]) # pyright: ignore[reportArgumentType]
return OpenAIModel(model[7:])
elif model.startswith('gemini'):
from .gemini import GeminiModel

Expand All @@ -253,6 +270,10 @@ def infer_model(model: Model | KnownModelName) -> Model:
from .vertexai import VertexAIModel

return VertexAIModel(model[9:]) # pyright: ignore[reportArgumentType]
elif model.startswith('ollama:'):
from .ollama import OllamaModel

return OllamaModel(model[7:])
else:
raise UserError(f'Unknown model: {model}')

Expand Down
118 changes: 118 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from __future__ import annotations as _annotations

from dataclasses import dataclass
from typing import Literal, Union

from httpx import AsyncClient as AsyncHTTPClient

from ..tools import ToolDefinition
from . import (
AgentModel,
Model,
cached_async_http_client,
)

try:
from openai import AsyncOpenAI
except ImportError as e:
raise ImportError(
'Please install `openai` to use the OpenAI model, '
"you can use the `openai` optional group — `pip install 'pydantic-ai[openai]'`"
) from e


from .openai import OpenAIModel

CommonOllamaModelNames = Literal[
'codellama',
'gemma',
'gemma2',
'llama3',
'llama3.1',
'llama3.2',
'llama3.2-vision',
'llama3.3',
'mistral',
'mistral-nemo',
'mixtral',
'phi3',
'qwq',
'qwen',
'qwen2',
'qwen2.5',
'starcoder2',
]
"""This contains just the most common ollama models.

For a full list see [ollama.com/library](https://ollama.com/library).
"""
OllamaModelName = Union[CommonOllamaModelNames, str]
"""Possible ollama models.

Since Ollama supports hundreds of models, we explicitly list the most models but
allow any name in the type hints.
"""


@dataclass(init=False)
class OllamaModel(Model):
"""A model that implements Ollama using the OpenAI API.

Internally, this uses the [OpenAI Python client](https://github.com/openai/openai-python) to interact with the Ollama server.

Apart from `__init__`, all methods are private or match those of the base class.
"""

model_name: OllamaModelName
openai_model: OpenAIModel

def __init__(
self,
model_name: OllamaModelName,
*,
base_url: str | None = 'http://localhost:11434/v1/',
openai_client: AsyncOpenAI | None = None,
http_client: AsyncHTTPClient | None = None,
):
"""Initialize an Ollama model.

Ollama has built-in compatability for the OpenAI chat completions API ([source](https://ollama.com/blog/openai-compatibility)), so we reuse the
[`OpenAIModel`][pydantic_ai.models.openai.OpenAIModel] here.

Args:
model_name: The name of the Ollama model to use. List of models available [here](https://ollama.com/library)
You must first download the model (`ollama pull <MODEL-NAME>`) in order to use the model
base_url: The base url for the ollama requests. The default value is the ollama default
openai_client: An existing
[`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage)
client to use, if provided, `base_url` and `http_client` must be `None`.
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
"""
self.model_name = model_name
if openai_client is not None:
assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
self.openai_model = OpenAIModel(model_name=model_name, openai_client=openai_client, http_client=http_client)
elif http_client is not None:
# API key is not required for ollama but a value is required to create the client
oai_client = AsyncOpenAI(base_url=base_url, api_key='ollama', http_client=http_client)
self.openai_model = OpenAIModel(model_name=model_name, openai_client=oai_client, http_client=http_client)
else:
# API key is not required for ollama but a value is required to create the client
oai_client = AsyncOpenAI(base_url=base_url, api_key='ollama', http_client=cached_async_http_client())
self.openai_model = OpenAIModel(model_name=model_name, openai_client=oai_client, http_client=http_client)

async def agent_model(
self,
*,
function_tools: list[ToolDefinition],
allow_text_result: bool,
result_tools: list[ToolDefinition],
) -> AgentModel:
return await self.openai_model.agent_model(
function_tools=function_tools,
allow_text_result=allow_text_result,
result_tools=result_tools,
)

def name(self) -> str:
return f'ollama:{self.model_name}'
16 changes: 11 additions & 5 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Literal, overload
from typing import Literal, Union, overload

from httpx import AsyncClient as AsyncHTTPClient
from typing_extensions import assert_never
Expand Down Expand Up @@ -43,6 +43,12 @@
"you can use the `openai` optional group — `pip install 'pydantic-ai[openai]'`"
) from e

OpenAIModelName = Union[ChatModel, str]
"""
Using this more broad type for the model name instead of the ChatModel definition
allows this model to be used more easily with other model types (ie, Ollama)
"""


@dataclass(init=False)
class OpenAIModel(Model):
Expand All @@ -53,12 +59,12 @@ class OpenAIModel(Model):
Apart from `__init__`, all methods are private or match those of the base class.
"""

model_name: ChatModel
model_name: OpenAIModelName
client: AsyncOpenAI = field(repr=False)

def __init__(
self,
model_name: ChatModel,
model_name: OpenAIModelName,
*,
api_key: str | None = None,
openai_client: AsyncOpenAI | None = None,
Expand All @@ -77,7 +83,7 @@ def __init__(
client to use, if provided, `api_key` and `http_client` must be `None`.
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
"""
self.model_name: ChatModel = model_name
self.model_name: OpenAIModelName = model_name
if openai_client is not None:
assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
Expand Down Expand Up @@ -125,7 +131,7 @@ class OpenAIAgentModel(AgentModel):
"""Implementation of `AgentModel` for OpenAI models."""

client: AsyncOpenAI
model_name: ChatModel
model_name: OpenAIModelName
allow_text_result: bool
tools: list[chat.ChatCompletionToolParam]

Expand Down
61 changes: 61 additions & 0 deletions tests/models/test_ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from __future__ import annotations as _annotations

from datetime import datetime, timezone

import pytest
from inline_snapshot import snapshot

from pydantic_ai import Agent
from pydantic_ai.messages import (
ModelTextResponse,
UserPrompt,
)
from pydantic_ai.result import Cost

from ..conftest import IsNow, try_import

with try_import() as imports_successful:
from openai.types.chat.chat_completion_message import ChatCompletionMessage

from pydantic_ai.models.ollama import OllamaModel

from .test_openai import MockOpenAI, completion_message

pytestmark = [
pytest.mark.skipif(not imports_successful(), reason='openai not installed'),
pytest.mark.anyio,
]


def test_init():
m = OllamaModel('llama3.2', base_url='foobar/')
assert m.openai_model.client.api_key == 'ollama'
assert m.openai_model.client.base_url == 'foobar/'
assert m.name() == 'ollama:llama3.2'


async def test_request_simple_success(allow_model_requests: None):
c = completion_message(ChatCompletionMessage(content='world', role='assistant'))
mock_client = MockOpenAI.create_mock(c)
print('here')
m = OllamaModel('llama3.2', openai_client=mock_client, base_url=None)
agent = Agent(m)

result = await agent.run('hello')
assert result.data == 'world'
assert result.cost() == snapshot(Cost())

# reset the index so we get the same response again
mock_client.index = 0 # type: ignore

result = await agent.run('hello', message_history=result.new_messages())
assert result.data == 'world'
assert result.cost() == snapshot(Cost())
assert result.all_messages() == snapshot(
[
UserPrompt(content='hello', timestamp=IsNow(tz=timezone.utc)),
ModelTextResponse(content='world', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)),
UserPrompt(content='hello', timestamp=IsNow(tz=timezone.utc)),
ModelTextResponse(content='world', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)),
]
)
Loading