diff --git a/replicate/deployment.py b/replicate/deployment.py index e17edcbc..1f0fdaba 100644 --- a/replicate/deployment.py +++ b/replicate/deployment.py @@ -8,6 +8,7 @@ from replicate.prediction import ( Prediction, _create_prediction_body, + _create_prediction_headers, _json_to_prediction, ) from replicate.resource import Namespace, Resource @@ -425,12 +426,14 @@ def create( client=self._client, file_encoding_strategy=file_encoding_strategy, ) + headers = _create_prediction_headers(wait=params.pop("wait", None)) body = _create_prediction_body(version=None, input=input, **params) resp = self._client._request( "POST", f"/v1/deployments/{self._deployment.owner}/{self._deployment.name}/predictions", json=body, + headers=headers, ) return _json_to_prediction(self._client, resp.json()) @@ -451,12 +454,14 @@ async def async_create( client=self._client, file_encoding_strategy=file_encoding_strategy, ) + headers = _create_prediction_headers(wait=params.pop("wait", None)) body = _create_prediction_body(version=None, input=input, **params) resp = await self._client._async_request( "POST", f"/v1/deployments/{self._deployment.owner}/{self._deployment.name}/predictions", json=body, + headers=headers, ) return _json_to_prediction(self._client, resp.json()) diff --git a/replicate/model.py b/replicate/model.py index ba5e1113..31f625af 100644 --- a/replicate/model.py +++ b/replicate/model.py @@ -9,6 +9,7 @@ from replicate.prediction import ( Prediction, _create_prediction_body, + _create_prediction_headers, _json_to_prediction, ) from replicate.resource import Namespace, Resource @@ -400,12 +401,14 @@ def create( client=self._client, file_encoding_strategy=file_encoding_strategy, ) + headers = _create_prediction_headers(wait=params.pop("wait", None)) body = _create_prediction_body(version=None, input=input, **params) resp = self._client._request( "POST", url, json=body, + headers=headers, ) return _json_to_prediction(self._client, resp.json()) @@ -429,12 +432,14 @@ async def async_create( client=self._client, file_encoding_strategy=file_encoding_strategy, ) + headers = _create_prediction_headers(wait=params.pop("wait", None)) body = _create_prediction_body(version=None, input=input, **params) resp = await self._client._async_request( "POST", url, json=body, + headers=headers, ) return _json_to_prediction(self._client, resp.json()) diff --git a/replicate/prediction.py b/replicate/prediction.py index 9770029b..d09ef504 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -383,6 +383,15 @@ class CreatePredictionParams(TypedDict): stream: NotRequired[bool] """Enable streaming of prediction output.""" + wait: NotRequired[Union[int, bool]] + """ + Wait until the prediction is completed before returning. + + If `True`, wait a predetermined number of seconds until the prediction + is completed before returning. + If an `int`, wait for the specified number of seconds. + """ + file_encoding_strategy: NotRequired[FileEncodingStrategy] """The strategy to use for encoding files in the prediction input.""" @@ -463,6 +472,7 @@ def create( # type: ignore client=self._client, file_encoding_strategy=file_encoding_strategy, ) + headers = _create_prediction_headers(wait=params.pop("wait", None)) body = _create_prediction_body( version, input, @@ -472,6 +482,7 @@ def create( # type: ignore resp = self._client._request( "POST", "/v1/predictions", + headers=headers, json=body, ) @@ -554,6 +565,7 @@ async def async_create( # type: ignore client=self._client, file_encoding_strategy=file_encoding_strategy, ) + headers = _create_prediction_headers(wait=params.pop("wait", None)) body = _create_prediction_body( version, input, @@ -563,6 +575,7 @@ async def async_create( # type: ignore resp = await self._client._async_request( "POST", "/v1/predictions", + headers=headers, json=body, ) @@ -603,6 +616,20 @@ async def async_cancel(self, id: str) -> Prediction: return _json_to_prediction(self._client, resp.json()) +def _create_prediction_headers( + *, + wait: Optional[Union[int, bool]] = None, +) -> Dict[str, Any]: + headers = {} + + if wait: + if isinstance(wait, bool): + headers["Prefer"] = "wait" + elif isinstance(wait, int): + headers["Prefer"] = f"wait={wait}" + return headers + + def _create_prediction_body( # pylint: disable=too-many-arguments version: Optional[Union[Version, str]], input: Optional[Dict[str, Any]],