Skip to content

Commit a9368b3

Browse files
committed
added stages unit tests
1 parent 28fd42d commit a9368b3

File tree

2 files changed

+132
-1
lines changed

2 files changed

+132
-1
lines changed

tests/unit/v1/test_pipeline_expressions.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,29 @@ def test_ctor(self):
4141
with pytest.raises(TypeError):
4242
expr.Expr()
4343

44+
@pytest.mark.parametrize("method,args,result_cls", [
45+
("eq", (None,), expr.Eq),
46+
("neq", (None,), expr.Neq),
47+
("lt", (None,), expr.Lt),
48+
("lte", (None,), expr.Lte),
49+
("gt", (None,), expr.Gt),
50+
("gte", (None,), expr.Gte),
51+
("in_any", ([None],), expr.In),
52+
("not_in_any", ([None],),expr.Not),
53+
("array_contains", (None,), expr.ArrayContains),
54+
("array_contains_any", ([None],), expr.ArrayContainsAny),
55+
("is_nan", (), expr.IsNaN),
56+
("exists", (), expr.Exists),
57+
])
58+
def test_methods(self, method, args, result_cls):
59+
"""
60+
base expr should have methods for certain stages
61+
"""
62+
method_ptr = getattr(expr.Expr, method)
63+
result = method_ptr(mock.Mock(), *args)
64+
assert isinstance(result, result_cls)
65+
66+
4467

4568
class TestConstant:
4669
@pytest.mark.parametrize(

tests/unit/v1/test_pipeline_stages.py

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515
import pytest
1616

1717
import google.cloud.firestore_v1.pipeline_stages as stages
18-
from google.cloud.firestore_v1.pipeline_expressions import Constant
18+
from google.cloud.firestore_v1.pipeline_expressions import (
19+
Constant,
20+
Field,
21+
Ordering,
22+
)
1923
from google.cloud.firestore_v1.types.document import Value
2024
from google.cloud.firestore_v1._helpers import GeoPoint
2125

@@ -119,3 +123,107 @@ def test_to_pb(self):
119123
assert result.args[0].boolean_value is True
120124
assert result.args[1].string_value == "test"
121125
assert len(result.options) == 0
126+
127+
128+
class TestLimit:
129+
def _make_one(self, *args, **kwargs):
130+
return stages.Limit(*args, **kwargs)
131+
132+
def test_repr(self):
133+
instance = self._make_one(10)
134+
repr_str = repr(instance)
135+
assert repr_str == "Limit(limit=10)"
136+
137+
def test_to_pb(self):
138+
instance = self._make_one(5)
139+
result = instance._to_pb()
140+
assert result.name == "limit"
141+
assert len(result.args) == 1
142+
assert result.args[0].integer_value == 5
143+
assert len(result.options) == 0
144+
145+
146+
class TestOffset:
147+
def _make_one(self, *args, **kwargs):
148+
return stages.Offset(*args, **kwargs)
149+
150+
def test_repr(self):
151+
instance = self._make_one(20)
152+
repr_str = repr(instance)
153+
assert repr_str == "Offset(offset=20)"
154+
155+
def test_to_pb(self):
156+
instance = self._make_one(3)
157+
result = instance._to_pb()
158+
assert result.name == "offset"
159+
assert len(result.args) == 1
160+
assert result.args[0].integer_value == 3
161+
assert len(result.options) == 0
162+
163+
164+
class TestSelect:
165+
def _make_one(self, *args, **kwargs):
166+
return stages.Select(*args, **kwargs)
167+
168+
def test_repr(self):
169+
instance = self._make_one("field1", Field.of("field2"))
170+
repr_str = repr(instance)
171+
assert repr_str == "Select(projections=[Field.of('field1'), Field.of('field2')])"
172+
173+
def test_to_pb(self):
174+
instance = self._make_one("field1", "field2.subfield", Field.of("field3"))
175+
result = instance._to_pb()
176+
assert result.name == "select"
177+
assert len(result.args) == 1
178+
got_map = result.args[0].map_value.fields
179+
assert got_map.get("field1").field_reference_value == "field1"
180+
assert got_map.get("field2.subfield").field_reference_value == "field2.subfield"
181+
assert got_map.get("field3").field_reference_value == "field3"
182+
assert len(result.options) == 0
183+
184+
185+
class TestSort:
186+
def _make_one(self, *args, **kwargs):
187+
return stages.Sort(*args, **kwargs)
188+
189+
def test_repr(self):
190+
order1 = Ordering(Field.of("field1"), "ASCENDING")
191+
instance = self._make_one(order1)
192+
repr_str = repr(instance)
193+
assert repr_str == "Sort(orders=[Field.of('field1').ascending()])"
194+
195+
def test_to_pb(self):
196+
order1 = Ordering(Field.of("name"), "ASCENDING")
197+
order2 = Ordering(Field.of("age"), "DESCENDING")
198+
instance = self._make_one(order1, order2)
199+
result = instance._to_pb()
200+
assert result.name == "sort"
201+
assert len(result.args) == 2
202+
got_map = result.args[0].map_value.fields
203+
assert got_map.get("expression").field_reference_value == "name"
204+
assert got_map.get("direction").string_value == "ascending"
205+
assert len(result.options) == 0
206+
207+
208+
class TestWhere:
209+
def _make_one(self, *args, **kwargs):
210+
return stages.Where(*args, **kwargs)
211+
212+
def test_repr(self):
213+
condition = Field.of("age").gt(30)
214+
instance = self._make_one(condition)
215+
repr_str = repr(instance)
216+
assert repr_str == "Where(condition=Gt(Field.of('age'), Constant.of(30)))"
217+
218+
def test_to_pb(self):
219+
condition = Field.of("city").eq("SF")
220+
instance = self._make_one(condition)
221+
result = instance._to_pb()
222+
assert result.name == "where"
223+
assert len(result.args) == 1
224+
got_fn = result.args[0].function_value
225+
assert got_fn.name == "eq"
226+
assert len(got_fn.args) == 2
227+
assert got_fn.args[0].field_reference_value == "city"
228+
assert got_fn.args[1].string_value == "SF"
229+
assert len(result.options) == 0

0 commit comments

Comments
 (0)