diff --git a/pyproject.toml b/pyproject.toml index e95ca1ea..765c0210 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,7 @@ ignore = [ "ANN003", # Missing type annotation for `**kwargs` "ANN101", # Missing type annotation for self in method "ANN102", # Missing type annotation for cls in classmethod + "ANN401", # Dynamically typed expressions (typing.Any) are disallowed in {name} "W191", # Indentation contains tabs "UP037", # Remove quotes from type annotation ] @@ -86,3 +87,7 @@ ignore = [ "ANN201", # Missing return type annotation for public function "ANN202", # Missing return type annotation for private function ] + +[tool.pyright] +venvPath = "." +venv = ".venv" diff --git a/replicate/client.py b/replicate/client.py index 5cee7a6f..3da3cc15 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -164,25 +164,27 @@ def run( self, ref: str, input: Optional[Dict[str, Any]] = None, + use_file_output: Optional[bool] = None, **params: Unpack["Predictions.CreatePredictionParams"], ) -> Union[Any, Iterator[Any]]: # noqa: ANN401 """ Run a model and wait for its output. """ - return run(self, ref, input, **params) + return run(self, ref, input, use_file_output, **params) async def async_run( self, ref: str, input: Optional[Dict[str, Any]] = None, + use_file_output: Optional[bool] = None, **params: Unpack["Predictions.CreatePredictionParams"], ) -> Union[Any, AsyncIterator[Any]]: # noqa: ANN401 """ Run a model and wait for its output asynchronously. """ - return await async_run(self, ref, input, **params) + return await async_run(self, ref, input, use_file_output, **params) def stream( self, diff --git a/replicate/deployment.py b/replicate/deployment.py index 8d9836b0..e17edcbc 100644 --- a/replicate/deployment.py +++ b/replicate/deployment.py @@ -3,7 +3,7 @@ from typing_extensions import Unpack, deprecated from replicate.account import Account -from replicate.json import async_encode_json, encode_json +from replicate.helpers import async_encode_json, encode_json from replicate.pagination import Page from replicate.prediction import ( Prediction, diff --git a/replicate/json.py b/replicate/helpers.py similarity index 57% rename from replicate/json.py rename to replicate/helpers.py index 90154a84..e0bada5d 100644 --- a/replicate/json.py +++ b/replicate/helpers.py @@ -1,9 +1,12 @@ import base64 import io import mimetypes +from collections.abc import Mapping, Sequence from pathlib import Path from types import GeneratorType -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, Optional + +import httpx if TYPE_CHECKING: from replicate.client import Client @@ -108,3 +111,80 @@ def base64_encode_file(file: io.IOBase) -> str: mimetypes.guess_type(getattr(file, "name", ""))[0] or "application/octet-stream" ) return f"data:{mime_type};base64,{encoded_body}" + + +class FileOutput(httpx.SyncByteStream, httpx.AsyncByteStream): + """ + An object that can be used to read the contents of an output file + created by running a Replicate model. + """ + + url: str + """ + The file URL. + """ + + _client: "Client" + + def __init__(self, url: str, client: "Client") -> None: + self.url = url + self._client = client + + def read(self) -> bytes: + if self.url.startswith("data:"): + _, encoded = self.url.split(",", 1) + return base64.b64decode(encoded) + + with self._client._client.stream("GET", self.url) as response: + response.raise_for_status() + return response.read() + + def __iter__(self) -> Iterator[bytes]: + if self.url.startswith("data:"): + yield self.read() + return + + with self._client._client.stream("GET", self.url) as response: + response.raise_for_status() + yield from response.iter_bytes() + + async def aread(self) -> bytes: + if self.url.startswith("data:"): + _, encoded = self.url.split(",", 1) + return base64.b64decode(encoded) + + async with self._client._async_client.stream("GET", self.url) as response: + response.raise_for_status() + return await response.aread() + + async def __aiter__(self) -> AsyncIterator[bytes]: + if self.url.startswith("data:"): + yield await self.aread() + return + + async with self._client._async_client.stream("GET", self.url) as response: + response.raise_for_status() + async for chunk in response.aiter_bytes(): + yield chunk + + def __str__(self) -> str: + return self.url + + +def transform_output(value: Any, client: "Client") -> Any: + """ + Transform the output of a prediction to a `FileOutput` object if it's a URL. + """ + + def transform(obj: Any) -> Any: + if isinstance(obj, Mapping): + return {k: transform(v) for k, v in obj.items()} + elif isinstance(obj, Sequence) and not isinstance(obj, str): + return [transform(item) for item in obj] + elif isinstance(obj, str) and ( + obj.startswith("https:") or obj.startswith("data:") + ): + return FileOutput(obj, client) + return obj + + return transform(value) diff --git a/replicate/model.py b/replicate/model.py index ccae9cd0..ba5e1113 100644 --- a/replicate/model.py +++ b/replicate/model.py @@ -3,8 +3,8 @@ from typing_extensions import NotRequired, TypedDict, Unpack, deprecated from replicate.exceptions import ReplicateException +from replicate.helpers import async_encode_json, encode_json from replicate.identifier import ModelVersionIdentifier -from replicate.json import async_encode_json, encode_json from replicate.pagination import Page from replicate.prediction import ( Prediction, diff --git a/replicate/prediction.py b/replicate/prediction.py index 7028a712..9770029b 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -20,7 +20,7 @@ from replicate.exceptions import ModelError, ReplicateError from replicate.file import FileEncodingStrategy -from replicate.json import async_encode_json, encode_json +from replicate.helpers import async_encode_json, encode_json from replicate.pagination import Page from replicate.resource import Namespace, Resource from replicate.stream import EventSource diff --git a/replicate/run.py b/replicate/run.py index ae1ca7e5..fd1accfb 100644 --- a/replicate/run.py +++ b/replicate/run.py @@ -13,6 +13,7 @@ from replicate import identifier from replicate.exceptions import ModelError +from replicate.helpers import transform_output from replicate.model import Model from replicate.prediction import Prediction from replicate.schema import make_schema_backwards_compatible @@ -28,6 +29,7 @@ def run( client: "Client", ref: Union["Model", "Version", "ModelVersionIdentifier", str], input: Optional[Dict[str, Any]] = None, + use_file_output: Optional[bool] = None, **params: Unpack["Predictions.CreatePredictionParams"], ) -> Union[Any, Iterator[Any]]: # noqa: ANN401 """ @@ -60,6 +62,9 @@ def run( if prediction.status == "failed": raise ModelError(prediction) + if use_file_output: + return transform_output(prediction.output, client) + return prediction.output @@ -67,6 +72,7 @@ async def async_run( client: "Client", ref: Union["Model", "Version", "ModelVersionIdentifier", str], input: Optional[Dict[str, Any]] = None, + use_file_output: Optional[bool] = None, **params: Unpack["Predictions.CreatePredictionParams"], ) -> Union[Any, AsyncIterator[Any]]: # noqa: ANN401 """ @@ -99,6 +105,9 @@ async def async_run( if prediction.status == "failed": raise ModelError(prediction) + if use_file_output: + return transform_output(prediction.output, client) + return prediction.output diff --git a/replicate/stream.py b/replicate/stream.py index 844973d4..3472799e 100644 --- a/replicate/stream.py +++ b/replicate/stream.py @@ -10,6 +10,7 @@ Union, ) +import httpx from typing_extensions import Unpack from replicate import identifier @@ -22,8 +23,6 @@ if TYPE_CHECKING: - import httpx - from replicate.client import Client from replicate.identifier import ModelVersionIdentifier from replicate.model import Model diff --git a/replicate/training.py b/replicate/training.py index ba3554df..28e28b4a 100644 --- a/replicate/training.py +++ b/replicate/training.py @@ -13,8 +13,8 @@ from typing_extensions import NotRequired, Unpack +from replicate.helpers import async_encode_json, encode_json from replicate.identifier import ModelVersionIdentifier -from replicate.json import async_encode_json, encode_json from replicate.model import Model from replicate.pagination import Page from replicate.resource import Namespace, Resource diff --git a/requirements-dev.lock b/requirements-dev.lock index 3eae4db9..90d7aeda 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -6,6 +6,8 @@ # features: [] # all-features: false # with-sources: false +# generate-hashes: false +# universal: false -e file:. annotated-types==0.6.0 diff --git a/requirements.lock b/requirements.lock index 53ab3f58..b1e20e40 100644 --- a/requirements.lock +++ b/requirements.lock @@ -6,6 +6,8 @@ # features: [] # all-features: false # with-sources: false +# generate-hashes: false +# universal: false -e file:. annotated-types==0.6.0 diff --git a/tests/test_json.py b/tests/test_helpers.py similarity index 95% rename from tests/test_json.py rename to tests/test_helpers.py index b8d76569..0c41cab7 100644 --- a/tests/test_json.py +++ b/tests/test_helpers.py @@ -2,7 +2,7 @@ import pytest -from replicate.json import base64_encode_file +from replicate.helpers import base64_encode_file @pytest.mark.parametrize( diff --git a/tests/test_run.py b/tests/test_run.py index d117eb32..11fde976 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,5 +1,6 @@ import asyncio import sys +from typing import cast import httpx import pytest @@ -8,6 +9,7 @@ import replicate from replicate.client import Client from replicate.exceptions import ModelError, ReplicateError +from replicate.helpers import FileOutput @pytest.mark.vcr("run.yaml") @@ -73,7 +75,7 @@ async def test_run_concurrently(mock_replicate_api_token, record_mode): results = await asyncio.gather(*tasks) assert len(results) == len(prompts) assert all(isinstance(result, list) for result in results) - assert all(len(result) > 0 for result in results) + assert all(len(results) > 0 for result in results) @pytest.mark.vcr("run.yaml") @@ -253,3 +255,255 @@ def prediction_with_status(status: str) -> dict: assert str(excinfo.value) == "OOM" assert excinfo.value.prediction.error == "OOM" assert excinfo.value.prediction.status == "failed" + + +@pytest.mark.asyncio +async def test_run_with_file_output(mock_replicate_api_token): + def prediction_with_status( + status: str, output: str | list[str] | None = None + ) -> dict: + return { + "id": "p1", + "model": "test/example", + "version": "v1", + "urls": { + "get": "https://api.replicate.com/v1/predictions/p1", + "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", + }, + "created_at": "2023-10-05T12:00:00.000000Z", + "source": "api", + "status": status, + "input": {"text": "world"}, + "output": output, + "error": "OOM" if status == "failed" else None, + "logs": "", + } + + router = respx.Router(base_url="https://api.replicate.com/v1") + router.route(method="POST", path="/predictions").mock( + return_value=httpx.Response( + 201, + json=prediction_with_status("processing"), + ) + ) + router.route(method="GET", path="/predictions/p1").mock( + return_value=httpx.Response( + 200, + json=prediction_with_status( + "succeeded", "https://api.replicate.com/v1/assets/output.txt" + ), + ) + ) + router.route( + method="GET", + path="/models/test/example/versions/v1", + ).mock( + return_value=httpx.Response( + 201, + json={ + "id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1", + "created_at": "2024-07-18T00:35:56.210272Z", + "cog_version": "0.9.10", + "openapi_schema": { + "openapi": "3.0.2", + }, + }, + ) + ) + router.route(method="GET", path="/assets/output.txt").mock( + return_value=httpx.Response(200, content=b"Hello, world!") + ) + + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + client.poll_interval = 0.001 + + output = cast( + FileOutput, + client.run( + "test/example:v1", + input={ + "text": "Hello, world!", + }, + use_file_output=True, + ), + ) + + assert output.url == "https://api.replicate.com/v1/assets/output.txt" + + assert output.read() == b"Hello, world!" + for chunk in output: + assert chunk == b"Hello, world!" + + assert await output.aread() == b"Hello, world!" + async for chunk in output: + assert chunk == b"Hello, world!" + + +@pytest.mark.asyncio +async def test_run_with_file_output_array(mock_replicate_api_token): + def prediction_with_status( + status: str, output: str | list[str] | None = None + ) -> dict: + return { + "id": "p1", + "model": "test/example", + "version": "v1", + "urls": { + "get": "https://api.replicate.com/v1/predictions/p1", + "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", + }, + "created_at": "2023-10-05T12:00:00.000000Z", + "source": "api", + "status": status, + "input": {"text": "world"}, + "output": output, + "error": "OOM" if status == "failed" else None, + "logs": "", + } + + router = respx.Router(base_url="https://api.replicate.com/v1") + router.route(method="POST", path="/predictions").mock( + return_value=httpx.Response( + 201, + json=prediction_with_status("processing"), + ) + ) + router.route(method="GET", path="/predictions/p1").mock( + return_value=httpx.Response( + 200, + json=prediction_with_status( + "succeeded", + [ + "https://api.replicate.com/v1/assets/hello.txt", + "https://api.replicate.com/v1/assets/world.txt", + ], + ), + ) + ) + router.route( + method="GET", + path="/models/test/example/versions/v1", + ).mock( + return_value=httpx.Response( + 201, + json={ + "id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1", + "created_at": "2024-07-18T00:35:56.210272Z", + "cog_version": "0.9.10", + "openapi_schema": { + "openapi": "3.0.2", + }, + }, + ) + ) + router.route(method="GET", path="/assets/hello.txt").mock( + return_value=httpx.Response(200, content=b"Hello,") + ) + router.route(method="GET", path="/assets/world.txt").mock( + return_value=httpx.Response(200, content=b" world!") + ) + + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + client.poll_interval = 0.001 + + [output1, output2] = cast( + list[FileOutput], + client.run( + "test/example:v1", + input={ + "text": "Hello, world!", + }, + use_file_output=True, + ), + ) + + assert output1.url == "https://api.replicate.com/v1/assets/hello.txt" + assert output2.url == "https://api.replicate.com/v1/assets/world.txt" + + assert output1.read() == b"Hello," + assert output2.read() == b" world!" + + +@pytest.mark.asyncio +async def test_run_with_file_output_data_uri(mock_replicate_api_token): + def prediction_with_status( + status: str, output: str | list[str] | None = None + ) -> dict: + return { + "id": "p1", + "model": "test/example", + "version": "v1", + "urls": { + "get": "https://api.replicate.com/v1/predictions/p1", + "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", + }, + "created_at": "2023-10-05T12:00:00.000000Z", + "source": "api", + "status": status, + "input": {"text": "world"}, + "output": output, + "error": "OOM" if status == "failed" else None, + "logs": "", + } + + router = respx.Router(base_url="https://api.replicate.com/v1") + router.route(method="POST", path="/predictions").mock( + return_value=httpx.Response( + 201, + json=prediction_with_status("processing"), + ) + ) + router.route(method="GET", path="/predictions/p1").mock( + return_value=httpx.Response( + 200, + json=prediction_with_status( + "succeeded", + "data:text/plain;base64,SGVsbG8sIHdvcmxkIQ==", + ), + ) + ) + router.route( + method="GET", + path="/models/test/example/versions/v1", + ).mock( + return_value=httpx.Response( + 201, + json={ + "id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1", + "created_at": "2024-07-18T00:35:56.210272Z", + "cog_version": "0.9.10", + "openapi_schema": { + "openapi": "3.0.2", + }, + }, + ) + ) + + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + client.poll_interval = 0.001 + + output = cast( + FileOutput, + client.run( + "test/example:v1", + input={ + "text": "Hello, world!", + }, + use_file_output=True, + ), + ) + + assert output.url == "data:text/plain;base64,SGVsbG8sIHdvcmxkIQ==" + assert output.read() == b"Hello, world!" + for chunk in output: + assert chunk == b"Hello, world!" + + assert await output.aread() == b"Hello, world!" + async for chunk in output: + assert chunk == b"Hello, world!"