Skip to content

Commit

Permalink
chore(internal): support more input types (#1211)
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-bot authored Mar 4, 2024
1 parent 7853a83 commit d0e4baa
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 1 deletion.
5 changes: 5 additions & 0 deletions src/openai/_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,17 @@
FileContent,
RequestFiles,
HttpxFileTypes,
Base64FileInput,
HttpxFileContent,
HttpxRequestFiles,
)
from ._utils import is_tuple_t, is_mapping_t, is_sequence_t


def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]:
return isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike)


def is_file_content(obj: object) -> TypeGuard[FileContent]:
return (
isinstance(obj, bytes) or isinstance(obj, tuple) or isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike)
Expand Down
2 changes: 2 additions & 0 deletions src/openai/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@
ProxiesDict = Dict["str | URL", Union[None, str, URL, Proxy]]
ProxiesTypes = Union[str, Proxy, ProxiesDict]
if TYPE_CHECKING:
Base64FileInput = Union[IO[bytes], PathLike[str]]
FileContent = Union[IO[bytes], bytes, PathLike[str]]
else:
Base64FileInput = Union[IO[bytes], PathLike]
FileContent = Union[IO[bytes], bytes, PathLike] # PathLike is not subscriptable in Python 3.8.
FileTypes = Union[
# file (or bytes)
Expand Down
39 changes: 38 additions & 1 deletion src/openai/_utils/_transform.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
from __future__ import annotations

import io
import base64
import pathlib
from typing import Any, Mapping, TypeVar, cast
from datetime import date, datetime
from typing_extensions import Literal, get_args, override, get_type_hints

import anyio
import pydantic

from ._utils import (
is_list,
is_mapping,
is_iterable,
)
from .._files import is_base64_file_input
from ._typing import (
is_list_type,
is_union_type,
Expand All @@ -29,7 +34,7 @@
# TODO: ensure works correctly with forward references in all cases


PropertyFormat = Literal["iso8601", "custom"]
PropertyFormat = Literal["iso8601", "base64", "custom"]


class PropertyInfo:
Expand Down Expand Up @@ -201,6 +206,22 @@ def _format_data(data: object, format_: PropertyFormat, format_template: str | N
if format_ == "custom" and format_template is not None:
return data.strftime(format_template)

if format_ == "base64" and is_base64_file_input(data):
binary: str | bytes | None = None

if isinstance(data, pathlib.Path):
binary = data.read_bytes()
elif isinstance(data, io.IOBase):
binary = data.read()

if isinstance(binary, str): # type: ignore[unreachable]
binary = binary.encode()

if not isinstance(binary, bytes):
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")

return base64.b64encode(binary).decode("ascii")

return data


Expand Down Expand Up @@ -323,6 +344,22 @@ async def _async_format_data(data: object, format_: PropertyFormat, format_templ
if format_ == "custom" and format_template is not None:
return data.strftime(format_template)

if format_ == "base64" and is_base64_file_input(data):
binary: str | bytes | None = None

if isinstance(data, pathlib.Path):
binary = await anyio.Path(data).read_bytes()
elif isinstance(data, io.IOBase):
binary = data.read()

if isinstance(binary, str): # type: ignore[unreachable]
binary = binary.encode()

if not isinstance(binary, bytes):
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")

return base64.b64encode(binary).decode("ascii")

return data


Expand Down
1 change: 1 addition & 0 deletions tests/sample_file.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Hello, world!
29 changes: 29 additions & 0 deletions tests/test_transform.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

import io
import pathlib
from typing import Any, List, Union, TypeVar, Iterable, Optional, cast
from datetime import date, datetime
from typing_extensions import Required, Annotated, TypedDict

import pytest

from openai._types import Base64FileInput
from openai._utils import (
PropertyInfo,
transform as _transform,
Expand All @@ -17,6 +20,8 @@

_T = TypeVar("_T")

SAMPLE_FILE_PATH = pathlib.Path(__file__).parent.joinpath("sample_file.txt")


async def transform(
data: _T,
Expand Down Expand Up @@ -377,3 +382,27 @@ async def test_iterable_union_str(use_async: bool) -> None:
assert cast(Any, await transform(iter([{"foo_baz": "bar"}]), Union[str, Iterable[Baz8]], use_async)) == [
{"fooBaz": "bar"}
]


class TypedDictBase64Input(TypedDict):
foo: Annotated[Union[str, Base64FileInput], PropertyInfo(format="base64")]


@parametrize
@pytest.mark.asyncio
async def test_base64_file_input(use_async: bool) -> None:
# strings are left as-is
assert await transform({"foo": "bar"}, TypedDictBase64Input, use_async) == {"foo": "bar"}

# pathlib.Path is automatically converted to base64
assert await transform({"foo": SAMPLE_FILE_PATH}, TypedDictBase64Input, use_async) == {
"foo": "SGVsbG8sIHdvcmxkIQo="
} # type: ignore[comparison-overlap]

# io instances are automatically converted to base64
assert await transform({"foo": io.StringIO("Hello, world!")}, TypedDictBase64Input, use_async) == {
"foo": "SGVsbG8sIHdvcmxkIQ=="
} # type: ignore[comparison-overlap]
assert await transform({"foo": io.BytesIO(b"Hello, world!")}, TypedDictBase64Input, use_async) == {
"foo": "SGVsbG8sIHdvcmxkIQ=="
} # type: ignore[comparison-overlap]

0 comments on commit d0e4baa

Please sign in to comment.