From 91b5b89b27690f78829f77433bac5973ae15813b Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 24 Nov 2024 12:10:57 +0000 Subject: [PATCH 1/3] support vertexai in infer_model --- pydantic_ai/models/__init__.py | 7 +++++++ pydantic_ai_examples/pydantic_model.py | 1 + tests/models/test_model.py | 22 ++++++++++++++++++++++ 3 files changed, 30 insertions(+) diff --git a/pydantic_ai/models/__init__.py b/pydantic_ai/models/__init__.py index fc0e901b5..5e0931b1a 100644 --- a/pydantic_ai/models/__init__.py +++ b/pydantic_ai/models/__init__.py @@ -46,6 +46,8 @@ 'groq:gemma-7b-it', 'gemini-1.5-flash', 'gemini-1.5-pro', + 'vertexai:gemini-1.5-flash', + 'vertexai:gemini-1.5-pro', 'test', ] """Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent]. @@ -245,6 +247,11 @@ def infer_model(model: Model | KnownModelName) -> Model: from .groq import GroqModel return GroqModel(model[5:]) # pyright: ignore[reportArgumentType] + elif model.startswith('vertexai:'): + from .vertexai import VertexAIModel + + # noinspection PyTypeChecker + return VertexAIModel(model[9:]) # pyright: ignore[reportArgumentType] else: from ..exceptions import UserError diff --git a/pydantic_ai_examples/pydantic_model.py b/pydantic_ai_examples/pydantic_model.py index 83804f228..54f53df2b 100644 --- a/pydantic_ai_examples/pydantic_model.py +++ b/pydantic_ai_examples/pydantic_model.py @@ -24,6 +24,7 @@ class MyModel(BaseModel): model = cast(KnownModelName, os.getenv('PYDANTIC_AI_MODEL', 'openai:gpt-4o')) +print(f'Using model: {model}') agent = Agent(model, result_type=MyModel) if __name__ == '__main__': diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 30bfb0ecf..4701468c0 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -1,3 +1,5 @@ +from typing import TYPE_CHECKING + import pytest from pydantic_ai import UserError @@ -6,6 +8,19 @@ from pydantic_ai.models.openai import OpenAIModel from tests.conftest import TestEnv +if TYPE_CHECKING: + from pydantic_ai.models.vertexai import VertexAIModel + + google_auth_installed = True + +else: + try: + from pydantic_ai.models.vertexai import VertexAIModel + except ImportError: + google_auth_installed = False + else: + google_auth_installed = True + def test_infer_str_openai(env: TestEnv): env.set('OPENAI_API_KEY', 'via-env-var') @@ -24,6 +39,13 @@ def test_infer_str_gemini(env: TestEnv): assert m.name() == 'gemini-1.5-flash' +@pytest.mark.skipif(not google_auth_installed, reason='google-auth not installed') +def test_infer_vertexai(env: TestEnv): + m = infer_model('vertexai:gemini-1.5-flash') + assert isinstance(m, VertexAIModel) + assert m.name() == 'vertexai:gemini-1.5-flash' + + def test_infer_str_unknown(): with pytest.raises(UserError, match='Unknown model: foobar'): infer_model('foobar') # pyright: ignore[reportArgumentType] From 58c89177a4167472d99f017692b79a5aa5ad790b Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 24 Nov 2024 12:11:39 +0000 Subject: [PATCH 2/3] tweak --- pydantic_ai/models/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pydantic_ai/models/__init__.py b/pydantic_ai/models/__init__.py index 5e0931b1a..359c620ad 100644 --- a/pydantic_ai/models/__init__.py +++ b/pydantic_ai/models/__init__.py @@ -250,7 +250,6 @@ def infer_model(model: Model | KnownModelName) -> Model: elif model.startswith('vertexai:'): from .vertexai import VertexAIModel - # noinspection PyTypeChecker return VertexAIModel(model[9:]) # pyright: ignore[reportArgumentType] else: from ..exceptions import UserError From ffb72f05fd01f0d5f1c8bcab958b21d2ad59fbcd Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 24 Nov 2024 12:12:34 +0000 Subject: [PATCH 3/3] global import --- pydantic_ai/models/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pydantic_ai/models/__init__.py b/pydantic_ai/models/__init__.py index 359c620ad..c56a5d19d 100644 --- a/pydantic_ai/models/__init__.py +++ b/pydantic_ai/models/__init__.py @@ -15,6 +15,7 @@ import httpx +from ..exceptions import UserError from ..messages import Message, ModelAnyResponse, ModelStructuredResponse if TYPE_CHECKING: @@ -252,8 +253,6 @@ def infer_model(model: Model | KnownModelName) -> Model: return VertexAIModel(model[9:]) # pyright: ignore[reportArgumentType] else: - from ..exceptions import UserError - raise UserError(f'Unknown model: {model}')