-
Notifications
You must be signed in to change notification settings - Fork 333
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sambanova models support #215
Closed
Closed
Changes from 3 commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
aaefff7
add sambanova model
jhpiedrahitao 7f74589
add base tests sambanova models
jhpiedrahitao 3de0afe
docs sambanova model
jhpiedrahitao 346feff
fmt
jhpiedrahitao cdf6ad1
minor change
jhpiedrahitao c8a91f5
fmt
jhpiedrahitao fc73273
minor change in test sambanova
jhpiedrahitao 43ac47f
minor change in test sambanova
jhpiedrahitao 25539c7
fmt
jhpiedrahitao 5109ac3
minor change in test sambanova
jhpiedrahitao f45b8e3
lint
jhpiedrahitao 4e3970b
lint
jhpiedrahitao 9dce38a
fmt
jhpiedrahitao 320c430
fmt
jhpiedrahitao 58daab4
minor change
jhpiedrahitao 09646f2
Merge branch 'main' into add-sambanova-models
jhpiedrahitao 971de83
update sambanova available model list
jhpiedrahitao ba6be28
update sambanova available model list
jhpiedrahitao File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# `pydantic_ai.models.sambanova` | ||
|
||
## Setup | ||
|
||
For details on how to set up authentication with this model, see [model configuration for SambaNova](../../install.md#sambanova). | ||
|
||
::: pydantic_ai.models.sambanova |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
|
||
from __future__ import annotations as _annotations | ||
|
||
import os | ||
import json | ||
|
||
from dataclasses import dataclass, field | ||
from datetime import datetime, timezone | ||
from typing import Literal, Union | ||
|
||
from httpx import AsyncClient as AsyncHTTPClient | ||
|
||
from ..messages import ( | ||
Message, | ||
ModelAnyResponse, | ||
ModelStructuredResponse, | ||
ModelTextResponse, | ||
ToolCall | ||
) | ||
from ..tools import ToolDefinition | ||
from . import ( | ||
AgentModel, | ||
Model, | ||
cached_async_http_client, | ||
check_allow_model_requests, | ||
) | ||
|
||
from .openai import OpenAIAgentModel | ||
|
||
try: | ||
from openai import AsyncOpenAI | ||
from openai.types import chat | ||
except ImportError as e: | ||
raise ImportError( | ||
'Please install `openai` to use the OpenAI client, ' | ||
"you can use the `openai` optional group — `pip install 'pydantic-ai[openai]'`" | ||
) from e | ||
|
||
|
||
SambaNovaModelNames = Literal[ | ||
'Meta-Llama-3.1-8B-Instruct', | ||
'Meta-Llama-3.1-70B-Instruct', | ||
'Meta-Llama-3.1-405B-Instruct', | ||
'Meta-Llama-3.2-1B-Instruct', | ||
'Meta-Llama-3.2-3B-Instruct', | ||
'Llama-3.2-11B-Vision-Instruct', | ||
'Llama-3.2-90B-Vision-Instruct', | ||
'Meta-Llama-Guard-3-8B', | ||
'Qwen2.5-Coder-32B-Instruct', | ||
'Qwen2.5-72B-Instruct', | ||
] | ||
|
||
SambaNovaModelName = Union[SambaNovaModelNames, str] | ||
|
||
@dataclass(init=False) | ||
class SambaNovaModel(Model): | ||
"""A model that uses the SambaNova models thought OpenAI client. | ||
|
||
Internally, this uses the [OpenAI Python client](https://github.com/openai/openai-python) to interact | ||
with the SambaNova APIs | ||
|
||
Apart from `__init__`, all methods are private or match those of the base class. | ||
""" | ||
|
||
client: AsyncOpenAI = field(repr=False) | ||
|
||
def __init__( | ||
self, | ||
model_name: SambaNovaModelName, | ||
*, | ||
base_url: str | None = 'https://api.sambanova.ai/v1', | ||
api_key: str | None = None, | ||
openai_client: AsyncOpenAI | None = None, | ||
http_client: AsyncHTTPClient | None = None, | ||
): | ||
"""Initialize an SambaNovaModel model. | ||
|
||
SambaNova models have built-in compatibility with OpenAI chat completions API, so we use the | ||
OpenAI client. | ||
|
||
Args: | ||
model_name: The name of the SambaNova model to use. List of model names available | ||
[here](https://cloud.sambanova.ai) | ||
base_url: The base url for the SambaNova requests. The default value is the SambaNovaCloud URL | ||
api_key: The API key to use for authentication, if not provided, the `SAMBANOVA_API_KEY` environment variable | ||
will be used if available. | ||
openai_client: An existing | ||
[`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage) | ||
client to use, if provided,`base_url`, `api_key` and `http_client` must be `None`. | ||
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests. | ||
""" | ||
self.model_name: SambaNovaModelName = model_name | ||
if api_key is None: | ||
api_key = os.environ.get('SAMBANOVA_API_KEY') | ||
if openai_client is not None: | ||
assert base_url is None, 'Cannot provide both `openai_client` and `base_url`' | ||
assert http_client is None, 'Cannot provide both `openai_client` and `http_client`' | ||
assert api_key is None, 'Cannot provide both `openai_client` and `api_key`' | ||
self.client = openai_client | ||
elif http_client is not None: | ||
assert base_url is None, 'Cannot provide both `http_client` and `base_url`' | ||
self.client = AsyncOpenAI(api_key=api_key,http_client=http_client) | ||
else: | ||
self.client = AsyncOpenAI( | ||
base_url = base_url, | ||
api_key=api_key, | ||
http_client=cached_async_http_client() | ||
) | ||
|
||
async def agent_model( | ||
self, | ||
function_tools: list[ToolDefinition], | ||
allow_text_result: bool, | ||
result_tools: list[ToolDefinition], | ||
) -> AgentModel: | ||
check_allow_model_requests() | ||
tools = [self._map_tool_definition(r) for r in function_tools] | ||
if result_tools is not None: | ||
tools += [self._map_tool_definition(r) for r in result_tools] | ||
return SambaNovaAgentModel( | ||
self.client, | ||
self.model_name, | ||
allow_text_result, | ||
tools, | ||
) | ||
|
||
def name(self) -> str: | ||
return f'sambanova:{self.model_name}' | ||
|
||
@staticmethod | ||
def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam: | ||
return { | ||
'type': 'function', | ||
'function': { | ||
'name': f.name, | ||
'description': f.description, | ||
'parameters': f.parameters_json_schema, | ||
}, | ||
} | ||
|
||
|
||
@dataclass | ||
class SambaNovaAgentModel(OpenAIAgentModel): | ||
"""Implementation of `AgentModel` for SambaNova models. | ||
|
||
SambaNova models have built-in compatibility with OpenAI chat completions API, | ||
so we inherit from[`OpenAIModelAgent`][pydantic_ai.models.openai.OpenAIModel] here. | ||
""" | ||
|
||
@staticmethod | ||
def _process_response(response: chat.ChatCompletion) -> ModelAnyResponse: | ||
"""Process a non-streamed response, and prepare a message to return.""" | ||
timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc) | ||
choice = response.choices[0] | ||
if choice.message.tool_calls is not None: | ||
calls = [] | ||
for tool_call in choice.message.tool_calls: | ||
if isinstance(tool_call.function.arguments, dict): | ||
calls.append(ToolCall.from_json(tool_call.function.name, json.dumps(tool_call.function.arguments), tool_call.id)) | ||
else: | ||
calls.append(ToolCall.from_json(tool_call.function.name, tool_call.function.arguments, tool_call.id)) | ||
return ModelStructuredResponse(calls,timestamp=timestamp) | ||
else: | ||
assert choice.message.content is not None, choice | ||
return ModelTextResponse(choice.message.content, timestamp=timestamp) | ||
|
||
async def _completions_create( | ||
self, messages: list[Message], stream: bool | ||
) -> chat.ChatCompletion: | ||
if stream == True: | ||
if self.tools: | ||
raise NotImplementedError('tool calling when streaming not supported') | ||
return await super()._completions_create(messages, stream) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from __future__ import annotations as _annotations | ||
|
||
from dataclasses import dataclass | ||
from datetime import datetime, timezone | ||
|
||
import pytest | ||
from inline_snapshot import snapshot | ||
|
||
from pydantic_ai import Agent | ||
from pydantic_ai.messages import ( | ||
ModelTextResponse, | ||
UserPrompt, | ||
) | ||
from pydantic_ai.result import Cost | ||
|
||
from ..conftest import IsNow, try_import | ||
|
||
with try_import() as imports_successful: | ||
from openai.types.chat.chat_completion_message import ChatCompletionMessage | ||
|
||
from pydantic_ai.models.sambanova import SambaNovaModel | ||
|
||
from .test_openai import MockOpenAI, completion_message | ||
|
||
pytestmark = [ | ||
pytest.mark.skipif(not imports_successful(), reason='openai not installed'), | ||
pytest.mark.anyio, | ||
] | ||
|
||
def test_init(): | ||
m = SambaNovaModel('Meta-Llama-3.1-8B-Instruct', api_key='foobar') | ||
assert m.client.api_key == 'foobar' | ||
assert m.client.base_url == 'https://api.sambanova.ai/v1' | ||
assert m.name() == 'Meta-Llama-3.1-8B-Instruct' | ||
|
||
|
||
async def test_request_simple_success(allow_model_requests: None): | ||
c = completion_message(ChatCompletionMessage(content='world', role='assistant')) | ||
mock_client = MockOpenAI.create_mock(c) | ||
print('here') | ||
m = SambaNovaModel('Meta-Llama-3.1-8B-Instruct', openai_client=mock_client) | ||
agent = Agent(m) | ||
|
||
result = await agent.run('hello') | ||
assert result.data == 'world' | ||
assert result.cost() == snapshot(Cost()) | ||
|
||
# reset the index so we get the same response again | ||
mock_client.index = 0 # type: ignore | ||
|
||
result = await agent.run('hello', message_history=result.new_messages()) | ||
assert result.data == 'world' | ||
assert result.cost() == snapshot(Cost()) | ||
assert result.all_messages() == snapshot( | ||
[ | ||
UserPrompt(content='hello', timestamp=IsNow(tz=timezone.utc)), | ||
ModelTextResponse(content='world', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)), | ||
UserPrompt(content='hello', timestamp=IsNow(tz=timezone.utc)), | ||
ModelTextResponse(content='world', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)), | ||
] | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You may also need to add Sambanova LLM API to the list at the top of this document: