Skip to content

Commit

Permalink
fix(json schema): unwrap allOfs with one entry
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertCraigie committed Aug 12, 2024
1 parent 1a388a1 commit 53d964d
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 3 deletions.
10 changes: 7 additions & 3 deletions src/openai/lib/_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,13 @@ def _ensure_strict_json_schema(
# intersections
all_of = json_schema.get("allOf")
if is_list(all_of):
json_schema["allOf"] = [
_ensure_strict_json_schema(entry, path=(*path, "anyOf", str(i))) for i, entry in enumerate(all_of)
]
if len(all_of) == 1:
json_schema.update(_ensure_strict_json_schema(all_of[0], path=(*path, "allOf", "0")))
json_schema.pop("allOf")
else:
json_schema["allOf"] = [
_ensure_strict_json_schema(entry, path=(*path, "allOf", str(i))) for i, entry in enumerate(all_of)
]

defs = json_schema.get("$defs")
if is_dict(defs):
Expand Down
63 changes: 63 additions & 0 deletions tests/lib/test_pydantic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

from enum import Enum

from pydantic import Field, BaseModel
from inline_snapshot import snapshot

import openai
Expand Down Expand Up @@ -161,3 +164,63 @@ def test_most_types() -> None:
},
}
)


class Color(Enum):
RED = "red"
BLUE = "blue"
GREEN = "green"


class ColorDetection(BaseModel):
color: Color = Field(description="The detected color")
hex_color_code: str = Field(description="The hex color code of the detected color")


def test_enums() -> None:
if PYDANTIC_V2:
assert openai.pydantic_function_tool(ColorDetection)["function"] == snapshot(
{
"name": "ColorDetection",
"strict": True,
"parameters": {
"$defs": {"Color": {"enum": ["red", "blue", "green"], "title": "Color", "type": "string"}},
"properties": {
"color": {"description": "The detected color", "$ref": "#/$defs/Color"},
"hex_color_code": {
"description": "The hex color code of the detected color",
"title": "Hex Color Code",
"type": "string",
},
},
"required": ["color", "hex_color_code"],
"title": "ColorDetection",
"type": "object",
"additionalProperties": False,
},
}
)
else:
assert openai.pydantic_function_tool(ColorDetection)["function"] == snapshot(
{
"name": "ColorDetection",
"strict": True,
"parameters": {
"properties": {
"color": {"description": "The detected color", "$ref": "#/definitions/Color"},
"hex_color_code": {
"description": "The hex color code of the detected color",
"title": "Hex Color Code",
"type": "string",
},
},
"required": ["color", "hex_color_code"],
"title": "ColorDetection",
"definitions": {
"Color": {"title": "Color", "description": "An enumeration.", "enum": ["red", "blue", "green"]}
},
"type": "object",
"additionalProperties": False,
},
}
)

0 comments on commit 53d964d

Please sign in to comment.