Skip to content

Commit ae45a9b

Browse files
committed
Add genai.errors handling in GoogleModel
- Add error handling helper method `_handle_google_error` - Convert Google API errors to ModelHTTPError with proper status codes - Map specific function-related errors (400-level) appropriately - Keep original error details in response body - Add test cases for API error handling Resolves: #3088
1 parent 78fb707 commit ae45a9b

File tree

2 files changed

+127
-4
lines changed

2 files changed

+127
-4
lines changed

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from .._output import OutputObjectDefinition
1515
from .._run_context import RunContext
1616
from ..builtin_tools import CodeExecutionTool, ImageGenerationTool, UrlContextTool, WebSearchTool
17-
from ..exceptions import UserError
17+
from ..exceptions import ModelHTTPError, UserError
1818
from ..messages import (
1919
BinaryContent,
2020
BuiltinToolCallPart,
@@ -50,7 +50,7 @@
5050
)
5151

5252
try:
53-
from google.genai import Client
53+
from google.genai import Client, errors
5454
from google.genai.types import (
5555
BlobDict,
5656
CodeExecutionResult,
@@ -357,6 +357,49 @@ def _get_tool_config(
357357
else:
358358
return None
359359

360+
def _handle_google_error(self, error: Exception) -> ModelHTTPError:
361+
"""Helper method to convert Google API errors to ModelHTTPError."""
362+
if isinstance(error, errors.APIError):
363+
return ModelHTTPError(
364+
status_code=getattr(error, 'code', 500),
365+
model_name=self._model_name,
366+
body=error.details,
367+
)
368+
369+
error_mappings = {
370+
errors.UnknownFunctionCallArgumentError: (
371+
400,
372+
'the function call argument cannot be converted to the parameter annotation.',
373+
'BAD_REQUEST',
374+
),
375+
errors.UnsupportedFunctionError: (404, 'the function is not supported.', 'NOT_FOUND'),
376+
errors.FunctionInvocationError: (
377+
400,
378+
'the function cannot be invoked with the given arguments.',
379+
'BAD_REQUEST',
380+
),
381+
errors.UnknownApiResponseError: (
382+
422,
383+
'the response from the API cannot be parsed as JSON.',
384+
'UNPROCESSABLE_CONTENT',
385+
),
386+
}
387+
388+
if error.__class__ in error_mappings:
389+
code, message, status = error_mappings[error.__class__]
390+
return ModelHTTPError(
391+
status_code=code,
392+
model_name=self._model_name,
393+
body={'error': {'code': code, 'message': message, 'status': status}},
394+
)
395+
396+
# Handle unknown errors as 500 Internal Server Error
397+
return ModelHTTPError(
398+
status_code=500,
399+
model_name=self._model_name,
400+
body={'error': {'code': 500, 'message': str(error), 'status': 'INTERNAL_ERROR'}},
401+
)
402+
360403
@overload
361404
async def _generate_content(
362405
self,
@@ -384,7 +427,16 @@ async def _generate_content(
384427
) -> GenerateContentResponse | Awaitable[AsyncIterator[GenerateContentResponse]]:
385428
contents, config = await self._build_content_and_config(messages, model_settings, model_request_parameters)
386429
func = self.client.aio.models.generate_content_stream if stream else self.client.aio.models.generate_content
387-
return await func(model=self._model_name, contents=contents, config=config) # type: ignore
430+
try:
431+
return await func(model=self._model_name, contents=contents, config=config) # type: ignore
432+
except (
433+
errors.APIError,
434+
errors.UnknownFunctionCallArgumentError,
435+
errors.UnsupportedFunctionError,
436+
errors.FunctionInvocationError,
437+
errors.UnknownApiResponseError,
438+
) as e:
439+
raise self._handle_google_error(e) from e
388440

389441
async def _build_content_and_config(
390442
self,

tests/models/test_google.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from httpx import Timeout
1010
from inline_snapshot import Is, snapshot
1111
from pydantic import BaseModel
12+
from pytest_mock import MockerFixture
1213
from typing_extensions import TypedDict
1314

1415
from pydantic_ai import (
@@ -41,7 +42,7 @@
4142
)
4243
from pydantic_ai.agent import Agent
4344
from pydantic_ai.builtin_tools import CodeExecutionTool, ImageGenerationTool, UrlContextTool, WebSearchTool
44-
from pydantic_ai.exceptions import ModelRetry, UnexpectedModelBehavior, UserError
45+
from pydantic_ai.exceptions import ModelHTTPError, ModelRetry, UnexpectedModelBehavior, UserError
4546
from pydantic_ai.messages import (
4647
BuiltinToolCallEvent, # pyright: ignore[reportDeprecated]
4748
BuiltinToolResultEvent, # pyright: ignore[reportDeprecated]
@@ -54,6 +55,7 @@
5455
from ..parts_from_messages import part_types_from_messages
5556

5657
with try_import() as imports_successful:
58+
from google.genai import errors
5759
from google.genai.types import (
5860
GenerateContentResponse,
5961
GenerateContentResponseUsageMetadata,
@@ -2929,3 +2931,72 @@ async def test_google_vertexai_image_generation(allow_model_requests: None, vert
29292931
identifier='f3edd8',
29302932
)
29312933
)
2934+
2935+
2936+
# API 에러 테스트 데이터
2937+
@pytest.mark.parametrize(
2938+
'error_class,error_response,expected_status',
2939+
[
2940+
(
2941+
errors.ServerError,
2942+
{'error': {'code': 503, 'message': 'The service is currently unavailable.', 'status': 'UNAVAILABLE'}},
2943+
503,
2944+
),
2945+
(
2946+
errors.ClientError,
2947+
{'error': {'code': 400, 'message': 'Invalid request parameters', 'status': 'INVALID_ARGUMENT'}},
2948+
400,
2949+
),
2950+
(
2951+
errors.ClientError,
2952+
{'error': {'code': 429, 'message': 'Rate limit exceeded', 'status': 'RESOURCE_EXHAUSTED'}},
2953+
429,
2954+
),
2955+
],
2956+
)
2957+
async def test_google_api_errors_are_handled(
2958+
allow_model_requests: None,
2959+
google_provider: GoogleProvider,
2960+
mocker: MockerFixture,
2961+
error_class: type[errors.APIError],
2962+
error_response: dict[str, Any],
2963+
expected_status: int,
2964+
):
2965+
model = GoogleModel('gemini-1.5-flash', provider=google_provider)
2966+
mocked_error = error_class(expected_status, error_response)
2967+
mocker.patch.object(model.client.aio.models, 'generate_content', side_effect=mocked_error)
2968+
2969+
agent = Agent(model=model)
2970+
2971+
with pytest.raises(ModelHTTPError) as exc_info:
2972+
await agent.run('This prompt will trigger the mocked error.')
2973+
2974+
assert exc_info.value.status_code == expected_status
2975+
assert error_response['error']['message'] in str(exc_info.value.body)
2976+
2977+
2978+
@pytest.mark.parametrize(
2979+
'error_class,expected_status',
2980+
[
2981+
(errors.UnknownFunctionCallArgumentError, 400),
2982+
(errors.UnsupportedFunctionError, 404),
2983+
(errors.FunctionInvocationError, 400),
2984+
(errors.UnknownApiResponseError, 422),
2985+
],
2986+
)
2987+
async def test_google_specific_errors_are_handled(
2988+
allow_model_requests: None,
2989+
google_provider: GoogleProvider,
2990+
mocker: MockerFixture,
2991+
error_class: type[errors.APIError],
2992+
expected_status: int,
2993+
):
2994+
model = GoogleModel('gemini-1.5-flash', provider=google_provider)
2995+
mocked_error = error_class
2996+
mocker.patch.object(model.client.aio.models, 'generate_content', side_effect=mocked_error)
2997+
2998+
agent = Agent(model=model)
2999+
3000+
with pytest.raises(ModelHTTPError) as exc_info:
3001+
await agent.run('This prompt will trigger the mocked error.')
3002+
assert exc_info.value.status_code == expected_status

0 commit comments

Comments
 (0)