|
9 | 9 | from httpx import Timeout |
10 | 10 | from inline_snapshot import Is, snapshot |
11 | 11 | from pydantic import BaseModel |
| 12 | +from pytest_mock import MockerFixture |
12 | 13 | from typing_extensions import TypedDict |
13 | 14 |
|
14 | 15 | from pydantic_ai import ( |
|
41 | 42 | ) |
42 | 43 | from pydantic_ai.agent import Agent |
43 | 44 | from pydantic_ai.builtin_tools import CodeExecutionTool, ImageGenerationTool, UrlContextTool, WebSearchTool |
44 | | -from pydantic_ai.exceptions import ModelRetry, UnexpectedModelBehavior, UserError |
| 45 | +from pydantic_ai.exceptions import ModelHTTPError, ModelRetry, UnexpectedModelBehavior, UserError |
45 | 46 | from pydantic_ai.messages import ( |
46 | 47 | BuiltinToolCallEvent, # pyright: ignore[reportDeprecated] |
47 | 48 | BuiltinToolResultEvent, # pyright: ignore[reportDeprecated] |
|
54 | 55 | from ..parts_from_messages import part_types_from_messages |
55 | 56 |
|
56 | 57 | with try_import() as imports_successful: |
| 58 | + from google.genai import errors |
57 | 59 | from google.genai.types import ( |
58 | 60 | GenerateContentResponse, |
59 | 61 | GenerateContentResponseUsageMetadata, |
@@ -2929,3 +2931,72 @@ async def test_google_vertexai_image_generation(allow_model_requests: None, vert |
2929 | 2931 | identifier='f3edd8', |
2930 | 2932 | ) |
2931 | 2933 | ) |
| 2934 | + |
| 2935 | + |
| 2936 | +# API 에러 테스트 데이터 |
| 2937 | +@pytest.mark.parametrize( |
| 2938 | + 'error_class,error_response,expected_status', |
| 2939 | + [ |
| 2940 | + ( |
| 2941 | + errors.ServerError, |
| 2942 | + {'error': {'code': 503, 'message': 'The service is currently unavailable.', 'status': 'UNAVAILABLE'}}, |
| 2943 | + 503, |
| 2944 | + ), |
| 2945 | + ( |
| 2946 | + errors.ClientError, |
| 2947 | + {'error': {'code': 400, 'message': 'Invalid request parameters', 'status': 'INVALID_ARGUMENT'}}, |
| 2948 | + 400, |
| 2949 | + ), |
| 2950 | + ( |
| 2951 | + errors.ClientError, |
| 2952 | + {'error': {'code': 429, 'message': 'Rate limit exceeded', 'status': 'RESOURCE_EXHAUSTED'}}, |
| 2953 | + 429, |
| 2954 | + ), |
| 2955 | + ], |
| 2956 | +) |
| 2957 | +async def test_google_api_errors_are_handled( |
| 2958 | + allow_model_requests: None, |
| 2959 | + google_provider: GoogleProvider, |
| 2960 | + mocker: MockerFixture, |
| 2961 | + error_class: type[errors.APIError], |
| 2962 | + error_response: dict[str, Any], |
| 2963 | + expected_status: int, |
| 2964 | +): |
| 2965 | + model = GoogleModel('gemini-1.5-flash', provider=google_provider) |
| 2966 | + mocked_error = error_class(expected_status, error_response) |
| 2967 | + mocker.patch.object(model.client.aio.models, 'generate_content', side_effect=mocked_error) |
| 2968 | + |
| 2969 | + agent = Agent(model=model) |
| 2970 | + |
| 2971 | + with pytest.raises(ModelHTTPError) as exc_info: |
| 2972 | + await agent.run('This prompt will trigger the mocked error.') |
| 2973 | + |
| 2974 | + assert exc_info.value.status_code == expected_status |
| 2975 | + assert error_response['error']['message'] in str(exc_info.value.body) |
| 2976 | + |
| 2977 | + |
| 2978 | +@pytest.mark.parametrize( |
| 2979 | + 'error_class,expected_status', |
| 2980 | + [ |
| 2981 | + (errors.UnknownFunctionCallArgumentError, 400), |
| 2982 | + (errors.UnsupportedFunctionError, 404), |
| 2983 | + (errors.FunctionInvocationError, 400), |
| 2984 | + (errors.UnknownApiResponseError, 422), |
| 2985 | + ], |
| 2986 | +) |
| 2987 | +async def test_google_specific_errors_are_handled( |
| 2988 | + allow_model_requests: None, |
| 2989 | + google_provider: GoogleProvider, |
| 2990 | + mocker: MockerFixture, |
| 2991 | + error_class: type[errors.APIError], |
| 2992 | + expected_status: int, |
| 2993 | +): |
| 2994 | + model = GoogleModel('gemini-1.5-flash', provider=google_provider) |
| 2995 | + mocked_error = error_class |
| 2996 | + mocker.patch.object(model.client.aio.models, 'generate_content', side_effect=mocked_error) |
| 2997 | + |
| 2998 | + agent = Agent(model=model) |
| 2999 | + |
| 3000 | + with pytest.raises(ModelHTTPError) as exc_info: |
| 3001 | + await agent.run('This prompt will trigger the mocked error.') |
| 3002 | + assert exc_info.value.status_code == expected_status |
0 commit comments