Skip to content

Commit 9cdb8c9

Browse files
committed
Merge branch 'pipeline_queries_2_query_parity' into pipeline_queries_3_stable_stages
2 parents 77dca01 + 09d45cb commit 9cdb8c9

19 files changed

+1203
-110
lines changed

google/cloud/firestore_v1/_helpers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@ def __ne__(self, other):
120120
else:
121121
return not equality_val
122122

123+
def __repr__(self):
124+
return f"{type(self).__name__}(latitude={self.latitude}, longitude={self.longitude})"
125+
123126

124127
def verify_path(path, is_collection) -> None:
125128
"""Verifies that a ``path`` has the correct form.

google/cloud/firestore_v1/async_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,4 +417,4 @@ def transaction(self, **kwargs) -> AsyncTransaction:
417417

418418
@property
419419
def _pipeline_cls(self):
420-
raise AsyncPipeline
420+
return AsyncPipeline

google/cloud/firestore_v1/async_pipeline.py

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,14 @@
1313
# limitations under the License.
1414

1515
from __future__ import annotations
16-
import datetime
1716
from typing import AsyncIterable, TYPE_CHECKING
1817
from google.cloud.firestore_v1 import pipeline_stages as stages
19-
from google.cloud.firestore_v1.types.firestore import ExecutePipelineRequest
20-
from google.cloud.firestore_v1.async_document import AsyncDocumentReference
2118
from google.cloud.firestore_v1.base_pipeline import _BasePipeline
22-
from google.cloud.firestore_v1.pipeline_result import PipelineResult
2319

2420
if TYPE_CHECKING:
2521
from google.cloud.firestore_v1.async_client import AsyncClient
22+
from google.cloud.firestore_v1.pipeline_result import PipelineResult
23+
from google.cloud.firestore_v1.async_transaction import AsyncTransaction
2624

2725

2826
class AsyncPipeline(_BasePipeline):
@@ -58,29 +56,24 @@ def __init__(self, client: AsyncClient, *stages: stages.Stage):
5856
"""
5957
super().__init__(client, *stages)
6058

61-
async def execute(self) -> AsyncIterable[PipelineResult]:
62-
database_name = (
63-
f"projects/{self._client.project}/databases/{self._client._database}"
64-
)
65-
request = ExecutePipelineRequest(
66-
database=database_name,
67-
structured_pipeline=self._to_pb(),
68-
read_time=datetime.datetime.now(),
69-
)
59+
async def execute(
60+
self,
61+
transaction: "AsyncTransaction" | None = None,
62+
) -> AsyncIterable[PipelineResult]:
63+
"""
64+
Executes this pipeline, providing results through an Iterable
65+
66+
Args:
67+
transaction
68+
(Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]):
69+
An existing transaction that this query will run in.
70+
If a ``transaction`` is used and it already has write operations
71+
added, this method cannot be used (i.e. read-after-write is not
72+
allowed).
73+
"""
74+
request = self._prep_execute_request(transaction)
7075
async for response in await self._client._firestore_api.execute_pipeline(
7176
request
7277
):
73-
for doc in response.results:
74-
doc_ref = (
75-
AsyncDocumentReference(doc.name, client=self._client)
76-
if doc.name
77-
else None
78-
)
79-
yield PipelineResult(
80-
self._client,
81-
doc.fields,
82-
doc_ref,
83-
response._pb.execution_time,
84-
doc.create_time,
85-
doc.update_tiem,
86-
)
78+
for result in self._execute_response_helper(response):
79+
yield result

google/cloud/firestore_v1/base_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@
6161
from google.cloud.firestore_v1.bulk_writer import BulkWriter, BulkWriterOptions
6262
from google.cloud.firestore_v1.field_path import render_field_path
6363
from google.cloud.firestore_v1.services.firestore import client as firestore_client
64+
from google.cloud.firestore_v1.pipeline_source import PipelineSource
65+
from google.cloud.firestore_v1.base_pipeline import _BasePipeline
6466

6567
DEFAULT_DATABASE = "(default)"
6668
"""str: The default database used in a :class:`~google.cloud.firestore_v1.client.Client`."""

google/cloud/firestore_v1/base_pipeline.py

Lines changed: 60 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,15 @@
1313
# limitations under the License.
1414

1515
from __future__ import annotations
16-
from typing import Optional, Sequence
17-
from typing_extensions import Self
18-
from typing import TYPE_CHECKING
16+
from typing import Iterable, Sequence, TYPE_CHECKING
1917
from google.cloud.firestore_v1 import pipeline_stages as stages
20-
from google.cloud.firestore_v1.document import DocumentReference
2118
from google.cloud.firestore_v1.types.pipeline import (
2219
StructuredPipeline as StructuredPipeline_pb,
2320
)
2421
from google.cloud.firestore_v1.vector import Vector
2522
from google.cloud.firestore_v1.base_vector_query import DistanceMeasure
23+
from google.cloud.firestore_v1.types.firestore import ExecutePipelineRequest
2624
from google.cloud.firestore_v1.pipeline_result import PipelineResult
27-
from google.cloud.firestore_v1 import _helpers, document
2825
from google.cloud.firestore_v1.pipeline_expressions import (
2926
Accumulator,
3027
Expr,
@@ -34,10 +31,13 @@
3431
Selectable,
3532
SampleOptions,
3633
)
34+
from google.cloud.firestore_v1 import _helpers
3735

3836
if TYPE_CHECKING:
3937
from google.cloud.firestore_v1.client import Client
4038
from google.cloud.firestore_v1.async_client import AsyncClient
39+
from google.cloud.firestore_v1.types.firestore import ExecutePipelineResponse
40+
from google.cloud.firestore_v1.transaction import BaseTransaction
4141

4242

4343
class _BasePipeline:
@@ -62,13 +62,14 @@ def __init__(self, client: Client | AsyncClient, *stages: stages.Stage):
6262
self.stages = tuple(stages)
6363

6464
def __repr__(self):
65+
cls_str = type(self).__name__
6566
if not self.stages:
66-
return "Pipeline()"
67+
return f"{cls_str}()"
6768
elif len(self.stages) == 1:
68-
return f"Pipeline({self.stages[0]!r})"
69+
return f"{cls_str}({self.stages[0]!r})"
6970
else:
7071
stages_str = ",\n ".join([repr(s) for s in self.stages])
71-
return f"Pipeline(\n {stages_str}\n)"
72+
return f"{cls_str}(\n {stages_str}\n)"
7273

7374
def _to_pb(self) -> StructuredPipeline_pb:
7475
return StructuredPipeline_pb(
@@ -81,20 +82,45 @@ def _append(self, new_stage):
8182
"""
8283
return self.__class__(self._client, *self.stages, new_stage)
8384

84-
@staticmethod
85-
def _parse_response(response_pb, client):
86-
for doc in response_pb.results:
87-
data = _helpers.decode_dict(doc.fields, client)
88-
yield document.DocumentSnapshot(
89-
None,
90-
data,
91-
exists=True,
92-
read_time=response_pb._pb.execution_time,
93-
create_time=doc.create_time,
94-
update_time=doc.update_time,
85+
def _prep_execute_request(
86+
self, transaction: BaseTransaction | None
87+
) -> ExecutePipelineRequest:
88+
"""
89+
shared logic for creating an ExecutePipelineRequest
90+
"""
91+
database_name = (
92+
f"projects/{self._client.project}/databases/{self._client._database}"
93+
)
94+
transaction_id = (
95+
_helpers.get_transaction_id(transaction)
96+
if transaction is not None
97+
else None
98+
)
99+
request = ExecutePipelineRequest(
100+
database=database_name,
101+
transaction=transaction_id,
102+
structured_pipeline=self._to_pb(),
103+
)
104+
return request
105+
106+
def _execute_response_helper(
107+
self, response: ExecutePipelineResponse
108+
) -> Iterable[PipelineResult]:
109+
"""
110+
shared logic for unpacking an ExecutePipelineReponse into PipelineResults
111+
"""
112+
for doc in response.results:
113+
ref = self._client.document(doc.name) if doc.name else None
114+
yield PipelineResult(
115+
self._client,
116+
doc.fields,
117+
ref,
118+
response._pb.execution_time,
119+
doc._pb.create_time if doc.create_time else None,
120+
doc._pb.update_time if doc.update_time else None,
95121
)
96122

97-
def add_fields(self, *fields: Selectable) -> Self:
123+
def add_fields(self, *fields: Selectable) -> "_BasePipeline":
98124
"""
99125
Adds new fields to outputs from previous stages.
100126
@@ -124,7 +150,7 @@ def add_fields(self, *fields: Selectable) -> Self:
124150
"""
125151
return self._append(stages.AddFields(*fields))
126152

127-
def remove_fields(self, *fields: Field | str) -> Self:
153+
def remove_fields(self, *fields: Field | str) -> "_BasePipeline":
128154
"""
129155
Removes fields from outputs of previous stages.
130156
@@ -146,7 +172,7 @@ def remove_fields(self, *fields: Field | str) -> Self:
146172
"""
147173
return self._append(stages.RemoveFields(*fields))
148174

149-
def select(self, *selections: str | Selectable) -> Self:
175+
def select(self, *selections: str | Selectable) -> "_BasePipeline":
150176
"""
151177
Selects or creates a set of fields from the outputs of previous stages.
152178
@@ -179,7 +205,7 @@ def select(self, *selections: str | Selectable) -> Self:
179205
"""
180206
return self._append(stages.Select(*selections))
181207

182-
def where(self, condition: FilterCondition) -> Self:
208+
def where(self, condition: FilterCondition) -> "_BasePipeline":
183209
"""
184210
Filters the documents from previous stages to only include those matching
185211
the specified `FilterCondition`.
@@ -223,8 +249,8 @@ def find_nearest(
223249
field: str | Expr,
224250
vector: Sequence[float] | "Vector",
225251
distance_measure: "DistanceMeasure",
226-
options: Optional[stages.FindNearestOptions] = None,
227-
) -> Self:
252+
options: stages.FindNearestOptions | None = None,
253+
) -> "_BasePipeline":
228254
"""
229255
Performs vector distance (similarity) search with given parameters on the
230256
stage inputs.
@@ -274,7 +300,7 @@ def find_nearest(
274300
stages.FindNearest(field, vector, distance_measure, options)
275301
)
276302

277-
def sort(self, *orders: stages.Ordering) -> Self:
303+
def sort(self, *orders: stages.Ordering) -> "_BasePipeline":
278304
"""
279305
Sorts the documents from previous stages based on one or more `Ordering` criteria.
280306
@@ -302,7 +328,7 @@ def sort(self, *orders: stages.Ordering) -> Self:
302328
"""
303329
return self._append(stages.Sort(*orders))
304330

305-
def sample(self, limit_or_options: int | SampleOptions) -> Self:
331+
def sample(self, limit_or_options: int | SampleOptions) -> "_BasePipeline":
306332
"""
307333
Performs a pseudo-random sampling of the documents from the previous stage.
308334
@@ -331,7 +357,7 @@ def sample(self, limit_or_options: int | SampleOptions) -> Self:
331357
"""
332358
return self._append(stages.Sample(limit_or_options))
333359

334-
def union(self, other: Self) -> Self:
360+
def union(self, other: Self) -> "_BasePipeline":
335361
"""
336362
Performs a union of all documents from this pipeline and another pipeline,
337363
including duplicates.
@@ -359,7 +385,7 @@ def unnest(
359385
field: str | Selectable,
360386
alias: str | Field | None = None,
361387
options: Optional[stages.UnnestOptions] = None,
362-
) -> Self:
388+
) -> "_BasePipeline":
363389
"""
364390
Produces a document for each element in an array field from the previous stage document.
365391
@@ -417,7 +443,7 @@ def unnest(
417443
"""
418444
return self._append(stages.Unnest(field, alias, options))
419445

420-
def generic_stage(self, name: str, *params: Expr) -> Self:
446+
def generic_stage(self, name: str, *params: Expr) -> "_BasePipeline":
421447
"""
422448
Adds a generic, named stage to the pipeline with specified parameters.
423449
@@ -440,7 +466,7 @@ def generic_stage(self, name: str, *params: Expr) -> Self:
440466
"""
441467
return self._append(stages.GenericStage(name, *params))
442468

443-
def offset(self, offset: int) -> Self:
469+
def offset(self, offset: int) -> "_BasePipeline":
444470
"""
445471
Skips the first `offset` number of documents from the results of previous stages.
446472
@@ -464,7 +490,7 @@ def offset(self, offset: int) -> Self:
464490
"""
465491
return self._append(stages.Offset(offset))
466492

467-
def limit(self, limit: int) -> Self:
493+
def limit(self, limit: int) -> "_BasePipeline":
468494
"""
469495
Limits the maximum number of documents returned by previous stages to `limit`.
470496
@@ -492,7 +518,7 @@ def aggregate(
492518
self,
493519
*accumulators: ExprWithAlias[Accumulator],
494520
groups: Sequence[str | Selectable] = (),
495-
) -> Self:
521+
) -> "_BasePipeline":
496522
"""
497523
Performs aggregation operations on the documents from previous stages,
498524
optionally grouped by specified fields or expressions.
@@ -538,7 +564,7 @@ def aggregate(
538564
"""
539565
return self._append(stages.Aggregate(*accumulators, groups=groups))
540566

541-
def distinct(self, *fields: str | Selectable) -> Self:
567+
def distinct(self, *fields: str | Selectable) -> "_BasePipeline":
542568
"""
543569
Returns documents with distinct combinations of values for the specified
544570
fields or expressions.

google/cloud/firestore_v1/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,4 +399,4 @@ def transaction(self, **kwargs) -> Transaction:
399399

400400
@property
401401
def _pipeline_cls(self):
402-
raise Pipeline
402+
return Pipeline

google/cloud/firestore_v1/field_path.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from __future__ import annotations
1717
import re
1818
from collections import abc
19-
from typing import Iterable, cast
19+
from typing import Any, Iterable, cast, MutableMapping
2020

2121
_FIELD_PATH_MISSING_TOP = "{!r} is not contained in the data"
2222
_FIELD_PATH_MISSING_KEY = "{!r} is not contained in the data for the key {!r}"
@@ -170,7 +170,7 @@ def render_field_path(field_names: Iterable[str]):
170170
get_field_path = render_field_path # backward-compatibility
171171

172172

173-
def get_nested_value(field_path: str, data: dict):
173+
def get_nested_value(field_path: str, data: MutableMapping[str, Any]):
174174
"""Get a (potentially nested) value from a dictionary.
175175
176176
If the data is nested, for example:

0 commit comments

Comments
 (0)