Skip to content

Add use_file_output to streaming methods #355

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions replicate/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,25 +190,27 @@ def stream(
self,
ref: str,
input: Optional[Dict[str, Any]] = None,
use_file_output: Optional[bool] = None,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Iterator["ServerSentEvent"]:
"""
Stream a model's output.
"""

return stream(self, ref, input, **params)
return stream(self, ref, input, use_file_output, **params)

async def async_stream(
self,
ref: str,
input: Optional[Dict[str, Any]] = None,
use_file_output: Optional[bool] = None,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> AsyncIterator["ServerSentEvent"]:
"""
Stream a model's output asynchronously.
"""

return async_stream(self, ref, input, **params)
return async_stream(self, ref, input, use_file_output, **params)


# Adapted from https://github.com/encode/httpx/issues/108#issuecomment-1132753155
Expand Down
18 changes: 14 additions & 4 deletions replicate/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,10 @@ async def async_wait(self) -> None:
await asyncio.sleep(self._client.poll_interval)
await self.async_reload()

def stream(self) -> Iterator["ServerSentEvent"]:
def stream(
self,
use_file_output: Optional[bool] = None,
) -> Iterator["ServerSentEvent"]:
"""
Stream the prediction output.

Expand All @@ -170,9 +173,14 @@ def stream(self) -> Iterator["ServerSentEvent"]:
headers["Cache-Control"] = "no-store"

with self._client._client.stream("GET", url, headers=headers) as response:
yield from EventSource(response)
yield from EventSource(
self._client, response, use_file_output=use_file_output
)

async def async_stream(self) -> AsyncIterator["ServerSentEvent"]:
async def async_stream(
self,
use_file_output: Optional[bool] = None,
) -> AsyncIterator["ServerSentEvent"]:
"""
Stream the prediction output asynchronously.

Expand All @@ -194,7 +202,9 @@ async def async_stream(self) -> AsyncIterator["ServerSentEvent"]:
async with self._client._async_client.stream(
"GET", url, headers=headers
) as response:
async for event in EventSource(response):
async for event in EventSource(
self._client, response, use_file_output=use_file_output
):
yield event

def cancel(self) -> None:
Expand Down
34 changes: 30 additions & 4 deletions replicate/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from replicate import identifier
from replicate.exceptions import ReplicateError
from replicate.helpers import transform_output

try:
from pydantic import v1 as pydantic # type: ignore
Expand Down Expand Up @@ -62,10 +63,19 @@ class EventSource:
A server-sent event source.
"""

client: "Client"
response: "httpx.Response"

def __init__(self, response: "httpx.Response") -> None:
use_file_output: bool

def __init__(
self,
client: "Client",
response: "httpx.Response",
use_file_output: Optional[bool] = None,
) -> None:
self.client = client
self.response = response
self.use_file_output = use_file_output or False
content_type, _, _ = response.headers["content-type"].partition(";")
if content_type != "text/event-stream":
raise ValueError(
Expand Down Expand Up @@ -147,6 +157,12 @@ def __iter__(self) -> Iterator[ServerSentEvent]:
if sse.event == ServerSentEvent.EventType.ERROR:
raise RuntimeError(sse.data)

if (
self.use_file_output
and sse.event == ServerSentEvent.EventType.OUTPUT
):
sse.data = transform_output(sse.data, client=self.client)

yield sse

if sse.event == ServerSentEvent.EventType.DONE:
Expand All @@ -161,6 +177,12 @@ async def __aiter__(self) -> AsyncIterator[ServerSentEvent]:
if sse.event == ServerSentEvent.EventType.ERROR:
raise RuntimeError(sse.data)

if (
self.use_file_output
and sse.event == ServerSentEvent.EventType.OUTPUT
):
sse.data = transform_output(sse.data, client=self.client)

yield sse

if sse.event == ServerSentEvent.EventType.DONE:
Expand All @@ -171,6 +193,7 @@ def stream(
client: "Client",
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
input: Optional[Dict[str, Any]] = None,
use_file_output: Optional[bool] = None,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Iterator[ServerSentEvent]:
"""
Expand Down Expand Up @@ -204,13 +227,14 @@ def stream(
headers["Cache-Control"] = "no-store"

with client._client.stream("GET", url, headers=headers) as response:
yield from EventSource(response)
yield from EventSource(client, response, use_file_output=use_file_output)


async def async_stream(
client: "Client",
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
input: Optional[Dict[str, Any]] = None,
use_file_output: Optional[bool] = None,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> AsyncIterator[ServerSentEvent]:
"""
Expand Down Expand Up @@ -244,7 +268,9 @@ async def async_stream(
headers["Cache-Control"] = "no-store"

async with client._async_client.stream("GET", url, headers=headers) as response:
async for event in EventSource(response):
async for event in EventSource(
client, response, use_file_output=use_file_output
):
yield event


Expand Down