Skip to content

Commit

Permalink
chore(internal): support serialising iterable types (openai#1127)
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-bot authored and megamanics committed Aug 14, 2024
1 parent a392b3b commit c229675
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 3 deletions.
2 changes: 2 additions & 0 deletions src/openai/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
is_mapping as is_mapping,
is_tuple_t as is_tuple_t,
parse_date as parse_date,
is_iterable as is_iterable,
is_sequence as is_sequence,
coerce_float as coerce_float,
is_mapping_t as is_mapping_t,
Expand All @@ -33,6 +34,7 @@
is_list_type as is_list_type,
is_union_type as is_union_type,
extract_type_arg as extract_type_arg,
is_iterable_type as is_iterable_type,
is_required_type as is_required_type,
is_annotated_type as is_annotated_type,
strip_annotated_type as strip_annotated_type,
Expand Down
9 changes: 8 additions & 1 deletion src/openai/_utils/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
from ._utils import (
is_list,
is_mapping,
is_iterable,
)
from ._typing import (
is_list_type,
is_union_type,
extract_type_arg,
is_iterable_type,
is_required_type,
is_annotated_type,
strip_annotated_type,
Expand Down Expand Up @@ -157,7 +159,12 @@ def _transform_recursive(
if is_typeddict(stripped_type) and is_mapping(data):
return _transform_typeddict(data, stripped_type)

if is_list_type(stripped_type) and is_list(data):
if (
# List[T]
(is_list_type(stripped_type) and is_list(data))
# Iterable[T]
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
):
inner_type = extract_type_arg(stripped_type, 0)
return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]

Expand Down
9 changes: 8 additions & 1 deletion src/openai/_utils/_typing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import Any, TypeVar, cast
from typing import Any, TypeVar, Iterable, cast
from collections import abc as _c_abc
from typing_extensions import Required, Annotated, get_args, get_origin

from .._types import InheritsGeneric
Expand All @@ -15,6 +16,12 @@ def is_list_type(typ: type) -> bool:
return (get_origin(typ) or typ) == list


def is_iterable_type(typ: type) -> bool:
"""If the given type is `typing.Iterable[T]`"""
origin = get_origin(typ) or typ
return origin == Iterable or origin == _c_abc.Iterable


def is_union_type(typ: type) -> bool:
return _is_union(get_origin(typ))

Expand Down
4 changes: 4 additions & 0 deletions src/openai/_utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ def is_list(obj: object) -> TypeGuard[list[object]]:
return isinstance(obj, list)


def is_iterable(obj: object) -> TypeGuard[Iterable[object]]:
return isinstance(obj, Iterable)


def deepcopy_minimal(item: _T) -> _T:
"""Minimal reimplementation of copy.deepcopy() that will only copy certain object types:
Expand Down
34 changes: 33 additions & 1 deletion tests/test_transform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, List, Union, Optional
from typing import Any, List, Union, Iterable, Optional, cast
from datetime import date, datetime
from typing_extensions import Required, Annotated, TypedDict

Expand Down Expand Up @@ -265,3 +265,35 @@ def test_pydantic_default_field() -> None:
assert model.with_none_default == "bar"
assert model.with_str_default == "baz"
assert transform(model, Any) == {"with_none_default": "bar", "with_str_default": "baz"}


class TypedDictIterableUnion(TypedDict):
foo: Annotated[Union[Bar8, Iterable[Baz8]], PropertyInfo(alias="FOO")]


class Bar8(TypedDict):
foo_bar: Annotated[str, PropertyInfo(alias="fooBar")]


class Baz8(TypedDict):
foo_baz: Annotated[str, PropertyInfo(alias="fooBaz")]


def test_iterable_of_dictionaries() -> None:
assert transform({"foo": [{"foo_baz": "bar"}]}, TypedDictIterableUnion) == {"FOO": [{"fooBaz": "bar"}]}
assert cast(Any, transform({"foo": ({"foo_baz": "bar"},)}, TypedDictIterableUnion)) == {"FOO": [{"fooBaz": "bar"}]}

def my_iter() -> Iterable[Baz8]:
yield {"foo_baz": "hello"}
yield {"foo_baz": "world"}

assert transform({"foo": my_iter()}, TypedDictIterableUnion) == {"FOO": [{"fooBaz": "hello"}, {"fooBaz": "world"}]}


class TypedDictIterableUnionStr(TypedDict):
foo: Annotated[Union[str, Iterable[Baz8]], PropertyInfo(alias="FOO")]


def test_iterable_union_str() -> None:
assert transform({"foo": "bar"}, TypedDictIterableUnionStr) == {"FOO": "bar"}
assert cast(Any, transform(iter([{"foo_baz": "bar"}]), Union[str, Iterable[Baz8]])) == [{"fooBaz": "bar"}]

0 comments on commit c229675

Please sign in to comment.