1515import mock
1616import pytest
1717
18+ from google .cloud .firestore_v1 import pipeline_stages as stages
19+
1820
1921def _make_async_pipeline (* args , client = mock .Mock ()):
2022 from google .cloud .firestore_v1 .async_pipeline import AsyncPipeline
@@ -54,11 +56,9 @@ def test_async_pipeline_repr_single_stage():
5456
5557
5658def test_async_pipeline_repr_multiple_stage ():
57- from google .cloud .firestore_v1 .pipeline_stages import GenericStage , Collection
58-
59- stage_1 = Collection ("path" )
60- stage_2 = GenericStage ("second" , 2 )
61- stage_3 = GenericStage ("third" , 3 )
59+ stage_1 = stages .Collection ("path" )
60+ stage_2 = stages .GenericStage ("second" , 2 )
61+ stage_3 = stages .GenericStage ("third" , 3 )
6262 ppl = _make_async_pipeline (stage_1 , stage_2 , stage_3 )
6363 repr_str = repr (ppl )
6464 assert repr_str == (
@@ -71,10 +71,8 @@ def test_async_pipeline_repr_multiple_stage():
7171
7272
7373def test_async_pipeline_repr_long ():
74- from google .cloud .firestore_v1 .pipeline_stages import GenericStage
75-
7674 num_stages = 100
77- stage_list = [GenericStage ("custom" , i ) for i in range (num_stages )]
75+ stage_list = [stages . GenericStage ("custom" , i ) for i in range (num_stages )]
7876 ppl = _make_async_pipeline (* stage_list )
7977 repr_str = repr (ppl )
8078 assert repr_str .count ("GenericStage" ) == num_stages
@@ -83,10 +81,9 @@ def test_async_pipeline_repr_long():
8381
8482def test_async_pipeline__to_pb ():
8583 from google .cloud .firestore_v1 .types .pipeline import StructuredPipeline
86- from google .cloud .firestore_v1 .pipeline_stages import GenericStage
8784
88- stage_1 = GenericStage ("first" )
89- stage_2 = GenericStage ("second" )
85+ stage_1 = stages . GenericStage ("first" )
86+ stage_2 = stages . GenericStage ("second" )
9087 ppl = _make_async_pipeline (stage_1 , stage_2 )
9188 pb = ppl ._to_pb ()
9289 assert isinstance (pb , StructuredPipeline )
@@ -96,11 +93,9 @@ def test_async_pipeline__to_pb():
9693
9794def test_async_pipeline_append ():
9895 """append should create a new pipeline with the additional stage"""
99- from google .cloud .firestore_v1 .pipeline_stages import GenericStage
100-
101- stage_1 = GenericStage ("first" )
96+ stage_1 = stages .GenericStage ("first" )
10297 ppl_1 = _make_async_pipeline (stage_1 , client = object ())
103- stage_2 = GenericStage ("second" )
98+ stage_2 = stages . GenericStage ("second" )
10499 ppl_2 = ppl_1 ._append (stage_2 )
105100 assert ppl_1 != ppl_2
106101 assert len (ppl_1 .stages ) == 1
@@ -118,15 +113,14 @@ async def test_async_pipeline_execute_empty():
118113 """
119114 from google .cloud .firestore_v1 .types import ExecutePipelineResponse
120115 from google .cloud .firestore_v1 .types import ExecutePipelineRequest
121- from google .cloud .firestore_v1 .pipeline_stages import GenericStage
122116
123117 client = mock .Mock ()
124118 client .project = "A"
125119 client ._database = "B"
126120 mock_rpc = mock .AsyncMock ()
127121 client ._firestore_api .execute_pipeline = mock_rpc
128122 mock_rpc .return_value = _async_it ([ExecutePipelineResponse ()])
129- ppl_1 = _make_async_pipeline (GenericStage ("s" ), client = client )
123+ ppl_1 = _make_async_pipeline (stages . GenericStage ("s" ), client = client )
130124
131125 results = [r async for r in ppl_1 .execute ()]
132126 assert results == []
@@ -145,7 +139,6 @@ async def test_async_pipeline_execute_no_doc_ref():
145139 from google .cloud .firestore_v1 .types import Document
146140 from google .cloud .firestore_v1 .types import ExecutePipelineResponse
147141 from google .cloud .firestore_v1 .types import ExecutePipelineRequest
148- from google .cloud .firestore_v1 .pipeline_stages import GenericStage
149142 from google .cloud .firestore_v1 .pipeline_result import PipelineResult
150143
151144 client = mock .Mock ()
@@ -156,7 +149,7 @@ async def test_async_pipeline_execute_no_doc_ref():
156149 mock_rpc .return_value = _async_it (
157150 [ExecutePipelineResponse (results = [Document ()], execution_time = {"seconds" : 9 })]
158151 )
159- ppl_1 = _make_async_pipeline (GenericStage ("s" ), client = client )
152+ ppl_1 = _make_async_pipeline (stages . GenericStage ("s" ), client = client )
160153
161154 results = [r async for r in ppl_1 .execute ()]
162155 assert len (results ) == 1
@@ -315,3 +308,20 @@ async def test_async_pipeline_execute_with_transaction():
315308 assert request .structured_pipeline == ppl_1 ._to_pb ()
316309 assert request .database == "projects/A/databases/B"
317310 assert request .transaction == b"123"
311+
312+ @pytest .mark .parametrize ("method,args,result_cls" , [
313+ ("select" , (), stages .Select ),
314+ ("where" , (mock .Mock (),), stages .Where ),
315+ ("sort" , (), stages .Sort ),
316+ ("offset" , (1 ,), stages .Offset ),
317+ ("limit" , (1 ,), stages .Limit ),
318+
319+ ])
320+ def test_async_pipeline_methods (method , args , result_cls ):
321+ start_ppl = _make_async_pipeline ()
322+ method_ptr = getattr (start_ppl , method )
323+ result_ppl = method_ptr (* args )
324+ assert result_ppl != start_ppl
325+ assert len (start_ppl .stages ) == 0
326+ assert len (result_ppl .stages ) == 1
327+ assert isinstance (result_ppl .stages [0 ], result_cls )
0 commit comments