Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sambanova models support #215

Closed
wants to merge 18 commits into from
Closed
7 changes: 7 additions & 0 deletions docs/api/models/sambanova.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# `pydantic_ai.models.sambanova`

## Setup

For details on how to set up authentication with this model, see [model configuration for SambaNova](../../install.md#sambanova).

::: pydantic_ai.models.sambanova
47 changes: 47 additions & 0 deletions docs/install.md
Original file line number Diff line number Diff line change
Expand Up @@ -329,3 +329,50 @@ agent = Agent(model)
To use [Ollama](https://ollama.com/), you must first download the Ollama client, and then download a model.

You must also ensure the Ollama server is running when trying to make requests to it. For more information, please see the [Ollama documentation](https://github.com/ollama/ollama/tree/main/docs)

### SambaNova

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may also need to add Sambanova LLM API to the list at the top of this document:

This installs the pydantic_ai package, core dependencies, and libraries required to use the following LLM APIs:


To use [SambaNovaCloud](https://cloud.sambanova.ai/) through their API, go to [cloud.sambanova.ai/apis](https://cloud.sambanova.ai/apis) and click on Generate New API key.
jhpiedrahitao marked this conversation as resolved.
Show resolved Hide resolved

[`SambanovaModelNames`][pydantic_ai.models.sambanova.SambaNovaModelNames] contains a list of available SambaNovaCloud models.

#### Environment variable

Once you have the API key, you can set it as an environment variable:

```bash
export SAMBANOVA_API_KEY='your-api-key'
```

You can then use [`SambaNovaModel`][pydantic_ai.models.sambanova.SmabaNovaModel] by name:

```py title="sambanova_model_by_name.py"
from pydantic_ai import Agent

agent = Agent('sambanova:Meta-Llama-3.1-70B-Instruct')
...
```

Or initialize the model directly with just the model name:

```py title="sambanova_model_init.py"
from pydantic_ai import Agent
from pydantic_ai.models.sambanova import SambaNovaModel

model = SambaNovaModel('Meta-Llama-3.1-70B-Instruct')
agent = Agent(model)
...
```

#### `api_key` argument

If you don't want to or can't set the environment variable, you can pass it at runtime via the [`api_key` argument][pydantic_ai.models.sambanova.SambaNovaModel.__init__]:

```py title="sambanova_model_api_key.py"
from pydantic_ai import Agent
from pydantic_ai.models.sambanova import SambaNovaModel

model = SambaNovaModel('Meta-Llama-3.1-70B-Instruct', api_key='your-api-key')
agent = Agent(model)
...
```
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ nav:
- api/models/gemini.md
- api/models/vertexai.md
- api/models/groq.md
- api/models/sambanova.md
- api/models/test.md
- api/models/function.md

Expand Down
14 changes: 14 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@
'ollama:qwen2',
'ollama:qwen2.5',
'ollama:starcoder2',
'sambanova:Meta-Llama-3.1-8B-Instruct',
'sambanova:Meta-Llama-3.1-70B-Instruct',
'sambanova:Meta-Llama-3.1-405B-Instruct',
'sambanova:Meta-Llama-3.2-1B-Instruct',
'sambanova:Meta-Llama-3.2-3B-Instruct',
'sambanova:Llama-3.2-11B-Vision-Instruct',
'sambanova:Llama-3.2-90B-Vision-Instruct',
'sambanova:Meta-Llama-Guard-3-8B',
'sambanova:Qwen2.5-Coder-32B-Instruct',
'sambanova:Qwen2.5-72B-Instruct',
'test',
]
"""Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
Expand Down Expand Up @@ -275,6 +285,10 @@ def infer_model(model: Model | KnownModelName) -> Model:
from .ollama import OllamaModel

return OllamaModel(model[7:])
elif model.startswith('sambanova:'):
from .sambanova import SambaNovaModel

return SambaNovaModel(model[10:]) # pyright: ignore[reportArgumentType]
else:
raise UserError(f'Unknown model: {model}')

Expand Down
173 changes: 173 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/sambanova.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@

from __future__ import annotations as _annotations

import os
import json

from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Literal, Union

from httpx import AsyncClient as AsyncHTTPClient

from ..messages import (
Message,
ModelAnyResponse,
ModelStructuredResponse,
ModelTextResponse,
ToolCall
)
from ..tools import ToolDefinition
from . import (
AgentModel,
Model,
cached_async_http_client,
check_allow_model_requests,
)

from .openai import OpenAIAgentModel

try:
from openai import AsyncOpenAI
from openai.types import chat
except ImportError as e:
raise ImportError(
'Please install `openai` to use the OpenAI client, '
"you can use the `openai` optional group — `pip install 'pydantic-ai[openai]'`"
) from e


SambaNovaModelNames = Literal[
'Meta-Llama-3.1-8B-Instruct',
'Meta-Llama-3.1-70B-Instruct',
'Meta-Llama-3.1-405B-Instruct',
'Meta-Llama-3.2-1B-Instruct',
'Meta-Llama-3.2-3B-Instruct',
'Llama-3.2-11B-Vision-Instruct',
'Llama-3.2-90B-Vision-Instruct',
'Meta-Llama-Guard-3-8B',
'Qwen2.5-Coder-32B-Instruct',
'Qwen2.5-72B-Instruct',
]

SambaNovaModelName = Union[SambaNovaModelNames, str]

@dataclass(init=False)
class SambaNovaModel(Model):
"""A model that uses the SambaNova models thought OpenAI client.

Internally, this uses the [OpenAI Python client](https://github.com/openai/openai-python) to interact
with the SambaNova APIs

Apart from `__init__`, all methods are private or match those of the base class.
"""

client: AsyncOpenAI = field(repr=False)

def __init__(
self,
model_name: SambaNovaModelName,
*,
base_url: str | None = 'https://api.sambanova.ai/v1',
api_key: str | None = None,
openai_client: AsyncOpenAI | None = None,
http_client: AsyncHTTPClient | None = None,
):
"""Initialize an SambaNovaModel model.

SambaNova models have built-in compatibility with OpenAI chat completions API, so we use the
OpenAI client.

Args:
model_name: The name of the SambaNova model to use. List of model names available
[here](https://cloud.sambanova.ai)
base_url: The base url for the SambaNova requests. The default value is the SambaNovaCloud URL
api_key: The API key to use for authentication, if not provided, the `SAMBANOVA_API_KEY` environment variable
will be used if available.
openai_client: An existing
[`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage)
client to use, if provided,`base_url`, `api_key` and `http_client` must be `None`.
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
"""
self.model_name: SambaNovaModelName = model_name
if api_key is None:
api_key = os.environ.get('SAMBANOVA_API_KEY')
if openai_client is not None:
assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
self.client = openai_client
elif http_client is not None:
assert base_url is None, 'Cannot provide both `http_client` and `base_url`'
self.client = AsyncOpenAI(api_key=api_key,http_client=http_client)
else:
self.client = AsyncOpenAI(
base_url = base_url,
api_key=api_key,
http_client=cached_async_http_client()
)

async def agent_model(
self,
function_tools: list[ToolDefinition],
allow_text_result: bool,
result_tools: list[ToolDefinition],
) -> AgentModel:
check_allow_model_requests()
tools = [self._map_tool_definition(r) for r in function_tools]
if result_tools is not None:
tools += [self._map_tool_definition(r) for r in result_tools]
return SambaNovaAgentModel(
self.client,
self.model_name,
allow_text_result,
tools,
)

def name(self) -> str:
return f'sambanova:{self.model_name}'

@staticmethod
def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
return {
'type': 'function',
'function': {
'name': f.name,
'description': f.description,
'parameters': f.parameters_json_schema,
},
}


@dataclass
class SambaNovaAgentModel(OpenAIAgentModel):
"""Implementation of `AgentModel` for SambaNova models.

SambaNova models have built-in compatibility with OpenAI chat completions API,
so we inherit from[`OpenAIModelAgent`][pydantic_ai.models.openai.OpenAIModel] here.
"""

@staticmethod
def _process_response(response: chat.ChatCompletion) -> ModelAnyResponse:
"""Process a non-streamed response, and prepare a message to return."""
timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
choice = response.choices[0]
if choice.message.tool_calls is not None:
calls = []
for tool_call in choice.message.tool_calls:
if isinstance(tool_call.function.arguments, dict):
calls.append(ToolCall.from_json(tool_call.function.name, json.dumps(tool_call.function.arguments), tool_call.id))
else:
calls.append(ToolCall.from_json(tool_call.function.name, tool_call.function.arguments, tool_call.id))
return ModelStructuredResponse(calls,timestamp=timestamp)
else:
assert choice.message.content is not None, choice
return ModelTextResponse(choice.message.content, timestamp=timestamp)

async def _completions_create(
self, messages: list[Message], stream: bool
) -> chat.ChatCompletion:
if stream == True:
if self.tools:
raise NotImplementedError('tool calling when streaming not supported')
return await super()._completions_create(messages, stream)
61 changes: 61 additions & 0 deletions tests/models/test_sambanova.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from __future__ import annotations as _annotations

from dataclasses import dataclass
from datetime import datetime, timezone

import pytest
from inline_snapshot import snapshot

from pydantic_ai import Agent
from pydantic_ai.messages import (
ModelTextResponse,
UserPrompt,
)
from pydantic_ai.result import Cost

from ..conftest import IsNow, try_import

with try_import() as imports_successful:
from openai.types.chat.chat_completion_message import ChatCompletionMessage

from pydantic_ai.models.sambanova import SambaNovaModel

from .test_openai import MockOpenAI, completion_message

pytestmark = [
pytest.mark.skipif(not imports_successful(), reason='openai not installed'),
pytest.mark.anyio,
]

def test_init():
m = SambaNovaModel('Meta-Llama-3.1-8B-Instruct', api_key='foobar')
assert m.client.api_key == 'foobar'
assert m.client.base_url == 'https://api.sambanova.ai/v1'
assert m.name() == 'Meta-Llama-3.1-8B-Instruct'


async def test_request_simple_success(allow_model_requests: None):
c = completion_message(ChatCompletionMessage(content='world', role='assistant'))
mock_client = MockOpenAI.create_mock(c)
print('here')
m = SambaNovaModel('Meta-Llama-3.1-8B-Instruct', openai_client=mock_client)
agent = Agent(m)

result = await agent.run('hello')
assert result.data == 'world'
assert result.cost() == snapshot(Cost())

# reset the index so we get the same response again
mock_client.index = 0 # type: ignore

result = await agent.run('hello', message_history=result.new_messages())
assert result.data == 'world'
assert result.cost() == snapshot(Cost())
assert result.all_messages() == snapshot(
[
UserPrompt(content='hello', timestamp=IsNow(tz=timezone.utc)),
ModelTextResponse(content='world', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)),
UserPrompt(content='hello', timestamp=IsNow(tz=timezone.utc)),
ModelTextResponse(content='world', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)),
]
)
Loading