Skip to content

Commit 9462d10

Browse files
feat: pipelines preview (#1156)
This PR adds support for Pipeline Queries
1 parent f0ed940 commit 9462d10

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+13577
-258
lines changed

google/cloud/firestore_v1/_helpers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@ def __ne__(self, other):
120120
else:
121121
return not equality_val
122122

123+
def __repr__(self):
124+
return f"{type(self).__name__}(latitude={self.latitude}, longitude={self.longitude})"
125+
123126

124127
def verify_path(path, is_collection) -> None:
125128
"""Verifies that a ``path`` has the correct form.

google/cloud/firestore_v1/async_client.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@
5656
from google.cloud.firestore_v1.services.firestore.transports import (
5757
grpc_asyncio as firestore_grpc_transport,
5858
)
59+
from google.cloud.firestore_v1.async_pipeline import AsyncPipeline
60+
from google.cloud.firestore_v1.pipeline_source import PipelineSource
5961

6062
if TYPE_CHECKING: # pragma: NO COVER
6163
import datetime
@@ -438,3 +440,10 @@ def transaction(
438440
A transaction attached to this client.
439441
"""
440442
return AsyncTransaction(self, max_attempts=max_attempts, read_only=read_only)
443+
444+
@property
445+
def _pipeline_cls(self):
446+
return AsyncPipeline
447+
448+
def pipeline(self) -> PipelineSource:
449+
return PipelineSource(self)
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
.. warning::
16+
**Preview API**: Firestore Pipelines is currently in preview and is
17+
subject to potential breaking changes in future releases
18+
"""
19+
20+
from __future__ import annotations
21+
from typing import TYPE_CHECKING
22+
from google.cloud.firestore_v1 import pipeline_stages as stages
23+
from google.cloud.firestore_v1.base_pipeline import _BasePipeline
24+
from google.cloud.firestore_v1.pipeline_result import AsyncPipelineStream
25+
from google.cloud.firestore_v1.pipeline_result import PipelineSnapshot
26+
from google.cloud.firestore_v1.pipeline_result import PipelineResult
27+
28+
if TYPE_CHECKING: # pragma: NO COVER
29+
import datetime
30+
from google.cloud.firestore_v1.async_client import AsyncClient
31+
from google.cloud.firestore_v1.async_transaction import AsyncTransaction
32+
from google.cloud.firestore_v1.pipeline_expressions import Constant
33+
from google.cloud.firestore_v1.types.document import Value
34+
from google.cloud.firestore_v1.query_profile import PipelineExplainOptions
35+
36+
37+
class AsyncPipeline(_BasePipeline):
38+
"""
39+
Pipelines allow for complex data transformations and queries involving
40+
multiple stages like filtering, projection, aggregation, and vector search.
41+
42+
This class extends `_BasePipeline` and provides methods to execute the
43+
defined pipeline stages using an asynchronous `AsyncClient`.
44+
45+
Usage Example:
46+
>>> from google.cloud.firestore_v1.pipeline_expressions import Field
47+
>>>
48+
>>> async def run_pipeline():
49+
... client = AsyncClient(...)
50+
... pipeline = client.pipeline()
51+
... .collection("books")
52+
... .where(Field.of("published").gt(1980))
53+
... .select("title", "author")
54+
... async for result in pipeline.stream():
55+
... print(result)
56+
57+
Use `client.pipeline()` to create instances of this class.
58+
59+
.. warning::
60+
**Preview API**: Firestore Pipelines is currently in preview and is
61+
subject to potential breaking changes in future releases
62+
"""
63+
64+
def __init__(self, client: AsyncClient, *stages: stages.Stage):
65+
"""
66+
Initializes an asynchronous Pipeline.
67+
68+
Args:
69+
client: The asynchronous `AsyncClient` instance to use for execution.
70+
*stages: Initial stages for the pipeline.
71+
"""
72+
super().__init__(client, *stages)
73+
74+
async def execute(
75+
self,
76+
*,
77+
transaction: "AsyncTransaction" | None = None,
78+
read_time: datetime.datetime | None = None,
79+
explain_options: PipelineExplainOptions | None = None,
80+
additional_options: dict[str, Value | Constant] = {},
81+
) -> PipelineSnapshot[PipelineResult]:
82+
"""
83+
Executes this pipeline and returns results as a list
84+
85+
Args:
86+
transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]):
87+
An existing transaction that this query will run in.
88+
If a ``transaction`` is used and it already has write operations
89+
added, this method cannot be used (i.e. read-after-write is not
90+
allowed).
91+
read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given
92+
time. This must be a microsecond precision timestamp within the past one hour, or
93+
if Point-in-Time Recovery is enabled, can additionally be a whole minute timestamp
94+
within the past 7 days. For the most accurate results, use UTC timezone.
95+
explain_options (Optional[:class:`~google.cloud.firestore_v1.query_profile.PipelineExplainOptions`]):
96+
Options to enable query profiling for this query. When set,
97+
explain_metrics will be available on the returned list.
98+
additional_options (Optional[dict[str, Value | Constant]]): Additional options to pass to the query.
99+
These options will take precedence over method argument if there is a conflict (e.g. explain_options)
100+
"""
101+
kwargs = {k: v for k, v in locals().items() if k != "self"}
102+
stream = AsyncPipelineStream(PipelineResult, self, **kwargs)
103+
results = [result async for result in stream]
104+
return PipelineSnapshot(results, stream)
105+
106+
def stream(
107+
self,
108+
*,
109+
read_time: datetime.datetime | None = None,
110+
transaction: "AsyncTransaction" | None = None,
111+
explain_options: PipelineExplainOptions | None = None,
112+
additional_options: dict[str, Value | Constant] = {},
113+
) -> AsyncPipelineStream[PipelineResult]:
114+
"""
115+
Process this pipeline as a stream, providing results through an AsyncIterable
116+
117+
Args:
118+
transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]):
119+
An existing transaction that this query will run in.
120+
If a ``transaction`` is used and it already has write operations
121+
added, this method cannot be used (i.e. read-after-write is not
122+
allowed).
123+
read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given
124+
time. This must be a microsecond precision timestamp within the past one hour, or
125+
if Point-in-Time Recovery is enabled, can additionally be a whole minute timestamp
126+
within the past 7 days. For the most accurate results, use UTC timezone.
127+
explain_options (Optional[:class:`~google.cloud.firestore_v1.query_profile.PipelineExplainOptions`]):
128+
Options to enable query profiling for this query. When set,
129+
explain_metrics will be available on the returned generator.
130+
additional_options (Optional[dict[str, Value | Constant]]): Additional options to pass to the query.
131+
These options will take precedence over method argument if there is a conflict (e.g. explain_options)
132+
"""
133+
kwargs = {k: v for k, v in locals().items() if k != "self"}
134+
return AsyncPipelineStream(PipelineResult, self, **kwargs)

google/cloud/firestore_v1/base_aggregation.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@
2121
from __future__ import annotations
2222

2323
import abc
24+
import itertools
2425

2526
from abc import ABC
26-
from typing import TYPE_CHECKING, Any, Coroutine, List, Optional, Tuple, Union
27+
from typing import TYPE_CHECKING, Any, Coroutine, List, Optional, Tuple, Union, Iterable
2728

2829
from google.api_core import gapic_v1
2930
from google.api_core import retry as retries
@@ -33,6 +34,10 @@
3334
from google.cloud.firestore_v1.types import (
3435
StructuredAggregationQuery,
3536
)
37+
from google.cloud.firestore_v1.pipeline_expressions import AggregateFunction
38+
from google.cloud.firestore_v1.pipeline_expressions import Count
39+
from google.cloud.firestore_v1.pipeline_expressions import AliasedExpression
40+
from google.cloud.firestore_v1.pipeline_expressions import Field
3641

3742
# Types needed only for Type Hints
3843
if TYPE_CHECKING: # pragma: NO COVER
@@ -43,6 +48,7 @@
4348
from google.cloud.firestore_v1.stream_generator import (
4449
StreamGenerator,
4550
)
51+
from google.cloud.firestore_v1.pipeline_source import PipelineSource
4652

4753
import datetime
4854

@@ -66,6 +72,9 @@ def __init__(self, alias: str, value: float, read_time=None):
6672
def __repr__(self):
6773
return f"<Aggregation alias={self.alias}, value={self.value}, readtime={self.read_time}>"
6874

75+
def _to_dict(self):
76+
return {self.alias: self.value}
77+
6978

7079
class BaseAggregation(ABC):
7180
def __init__(self, alias: str | None = None):
@@ -75,6 +84,27 @@ def __init__(self, alias: str | None = None):
7584
def _to_protobuf(self):
7685
"""Convert this instance to the protobuf representation"""
7786

87+
@abc.abstractmethod
88+
def _to_pipeline_expr(
89+
self, autoindexer: Iterable[int]
90+
) -> AliasedExpression[AggregateFunction]:
91+
"""
92+
Convert this instance to a pipeline expression for use with pipeline.aggregate()
93+
94+
Args:
95+
autoindexer: If an alias isn't supplied, one should be created with the format "field_n"
96+
The autoindexer is an iterable that provides the `n` value to use for each expression
97+
"""
98+
99+
def _pipeline_alias(self, autoindexer):
100+
"""
101+
Helper to build the alias for the pipeline expression
102+
"""
103+
if self.alias is not None:
104+
return self.alias
105+
else:
106+
return f"field_{next(autoindexer)}"
107+
78108

79109
class CountAggregation(BaseAggregation):
80110
def __init__(self, alias: str | None = None):
@@ -88,6 +118,9 @@ def _to_protobuf(self):
88118
aggregation_pb.count = StructuredAggregationQuery.Aggregation.Count()
89119
return aggregation_pb
90120

121+
def _to_pipeline_expr(self, autoindexer: Iterable[int]):
122+
return Count().as_(self._pipeline_alias(autoindexer))
123+
91124

92125
class SumAggregation(BaseAggregation):
93126
def __init__(self, field_ref: str | FieldPath, alias: str | None = None):
@@ -107,6 +140,9 @@ def _to_protobuf(self):
107140
aggregation_pb.sum.field.field_path = self.field_ref
108141
return aggregation_pb
109142

143+
def _to_pipeline_expr(self, autoindexer: Iterable[int]):
144+
return Field.of(self.field_ref).sum().as_(self._pipeline_alias(autoindexer))
145+
110146

111147
class AvgAggregation(BaseAggregation):
112148
def __init__(self, field_ref: str | FieldPath, alias: str | None = None):
@@ -126,6 +162,9 @@ def _to_protobuf(self):
126162
aggregation_pb.avg.field.field_path = self.field_ref
127163
return aggregation_pb
128164

165+
def _to_pipeline_expr(self, autoindexer: Iterable[int]):
166+
return Field.of(self.field_ref).average().as_(self._pipeline_alias(autoindexer))
167+
129168

130169
def _query_response_to_result(
131170
response_pb,
@@ -317,3 +356,21 @@ def stream(
317356
StreamGenerator[List[AggregationResult]] | AsyncStreamGenerator[List[AggregationResult]]:
318357
A generator of the query results.
319358
"""
359+
360+
def _build_pipeline(self, source: "PipelineSource"):
361+
"""
362+
Convert this query into a Pipeline
363+
364+
Queries containing a `cursor` or `limit_to_last` are not currently supported
365+
366+
Args:
367+
source: the PipelineSource to build the pipeline off of
368+
Raises:
369+
- NotImplementedError: raised if the query contains a `cursor` or `limit_to_last`
370+
Returns:
371+
a Pipeline representing the query
372+
"""
373+
# use autoindexer to keep track of which field number to use for un-aliased fields
374+
autoindexer = itertools.count(start=1)
375+
exprs = [a._to_pipeline_expr(autoindexer) for a in self._aggregations]
376+
return self._nested_query._build_pipeline(source).aggregate(*exprs)

google/cloud/firestore_v1/base_client.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
Optional,
3838
Tuple,
3939
Union,
40+
Type,
4041
)
4142

4243
import google.api_core.client_options
@@ -61,6 +62,8 @@
6162
from google.cloud.firestore_v1.bulk_writer import BulkWriter, BulkWriterOptions
6263
from google.cloud.firestore_v1.field_path import render_field_path
6364
from google.cloud.firestore_v1.services.firestore import client as firestore_client
65+
from google.cloud.firestore_v1.pipeline_source import PipelineSource
66+
from google.cloud.firestore_v1.base_pipeline import _BasePipeline
6467

6568
DEFAULT_DATABASE = "(default)"
6669
"""str: The default database used in a :class:`~google.cloud.firestore_v1.client.Client`."""
@@ -502,6 +505,20 @@ def transaction(
502505
) -> BaseTransaction:
503506
raise NotImplementedError
504507

508+
def pipeline(self) -> PipelineSource:
509+
"""
510+
Start a pipeline with this client.
511+
512+
Returns:
513+
:class:`~google.cloud.firestore_v1.pipeline_source.PipelineSource`:
514+
A pipeline that uses this client`
515+
"""
516+
raise NotImplementedError
517+
518+
@property
519+
def _pipeline_cls(self) -> Type["_BasePipeline"]:
520+
raise NotImplementedError
521+
505522

506523
def _reference_info(references: list) -> Tuple[list, dict]:
507524
"""Get information about document references.

google/cloud/firestore_v1/base_collection.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from google.cloud.firestore_v1.async_document import AsyncDocumentReference
5050
from google.cloud.firestore_v1.document import DocumentReference
5151
from google.cloud.firestore_v1.field_path import FieldPath
52+
from google.cloud.firestore_v1.pipeline_source import PipelineSource
5253
from google.cloud.firestore_v1.query_profile import ExplainOptions
5354
from google.cloud.firestore_v1.query_results import QueryResultsList
5455
from google.cloud.firestore_v1.stream_generator import StreamGenerator
@@ -603,6 +604,21 @@ def find_nearest(
603604
distance_threshold=distance_threshold,
604605
)
605606

607+
def _build_pipeline(self, source: "PipelineSource"):
608+
"""
609+
Convert this query into a Pipeline
610+
611+
Queries containing a `cursor` or `limit_to_last` are not currently supported
612+
613+
Args:
614+
source: the PipelineSource to build the pipeline off o
615+
Raises:
616+
- NotImplementedError: raised if the query contains a `cursor` or `limit_to_last`
617+
Returns:
618+
a Pipeline representing the query
619+
"""
620+
return self._query()._build_pipeline(source)
621+
606622

607623
def _auto_id() -> str:
608624
"""Generate a "random" automatically generated ID.

0 commit comments

Comments
 (0)