1313# limitations under the License.
1414
1515from __future__ import annotations
16- from typing import Optional , Sequence
17- from typing_extensions import Self
18- from typing import TYPE_CHECKING
16+ from typing import Iterable , Sequence , TYPE_CHECKING
1917from google .cloud .firestore_v1 import pipeline_stages as stages
20- from google .cloud .firestore_v1 .document import DocumentReference
2118from google .cloud .firestore_v1 .types .pipeline import (
2219 StructuredPipeline as StructuredPipeline_pb ,
2320)
2421from google .cloud .firestore_v1 .vector import Vector
2522from google .cloud .firestore_v1 .base_vector_query import DistanceMeasure
23+ from google .cloud .firestore_v1 .types .firestore import ExecutePipelineRequest
2624from google .cloud .firestore_v1 .pipeline_result import PipelineResult
27- from google .cloud .firestore_v1 import _helpers , document
2825from google .cloud .firestore_v1 .pipeline_expressions import (
2926 Accumulator ,
3027 Expr ,
3431 Selectable ,
3532 SampleOptions ,
3633)
34+ from google .cloud .firestore_v1 import _helpers
3735
3836if TYPE_CHECKING :
3937 from google .cloud .firestore_v1 .client import Client
4038 from google .cloud .firestore_v1 .async_client import AsyncClient
39+ from google .cloud .firestore_v1 .types .firestore import ExecutePipelineResponse
40+ from google .cloud .firestore_v1 .transaction import BaseTransaction
4141
4242
4343class _BasePipeline :
@@ -62,13 +62,14 @@ def __init__(self, client: Client | AsyncClient, *stages: stages.Stage):
6262 self .stages = tuple (stages )
6363
6464 def __repr__ (self ):
65+ cls_str = type (self ).__name__
6566 if not self .stages :
66- return "Pipeline ()"
67+ return f" { cls_str } ()"
6768 elif len (self .stages ) == 1 :
68- return f"Pipeline ({ self .stages [0 ]!r} )"
69+ return f"{ cls_str } ({ self .stages [0 ]!r} )"
6970 else :
7071 stages_str = ",\n " .join ([repr (s ) for s in self .stages ])
71- return f"Pipeline (\n { stages_str } \n )"
72+ return f"{ cls_str } (\n { stages_str } \n )"
7273
7374 def _to_pb (self ) -> StructuredPipeline_pb :
7475 return StructuredPipeline_pb (
@@ -81,20 +82,45 @@ def _append(self, new_stage):
8182 """
8283 return self .__class__ (self ._client , * self .stages , new_stage )
8384
84- @staticmethod
85- def _parse_response (response_pb , client ):
86- for doc in response_pb .results :
87- data = _helpers .decode_dict (doc .fields , client )
88- yield document .DocumentSnapshot (
89- None ,
90- data ,
91- exists = True ,
92- read_time = response_pb ._pb .execution_time ,
93- create_time = doc .create_time ,
94- update_time = doc .update_time ,
85+ def _prep_execute_request (
86+ self , transaction : BaseTransaction | None
87+ ) -> ExecutePipelineRequest :
88+ """
89+ shared logic for creating an ExecutePipelineRequest
90+ """
91+ database_name = (
92+ f"projects/{ self ._client .project } /databases/{ self ._client ._database } "
93+ )
94+ transaction_id = (
95+ _helpers .get_transaction_id (transaction )
96+ if transaction is not None
97+ else None
98+ )
99+ request = ExecutePipelineRequest (
100+ database = database_name ,
101+ transaction = transaction_id ,
102+ structured_pipeline = self ._to_pb (),
103+ )
104+ return request
105+
106+ def _execute_response_helper (
107+ self , response : ExecutePipelineResponse
108+ ) -> Iterable [PipelineResult ]:
109+ """
110+ shared logic for unpacking an ExecutePipelineReponse into PipelineResults
111+ """
112+ for doc in response .results :
113+ ref = self ._client .document (doc .name ) if doc .name else None
114+ yield PipelineResult (
115+ self ._client ,
116+ doc .fields ,
117+ ref ,
118+ response ._pb .execution_time ,
119+ doc ._pb .create_time if doc .create_time else None ,
120+ doc ._pb .update_time if doc .update_time else None ,
95121 )
96122
97- def add_fields (self , * fields : Selectable ) -> Self :
123+ def add_fields (self , * fields : Selectable ) -> "_BasePipeline" :
98124 """
99125 Adds new fields to outputs from previous stages.
100126
@@ -124,7 +150,7 @@ def add_fields(self, *fields: Selectable) -> Self:
124150 """
125151 return self ._append (stages .AddFields (* fields ))
126152
127- def remove_fields (self , * fields : Field | str ) -> Self :
153+ def remove_fields (self , * fields : Field | str ) -> "_BasePipeline" :
128154 """
129155 Removes fields from outputs of previous stages.
130156
@@ -146,7 +172,7 @@ def remove_fields(self, *fields: Field | str) -> Self:
146172 """
147173 return self ._append (stages .RemoveFields (* fields ))
148174
149- def select (self , * selections : str | Selectable ) -> Self :
175+ def select (self , * selections : str | Selectable ) -> "_BasePipeline" :
150176 """
151177 Selects or creates a set of fields from the outputs of previous stages.
152178
@@ -179,7 +205,7 @@ def select(self, *selections: str | Selectable) -> Self:
179205 """
180206 return self ._append (stages .Select (* selections ))
181207
182- def where (self , condition : FilterCondition ) -> Self :
208+ def where (self , condition : FilterCondition ) -> "_BasePipeline" :
183209 """
184210 Filters the documents from previous stages to only include those matching
185211 the specified `FilterCondition`.
@@ -223,8 +249,8 @@ def find_nearest(
223249 field : str | Expr ,
224250 vector : Sequence [float ] | "Vector" ,
225251 distance_measure : "DistanceMeasure" ,
226- options : Optional [ stages .FindNearestOptions ] = None ,
227- ) -> Self :
252+ options : stages .FindNearestOptions | None = None ,
253+ ) -> "_BasePipeline" :
228254 """
229255 Performs vector distance (similarity) search with given parameters on the
230256 stage inputs.
@@ -274,7 +300,7 @@ def find_nearest(
274300 stages .FindNearest (field , vector , distance_measure , options )
275301 )
276302
277- def sort (self , * orders : stages .Ordering ) -> Self :
303+ def sort (self , * orders : stages .Ordering ) -> "_BasePipeline" :
278304 """
279305 Sorts the documents from previous stages based on one or more `Ordering` criteria.
280306
@@ -302,7 +328,7 @@ def sort(self, *orders: stages.Ordering) -> Self:
302328 """
303329 return self ._append (stages .Sort (* orders ))
304330
305- def sample (self , limit_or_options : int | SampleOptions ) -> Self :
331+ def sample (self , limit_or_options : int | SampleOptions ) -> "_BasePipeline" :
306332 """
307333 Performs a pseudo-random sampling of the documents from the previous stage.
308334
@@ -331,7 +357,7 @@ def sample(self, limit_or_options: int | SampleOptions) -> Self:
331357 """
332358 return self ._append (stages .Sample (limit_or_options ))
333359
334- def union (self , other : Self ) -> Self :
360+ def union (self , other : Self ) -> "_BasePipeline" :
335361 """
336362 Performs a union of all documents from this pipeline and another pipeline,
337363 including duplicates.
@@ -359,7 +385,7 @@ def unnest(
359385 field : str | Selectable ,
360386 alias : str | Field | None = None ,
361387 options : Optional [stages .UnnestOptions ] = None ,
362- ) -> Self :
388+ ) -> "_BasePipeline" :
363389 """
364390 Produces a document for each element in an array field from the previous stage document.
365391
@@ -417,7 +443,7 @@ def unnest(
417443 """
418444 return self ._append (stages .Unnest (field , alias , options ))
419445
420- def generic_stage (self , name : str , * params : Expr ) -> Self :
446+ def generic_stage (self , name : str , * params : Expr ) -> "_BasePipeline" :
421447 """
422448 Adds a generic, named stage to the pipeline with specified parameters.
423449
@@ -440,7 +466,7 @@ def generic_stage(self, name: str, *params: Expr) -> Self:
440466 """
441467 return self ._append (stages .GenericStage (name , * params ))
442468
443- def offset (self , offset : int ) -> Self :
469+ def offset (self , offset : int ) -> "_BasePipeline" :
444470 """
445471 Skips the first `offset` number of documents from the results of previous stages.
446472
@@ -464,7 +490,7 @@ def offset(self, offset: int) -> Self:
464490 """
465491 return self ._append (stages .Offset (offset ))
466492
467- def limit (self , limit : int ) -> Self :
493+ def limit (self , limit : int ) -> "_BasePipeline" :
468494 """
469495 Limits the maximum number of documents returned by previous stages to `limit`.
470496
@@ -492,7 +518,7 @@ def aggregate(
492518 self ,
493519 * accumulators : ExprWithAlias [Accumulator ],
494520 groups : Sequence [str | Selectable ] = (),
495- ) -> Self :
521+ ) -> "_BasePipeline" :
496522 """
497523 Performs aggregation operations on the documents from previous stages,
498524 optionally grouped by specified fields or expressions.
@@ -538,7 +564,7 @@ def aggregate(
538564 """
539565 return self ._append (stages .Aggregate (* accumulators , groups = groups ))
540566
541- def distinct (self , * fields : str | Selectable ) -> Self :
567+ def distinct (self , * fields : str | Selectable ) -> "_BasePipeline" :
542568 """
543569 Returns documents with distinct combinations of values for the specified
544570 fields or expressions.
0 commit comments