diff --git a/postgrest/_async/request_builder.py b/postgrest/_async/request_builder.py index 13df7715..c13cf200 100644 --- a/postgrest/_async/request_builder.py +++ b/postgrest/_async/request_builder.py @@ -282,6 +282,7 @@ def select( self, *columns: str, count: Optional[CountMethod] = None, + head: Optional[bool] = None, ) -> AsyncSelectRequestBuilder[_ReturnT]: """Run a SELECT query. @@ -291,7 +292,7 @@ def select( Returns: :class:`AsyncSelectRequestBuilder` """ - method, params, headers, json = pre_select(*columns, count=count) + method, params, headers, json = pre_select(*columns, count=count, head=head) return AsyncSelectRequestBuilder[_ReturnT]( self.session, self.path, method, headers, params, json ) diff --git a/postgrest/_sync/request_builder.py b/postgrest/_sync/request_builder.py index f713e7d2..742db9a2 100644 --- a/postgrest/_sync/request_builder.py +++ b/postgrest/_sync/request_builder.py @@ -282,6 +282,7 @@ def select( self, *columns: str, count: Optional[CountMethod] = None, + head: Optional[bool] = None, ) -> SyncSelectRequestBuilder[_ReturnT]: """Run a SELECT query. @@ -291,7 +292,7 @@ def select( Returns: :class:`SyncSelectRequestBuilder` """ - method, params, headers, json = pre_select(*columns, count=count) + method, params, headers, json = pre_select(*columns, count=count, head=head) return SyncSelectRequestBuilder[_ReturnT]( self.session, self.path, method, headers, params, json ) diff --git a/postgrest/base_request_builder.py b/postgrest/base_request_builder.py index f8c5a324..eb7721e2 100644 --- a/postgrest/base_request_builder.py +++ b/postgrest/base_request_builder.py @@ -52,16 +52,29 @@ def _unique_columns(json: List[Dict]): return columns +def _cleaned_columns(columns: str) -> str: + quoted = False + result = [] + + for c in columns: + if c.isspace() and not quoted: + continue + if c == '"': + quoted = not quoted + result.append(c) + + return ",".join(result) + + def pre_select( *columns: str, count: Optional[CountMethod] = None, + head: Optional[bool] = None, ) -> QueryArgs: - if columns: - method = RequestMethod.GET - params = QueryParams({"select": ",".join(columns)}) - else: - method = RequestMethod.HEAD - params = QueryParams() + method = RequestMethod.HEAD if head else RequestMethod.GET + cleaned_columns = _cleaned_columns(columns or "*") + params = QueryParams({"select": cleaned_columns}) + headers = Headers({"Prefer": f"count={count}"}) if count else Headers() return QueryArgs(method, params, headers, {}) diff --git a/tests/_async/test_request_builder.py b/tests/_async/test_request_builder.py index bd55e140..feb98032 100644 --- a/tests/_async/test_request_builder.py +++ b/tests/_async/test_request_builder.py @@ -31,8 +31,16 @@ def test_select(self, request_builder: AsyncRequestBuilder): def test_select_with_count(self, request_builder: AsyncRequestBuilder): builder = request_builder.select(count=CountMethod.exact) - assert builder.params.get("select") is None + assert builder.params["select"] == "*" assert builder.headers["prefer"] == "count=exact" + assert builder.http_method == "GET" + assert builder.json == {} + + def test_select_with_head(self, request_builder: AsyncRequestBuilder): + builder = request_builder.select("col1", "col2", head=True) + + assert builder.params.get("select") == "col1,col2" + assert builder.headers.get("prefer") is None assert builder.http_method == "HEAD" assert builder.json == {} @@ -193,7 +201,7 @@ def test_explain_options(self, request_builder: AsyncRequestBuilder): class TestOrder: def test_order(self, request_builder: AsyncRequestBuilder): builder = request_builder.select().order("country_name", desc=True) - assert str(builder.params) == "order=country_name.desc" + assert str(builder.params) == "select=%2A&order=country_name.desc" def test_multiple_orders(self, request_builder: AsyncRequestBuilder): builder = ( @@ -201,7 +209,7 @@ def test_multiple_orders(self, request_builder: AsyncRequestBuilder): .order("country_name", desc=True) .order("iso", desc=True) ) - assert str(builder.params) == "order=country_name.desc%2Ciso.desc" + assert str(builder.params) == "select=%2A&order=country_name.desc%2Ciso.desc" def test_multiple_orders_on_foreign_table(self, request_builder: AsyncRequestBuilder): foreign_table = "cities" @@ -212,7 +220,7 @@ def test_multiple_orders_on_foreign_table(self, request_builder: AsyncRequestBui ) assert ( str(builder.params) - == "order=cities%28city_name%29.desc%2Ccities%28id%29.desc" + == "select=%2A&order=cities%28city_name%29.desc%2Ccities%28id%29.desc" ) diff --git a/tests/_sync/test_request_builder.py b/tests/_sync/test_request_builder.py index 45657ed9..8d8a1939 100644 --- a/tests/_sync/test_request_builder.py +++ b/tests/_sync/test_request_builder.py @@ -31,8 +31,16 @@ def test_select(self, request_builder: SyncRequestBuilder): def test_select_with_count(self, request_builder: SyncRequestBuilder): builder = request_builder.select(count=CountMethod.exact) - assert builder.params.get("select") is None + assert builder.params["select"] == "*" assert builder.headers["prefer"] == "count=exact" + assert builder.http_method == "GET" + assert builder.json == {} + + def test_select_with_head(self, request_builder: SyncRequestBuilder): + builder = request_builder.select("col1", "col2", head=True) + + assert builder.params.get("select") == "col1,col2" + assert builder.headers.get("prefer") is None assert builder.http_method == "HEAD" assert builder.json == {} @@ -193,7 +201,7 @@ def test_explain_options(self, request_builder: SyncRequestBuilder): class TestOrder: def test_order(self, request_builder: SyncRequestBuilder): builder = request_builder.select().order("country_name", desc=True) - assert str(builder.params) == "order=country_name.desc" + assert str(builder.params) == "select=%2A&order=country_name.desc" def test_multiple_orders(self, request_builder: SyncRequestBuilder): builder = ( @@ -201,7 +209,7 @@ def test_multiple_orders(self, request_builder: SyncRequestBuilder): .order("country_name", desc=True) .order("iso", desc=True) ) - assert str(builder.params) == "order=country_name.desc%2Ciso.desc" + assert str(builder.params) == "select=%2A&order=country_name.desc%2Ciso.desc" def test_multiple_orders_on_foreign_table(self, request_builder: SyncRequestBuilder): foreign_table = "cities" @@ -212,7 +220,7 @@ def test_multiple_orders_on_foreign_table(self, request_builder: SyncRequestBuil ) assert ( str(builder.params) - == "order=cities%28city_name%29.desc%2Ccities%28id%29.desc" + == "select=%2A&order=cities%28city_name%29.desc%2Ccities%28id%29.desc" )