diff --git a/postgrest/_async/client.py b/postgrest/_async/client.py index 32384da9..4c9a102f 100644 --- a/postgrest/_async/client.py +++ b/postgrest/_async/client.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Dict, Union, cast +from typing import Any, Dict, Optional, Union, cast from deprecation import deprecated from httpx import Headers, QueryParams, Timeout @@ -11,6 +11,7 @@ DEFAULT_POSTGREST_CLIENT_HEADERS, DEFAULT_POSTGREST_CLIENT_TIMEOUT, ) +from ..types import CountMethod from ..utils import AsyncClient from .request_builder import AsyncRequestBuilder, AsyncRPCFilterRequestBuilder @@ -78,12 +79,22 @@ def from_table(self, table: str) -> AsyncRequestBuilder: """Alias to :meth:`from_`.""" return self.from_(table) - def rpc(self, func: str, params: dict) -> AsyncRPCFilterRequestBuilder[Any]: + def rpc( + self, + func: str, + params: dict, + count: Optional[CountMethod] = None, + head: bool = False, + get: bool = False, + ) -> AsyncRPCFilterRequestBuilder[Any]: """Perform a stored procedure call. Args: func: The name of the remote procedure to run. params: The parameters to be passed to the remote procedure. + count: The method to use to get the count of rows returned. + head: When set to `true`, `data` will not be returned. Useful if you only need the count. + get: When set to `true`, the function will be called with read-only access mode. Returns: :class:`AsyncRPCFilterRequestBuilder` Example: @@ -97,7 +108,11 @@ def rpc(self, func: str, params: dict) -> AsyncRPCFilterRequestBuilder[Any]: This method now returns a :class:`AsyncFilterRequestBuilder` which allows you to filter on the RPC's resultset. """ + method = "HEAD" if head else "GET" if get else "POST" + + headers = Headers({"Prefer": f"count={count}"}) if count else Headers() + # the params here are params to be sent to the RPC and not the queryparams! return AsyncRPCFilterRequestBuilder[Any]( - self.session, f"/rpc/{func}", "POST", Headers(), QueryParams(), json=params + self.session, f"/rpc/{func}", method, headers, QueryParams(), json=params ) diff --git a/postgrest/_sync/client.py b/postgrest/_sync/client.py index 226e9994..f3ca5e7e 100644 --- a/postgrest/_sync/client.py +++ b/postgrest/_sync/client.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Dict, Union, cast +from typing import Any, Dict, Optional, Union, cast from deprecation import deprecated from httpx import Headers, QueryParams, Timeout @@ -11,6 +11,7 @@ DEFAULT_POSTGREST_CLIENT_HEADERS, DEFAULT_POSTGREST_CLIENT_TIMEOUT, ) +from ..types import CountMethod from ..utils import SyncClient from .request_builder import SyncRequestBuilder, SyncRPCFilterRequestBuilder @@ -78,12 +79,22 @@ def from_table(self, table: str) -> SyncRequestBuilder: """Alias to :meth:`from_`.""" return self.from_(table) - def rpc(self, func: str, params: dict) -> SyncRPCFilterRequestBuilder[Any]: + def rpc( + self, + func: str, + params: dict, + count: Optional[CountMethod] = None, + head: bool = False, + get: bool = False, + ) -> SyncRPCFilterRequestBuilder[Any]: """Perform a stored procedure call. Args: func: The name of the remote procedure to run. params: The parameters to be passed to the remote procedure. + count: The method to use to get the count of rows returned. + head: When set to `true`, `data` will not be returned. Useful if you only need the count. + get: When set to `true`, the function will be called with read-only access mode. Returns: :class:`AsyncRPCFilterRequestBuilder` Example: @@ -97,7 +108,11 @@ def rpc(self, func: str, params: dict) -> SyncRPCFilterRequestBuilder[Any]: This method now returns a :class:`AsyncFilterRequestBuilder` which allows you to filter on the RPC's resultset. """ + method = "HEAD" if head else "GET" if get else "POST" + + headers = Headers({"Prefer": f"count={count}"}) if count else Headers() + # the params here are params to be sent to the RPC and not the queryparams! return SyncRPCFilterRequestBuilder[Any]( - self.session, f"/rpc/{func}", "POST", Headers(), QueryParams(), json=params + self.session, f"/rpc/{func}", method, headers, QueryParams(), json=params )