diff --git a/replicate/deployment.py b/replicate/deployment.py index 8dfb6e7..21fc990 100644 --- a/replicate/deployment.py +++ b/replicate/deployment.py @@ -8,7 +8,7 @@ from replicate.prediction import ( Prediction, _create_prediction_body, - _create_prediction_headers, + _create_prediction_request_params, _json_to_prediction, ) from replicate.resource import Namespace, Resource @@ -421,21 +421,25 @@ def create( Create a new prediction with the deployment. """ + wait = params.pop("wait", None) file_encoding_strategy = params.pop("file_encoding_strategy", None) + if input is not None: input = encode_json( input, client=self._client, file_encoding_strategy=file_encoding_strategy, ) - headers = _create_prediction_headers(wait=params.pop("wait", None)) - body = _create_prediction_body(version=None, input=input, **params) + body = _create_prediction_body(version=None, input=input, **params) + extras = _create_prediction_request_params( + wait=wait, + ) resp = self._client._request( "POST", f"/v1/deployments/{self._deployment.owner}/{self._deployment.name}/predictions", json=body, - headers=headers, + **extras, ) return _json_to_prediction(self._client, resp.json()) @@ -449,6 +453,7 @@ async def async_create( Create a new prediction with the deployment. """ + wait = params.pop("wait", None) file_encoding_strategy = params.pop("file_encoding_strategy", None) if input is not None: input = await async_encode_json( @@ -456,14 +461,16 @@ async def async_create( client=self._client, file_encoding_strategy=file_encoding_strategy, ) - headers = _create_prediction_headers(wait=params.pop("wait", None)) - body = _create_prediction_body(version=None, input=input, **params) + body = _create_prediction_body(version=None, input=input, **params) + extras = _create_prediction_request_params( + wait=wait, + ) resp = await self._client._async_request( "POST", f"/v1/deployments/{self._deployment.owner}/{self._deployment.name}/predictions", json=body, - headers=headers, + **extras, ) return _json_to_prediction(self._client, resp.json()) @@ -484,24 +491,20 @@ def create( Create a new prediction with the deployment. """ - url = _create_prediction_url_from_deployment(deployment) - + wait = params.pop("wait", None) file_encoding_strategy = params.pop("file_encoding_strategy", None) + + url = _create_prediction_url_from_deployment(deployment) if input is not None: input = encode_json( input, client=self._client, file_encoding_strategy=file_encoding_strategy, ) - headers = _create_prediction_headers(wait=params.pop("wait", None)) - body = _create_prediction_body(version=None, input=input, **params) - resp = self._client._request( - "POST", - url, - json=body, - headers=headers, - ) + body = _create_prediction_body(version=None, input=input, **params) + extras = _create_prediction_request_params(wait=wait) + resp = self._client._request("POST", url, json=body, **extras) return _json_to_prediction(self._client, resp.json()) @@ -515,9 +518,10 @@ async def async_create( Create a new prediction with the deployment. """ - url = _create_prediction_url_from_deployment(deployment) - + wait = params.pop("wait", None) file_encoding_strategy = params.pop("file_encoding_strategy", None) + + url = _create_prediction_url_from_deployment(deployment) if input is not None: input = await async_encode_json( input, @@ -525,15 +529,9 @@ async def async_create( file_encoding_strategy=file_encoding_strategy, ) - headers = _create_prediction_headers(wait=params.pop("wait", None)) body = _create_prediction_body(version=None, input=input, **params) - - resp = await self._client._async_request( - "POST", - url, - json=body, - headers=headers, - ) + extras = _create_prediction_request_params(wait=wait) + resp = await self._client._async_request("POST", url, json=body, **extras) return _json_to_prediction(self._client, resp.json()) diff --git a/replicate/model.py b/replicate/model.py index 1cf144a..a52459e 100644 --- a/replicate/model.py +++ b/replicate/model.py @@ -9,7 +9,7 @@ from replicate.prediction import ( Prediction, _create_prediction_body, - _create_prediction_headers, + _create_prediction_request_params, _json_to_prediction, ) from replicate.resource import Namespace, Resource @@ -389,24 +389,20 @@ def create( Create a new prediction with the deployment. """ - url = _create_prediction_url_from_model(model) - + wait = params.pop("wait", None) file_encoding_strategy = params.pop("file_encoding_strategy", None) + + path = _create_prediction_path_from_model(model) if input is not None: input = encode_json( input, client=self._client, file_encoding_strategy=file_encoding_strategy, ) - headers = _create_prediction_headers(wait=params.pop("wait", None)) - body = _create_prediction_body(version=None, input=input, **params) - resp = self._client._request( - "POST", - url, - json=body, - headers=headers, - ) + body = _create_prediction_body(version=None, input=input, **params) + extras = _create_prediction_request_params(wait=wait) + resp = self._client._request("POST", path, json=body, **extras) return _json_to_prediction(self._client, resp.json()) @@ -420,24 +416,21 @@ async def async_create( Create a new prediction with the deployment. """ - url = _create_prediction_url_from_model(model) - + wait = params.pop("wait", None) file_encoding_strategy = params.pop("file_encoding_strategy", None) + + path = _create_prediction_path_from_model(model) + if input is not None: input = await async_encode_json( input, client=self._client, file_encoding_strategy=file_encoding_strategy, ) - headers = _create_prediction_headers(wait=params.pop("wait", None)) - body = _create_prediction_body(version=None, input=input, **params) - resp = await self._client._async_request( - "POST", - url, - json=body, - headers=headers, - ) + body = _create_prediction_body(version=None, input=input, **params) + extras = _create_prediction_request_params(wait=wait) + resp = await self._client._async_request("POST", path, json=body, **extras) return _json_to_prediction(self._client, resp.json()) @@ -522,7 +515,7 @@ def _json_to_model(client: "Client", json: Dict[str, Any]) -> Model: return model -def _create_prediction_url_from_model( +def _create_prediction_path_from_model( model: Union[str, Tuple[str, str], "Model"], ) -> str: owner, name = None, None diff --git a/replicate/prediction.py b/replicate/prediction.py index aa3e45c..b4ff047 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -16,6 +16,7 @@ overload, ) +import httpx from typing_extensions import NotRequired, TypedDict, Unpack from replicate.exceptions import ModelError, ReplicateError @@ -446,6 +447,9 @@ def create( # type: ignore Create a new prediction for the specified model, version, or deployment. """ + wait = params.pop("wait", None) + file_encoding_strategy = params.pop("file_encoding_strategy", None) + if args: version = args[0] if len(args) > 0 else None input = args[1] if len(args) > 1 else input @@ -477,26 +481,20 @@ def create( # type: ignore **params, ) - file_encoding_strategy = params.pop("file_encoding_strategy", None) if input is not None: input = encode_json( input, client=self._client, file_encoding_strategy=file_encoding_strategy, ) - headers = _create_prediction_headers(wait=params.pop("wait", None)) + body = _create_prediction_body( version, input, **params, ) - - resp = self._client._request( - "POST", - "/v1/predictions", - headers=headers, - json=body, - ) + extras = _create_prediction_request_params(wait=wait) + resp = self._client._request("POST", "/v1/predictions", json=body, **extras) return _json_to_prediction(self._client, resp.json()) @@ -538,6 +536,8 @@ async def async_create( # type: ignore """ Create a new prediction for the specified model, version, or deployment. """ + wait = params.pop("wait", None) + file_encoding_strategy = params.pop("file_encoding_strategy", None) if args: version = args[0] if len(args) > 0 else None @@ -570,25 +570,21 @@ async def async_create( # type: ignore **params, ) - file_encoding_strategy = params.pop("file_encoding_strategy", None) if input is not None: input = await async_encode_json( input, client=self._client, file_encoding_strategy=file_encoding_strategy, ) - headers = _create_prediction_headers(wait=params.pop("wait", None)) + body = _create_prediction_body( version, input, **params, ) - + extras = _create_prediction_request_params(wait=wait) resp = await self._client._async_request( - "POST", - "/v1/predictions", - headers=headers, - json=body, + "POST", "/v1/predictions", json=body, **extras ) return _json_to_prediction(self._client, resp.json()) @@ -628,6 +624,40 @@ async def async_cancel(self, id: str) -> Prediction: return _json_to_prediction(self._client, resp.json()) +class CreatePredictionRequestParams(TypedDict): + headers: NotRequired[Optional[dict]] + timeout: NotRequired[Optional[httpx.Timeout]] + + +def _create_prediction_request_params( + wait: Optional[Union[int, bool]], +) -> CreatePredictionRequestParams: + timeout = _create_prediction_timeout(wait=wait) + headers = _create_prediction_headers(wait=wait) + + return { + "headers": headers, + "timeout": timeout, + } + + +def _create_prediction_timeout( + *, wait: Optional[Union[int, bool]] = None +) -> Union[httpx.Timeout, None]: + """ + Returns an `httpx.Timeout` instances appropriate for the optional + `Prefer: wait=x` header that can be provided with the request. This + will ensure that we give the server enough time to respond with + a partial prediction in the event that the request times out. + """ + + if not wait: + return None + + read_timeout = 60.0 if isinstance(wait, bool) else wait + return httpx.Timeout(5.0, read=read_timeout + 0.5) + + def _create_prediction_headers( *, wait: Optional[Union[int, bool]] = None,