Skip to content

Commit 1341160

Browse files
committed
Add OpenRouterModel
1 parent 45d0ff2 commit 1341160

File tree

2 files changed

+70
-4
lines changed

2 files changed

+70
-4
lines changed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from typing import Any, cast
2+
3+
from openai.types import chat
4+
5+
from ..messages import ModelResponse
6+
from .openai import OpenAIModel
7+
8+
9+
class OpenRouterChatCompletion(chat.ChatCompletion):
10+
"""Extends ChatCompletion with OpenRouter-specific attributes.
11+
12+
This class extends the base ChatCompletion model to include additional
13+
fields returned specifically by the OpenRouter API.
14+
15+
Attributes:
16+
provider: The name of the upstream LLM provider (e.g., "Anthropic",
17+
"OpenAI", etc.) that processed the request through OpenRouter.
18+
"""
19+
20+
provider: str
21+
22+
23+
class OpenRouterModel(OpenAIModel):
24+
"""Extends OpenAIModel to capture extra metadata for Openrouter."""
25+
26+
def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
27+
response = cast(OpenRouterChatCompletion, response)
28+
model_response = super()._process_response(response=response)
29+
openrouter_provider: str | None = response.provider if hasattr(response, 'provider') else None
30+
if openrouter_provider:
31+
vendor_details: dict[str, Any] = getattr(model_response, 'vendor_details') or {}
32+
vendor_details['provider'] = openrouter_provider
33+
model_response.vendor_details = vendor_details
34+
return model_response

tests/providers/test_openrouter.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66

77
from pydantic_ai.agent import Agent
88
from pydantic_ai.exceptions import UserError
9+
from pydantic_ai.messages import ModelRequest, ModelResponse, TextPart, UserPromptPart
10+
from pydantic_ai.usage import Usage
911

10-
from ..conftest import TestEnv, try_import
12+
from ..conftest import IsDatetime, IsStr, TestEnv, try_import
1113

1214
with try_import() as imports_successful:
1315
import openai
1416

15-
from pydantic_ai.models.openai import OpenAIModel
17+
from pydantic_ai.models.openrouter import OpenRouterModel
1618
from pydantic_ai.providers.openrouter import OpenRouterProvider
1719

1820

@@ -57,11 +59,41 @@ def test_openrouter_pass_openai_client() -> None:
5759

5860
async def test_openrouter_with_google_model(allow_model_requests: None, openrouter_api_key: str) -> None:
5961
provider = OpenRouterProvider(api_key=openrouter_api_key)
60-
model = OpenAIModel('google/gemini-2.0-flash-exp:free', provider=provider)
61-
agent = Agent(model, instructions='Be helpful.')
62+
model = OpenRouterModel('google/gemini-2.0-flash-exp:free', provider=provider)
63+
agent = Agent(model, instructions='Be helpful.', retries=1)
6264
response = await agent.run('Tell me a joke.')
6365
assert response.output == snapshot("""\
6466
Why don't scientists trust atoms? \n\
6567
6668
Because they make up everything!
6769
""")
70+
71+
assert response.all_messages() == snapshot(
72+
[
73+
ModelRequest(
74+
parts=[
75+
UserPromptPart(
76+
content='Tell me a joke.',
77+
timestamp=IsDatetime(iso_string=True),
78+
)
79+
],
80+
instructions='Be helpful.',
81+
),
82+
ModelResponse(
83+
parts=[
84+
TextPart(
85+
content="""\
86+
Why don't scientists trust atoms? \n\
87+
88+
Because they make up everything!
89+
"""
90+
)
91+
],
92+
usage=Usage(requests=1, request_tokens=8, response_tokens=17, total_tokens=25, details={}),
93+
model_name='google/gemini-2.0-flash-exp:free',
94+
timestamp=IsDatetime(iso_string=True),
95+
vendor_details={'provider': 'Google'},
96+
vendor_id=IsStr(),
97+
),
98+
]
99+
)

0 commit comments

Comments
 (0)