Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(parsing): add support for pydantic dataclasses #1655

Merged
merged 1 commit into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions src/openai/lib/_parsing/_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from .._tools import PydanticFunctionTool
from ..._types import NOT_GIVEN, NotGiven
from ..._utils import is_dict, is_given
from ..._compat import model_parse_json
from ..._compat import PYDANTIC_V2, model_parse_json
from ..._models import construct_type_unchecked
from .._pydantic import to_strict_json_schema
from .._pydantic import is_basemodel_type, to_strict_json_schema, is_dataclass_like_type
from ...types.chat import (
ParsedChoice,
ChatCompletion,
Expand Down Expand Up @@ -216,14 +216,16 @@ def is_parseable_tool(input_tool: ChatCompletionToolParam) -> bool:
return cast(FunctionDefinition, input_fn).get("strict") or False


def is_basemodel_type(typ: type) -> TypeGuard[type[pydantic.BaseModel]]:
return issubclass(typ, pydantic.BaseModel)


def _parse_content(response_format: type[ResponseFormatT], content: str) -> ResponseFormatT:
if is_basemodel_type(response_format):
return cast(ResponseFormatT, model_parse_json(response_format, content))

if is_dataclass_like_type(response_format):
if not PYDANTIC_V2:
raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {response_format}")

return pydantic.TypeAdapter(response_format).validate_json(content)

raise TypeError(f"Unable to automatically parse response format type {response_format}")


Expand All @@ -241,14 +243,22 @@ def type_to_response_format_param(
# can only be a `type`
response_format = cast(type, response_format)

if not is_basemodel_type(response_format):
json_schema_type: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any] | None = None

if is_basemodel_type(response_format):
name = response_format.__name__
json_schema_type = response_format
elif is_dataclass_like_type(response_format):
name = response_format.__name__
json_schema_type = pydantic.TypeAdapter(response_format)
else:
raise TypeError(f"Unsupported response_format type - {response_format}")

return {
"type": "json_schema",
"json_schema": {
"schema": to_strict_json_schema(response_format),
"name": response_format.__name__,
"schema": to_strict_json_schema(json_schema_type),
"name": name,
"strict": True,
},
}
26 changes: 22 additions & 4 deletions src/openai/lib/_pydantic.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
from __future__ import annotations

from typing import Any
import inspect
from typing import Any, TypeVar
from typing_extensions import TypeGuard

import pydantic

from .._types import NOT_GIVEN
from .._utils import is_dict as _is_dict, is_list
from .._compat import model_json_schema
from .._compat import PYDANTIC_V2, model_json_schema

_T = TypeVar("_T")


def to_strict_json_schema(model: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any]) -> dict[str, Any]:
if inspect.isclass(model) and is_basemodel_type(model):
schema = model_json_schema(model)
elif PYDANTIC_V2 and isinstance(model, pydantic.TypeAdapter):
schema = model.json_schema()
else:
raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {model}")

def to_strict_json_schema(model: type[pydantic.BaseModel]) -> dict[str, Any]:
schema = model_json_schema(model)
return _ensure_strict_json_schema(schema, path=(), root=schema)


Expand Down Expand Up @@ -117,6 +126,15 @@ def resolve_ref(*, root: dict[str, object], ref: str) -> object:
return resolved


def is_basemodel_type(typ: type) -> TypeGuard[type[pydantic.BaseModel]]:
return issubclass(typ, pydantic.BaseModel)


def is_dataclass_like_type(typ: type) -> bool:
"""Returns True if the given type likely used `@pydantic.dataclass`"""
return hasattr(typ, "__pydantic_config__")


def is_dict(obj: object) -> TypeGuard[dict[str, object]]:
# just pretend that we know there are only `str` keys
# as that check is not worth the performance cost
Expand Down
59 changes: 58 additions & 1 deletion tests/lib/chat/test_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import json
from enum import Enum
from typing import Any, Callable, Optional
from typing import Any, List, Callable, Optional
from typing_extensions import Literal, TypeVar

import httpx
Expand Down Expand Up @@ -317,6 +317,63 @@ class Location(BaseModel):
)


@pytest.mark.respx(base_url=base_url)
@pytest.mark.skipif(not PYDANTIC_V2, reason="dataclasses only supported in v2")
def test_parse_pydantic_dataclass(client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch) -> None:
from pydantic.dataclasses import dataclass

@dataclass
class CalendarEvent:
name: str
date: str
participants: List[str]

completion = _make_snapshot_request(
lambda c: c.beta.chat.completions.parse(
model="gpt-4o-2024-08-06",
messages=[
{"role": "system", "content": "Extract the event information."},
{"role": "user", "content": "Alice and Bob are going to a science fair on Friday."},
],
response_format=CalendarEvent,
),
content_snapshot=snapshot(
'{"id": "chatcmpl-9wdGqXkJJARAz7rOrLH5u5FBwLjF3", "object": "chat.completion", "created": 1723761008, "model": "gpt-4o-2024-08-06", "choices": [{"index": 0, "message": {"role": "assistant", "content": "{\\"name\\":\\"Science Fair\\",\\"date\\":\\"Friday\\",\\"participants\\":[\\"Alice\\",\\"Bob\\"]}", "refusal": null}, "logprobs": null, "finish_reason": "stop"}], "usage": {"prompt_tokens": 32, "completion_tokens": 17, "total_tokens": 49}, "system_fingerprint": "fp_2a322c9ffc"}'
),
mock_client=client,
respx_mock=respx_mock,
)

assert print_obj(completion, monkeypatch) == snapshot(
"""\
ParsedChatCompletion[CalendarEvent](
choices=[
ParsedChoice[CalendarEvent](
finish_reason='stop',
index=0,
logprobs=None,
message=ParsedChatCompletionMessage[CalendarEvent](
content='{"name":"Science Fair","date":"Friday","participants":["Alice","Bob"]}',
function_call=None,
parsed=CalendarEvent(name='Science Fair', date='Friday', participants=['Alice', 'Bob']),
refusal=None,
role='assistant',
tool_calls=[]
)
)
],
created=1723761008,
id='chatcmpl-9wdGqXkJJARAz7rOrLH5u5FBwLjF3',
model='gpt-4o-2024-08-06',
object='chat.completion',
service_tier=None,
system_fingerprint='fp_2a322c9ffc',
usage=CompletionUsage(completion_tokens=17, prompt_tokens=32, total_tokens=49)
)
"""
)


@pytest.mark.respx(base_url=base_url)
def test_pydantic_tool_model_all_types(client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch) -> None:
completion = _make_snapshot_request(
Expand Down