Skip to content

Introduce experimental FileOutput interface for models that output File and Path types #348

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ ignore = [
"ANN003", # Missing type annotation for `**kwargs`
"ANN101", # Missing type annotation for self in method
"ANN102", # Missing type annotation for cls in classmethod
"ANN401", # Dynamically typed expressions (typing.Any) are disallowed in {name}
"W191", # Indentation contains tabs
"UP037", # Remove quotes from type annotation
]
Expand All @@ -86,3 +87,7 @@ ignore = [
"ANN201", # Missing return type annotation for public function
"ANN202", # Missing return type annotation for private function
]

[tool.pyright]
venvPath = "."
venv = ".venv"
6 changes: 4 additions & 2 deletions replicate/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,25 +164,27 @@ def run(
self,
ref: str,
input: Optional[Dict[str, Any]] = None,
use_file_output: Optional[bool] = None,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Union[Any, Iterator[Any]]: # noqa: ANN401
"""
Run a model and wait for its output.
"""

return run(self, ref, input, **params)
return run(self, ref, input, use_file_output, **params)

async def async_run(
self,
ref: str,
input: Optional[Dict[str, Any]] = None,
use_file_output: Optional[bool] = None,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Union[Any, AsyncIterator[Any]]: # noqa: ANN401
"""
Run a model and wait for its output asynchronously.
"""

return await async_run(self, ref, input, **params)
return await async_run(self, ref, input, use_file_output, **params)

def stream(
self,
Expand Down
2 changes: 1 addition & 1 deletion replicate/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing_extensions import Unpack, deprecated

from replicate.account import Account
from replicate.json import async_encode_json, encode_json
from replicate.helpers import async_encode_json, encode_json
from replicate.pagination import Page
from replicate.prediction import (
Prediction,
Expand Down
82 changes: 81 additions & 1 deletion replicate/json.py → replicate/helpers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import base64
import io
import mimetypes
from collections.abc import Mapping, Sequence
from pathlib import Path
from types import GeneratorType
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, Optional

import httpx

if TYPE_CHECKING:
from replicate.client import Client
Expand Down Expand Up @@ -108,3 +111,80 @@ def base64_encode_file(file: io.IOBase) -> str:
mimetypes.guess_type(getattr(file, "name", ""))[0] or "application/octet-stream"
)
return f"data:{mime_type};base64,{encoded_body}"


class FileOutput(httpx.SyncByteStream, httpx.AsyncByteStream):
"""
An object that can be used to read the contents of an output file
created by running a Replicate model.
"""

url: str
"""
The file URL.
"""

_client: "Client"

def __init__(self, url: str, client: "Client") -> None:
self.url = url
self._client = client

def read(self) -> bytes:
if self.url.startswith("data:"):
_, encoded = self.url.split(",", 1)
return base64.b64decode(encoded)

with self._client._client.stream("GET", self.url) as response:
response.raise_for_status()
return response.read()

def __iter__(self) -> Iterator[bytes]:
if self.url.startswith("data:"):
yield self.read()
return

with self._client._client.stream("GET", self.url) as response:
response.raise_for_status()
yield from response.iter_bytes()

async def aread(self) -> bytes:
if self.url.startswith("data:"):
_, encoded = self.url.split(",", 1)
return base64.b64decode(encoded)

async with self._client._async_client.stream("GET", self.url) as response:
response.raise_for_status()
return await response.aread()

async def __aiter__(self) -> AsyncIterator[bytes]:
if self.url.startswith("data:"):
yield await self.aread()
return

async with self._client._async_client.stream("GET", self.url) as response:
response.raise_for_status()
async for chunk in response.aiter_bytes():
yield chunk

def __str__(self) -> str:
return self.url


def transform_output(value: Any, client: "Client") -> Any:
"""
Transform the output of a prediction to a `FileOutput` object if it's a URL.
"""

def transform(obj: Any) -> Any:
if isinstance(obj, Mapping):
return {k: transform(v) for k, v in obj.items()}
elif isinstance(obj, Sequence) and not isinstance(obj, str):
return [transform(item) for item in obj]
elif isinstance(obj, str) and (
obj.startswith("https:") or obj.startswith("data:")
):
return FileOutput(obj, client)
return obj

return transform(value)
2 changes: 1 addition & 1 deletion replicate/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from typing_extensions import NotRequired, TypedDict, Unpack, deprecated

from replicate.exceptions import ReplicateException
from replicate.helpers import async_encode_json, encode_json
from replicate.identifier import ModelVersionIdentifier
from replicate.json import async_encode_json, encode_json
from replicate.pagination import Page
from replicate.prediction import (
Prediction,
Expand Down
2 changes: 1 addition & 1 deletion replicate/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from replicate.exceptions import ModelError, ReplicateError
from replicate.file import FileEncodingStrategy
from replicate.json import async_encode_json, encode_json
from replicate.helpers import async_encode_json, encode_json
from replicate.pagination import Page
from replicate.resource import Namespace, Resource
from replicate.stream import EventSource
Expand Down
9 changes: 9 additions & 0 deletions replicate/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from replicate import identifier
from replicate.exceptions import ModelError
from replicate.helpers import transform_output
from replicate.model import Model
from replicate.prediction import Prediction
from replicate.schema import make_schema_backwards_compatible
Expand All @@ -28,6 +29,7 @@ def run(
client: "Client",
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
input: Optional[Dict[str, Any]] = None,
use_file_output: Optional[bool] = None,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Union[Any, Iterator[Any]]: # noqa: ANN401
"""
Expand Down Expand Up @@ -60,13 +62,17 @@ def run(
if prediction.status == "failed":
raise ModelError(prediction)

if use_file_output:
return transform_output(prediction.output, client)

return prediction.output


async def async_run(
client: "Client",
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
input: Optional[Dict[str, Any]] = None,
use_file_output: Optional[bool] = None,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Union[Any, AsyncIterator[Any]]: # noqa: ANN401
"""
Expand Down Expand Up @@ -99,6 +105,9 @@ async def async_run(
if prediction.status == "failed":
raise ModelError(prediction)

if use_file_output:
return transform_output(prediction.output, client)

return prediction.output


Expand Down
3 changes: 1 addition & 2 deletions replicate/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Union,
)

import httpx
from typing_extensions import Unpack

from replicate import identifier
Expand All @@ -22,8 +23,6 @@


if TYPE_CHECKING:
import httpx

from replicate.client import Client
from replicate.identifier import ModelVersionIdentifier
from replicate.model import Model
Expand Down
2 changes: 1 addition & 1 deletion replicate/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@

from typing_extensions import NotRequired, Unpack

from replicate.helpers import async_encode_json, encode_json
from replicate.identifier import ModelVersionIdentifier
from replicate.json import async_encode_json, encode_json
from replicate.model import Model
from replicate.pagination import Page
from replicate.resource import Namespace, Resource
Expand Down
2 changes: 2 additions & 0 deletions requirements-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
# features: []
# all-features: false
# with-sources: false
# generate-hashes: false
# universal: false

-e file:.
annotated-types==0.6.0
Expand Down
2 changes: 2 additions & 0 deletions requirements.lock
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
# features: []
# all-features: false
# with-sources: false
# generate-hashes: false
# universal: false

-e file:.
annotated-types==0.6.0
Expand Down
2 changes: 1 addition & 1 deletion tests/test_json.py → tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from replicate.json import base64_encode_file
from replicate.helpers import base64_encode_file


@pytest.mark.parametrize(
Expand Down
Loading