Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix iterator support for replicate.run() #383

Merged
merged 6 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 63 additions & 14 deletions replicate/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,24 @@ def run(
if not version and (owner and name and version_id):
version = Versions(client, model=(owner, name)).get(version_id)

if version and (iterator := _make_output_iterator(version, prediction)):
return iterator

# Currently the "Prefer: wait" interface will return a prediction with a status
# of "processing" rather than a terminal state because it returns before the
# prediction has been fully processed. If request exceeds the wait time, the
# prediction will be in a "starting" state.
nickstenning marked this conversation as resolved.
Show resolved Hide resolved
if not (is_blocking and prediction.status != "starting"):
# Return a "polling" iterator if the model has an output iterator array type.
if version and (iterator := _make_output_iterator(client, version, prediction)):
return iterator

prediction.wait()

if prediction.status == "failed":
raise ModelError(prediction)

# Return an iterator for the completed prediction when needed.
if version and (iterator := _make_output_iterator(client, version, prediction)):
return iterator

if use_file_output:
return transform_output(prediction.output, client)

Expand Down Expand Up @@ -108,12 +117,25 @@ async def async_run(
if not version and (owner and name and version_id):
version = await Versions(client, model=(owner, name)).async_get(version_id)

if version and (iterator := _make_async_output_iterator(version, prediction)):
return iterator

# Currently the "Prefer: wait" interface will return a prediction with a status
# of "processing" rather than a terminal state because it returns before the
# prediction has been fully processed. If request exceeds the wait time, the
# prediction will be in a "starting" state.
if not (is_blocking and prediction.status != "starting"):
# Return a "polling" iterator if the model has an output iterator array type.
if version and (
iterator := _make_async_output_iterator(client, version, prediction)
):
nickstenning marked this conversation as resolved.
Show resolved Hide resolved
return iterator

await prediction.async_wait()

# Return an iterator for completed output if the model has an output iterator array type.
if version and (
iterator := _make_async_output_iterator(client, version, prediction)
):
return iterator

if prediction.status == "failed":
raise ModelError(prediction)

Expand All @@ -134,21 +156,48 @@ def _has_output_iterator_array_type(version: Version) -> bool:


def _make_output_iterator(
version: Version, prediction: Prediction
client: "Client", version: Version, prediction: Prediction
) -> Optional[Iterator[Any]]:
if _has_output_iterator_array_type(version):
return prediction.output_iterator()
if not _has_output_iterator_array_type(version):
return None

if prediction.status == "starting":
iterator = prediction.output_iterator()
elif prediction.output is not None:
iterator = iter(prediction.output)
else:
return None

return None
def _iterate(iter: Iterator[Any]) -> Iterator[Any]:
for chunk in iter:
yield transform_output(chunk, client)

return _iterate(iterator)
nickstenning marked this conversation as resolved.
Show resolved Hide resolved


def _make_async_output_iterator(
version: Version, prediction: Prediction
client: "Client", version: Version, prediction: Prediction
) -> Optional[AsyncIterator[Any]]:
if _has_output_iterator_array_type(version):
return prediction.async_output_iterator()
if not _has_output_iterator_array_type(version):
return None

if prediction.status == "starting":
iterator = prediction.async_output_iterator()
elif prediction.output is not None:

async def _list_to_aiter(lst: list) -> AsyncIterator:
for item in lst:
yield item

iterator = _list_to_aiter(prediction.output)
else:
return None

async def _transform(iter: AsyncIterator[Any]) -> AsyncIterator:
async for chunk in iter:
yield transform_output(chunk, client)

return None
return _transform(iterator)
nickstenning marked this conversation as resolved.
Show resolved Hide resolved


__all__: List = []
Loading