From 886c8e8430fa9741267266acebd4a13bd418c925 Mon Sep 17 00:00:00 2001 From: Jim Morrison Date: Fri, 23 Feb 2024 23:48:53 +0000 Subject: [PATCH] feat: Pass through small IN queries to the server. --- google/cloud/ndb/_datastore_query.py | 1 + google/cloud/ndb/query.py | 14 ++++++---- tests/system/test_query.py | 38 ++++++++++++++++++++++++++-- tests/unit/test__gql.py | 14 ++++++++++ tests/unit/test_model.py | 11 ++++++++ tests/unit/test_query.py | 10 ++++++++ 6 files changed, 81 insertions(+), 7 deletions(-) diff --git a/google/cloud/ndb/_datastore_query.py b/google/cloud/ndb/_datastore_query.py index 90c32ba1..b5cd2aa9 100644 --- a/google/cloud/ndb/_datastore_query.py +++ b/google/cloud/ndb/_datastore_query.py @@ -56,6 +56,7 @@ "<=": query_pb2.PropertyFilter.Operator.LESS_THAN_OR_EQUAL, ">": query_pb2.PropertyFilter.Operator.GREATER_THAN, ">=": query_pb2.PropertyFilter.Operator.GREATER_THAN_OR_EQUAL, + "in": query_pb2.PropertyFilter.Operator.IN, } _KEY_NOT_IN_CACHE = object() diff --git a/google/cloud/ndb/query.py b/google/cloud/ndb/query.py index 65b8f140..2770d742 100644 --- a/google/cloud/ndb/query.py +++ b/google/cloud/ndb/query.py @@ -176,6 +176,9 @@ def ranked(cls, rank): _GT_OP = ">" _OPS = frozenset([_EQ_OP, _NE_OP, _LT_OP, "<=", _GT_OP, ">=", _IN_OP]) +# Limit from https://cloud.google.com/datastore/docs/concepts/queries#in +_SERVER_IN_LIMIT = 30 + _log = logging.getLogger(__name__) @@ -655,7 +658,8 @@ def __new__(cls, name, opsymbol, value): return FalseNode() if len(nodes) == 1: return nodes[0] - return DisjunctionNode(*nodes) + if len(nodes) > _SERVER_IN_LIMIT: + return DisjunctionNode(*nodes) instance = super(FilterNode, cls).__new__(cls) instance._name = name @@ -704,17 +708,17 @@ def _to_filter(self, post=False): representation of the filter. Raises: - NotImplementedError: If the ``opsymbol`` is ``!=`` or ``in``, since - they should correspond to a composite filter. This should + NotImplementedError: If the ``opsymbol`` is ``!=``, since + it should correspond to a composite filter. This should never occur since the constructor will create ``OR`` nodes for - ``!=`` and ``in`` + ``!=``. """ # Avoid circular import in Python 2.7 from google.cloud.ndb import _datastore_query if post: return None - if self._opsymbol in (_NE_OP, _IN_OP): + if self._opsymbol in (_NE_OP,): raise NotImplementedError( "Inequality filters are not single filter " "expressions and therefore cannot be converted " diff --git a/tests/system/test_query.py b/tests/system/test_query.py index df00a6b6..5c2357d1 100644 --- a/tests/system/test_query.py +++ b/tests/system/test_query.py @@ -865,6 +865,40 @@ def make_entities(): assert not more +@pytest.mark.usefixtures("client_context") +def test_fetch_page_in_query(dispose_of): + page_size = 5 + n_entities = page_size * 2 + + class SomeKind(ndb.Model): + foo = ndb.IntegerProperty() + + @ndb.toplevel + def make_entities(): + entities = [SomeKind(foo=n_entities) for i in range(n_entities)] + keys = yield [entity.put_async() for entity in entities] + raise ndb.Return(keys) + + for key in make_entities(): + dispose_of(key._key) + + query = SomeKind.query().filter(SomeKind.foo.IN([1, 2, n_entities])) + eventually(query.fetch, length_equals(n_entities)) + + results, cursor, more = query.fetch_page(page_size) + assert len(results) == page_size + assert more + + safe_cursor = cursor.urlsafe() + next_cursor = ndb.Cursor(urlsafe=safe_cursor) + results, cursor, more = query.fetch_page(page_size, start_cursor=next_cursor) + assert len(results) == page_size + + results, cursor, more = query.fetch_page(page_size, start_cursor=cursor) + assert not results + assert not more + + @pytest.mark.usefixtures("client_context") def test_polymodel_query(ds_entity): class Animal(ndb.PolyModel): @@ -1819,13 +1853,13 @@ class SomeKind(ndb.Model): eventually(SomeKind.query().fetch, length_equals(5)) - query = SomeKind.gql("where foo in (2, 3)").order(SomeKind.foo) + query = SomeKind.gql("where foo in (2, 3)") results = query.fetch() assert len(results) == 2 assert results[0].foo == 2 assert results[1].foo == 3 - query = SomeKind.gql("where foo in :1", [2, 3]).order(SomeKind.foo) + query = SomeKind.gql("where foo in :1", [2, 3]) results = query.fetch() assert len(results) == 2 assert results[0].foo == 2 diff --git a/tests/unit/test__gql.py b/tests/unit/test__gql.py index ee9371c8..834b7dbe 100644 --- a/tests/unit/test__gql.py +++ b/tests/unit/test__gql.py @@ -317,6 +317,20 @@ class SomeKind(model.Model): @staticmethod @pytest.mark.usefixtures("in_context") def test_get_query_in(): + query_module._SERVER_IN_LIMIT = 5 + + class SomeKind(model.Model): + prop1 = model.IntegerProperty() + + gql = gql_module.GQL("SELECT prop1 FROM SomeKind WHERE prop1 IN (1, 2, 3)") + query = gql.get_query() + assert query.filters == query_module.FilterNode("prop1", "in", [1, 2, 3]) + + @staticmethod + @pytest.mark.usefixtures("in_context") + def test_get_query_in_large(): + query_module._SERVER_IN_LIMIT = 2 + class SomeKind(model.Model): prop1 = model.IntegerProperty() diff --git a/tests/unit/test_model.py b/tests/unit/test_model.py index 6cb0ac90..a9aa9aa6 100644 --- a/tests/unit/test_model.py +++ b/tests/unit/test_model.py @@ -553,6 +553,17 @@ def test__IN_wrong_container(): @staticmethod def test__IN(): + query_module._SERVER_IN_LIMIT = 8 + prop = model.Property("name", indexed=True) + in_node = prop._IN(["a", None, "xy"]) + expected = query_module.FilterNode("name", "in", ["a", None, "xy"]) + assert in_node == expected + # Also verify the alias + assert in_node == prop.IN(["a", None, "xy"]) + + @staticmethod + def test__IN_large(): + query_module._SERVER_IN_LIMIT = 2 prop = model.Property("name", indexed=True) or_node = prop._IN(["a", None, "xy"]) expected = query_module.DisjunctionNode( diff --git a/tests/unit/test_query.py b/tests/unit/test_query.py index df7df55a..fd4e11d3 100644 --- a/tests/unit/test_query.py +++ b/tests/unit/test_query.py @@ -627,6 +627,16 @@ def test_constructor_with_key(): @staticmethod def test_constructor_in(): + query_module._SERVER_IN_LIMIT = 8 + in_node = query_module.FilterNode("a", "in", ("x", "y", "z")) + assert not isinstance(in_node, query_module.DisjunctionNode) + assert in_node._name == "a" + assert in_node._opsymbol == "in" + assert in_node._value == ("x", "y", "z") + + @staticmethod + def test_constructor_in_large(): + query_module._SERVER_IN_LIMIT = 2 or_node = query_module.FilterNode("a", "in", ("x", "y", "z")) filter_node1 = query_module.FilterNode("a", "=", "x")