From c8dbebe720d8356f58041b2be0dfe4d5962cc1df Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 12 Dec 2023 16:23:51 -0800 Subject: [PATCH 1/4] point mypy check at the right folder --- noxfile.py | 4 ++-- owlbot.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/noxfile.py b/noxfile.py index b4b720257..d76919b5c 100644 --- a/noxfile.py +++ b/noxfile.py @@ -146,9 +146,9 @@ def pytype(session): def mypy(session): """Verify type hints are mypy compatible.""" session.install("-e", ".") - session.install("mypy", "types-setuptools") + session.install("mypy", "types-setuptools", "types-protobuf", "types-requests") # TODO: also verify types on tests, all of google package - session.run("mypy", "-p", "google.cloud.firestore", "--no-incremental") + session.run("mypy", "-p", "google.cloud.firestore_v1", "--no-incremental") @nox.session(python=DEFAULT_PYTHON_VERSION) diff --git a/owlbot.py b/owlbot.py index 4384bb53a..75ee1bc2f 100644 --- a/owlbot.py +++ b/owlbot.py @@ -314,7 +314,7 @@ def mypy(session): session.install("-e", ".") session.install("mypy", "types-setuptools") # TODO: also verify types on tests, all of google package - session.run("mypy", "-p", "google.cloud.firestore", "--no-incremental") + session.run("mypy", "-p", "google.cloud.firestore_v1", "--no-incremental") @nox.session(python=DEFAULT_PYTHON_VERSION) From 36a2eedaf13c71ad86cd91f74c428824ce2f226b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 12 Dec 2023 16:25:16 -0800 Subject: [PATCH 2/4] strip types from all veneer files --- google/cloud/firestore_v1/__init__.py | 2 +- google/cloud/firestore_v1/_helpers.py | 164 ++++++-------- google/cloud/firestore_v1/aggregation.py | 18 +- .../cloud/firestore_v1/async_aggregation.py | 18 +- google/cloud/firestore_v1/async_batch.py | 8 +- google/cloud/firestore_v1/async_client.py | 52 ++--- google/cloud/firestore_v1/async_collection.py | 46 ++-- google/cloud/firestore_v1/async_document.py | 54 ++--- google/cloud/firestore_v1/async_query.py | 54 ++--- .../cloud/firestore_v1/async_transaction.py | 42 ++-- google/cloud/firestore_v1/base_aggregation.py | 53 ++--- google/cloud/firestore_v1/base_batch.py | 32 ++- google/cloud/firestore_v1/base_client.py | 101 ++++----- google/cloud/firestore_v1/base_collection.py | 116 +++++----- google/cloud/firestore_v1/base_document.py | 126 ++++++----- google/cloud/firestore_v1/base_query.py | 188 ++++++++-------- google/cloud/firestore_v1/base_transaction.py | 76 ++++--- google/cloud/firestore_v1/batch.py | 6 +- google/cloud/firestore_v1/bulk_batch.py | 10 +- google/cloud/firestore_v1/bulk_writer.py | 206 +++++++++--------- google/cloud/firestore_v1/client.py | 56 ++--- google/cloud/firestore_v1/collection.py | 44 ++-- google/cloud/firestore_v1/document.py | 56 ++--- google/cloud/firestore_v1/field_path.py | 14 +- google/cloud/firestore_v1/order.py | 22 +- google/cloud/firestore_v1/query.py | 54 ++--- google/cloud/firestore_v1/rate_limiter.py | 50 ++--- google/cloud/firestore_v1/transaction.py | 40 ++-- google/cloud/firestore_v1/transforms.py | 6 +- 29 files changed, 797 insertions(+), 917 deletions(-) diff --git a/google/cloud/firestore_v1/__init__.py b/google/cloud/firestore_v1/__init__.py index 1d143556f..a7322fc24 100644 --- a/google/cloud/firestore_v1/__init__.py +++ b/google/cloud/firestore_v1/__init__.py @@ -109,7 +109,7 @@ # from .types.write import Write # from .types.write import WriteResult -__all__: List[str] = [ +__all__ = [ "__version__", "And", "ArrayRemove", diff --git a/google/cloud/firestore_v1/_helpers.py b/google/cloud/firestore_v1/_helpers.py index a6b6616d3..63f92d5bc 100644 --- a/google/cloud/firestore_v1/_helpers.py +++ b/google/cloud/firestore_v1/_helpers.py @@ -47,8 +47,8 @@ Union, ) -_EmptyDict: transforms.Sentinel -_GRPC_ERROR_MAPPING: dict +# _EmptyDict: transforms.Sentinel +# _GRPC_ERROR_MAPPING: dict BAD_PATH_TEMPLATE = "A path element must be a string. Received {}, which is a {}." @@ -78,11 +78,11 @@ class GeoPoint(object): longitude (float): Longitude of a point. """ - def __init__(self, latitude, longitude) -> None: + def __init__(self, latitude, longitude): self.latitude = latitude self.longitude = longitude - def to_protobuf(self) -> latlng_pb2.LatLng: + def to_protobuf(self): """Convert the current object to protobuf. Returns: @@ -118,7 +118,7 @@ def __ne__(self, other): return not equality_val -def verify_path(path, is_collection) -> None: +def verify_path(path, is_collection): """Verifies that a ``path`` has the correct form. Checks that all of the elements in ``path`` are strings. @@ -155,7 +155,7 @@ def verify_path(path, is_collection) -> None: raise ValueError(msg) -def encode_value(value) -> types.document.Value: +def encode_value(value): """Converts a native Python value into a Firestore protobuf ``Value``. Args: @@ -219,7 +219,7 @@ def encode_value(value) -> types.document.Value: ) -def encode_dict(values_dict) -> dict: +def encode_dict(values_dict): """Encode a dictionary into protobuf ``Value``-s. Args: @@ -233,7 +233,7 @@ def encode_dict(values_dict) -> dict: return {key: encode_value(value) for key, value in values_dict.items()} -def document_snapshot_to_protobuf(snapshot: "google.cloud.firestore_v1.base_document.DocumentSnapshot") -> Optional["google.cloud.firestore_v1.types.Document"]: # type: ignore +def document_snapshot_to_protobuf(snapshot): # type: ignore from google.cloud.firestore_v1.types import Document if not snapshot.exists: @@ -263,7 +263,7 @@ class DocumentReferenceValue: ValueError: If the supplied value cannot satisfy a complete path. """ - def __init__(self, reference_value: str): + def __init__(self, reference_value): self._reference_value = reference_value # The first 5 parts are @@ -279,13 +279,13 @@ def __init__(self, reference_value: str): self.document_id = "/".join(parts[6:]) @property - def full_key(self) -> str: + def full_key(self): """Computed property for a DocumentReference's collection_name and document Id""" return "/".join([self.collection_name, self.document_id]) @property - def full_path(self) -> str: + def full_path(self): return self._reference_value or "/".join( [ "projects", @@ -299,7 +299,7 @@ def full_path(self) -> str: ) -def reference_value_to_document(reference_value, client) -> Any: +def reference_value_to_document(reference_value, client): """Convert a reference value string to a document. Args: @@ -321,7 +321,7 @@ def reference_value_to_document(reference_value, client) -> Any: doc_ref_value = DocumentReferenceValue(reference_value) - document: BaseDocumentReference = client.document(doc_ref_value.full_key) + document = client.document(doc_ref_value.full_key) if document._document_path != reference_value: msg = WRONG_APP_REFERENCE.format(reference_value, client._database_string) raise ValueError(msg) @@ -329,9 +329,7 @@ def reference_value_to_document(reference_value, client) -> Any: return document -def decode_value( - value, client -) -> Union[None, bool, int, float, list, datetime.datetime, str, bytes, dict, GeoPoint]: +def decode_value(value, client): """Converts a Firestore protobuf ``Value`` to a native Python value. Args: @@ -382,7 +380,7 @@ def decode_value( raise ValueError("Unknown ``value_type``", value_type) -def decode_dict(value_fields, client) -> dict: +def decode_dict(value_fields, client): """Converts a protobuf map of Firestore ``Value``-s. Args: @@ -401,7 +399,7 @@ def decode_dict(value_fields, client) -> dict: return {key: decode_value(value, client) for key, value in value_fields_pb.items()} -def get_doc_id(document_pb, expected_prefix) -> str: +def get_doc_id(document_pb, expected_prefix): """Parse a document ID from a document protobuf. Args: @@ -432,9 +430,7 @@ def get_doc_id(document_pb, expected_prefix) -> str: _EmptyDict = transforms.Sentinel("Marker for an empty dict value") -def extract_fields( - document_data, prefix_path: FieldPath, expand_dots=False -) -> Generator[Tuple[Any, Any], Any, None]: +def extract_fields(document_data, prefix_path, expand_dots=False): """Do depth-first walk of tree, yielding field_path, value""" if not document_data: yield prefix_path, _EmptyDict @@ -454,7 +450,7 @@ def extract_fields( yield field_path, value -def set_field_value(document_data, field_path, value) -> None: +def set_field_value(document_data, field_path, value): """Set a value into a document for a field_path""" current = document_data for element in field_path.parts[:-1]: @@ -464,7 +460,7 @@ def set_field_value(document_data, field_path, value) -> None: current[field_path.parts[-1]] = value -def get_field_value(document_data, field_path) -> Any: +def get_field_value(document_data, field_path): if not field_path.parts: raise ValueError("Empty path") @@ -485,7 +481,7 @@ class DocumentExtractor(object): a document. """ - def __init__(self, document_data) -> None: + def __init__(self, document_data): self.document_data = document_data self.field_paths = [] self.deleted_fields = [] @@ -530,9 +526,7 @@ def __init__(self, document_data) -> None: self.field_paths.append(field_path) set_field_value(self.set_fields, field_path, value) - def _get_document_iterator( - self, prefix_path: FieldPath - ) -> Generator[Tuple[Any, Any], Any, None]: + def _get_document_iterator(self, prefix_path): return extract_fields(self.document_data, prefix_path) @property @@ -557,12 +551,10 @@ def transform_paths(self): + list(self.minimums) ) - def _get_update_mask(self, allow_empty_mask=False) -> None: + def _get_update_mask(self, allow_empty_mask=False): return None - def get_update_pb( - self, document_path, exists=None, allow_empty_mask=False - ) -> types.write.Write: + def get_update_pb(self, document_path, exists=None, allow_empty_mask=False): if exists is not None: current_document = common.Precondition(exists=exists) else: @@ -578,9 +570,7 @@ def get_update_pb( return update_pb - def get_field_transform_pbs( - self, document_path - ) -> List[types.write.DocumentTransform.FieldTransform]: + def get_field_transform_pbs(self, document_path): def make_array_value(values): value_list = [encode_value(element) for element in values] return document.ArrayValue(values=value_list) @@ -646,7 +636,7 @@ def make_array_value(values): ) return [transform for path, transform in sorted(path_field_transforms)] - def get_transform_pb(self, document_path, exists=None) -> types.write.Write: + def get_transform_pb(self, document_path, exists=None): field_transforms = self.get_field_transform_pbs(document_path) transform_pb = write.Write( transform=write.DocumentTransform( @@ -661,7 +651,7 @@ def get_transform_pb(self, document_path, exists=None) -> types.write.Write: return transform_pb -def pbs_for_create(document_path, document_data) -> List[types.write.Write]: +def pbs_for_create(document_path, document_data): """Make ``Write`` protobufs for ``create()`` methods. Args: @@ -687,7 +677,7 @@ def pbs_for_create(document_path, document_data) -> List[types.write.Write]: return [create_pb] -def pbs_for_set_no_merge(document_path, document_data) -> List[types.write.Write]: +def pbs_for_set_no_merge(document_path, document_data): """Make ``Write`` protobufs for ``set()`` methods. Args: @@ -719,26 +709,26 @@ def pbs_for_set_no_merge(document_path, document_data) -> List[types.write.Write class DocumentExtractorForMerge(DocumentExtractor): """Break document data up into actual data and transforms.""" - def __init__(self, document_data) -> None: + def __init__(self, document_data): super(DocumentExtractorForMerge, self).__init__(document_data) self.data_merge = [] self.transform_merge = [] self.merge = [] - def _apply_merge_all(self) -> None: + def _apply_merge_all(self): self.data_merge = sorted(self.field_paths + self.deleted_fields) # TODO: other transforms self.transform_merge = self.transform_paths self.merge = sorted(self.data_merge + self.transform_paths) - def _construct_merge_paths(self, merge) -> Generator[Any, Any, None]: + def _construct_merge_paths(self, merge): for merge_field in merge: if isinstance(merge_field, FieldPath): yield merge_field else: yield FieldPath(*parse_field_path(merge_field)) - def _normalize_merge_paths(self, merge) -> list: + def _normalize_merge_paths(self, merge): merge_paths = sorted(self._construct_merge_paths(merge)) # Raise if any merge path is a parent of another. Leverage sorting @@ -758,7 +748,7 @@ def _normalize_merge_paths(self, merge) -> list: return merge_paths - def _apply_merge_paths(self, merge) -> None: + def _apply_merge_paths(self, merge): if self.empty_document: raise ValueError("Cannot merge specific fields with empty document.") @@ -820,15 +810,13 @@ def _apply_merge_paths(self, merge) -> None: if path in merged_transform_paths } - def apply_merge(self, merge) -> None: + def apply_merge(self, merge): if merge is True: # merge all fields self._apply_merge_all() else: self._apply_merge_paths(merge) - def _get_update_mask( - self, allow_empty_mask=False - ) -> Optional[types.common.DocumentMask]: + def _get_update_mask(self, allow_empty_mask=False): # Mask uses dotted / quoted paths. mask_paths = [ field_path.to_api_repr() @@ -839,9 +827,7 @@ def _get_update_mask( return common.DocumentMask(field_paths=mask_paths) -def pbs_for_set_with_merge( - document_path, document_data, merge -) -> List[types.write.Write]: +def pbs_for_set_with_merge(document_path, document_data, merge): """Make ``Write`` protobufs for ``set()`` methods. Args: @@ -870,7 +856,7 @@ def pbs_for_set_with_merge( class DocumentExtractorForUpdate(DocumentExtractor): """Break document data up into actual data and transforms.""" - def __init__(self, document_data) -> None: + def __init__(self, document_data): super(DocumentExtractorForUpdate, self).__init__(document_data) self.top_level_paths = sorted( [FieldPath.from_string(key) for key in document_data] @@ -891,12 +877,10 @@ def __init__(self, document_data) -> None: "Cannot update with nest delete: {}".format(field_path) ) - def _get_document_iterator( - self, prefix_path: FieldPath - ) -> Generator[Tuple[Any, Any], Any, None]: + def _get_document_iterator(self, prefix_path): return extract_fields(self.document_data, prefix_path, expand_dots=True) - def _get_update_mask(self, allow_empty_mask=False) -> types.common.DocumentMask: + def _get_update_mask(self, allow_empty_mask=False): mask_paths = [] for field_path in self.top_level_paths: if field_path not in self.transform_paths: @@ -905,7 +889,7 @@ def _get_update_mask(self, allow_empty_mask=False) -> types.common.DocumentMask: return common.DocumentMask(field_paths=mask_paths) -def pbs_for_update(document_path, field_updates, option) -> List[types.write.Write]: +def pbs_for_update(document_path, field_updates, option): """Make ``Write`` protobufs for ``update()`` methods. Args: @@ -938,7 +922,7 @@ def pbs_for_update(document_path, field_updates, option) -> List[types.write.Wri return [update_pb] -def pb_for_delete(document_path, option) -> types.write.Write: +def pb_for_delete(document_path, option): """Make a ``Write`` protobuf for ``delete()`` methods. Args: @@ -965,7 +949,7 @@ class ReadAfterWriteError(Exception): """ -def get_transaction_id(transaction, read_operation=True) -> Union[bytes, None]: +def get_transaction_id(transaction, read_operation=True): """Get the transaction ID from a ``Transaction`` object. Args: @@ -995,7 +979,7 @@ def get_transaction_id(transaction, read_operation=True) -> Union[bytes, None]: return transaction.id -def metadata_with_prefix(prefix: str, **kw) -> List[Tuple[str, str]]: +def metadata_with_prefix(prefix, **kw): """Create RPC metadata containing a prefix. Args: @@ -1010,7 +994,7 @@ def metadata_with_prefix(prefix: str, **kw) -> List[Tuple[str, str]]: class WriteOption(object): """Option used to assert a condition on a write operation.""" - def modify_write(self, write, no_create_msg=None) -> NoReturn: + def modify_write(self, write, no_create_msg=None): """Modify a ``Write`` protobuf based on the state of this write option. This is a virtual method intended to be implemented by subclasses. @@ -1042,7 +1026,7 @@ class LastUpdateOption(WriteOption): as part of a "write result" protobuf or directly. """ - def __init__(self, last_update_time) -> None: + def __init__(self, last_update_time): self._last_update_time = last_update_time def __eq__(self, other): @@ -1050,7 +1034,7 @@ def __eq__(self, other): return NotImplemented return self._last_update_time == other._last_update_time - def modify_write(self, write, **unused_kwargs) -> None: + def modify_write(self, write, **unused_kwargs): """Modify a ``Write`` protobuf based on the state of this write option. The ``last_update_time`` is added to ``write_pb`` as an "update time" @@ -1079,7 +1063,7 @@ class ExistsOption(WriteOption): should already exist. """ - def __init__(self, exists) -> None: + def __init__(self, exists): self._exists = exists def __eq__(self, other): @@ -1087,7 +1071,7 @@ def __eq__(self, other): return NotImplemented return self._exists == other._exists - def modify_write(self, write, **unused_kwargs) -> None: + def modify_write(self, write, **unused_kwargs): """Modify a ``Write`` protobuf based on the state of this write option. If: @@ -1106,7 +1090,7 @@ def modify_write(self, write, **unused_kwargs) -> None: write._pb.current_document.CopyFrom(current_doc._pb) -def make_retry_timeout_kwargs(retry, timeout) -> dict: +def make_retry_timeout_kwargs(retry, timeout): """Helper fo API methods which take optional 'retry' / 'timeout' args.""" kwargs = {} @@ -1119,9 +1103,7 @@ def make_retry_timeout_kwargs(retry, timeout) -> dict: return kwargs -def build_timestamp( - dt: Optional[Union[DatetimeWithNanoseconds, datetime.datetime]] = None -) -> Timestamp: +def build_timestamp(dt=None): """Returns the supplied datetime (or "now") as a Timestamp""" return _datetime_to_pb_timestamp( dt or DatetimeWithNanoseconds.now(tz=datetime.timezone.utc) @@ -1129,9 +1111,9 @@ def build_timestamp( def compare_timestamps( - ts1: Union[Timestamp, datetime.datetime], - ts2: Union[Timestamp, datetime.datetime], -) -> int: + ts1, + ts2, +): ts1 = build_timestamp(ts1) if not isinstance(ts1, Timestamp) else ts1 ts2 = build_timestamp(ts2) if not isinstance(ts2, Timestamp) else ts2 ts1_nanos = ts1.nanos + ts1.seconds * 1e9 @@ -1142,9 +1124,9 @@ def compare_timestamps( def deserialize_bundle( - serialized: Union[str, bytes], - client: "google.cloud.firestore_v1.client.BaseClient", # type: ignore -) -> "google.cloud.firestore_bundle.FirestoreBundle": # type: ignore + serialized, + client, # type: ignore +): # type: ignore """Inverse operation to a `FirestoreBundle` instance's `build()` method. Args: @@ -1172,26 +1154,26 @@ def deserialize_bundle( "documentMetadata": ["document"], "document": ["documentMetadata", "__end__"], } - allowed_next_element_types: List[str] = bundle_state_machine["__initial__"] + allowed_next_element_types = bundle_state_machine["__initial__"] # This must be saved and added last, since we cache it to preserve timestamps, # yet must flush it whenever a new document or query is added to a bundle. # The process of deserializing a bundle uses these methods which flush a # cached metadata element, and thus, it must be the last BundleElement # added during deserialization. - metadata_bundle_element: Optional[BundleElement] = None + metadata_bundle_element = None - bundle: Optional[FirestoreBundle] = None - data: Dict + bundle = None + # data: Dict for data in _parse_bundle_elements_data(serialized): # BundleElements are serialized as JSON containing one key outlining # the type, with all further data nested under that key - keys: List[str] = list(data.keys()) + keys = list(data.keys()) if len(keys) != 1: raise ValueError("Expected serialized BundleElement with one top-level key") - key: str = keys[0] + key = keys[0] if key not in allowed_next_element_types: raise ValueError( @@ -1200,9 +1182,9 @@ def deserialize_bundle( ) # Create and add our BundleElement - bundle_element: BundleElement + # bundle_element: BundleElement try: - bundle_element: BundleElement = BundleElement.from_json(json.dumps(data)) # type: ignore + bundle_element = BundleElement.from_json(json.dumps(data)) # type: ignore except AttributeError as e: # Some bad serialization formats cannot be universally deserialized. if e.args[0] == "'dict' object has no attribute 'find'": # pragma: NO COVER @@ -1237,7 +1219,7 @@ def deserialize_bundle( return bundle -def _parse_bundle_elements_data(serialized: Union[str, bytes]) -> Generator[Dict, None, None]: # type: ignore +def _parse_bundle_elements_data(serialized): # type: ignore """Reads through a serialized FirestoreBundle and yields JSON chunks that were created via `BundleElement.to_json(bundle_element)`. @@ -1250,18 +1232,18 @@ def _parse_bundle_elements_data(serialized: Union[str, bytes]) -> Generator[Dict ValueError: If a chunk of JSON ever starts without following a length prefix. """ - _serialized: Iterator[int] = iter( + _serialized = iter( serialized if isinstance(serialized, bytes) else serialized.encode("utf-8") ) - length_prefix: str = "" + length_prefix = "" while True: - byte: Optional[int] = next(_serialized, None) + byte = next(_serialized, None) if byte is None: return None - _str: str = chr(byte) + _str = chr(byte) if _str.isnumeric(): length_prefix += _str else: @@ -1279,12 +1261,10 @@ def _parse_bundle_elements_data(serialized: Union[str, bytes]) -> Generator[Dict yield json.loads(_bytes.decode("utf-8")) -def _get_documents_from_bundle( - bundle, *, query_name: Optional[str] = None -) -> Generator["google.cloud.firestore.DocumentSnapshot", None, None]: # type: ignore +def _get_documents_from_bundle(bundle, *, query_name=None): # type: ignore from google.cloud.firestore_bundle.bundle import _BundledDocument - bundled_doc: _BundledDocument + # bundled_doc: _BundledDocument for bundled_doc in bundle.documents.values(): if query_name and query_name not in bundled_doc.metadata.queries: continue @@ -1294,8 +1274,8 @@ def _get_documents_from_bundle( def _get_document_from_bundle( bundle, *, - document_id: str, -) -> Optional["google.cloud.firestore.DocumentSnapshot"]: # type: ignore + document_id, +): # type: ignore bundled_doc = bundle.documents.get(document_id) if bundled_doc: return bundled_doc.snapshot diff --git a/google/cloud/firestore_v1/aggregation.py b/google/cloud/firestore_v1/aggregation.py index 609f82f75..1767a6480 100644 --- a/google/cloud/firestore_v1/aggregation.py +++ b/google/cloud/firestore_v1/aggregation.py @@ -40,17 +40,15 @@ class AggregationQuery(BaseAggregationQuery): def __init__( self, nested_query, - ) -> None: + ): super(AggregationQuery, self).__init__(nested_query) def get( self, transaction=None, - retry: Union[ - retries.Retry, None, gapic_v1.method._MethodDefault - ] = gapic_v1.method.DEFAULT, - timeout: float | None = None, - ) -> List[AggregationResult]: + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """Runs the aggregation query. This sends a ``RunAggregationQuery`` RPC and returns a list of aggregation results in the stream of ``RunAggregationQueryResponse`` messages. @@ -102,11 +100,9 @@ def _retry_query_after_exception(self, exc, retry, transaction): def stream( self, transaction=None, - retry: Union[ - retries.Retry, None, gapic_v1.method._MethodDefault - ] = gapic_v1.method.DEFAULT, - timeout: float | None = None, - ) -> Union[Generator[List[AggregationResult], Any, None]]: + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """Runs the aggregation query. This sends a ``RunAggregationQuery`` RPC and then returns an iterator which diff --git a/google/cloud/firestore_v1/async_aggregation.py b/google/cloud/firestore_v1/async_aggregation.py index 194016cd2..16b17c408 100644 --- a/google/cloud/firestore_v1/async_aggregation.py +++ b/google/cloud/firestore_v1/async_aggregation.py @@ -39,17 +39,15 @@ class AsyncAggregationQuery(BaseAggregationQuery): def __init__( self, nested_query, - ) -> None: + ): super(AsyncAggregationQuery, self).__init__(nested_query) async def get( self, transaction=None, - retry: Union[ - retries.Retry, None, gapic_v1.method._MethodDefault - ] = gapic_v1.method.DEFAULT, - timeout: float | None = None, - ) -> List[AggregationResult]: + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """Runs the aggregation query. This sends a ``RunAggregationQuery`` RPC and returns a list of aggregation results in the stream of ``RunAggregationQueryResponse`` messages. @@ -79,11 +77,9 @@ async def get( async def stream( self, transaction=None, - retry: Union[ - retries.Retry, None, gapic_v1.method._MethodDefault - ] = gapic_v1.method.DEFAULT, - timeout: float | None = None, - ) -> Union[AsyncGenerator[List[AggregationResult], None]]: + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """Runs the aggregation query. This sends a ``RunAggregationQuery`` RPC and then returns an iterator which diff --git a/google/cloud/firestore_v1/async_batch.py b/google/cloud/firestore_v1/async_batch.py index e33d28f13..de812e139 100644 --- a/google/cloud/firestore_v1/async_batch.py +++ b/google/cloud/firestore_v1/async_batch.py @@ -33,14 +33,14 @@ class AsyncWriteBatch(BaseWriteBatch): The client that created this batch. """ - def __init__(self, client) -> None: + def __init__(self, client): super(AsyncWriteBatch, self).__init__(client=client) async def commit( self, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - ) -> list: + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """Commit the changes accumulated in this batch. Args: diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index 10e1d2495..7bee6b12a 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -90,7 +90,7 @@ def __init__( database=None, client_info=_CLIENT_INFO, client_options=None, - ) -> None: + ): super(AsyncClient, self).__init__( project=project, credentials=credentials, @@ -135,7 +135,7 @@ def _target(self): """ return self._target_helper(firestore_client.FirestoreAsyncClient) - def collection(self, *collection_path: str) -> AsyncCollectionReference: + def collection(self, *collection_path): """Get a reference to a collection. For a top-level collection: @@ -166,7 +166,7 @@ def collection(self, *collection_path: str) -> AsyncCollectionReference: """ return AsyncCollectionReference(*_path_helper(collection_path), client=self) - def collection_group(self, collection_id: str) -> AsyncCollectionGroup: + def collection_group(self, collection_id): """ Creates and returns a new AsyncQuery that includes all documents in the database that are contained in a collection or subcollection with the @@ -188,7 +188,7 @@ def collection_group(self, collection_id: str) -> AsyncCollectionGroup: """ return AsyncCollectionGroup(self._get_collection_reference(collection_id)) - def document(self, *document_path: str) -> AsyncDocumentReference: + def document(self, *document_path): """Get a reference to a document in a collection. For a top-level document: @@ -225,12 +225,12 @@ def document(self, *document_path: str) -> AsyncDocumentReference: async def get_all( self, - references: List[AsyncDocumentReference], - field_paths: Iterable[str] = None, + references, + field_paths=None, transaction=None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - ) -> AsyncGenerator[DocumentSnapshot, Any]: + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """Retrieve a batch of documents. .. note:: @@ -284,9 +284,9 @@ async def get_all( async def collections( self, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - ) -> AsyncGenerator[AsyncCollectionReference, Any]: + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """List top-level collections of the client's database. Args: @@ -311,10 +311,10 @@ async def collections( async def recursive_delete( self, - reference: Union[AsyncCollectionReference, AsyncDocumentReference], + reference, *, - bulk_writer: Optional["BulkWriter"] = None, - chunk_size: Optional[int] = 5000, + bulk_writer=None, + chunk_size=5000, ): """Deletes documents and their subcollections, regardless of collection name. @@ -347,28 +347,28 @@ async def recursive_delete( async def _recursive_delete( self, - reference: Union[AsyncCollectionReference, AsyncDocumentReference], - bulk_writer: "BulkWriter", + reference, + bulk_writer, *, - chunk_size: Optional[int] = 5000, - depth: Optional[int] = 0, - ) -> int: + chunk_size=5000, + depth=0, + ): """Recursion helper for `recursive_delete.""" - num_deleted: int = 0 + num_deleted = 0 if isinstance(reference, AsyncCollectionReference): - chunk: List[DocumentSnapshot] + # chunk: List[DocumentSnapshot] async for chunk in reference.recursive().select( [FieldPath.document_id()] )._chunkify(chunk_size): - doc_snap: DocumentSnapshot + # doc_snap: DocumentSnapshot for doc_snap in chunk: num_deleted += 1 bulk_writer.delete(doc_snap.reference) elif isinstance(reference, AsyncDocumentReference): - col_ref: AsyncCollectionReference + # col_ref: AsyncCollectionReference async for col_ref in reference.collections(): num_deleted += await self._recursive_delete( col_ref, @@ -389,7 +389,7 @@ async def _recursive_delete( return num_deleted - def batch(self) -> AsyncWriteBatch: + def batch(self): """Get a batch instance from this client. Returns: @@ -399,7 +399,7 @@ def batch(self) -> AsyncWriteBatch: """ return AsyncWriteBatch(self) - def transaction(self, **kwargs) -> AsyncTransaction: + def transaction(self, **kwargs): """Get a transaction that uses this client. See :class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction` for diff --git a/google/cloud/firestore_v1/async_collection.py b/google/cloud/firestore_v1/async_collection.py index 293a1e0f5..259cf4023 100644 --- a/google/cloud/firestore_v1/async_collection.py +++ b/google/cloud/firestore_v1/async_collection.py @@ -58,10 +58,10 @@ class AsyncCollectionReference(BaseCollectionReference[async_query.AsyncQuery]): TypeError: If a keyword other than ``client`` is used. """ - def __init__(self, *path, **kwargs) -> None: + def __init__(self, *path, **kwargs): super(AsyncCollectionReference, self).__init__(*path, **kwargs) - def _query(self) -> async_query.AsyncQuery: + def _query(self): """Query factory. Returns: @@ -69,7 +69,7 @@ def _query(self) -> async_query.AsyncQuery: """ return async_query.AsyncQuery(self) - def _aggregation_query(self) -> async_aggregation.AsyncAggregationQuery: + def _aggregation_query(self): """AsyncAggregationQuery factory. Returns: @@ -77,17 +77,17 @@ def _aggregation_query(self) -> async_aggregation.AsyncAggregationQuery: """ return async_aggregation.AsyncAggregationQuery(self._query()) - async def _chunkify(self, chunk_size: int): + async def _chunkify(self, chunk_size): async for page in self._query()._chunkify(chunk_size): yield page async def add( self, - document_data: dict, - document_id: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - ) -> Tuple[Any, Any]: + document_data, + document_id=None, + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """Create a document in the Firestore database with the provided data. Args: @@ -124,9 +124,7 @@ async def add( write_result = await document_ref.create(document_data, **kwargs) return write_result.update_time, document_ref - def document( - self, document_id: str = None - ) -> async_document.AsyncDocumentReference: + def document(self, document_id=None): """Create a sub-document underneath the current collection. Args: @@ -143,10 +141,10 @@ def document( async def list_documents( self, - page_size: int = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - ) -> AsyncGenerator[DocumentReference, None]: + page_size=None, + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """List all subdocuments of the current collection. Args: @@ -176,10 +174,10 @@ async def list_documents( async def get( self, - transaction: Transaction = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - ) -> list: + transaction=None, + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """Read the documents in this collection. This sends a ``RunQuery`` RPC and returns a list of documents @@ -207,10 +205,10 @@ async def get( async def stream( self, - transaction: Transaction = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - ) -> AsyncIterator[async_document.DocumentSnapshot]: + transaction=None, + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """Read the documents in this collection. This sends a ``RunQuery`` RPC and then returns an iterator which diff --git a/google/cloud/firestore_v1/async_document.py b/google/cloud/firestore_v1/async_document.py index 47cce42af..d8ed03faf 100644 --- a/google/cloud/firestore_v1/async_document.py +++ b/google/cloud/firestore_v1/async_document.py @@ -59,15 +59,15 @@ class AsyncDocumentReference(BaseDocumentReference): TypeError: If a keyword other than ``client`` is used. """ - def __init__(self, *path, **kwargs) -> None: + def __init__(self, *path, **kwargs): super(AsyncDocumentReference, self).__init__(*path, **kwargs) async def create( self, - document_data: dict, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - ) -> write.WriteResult: + document_data, + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """Create the current document in the Firestore database. Args: @@ -93,11 +93,11 @@ async def create( async def set( self, - document_data: dict, - merge: bool = False, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - ) -> write.WriteResult: + document_data, + merge=False, + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """Replace the current document in the Firestore database. A write ``option`` can be specified to indicate preconditions of @@ -133,11 +133,11 @@ async def set( async def update( self, - field_updates: dict, - option: _helpers.WriteOption = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - ) -> write.WriteResult: + field_updates, + option=None, + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """Update an existing document in the Firestore database. By default, this method verifies that the document exists on the @@ -291,10 +291,10 @@ async def update( async def delete( self, - option: _helpers.WriteOption = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - ) -> Timestamp: + option=None, + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """Delete the current document in the Firestore database. Args: @@ -325,11 +325,11 @@ async def delete( async def get( self, - field_paths: Iterable[str] = None, + field_paths=None, transaction=None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - ) -> DocumentSnapshot: + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """Retrieve a snapshot of the current document. See :meth:`~google.cloud.firestore_v1.base_client.BaseClient.field_path` for @@ -394,10 +394,10 @@ async def get( async def collections( self, - page_size: int = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - ) -> AsyncGenerator: + page_size=None, + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """List subcollections of the current document. Args: diff --git a/google/cloud/firestore_v1/async_query.py b/google/cloud/firestore_v1/async_query.py index d03ab72b8..3a15647b6 100644 --- a/google/cloud/firestore_v1/async_query.py +++ b/google/cloud/firestore_v1/async_query.py @@ -116,7 +116,7 @@ def __init__( end_at=None, all_descendants=False, recursive=False, - ) -> None: + ): super(AsyncQuery, self).__init__( parent=parent, projection=projection, @@ -131,18 +131,16 @@ def __init__( recursive=recursive, ) - async def _chunkify( - self, chunk_size: int - ) -> AsyncGenerator[List[DocumentSnapshot], None]: - max_to_return: Optional[int] = self._limit - num_returned: int = 0 - original: AsyncQuery = self._copy() - last_document: Optional[DocumentSnapshot] = None + async def _chunkify(self, chunk_size): + max_to_return = self._limit + num_returned = 0 + original = self._copy() + last_document = None while True: # Optionally trim the `chunk_size` down to honor a previously # applied limit as set by `self.limit()` - _chunk_size: int = original._resolve_chunk_size(num_returned, chunk_size) + _chunk_size = original._resolve_chunk_size(num_returned, chunk_size) # Apply the optionally pruned limit and the cursor, if we are past # the first page. @@ -171,10 +169,10 @@ async def _chunkify( async def get( self, - transaction: Transaction = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - ) -> list: + transaction=None, + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """Read the documents in the collection that match this query. This sends a ``RunQuery`` RPC and returns a list of documents @@ -217,9 +215,7 @@ async def get( return result - def count( - self, alias: str | None = None - ) -> Type["firestore_v1.async_aggregation.AsyncAggregationQuery"]: + def count(self, alias=None): """Adds a count over the nested query. Args: @@ -232,9 +228,7 @@ def count( """ return AsyncAggregationQuery(self).count(alias=alias) - def sum( - self, field_ref: str | FieldPath, alias: str | None = None - ) -> Type["firestore_v1.async_aggregation.AsyncAggregationQuery"]: + def sum(self, field_ref, alias=None): """Adds a sum over the nested query. Args: @@ -248,9 +242,7 @@ def sum( """ return AsyncAggregationQuery(self).sum(field_ref, alias=alias) - def avg( - self, field_ref: str | FieldPath, alias: str | None = None - ) -> Type["firestore_v1.async_aggregation.AsyncAggregationQuery"]: + def avg(self, field_ref, alias=None): """Adds an avg over the nested query. Args: @@ -267,9 +259,9 @@ def avg( async def stream( self, transaction=None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - ) -> AsyncGenerator[async_document.DocumentSnapshot, None]: + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """Read the documents in the collection that match this query. This sends a ``RunQuery`` RPC and then returns an iterator which @@ -325,9 +317,7 @@ async def stream( yield snapshot @staticmethod - def _get_collection_reference_class() -> ( - Type["firestore_v1.async_collection.AsyncCollectionReference"] - ): + def _get_collection_reference_class(): from google.cloud.firestore_v1.async_collection import AsyncCollectionReference return AsyncCollectionReference @@ -358,7 +348,7 @@ def __init__( end_at=None, all_descendants=True, recursive=False, - ) -> None: + ): super(AsyncCollectionGroup, self).__init__( parent=parent, projection=projection, @@ -380,9 +370,9 @@ def _get_query_class(): async def get_partitions( self, partition_count, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - ) -> AsyncGenerator[QueryPartition, None]: + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """Partition a query for parallelization. Partitions a query by returning partition cursors that can be used to run the diff --git a/google/cloud/firestore_v1/async_transaction.py b/google/cloud/firestore_v1/async_transaction.py index b504bebad..0e9817717 100644 --- a/google/cloud/firestore_v1/async_transaction.py +++ b/google/cloud/firestore_v1/async_transaction.py @@ -63,11 +63,11 @@ class AsyncTransaction(async_batch.AsyncWriteBatch, BaseTransaction): :data:`False`. """ - def __init__(self, client, max_attempts=MAX_ATTEMPTS, read_only=False) -> None: + def __init__(self, client, max_attempts=MAX_ATTEMPTS, read_only=False): super(AsyncTransaction, self).__init__(client) BaseTransaction.__init__(self, max_attempts, read_only) - def _add_write_pbs(self, write_pbs: list) -> None: + def _add_write_pbs(self, write_pbs): """Add `Write`` protobufs to this transaction. Args: @@ -82,7 +82,7 @@ def _add_write_pbs(self, write_pbs: list) -> None: super(AsyncTransaction, self)._add_write_pbs(write_pbs) - async def _begin(self, retry_id: bytes = None) -> None: + async def _begin(self, retry_id=None): """Begin the transaction. Args: @@ -105,7 +105,7 @@ async def _begin(self, retry_id: bytes = None) -> None: ) self._id = transaction_response.transaction - async def _rollback(self) -> None: + async def _rollback(self): """Roll back the transaction. Raises: @@ -128,7 +128,7 @@ async def _rollback(self) -> None: # clean up, even if rollback fails self._clean_up() - async def _commit(self) -> list: + async def _commit(self): """Transactionally commit the changes accumulated. Returns: @@ -152,10 +152,10 @@ async def _commit(self) -> list: async def get_all( self, - references: list, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - ) -> AsyncGenerator[DocumentSnapshot, Any]: + references, + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """Retrieves multiple documents from Firestore. Args: @@ -176,9 +176,9 @@ async def get_all( async def get( self, ref_or_query, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - ) -> AsyncGenerator[DocumentSnapshot, Any]: + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """ Retrieve a document or a query result from the database. @@ -217,12 +217,10 @@ class _AsyncTransactional(_BaseTransactional): A coroutine that should be run (and retried) in a transaction. """ - def __init__(self, to_wrap) -> None: + def __init__(self, to_wrap): super(_AsyncTransactional, self).__init__(to_wrap) - async def _pre_commit( - self, transaction: AsyncTransaction, *args, **kwargs - ) -> Coroutine: + async def _pre_commit(self, transaction, *args, **kwargs): """Begin transaction and call the wrapped coroutine. Args: @@ -301,9 +299,7 @@ async def __call__(self, transaction, *args, **kwargs): raise -def async_transactional( - to_wrap: Callable[[AsyncTransaction], Any] -) -> _AsyncTransactional: +def async_transactional(to_wrap): """Decorate a callable so that it runs in a transaction. Args: @@ -319,9 +315,7 @@ def async_transactional( # TODO(crwilcox): this was 'coroutine' from pytype merge-pyi... -async def _commit_with_retry( - client: Client, write_pbs: list, transaction_id: bytes -) -> types.CommitResponse: +async def _commit_with_retry(client, write_pbs, transaction_id): """Call ``Commit`` on the GAPIC client with retry / sleep. Retries the ``Commit`` RPC on Unavailable. Usually this RPC-level @@ -364,9 +358,7 @@ async def _commit_with_retry( current_sleep = await _sleep(current_sleep) -async def _sleep( - current_sleep: float, max_sleep: float = _MAX_SLEEP, multiplier: float = _MULTIPLIER -) -> float: +async def _sleep(current_sleep, max_sleep=_MAX_SLEEP, multiplier=_MULTIPLIER): """Sleep and produce a new sleep time. .. _Exponential Backoff And Jitter: https://www.awsarchitectureblog.com/\ diff --git a/google/cloud/firestore_v1/base_aggregation.py b/google/cloud/firestore_v1/base_aggregation.py index d6097c136..3da011adb 100644 --- a/google/cloud/firestore_v1/base_aggregation.py +++ b/google/cloud/firestore_v1/base_aggregation.py @@ -50,7 +50,7 @@ class AggregationResult(object): :param value: The resulting read_time """ - def __init__(self, alias: str, value: int, read_time=None): + def __init__(self, alias, value, read_time=None): self.alias = alias self.value = value self.read_time = read_time @@ -60,7 +60,7 @@ def __repr__(self): class BaseAggregation(ABC): - def __init__(self, alias: str | None = None): + def __init__(self, alias=None): self.alias = alias @abc.abstractmethod @@ -69,7 +69,7 @@ def _to_protobuf(self): class CountAggregation(BaseAggregation): - def __init__(self, alias: str | None = None): + def __init__(self, alias=None): super(CountAggregation, self).__init__(alias=alias) def _to_protobuf(self): @@ -81,7 +81,7 @@ def _to_protobuf(self): class SumAggregation(BaseAggregation): - def __init__(self, field_ref: str | FieldPath, alias: str | None = None): + def __init__(self, field_ref, alias=None): if isinstance(field_ref, FieldPath): # convert field path to string field_ref = field_ref.to_api_repr() @@ -98,7 +98,7 @@ def _to_protobuf(self): class AvgAggregation(BaseAggregation): - def __init__(self, field_ref: str | FieldPath, alias: str | None = None): + def __init__(self, field_ref, alias=None): if isinstance(field_ref, FieldPath): # convert field path to string field_ref = field_ref.to_api_repr() @@ -115,8 +115,8 @@ def _to_protobuf(self): def _query_response_to_result( - response_pb: RunAggregationQueryResponse, -) -> List[AggregationResult]: + response_pb, +): results = [ AggregationResult( alias=key, @@ -133,17 +133,17 @@ def _query_response_to_result( class BaseAggregationQuery(ABC): """Represents an aggregation query to the Firestore API.""" - def __init__(self, nested_query, alias: str | None = None) -> None: + def __init__(self, nested_query, alias=None): self._nested_query = nested_query self._alias = alias self._collection_ref = nested_query._parent - self._aggregations: List[BaseAggregation] = [] + self._aggregations = [] @property def _client(self): return self._collection_ref._client - def count(self, alias: str | None = None): + def count(self, alias=None): """ Adds a count over the nested query """ @@ -151,7 +151,7 @@ def count(self, alias: str | None = None): self._aggregations.append(count_aggregation) return self - def sum(self, field_ref: str | FieldPath, alias: str | None = None): + def sum(self, field_ref, alias=None): """ Adds a sum over the nested query """ @@ -159,7 +159,7 @@ def sum(self, field_ref: str | FieldPath, alias: str | None = None): self._aggregations.append(sum_aggregation) return self - def avg(self, field_ref: str | FieldPath, alias: str | None = None): + def avg(self, field_ref, alias=None): """ Adds an avg over the nested query """ @@ -167,7 +167,7 @@ def avg(self, field_ref: str | FieldPath, alias: str | None = None): self._aggregations.append(avg_aggregation) return self - def add_aggregation(self, aggregation: BaseAggregation) -> None: + def add_aggregation(self, aggregation): """ Adds an aggregation operation to the nested query @@ -176,7 +176,7 @@ def add_aggregation(self, aggregation: BaseAggregation) -> None: """ self._aggregations.append(aggregation) - def add_aggregations(self, aggregations: List[BaseAggregation]) -> None: + def add_aggregations(self, aggregations): """ Adds a list of aggregations to the nested query @@ -185,7 +185,7 @@ def add_aggregations(self, aggregations: List[BaseAggregation]) -> None: """ self._aggregations.extend(aggregations) - def _to_protobuf(self) -> StructuredAggregationQuery: + def _to_protobuf(self): pb = StructuredAggregationQuery() pb.structured_query = self._nested_query._to_protobuf() @@ -197,9 +197,9 @@ def _to_protobuf(self) -> StructuredAggregationQuery: def _prep_stream( self, transaction=None, - retry: Union[retries.Retry, None, gapic_v1.method._MethodDefault] = None, - timeout: float | None = None, - ) -> Tuple[dict, dict]: + retry=None, + timeout=None, + ): parent_path, expected_prefix = self._collection_ref._parent_info() request = { "parent": parent_path, @@ -214,11 +214,9 @@ def _prep_stream( def get( self, transaction=None, - retry: Union[ - retries.Retry, None, gapic_v1.method._MethodDefault - ] = gapic_v1.method.DEFAULT, - timeout: float | None = None, - ) -> List[AggregationResult] | Coroutine[Any, Any, List[AggregationResult]]: + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """Runs the aggregation query. This sends a ``RunAggregationQuery`` RPC and returns a list of aggregation results in the stream of ``RunAggregationQueryResponse`` messages. @@ -244,13 +242,8 @@ def get( def stream( self, transaction=None, - retry: Union[ - retries.Retry, None, gapic_v1.method._MethodDefault - ] = gapic_v1.method.DEFAULT, - timeout: float | None = None, - ) -> ( - Generator[List[AggregationResult], Any, None] - | AsyncGenerator[List[AggregationResult], None] + retry=gapic_v1.method.DEFAULT, + timeout=None, ): """Runs the aggregation query. diff --git a/google/cloud/firestore_v1/base_batch.py b/google/cloud/firestore_v1/base_batch.py index ca3a66c89..655c90057 100644 --- a/google/cloud/firestore_v1/base_batch.py +++ b/google/cloud/firestore_v1/base_batch.py @@ -35,20 +35,20 @@ class BaseBatch(metaclass=abc.ABCMeta): The client that created this batch. """ - def __init__(self, client) -> None: + def __init__(self, client): self._client = client self._write_pbs = [] - self._document_references: Dict[str, BaseDocumentReference] = {} + self._document_references = {} self.write_results = None self.commit_time = None def __len__(self): return len(self._document_references) - def __contains__(self, reference: BaseDocumentReference): + def __contains__(self, reference): return reference._document_path in self._document_references - def _add_write_pbs(self, write_pbs: list) -> None: + def _add_write_pbs(self, write_pbs): """Add `Write`` protobufs to this transaction. This method intended to be over-ridden by subclasses. @@ -65,7 +65,7 @@ def commit(self): write depend on the implementing class.""" raise NotImplementedError() - def create(self, reference: BaseDocumentReference, document_data: dict) -> None: + def create(self, reference, document_data): """Add a "change" to this batch to create a document. If the document given by ``reference`` already exists, then this @@ -83,10 +83,10 @@ def create(self, reference: BaseDocumentReference, document_data: dict) -> None: def set( self, - reference: BaseDocumentReference, - document_data: dict, - merge: Union[bool, list] = False, - ) -> None: + reference, + document_data, + merge=False, + ): """Add a "change" to replace a document. See @@ -117,10 +117,10 @@ def set( def update( self, - reference: BaseDocumentReference, - field_updates: dict, - option: _helpers.WriteOption = None, - ) -> None: + reference, + field_updates, + option=None, + ): """Add a "change" to update a document. See @@ -144,9 +144,7 @@ def update( self._document_references[reference._document_path] = reference self._add_write_pbs(write_pbs) - def delete( - self, reference: BaseDocumentReference, option: _helpers.WriteOption = None - ) -> None: + def delete(self, reference, option=None): """Add a "change" to delete a document. See @@ -170,7 +168,7 @@ class BaseWriteBatch(BaseBatch): """Base class for a/sync implementations of the `commit` RPC. `commit` is useful for lower volumes or when the order of write operations is important.""" - def _prep_commit(self, retry: retries.Retry, timeout: float): + def _prep_commit(self, retry, timeout): """Shared setup for async/sync :meth:`commit`.""" request = { "database": self._client._database_string, diff --git a/google/cloud/firestore_v1/base_client.py b/google/cloud/firestore_v1/base_client.py index 345e06142..01da83c07 100644 --- a/google/cloud/firestore_v1/base_client.py +++ b/google/cloud/firestore_v1/base_client.py @@ -66,13 +66,13 @@ _BAD_OPTION_ERR = ( "Exactly one of ``last_update_time`` or ``exists`` " "must be provided." ) -_BAD_DOC_TEMPLATE: str = ( +_BAD_DOC_TEMPLATE = ( "Document {!r} appeared in response but was not present among references" ) -_ACTIVE_TXN: str = "There is already an active transaction." -_INACTIVE_TXN: str = "There is no active transaction." -_CLIENT_INFO: Any = client_info.ClientInfo(client_library_version=__version__) -_FIRESTORE_EMULATOR_HOST: str = "FIRESTORE_EMULATOR_HOST" +_ACTIVE_TXN = "There is already an active transaction." +_INACTIVE_TXN = "There is no active transaction." +_CLIENT_INFO = client_info.ClientInfo(client_library_version=__version__) +_FIRESTORE_EMULATOR_HOST = "FIRESTORE_EMULATOR_HOST" class BaseClient(ClientWithProject): @@ -120,7 +120,7 @@ def __init__( database=None, client_info=_CLIENT_INFO, client_options=None, - ) -> None: + ): database = database or DEFAULT_DATABASE # NOTE: This API has no use for the _http argument, but sending it # will have no impact since the _http() @property only lazily @@ -149,7 +149,7 @@ def __init__( self._database = database - def _firestore_api_helper(self, transport, client_class, client_module) -> Any: + def _firestore_api_helper(self, transport, client_class, client_module): """Lazy-loading getter GAPIC Firestore API. Returns: The GAPIC client with the credentials of the current client. @@ -202,7 +202,7 @@ def _emulator_channel(self, transport): else: return grpc.insecure_channel(self._emulator_host, options=options) - def _target_helper(self, client_class) -> str: + def _target_helper(self, client_class): """Return the target (where the API is). Eg. "firestore.googleapis.com" @@ -262,15 +262,13 @@ def _rpc_metadata(self): return self._rpc_metadata_internal - def collection(self, *collection_path) -> BaseCollectionReference[BaseQuery]: + def collection(self, *collection_path): raise NotImplementedError - def collection_group(self, collection_id: str) -> BaseQuery: + def collection_group(self, collection_id): raise NotImplementedError - def _get_collection_reference( - self, collection_id: str - ) -> BaseCollectionReference[BaseQuery]: + def _get_collection_reference(self, collection_id): """Checks validity of collection_id and then uses subclasses collection implementation. Args: @@ -291,10 +289,10 @@ def _get_collection_reference( return self.collection(collection_id) - def document(self, *document_path) -> BaseDocumentReference: + def document(self, *document_path): raise NotImplementedError - def bulk_writer(self, options: Optional[BulkWriterOptions] = None) -> BulkWriter: + def bulk_writer(self, options=None): """Get a BulkWriter instance from this client. Args: @@ -309,7 +307,7 @@ def bulk_writer(self, options: Optional[BulkWriterOptions] = None) -> BulkWriter """ return BulkWriter(client=self, options=options) - def _document_path_helper(self, *document_path) -> List[str]: + def _document_path_helper(self, *document_path): """Standardize the format of path to tuple of path segments and strip the database string from path if present. Args: @@ -327,13 +325,13 @@ def _document_path_helper(self, *document_path) -> List[str]: def recursive_delete( self, - reference: Union[BaseCollectionReference[BaseQuery], BaseDocumentReference], - bulk_writer: Optional["BulkWriter"] = None, # type: ignore - ) -> int: + reference, + bulk_writer=None, # type: ignore + ): raise NotImplementedError @staticmethod - def field_path(*field_names: str) -> str: + def field_path(*field_names): """Create a **field path** from a list of nested field names. A **field path** is a ``.``-delimited concatenation of the field @@ -364,7 +362,7 @@ def field_path(*field_names: str) -> str: @staticmethod def write_option( **kwargs, - ) -> Union[_helpers.ExistsOption, _helpers.LastUpdateOption]: + ): """Create a write option for write operations. Write operations include :meth:`~google.cloud.DocumentReference.set`, @@ -414,12 +412,12 @@ def write_option( def _prep_get_all( self, - references: list, - field_paths: Iterable[str] = None, - transaction: BaseTransaction = None, - retry: retries.Retry = None, - timeout: float = None, - ) -> Tuple[dict, dict, dict]: + references, + field_paths=None, + transaction=None, + retry=None, + timeout=None, + ): """Shared setup for async/sync :meth:`get_all`.""" document_paths, reference_map = _reference_info(references) mask = _get_doc_mask(field_paths) @@ -435,21 +433,19 @@ def _prep_get_all( def get_all( self, - references: list, - field_paths: Iterable[str] = None, - transaction: BaseTransaction = None, - retry: retries.Retry = None, - timeout: float = None, - ) -> Union[ - AsyncGenerator[DocumentSnapshot, Any], Generator[DocumentSnapshot, Any, Any] - ]: + references, + field_paths=None, + transaction=None, + retry=None, + timeout=None, + ): raise NotImplementedError def _prep_collections( self, - retry: retries.Retry = None, - timeout: float = None, - ) -> Tuple[dict, dict]: + retry=None, + timeout=None, + ): """Shared setup for async/sync :meth:`collections`.""" request = {"parent": "{}/documents".format(self._database_string)} kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) @@ -458,22 +454,19 @@ def _prep_collections( def collections( self, - retry: retries.Retry = None, - timeout: float = None, - ) -> Union[ - AsyncGenerator[BaseCollectionReference[BaseQuery], Any], - Generator[BaseCollectionReference[BaseQuery], Any, Any], - ]: + retry=None, + timeout=None, + ): raise NotImplementedError - def batch(self) -> BaseWriteBatch: + def batch(self): raise NotImplementedError - def transaction(self, **kwargs) -> BaseTransaction: + def transaction(self, **kwargs): raise NotImplementedError -def _reference_info(references: list) -> Tuple[list, dict]: +def _reference_info(references): """Get information about document references. Helper for :meth:`~google.cloud.firestore_v1.client.Client.get_all`. @@ -500,7 +493,7 @@ def _reference_info(references: list) -> Tuple[list, dict]: return document_paths, reference_map -def _get_reference(document_path: str, reference_map: dict) -> BaseDocumentReference: +def _get_reference(document_path, reference_map): """Get a document reference from a dictionary. This just wraps a simple dictionary look-up with a helpful error that is @@ -527,10 +520,10 @@ def _get_reference(document_path: str, reference_map: dict) -> BaseDocumentRefer def _parse_batch_get( - get_doc_response: types.BatchGetDocumentsResponse, - reference_map: dict, - client: BaseClient, -) -> DocumentSnapshot: + get_doc_response, + reference_map, + client, +): """Parse a `BatchGetDocumentsResponse` protobuf. Args: @@ -580,7 +573,7 @@ def _parse_batch_get( return snapshot -def _get_doc_mask(field_paths: Iterable[str]) -> Optional[types.common.DocumentMask]: +def _get_doc_mask(field_paths): """Get a document mask if field paths are provided. Args: @@ -598,7 +591,7 @@ def _get_doc_mask(field_paths: Iterable[str]) -> Optional[types.common.DocumentM return types.DocumentMask(field_paths=field_paths) -def _path_helper(path: tuple) -> Tuple[str]: +def _path_helper(path): """Standardize path into a tuple of path segments. Args: diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index a9d644c4b..437cfb49b 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -76,7 +76,7 @@ class BaseCollectionReference(Generic[QueryType]): TypeError: If a keyword other than ``client`` is used. """ - def __init__(self, *path, **kwargs) -> None: + def __init__(self, *path, **kwargs): _helpers.verify_path(path, is_collection=True) self._path = path self._client = kwargs.pop("client", None) @@ -114,13 +114,13 @@ def parent(self): parent_path = self._path[:-1] return self._client.document(*parent_path) - def _query(self) -> QueryType: + def _query(self): raise NotImplementedError - def _aggregation_query(self) -> BaseAggregationQuery: + def _aggregation_query(self): raise NotImplementedError - def document(self, document_id: Optional[str] = None) -> DocumentReference: + def document(self, document_id=None): """Create a sub-document underneath the current collection. Args: @@ -142,7 +142,7 @@ def document(self, document_id: Optional[str] = None) -> DocumentReference: child_path = self._path + (document_id,) if self._path[0] else (document_id,) return self._client.document(*child_path) - def _parent_info(self) -> Tuple[Any, str]: + def _parent_info(self): """Get fully-qualified parent path and prefix for this collection. Returns: @@ -166,11 +166,11 @@ def _parent_info(self) -> Tuple[Any, str]: def _prep_add( self, - document_data: dict, - document_id: Optional[str] = None, - retry: Optional[retries.Retry] = None, - timeout: Optional[float] = None, - ) -> Tuple[DocumentReference, dict]: + document_data, + document_id=None, + retry=None, + timeout=None, + ): """Shared setup for async / sync :method:`add`""" if document_id is None: document_id = _auto_id() @@ -182,19 +182,19 @@ def _prep_add( def add( self, - document_data: dict, - document_id: Optional[str] = None, - retry: Optional[retries.Retry] = None, - timeout: Optional[float] = None, - ) -> Union[Tuple[Any, Any], Coroutine[Any, Any, Tuple[Any, Any]]]: + document_data, + document_id=None, + retry=None, + timeout=None, + ): raise NotImplementedError def _prep_list_documents( self, - page_size: Optional[int] = None, - retry: Optional[retries.Retry] = None, - timeout: Optional[float] = None, - ) -> Tuple[dict, dict]: + page_size=None, + retry=None, + timeout=None, + ): """Shared setup for async / sync :method:`list_documents`""" parent, _ = self._parent_info() request = { @@ -213,18 +213,16 @@ def _prep_list_documents( def list_documents( self, - page_size: Optional[int] = None, - retry: Optional[retries.Retry] = None, - timeout: Optional[float] = None, - ) -> Union[ - Generator[DocumentReference, Any, Any], AsyncGenerator[DocumentReference, Any] - ]: + page_size=None, + retry=None, + timeout=None, + ): raise NotImplementedError - def recursive(self) -> QueryType: + def recursive(self): return self._query().recursive() - def select(self, field_paths: Iterable[str]) -> QueryType: + def select(self, field_paths): """Create a "select" query with this collection as parent. See @@ -245,12 +243,12 @@ def select(self, field_paths: Iterable[str]) -> QueryType: def where( self, - field_path: Optional[str] = None, - op_string: Optional[str] = None, + field_path=None, + op_string=None, value=None, *, filter=None, - ) -> QueryType: + ): """Create a "where" query with this collection as parent. See @@ -296,7 +294,7 @@ def where( else: return query.where(filter=filter) - def order_by(self, field_path: str, **kwargs) -> QueryType: + def order_by(self, field_path, **kwargs): """Create an "order by" query with this collection as parent. See @@ -318,7 +316,7 @@ def order_by(self, field_path: str, **kwargs) -> QueryType: query = self._query() return query.order_by(field_path, **kwargs) - def limit(self, count: int) -> QueryType: + def limit(self, count): """Create a limited query with this collection as parent. .. note:: @@ -340,7 +338,7 @@ def limit(self, count: int) -> QueryType: query = self._query() return query.limit(count) - def limit_to_last(self, count: int): + def limit_to_last(self, count): """Create a limited to last query with this collection as parent. .. note:: @@ -361,7 +359,7 @@ def limit_to_last(self, count: int): query = self._query() return query.limit_to_last(count) - def offset(self, num_to_skip: int) -> QueryType: + def offset(self, num_to_skip): """Skip to an offset in a query with this collection as parent. See @@ -379,9 +377,7 @@ def offset(self, num_to_skip: int) -> QueryType: query = self._query() return query.offset(num_to_skip) - def start_at( - self, document_fields: Union[DocumentSnapshot, dict, list, tuple] - ) -> QueryType: + def start_at(self, document_fields): """Start query at a cursor with this collection as parent. See @@ -402,9 +398,7 @@ def start_at( query = self._query() return query.start_at(document_fields) - def start_after( - self, document_fields: Union[DocumentSnapshot, dict, list, tuple] - ) -> QueryType: + def start_after(self, document_fields): """Start query after a cursor with this collection as parent. See @@ -425,9 +419,7 @@ def start_after( query = self._query() return query.start_after(document_fields) - def end_before( - self, document_fields: Union[DocumentSnapshot, dict, list, tuple] - ) -> QueryType: + def end_before(self, document_fields): """End query before a cursor with this collection as parent. See @@ -448,9 +440,7 @@ def end_before( query = self._query() return query.end_before(document_fields) - def end_at( - self, document_fields: Union[DocumentSnapshot, dict, list, tuple] - ) -> QueryType: + def end_at(self, document_fields): """End query at a cursor with this collection as parent. See @@ -473,9 +463,9 @@ def end_at( def _prep_get_or_stream( self, - retry: Optional[retries.Retry] = None, - timeout: Optional[float] = None, - ) -> Tuple[Any, dict]: + retry=None, + timeout=None, + ): """Shared setup for async / sync :meth:`get` / :meth:`stream`""" query = self._query() kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) @@ -484,23 +474,21 @@ def _prep_get_or_stream( def get( self, - transaction: Optional[Transaction] = None, - retry: Optional[retries.Retry] = None, - timeout: Optional[float] = None, - ) -> Union[ - Generator[DocumentSnapshot, Any, Any], AsyncGenerator[DocumentSnapshot, Any] - ]: + transaction=None, + retry=None, + timeout=None, + ): raise NotImplementedError def stream( self, - transaction: Optional[Transaction] = None, - retry: Optional[retries.Retry] = None, - timeout: Optional[float] = None, - ) -> Union[Iterator[DocumentSnapshot], AsyncIterator[DocumentSnapshot]]: + transaction=None, + retry=None, + timeout=None, + ): raise NotImplementedError - def on_snapshot(self, callback) -> NoReturn: + def on_snapshot(self, callback): raise NotImplementedError def count(self, alias=None): @@ -512,7 +500,7 @@ def count(self, alias=None): """ return self._aggregation_query().count(alias=alias) - def sum(self, field_ref: str | FieldPath, alias=None): + def sum(self, field_ref, alias=None): """ Adds a sum over the nested query. @@ -526,7 +514,7 @@ def sum(self, field_ref: str | FieldPath, alias=None): """ return self._aggregation_query().sum(field_ref, alias=alias) - def avg(self, field_ref: str | FieldPath, alias=None): + def avg(self, field_ref, alias=None): """ Adds an avg over the nested query. @@ -540,7 +528,7 @@ def avg(self, field_ref: str | FieldPath, alias=None): return self._aggregation_query().avg(field_ref, alias=alias) -def _auto_id() -> str: +def _auto_id(): """Generate a "random" automatically generated ID. Returns: @@ -551,7 +539,7 @@ def _auto_id() -> str: return "".join(random.choice(_AUTO_ID_CHARS) for _ in range(20)) -def _item_to_document_ref(collection_reference, item) -> DocumentReference: +def _item_to_document_ref(collection_reference, item): """Convert Document resource to document ref. Args: diff --git a/google/cloud/firestore_v1/base_document.py b/google/cloud/firestore_v1/base_document.py index 3997b5b4d..a8240cc53 100644 --- a/google/cloud/firestore_v1/base_document.py +++ b/google/cloud/firestore_v1/base_document.py @@ -56,7 +56,7 @@ class BaseDocumentReference(object): _document_path_internal = None - def __init__(self, *path, **kwargs) -> None: + def __init__(self, *path, **kwargs): _helpers.verify_path(path, is_collection=False) self._path = path self._client = kwargs.pop("client", None) @@ -172,7 +172,7 @@ def parent(self): parent_path = self._path[:-1] return self._client.collection(*parent_path) - def collection(self, collection_id: str) -> Any: + def collection(self, collection_id): """Create a sub-collection underneath the current document. Args: @@ -188,10 +188,10 @@ def collection(self, collection_id: str) -> Any: def _prep_create( self, - document_data: dict, - retry: retries.Retry = None, - timeout: float = None, - ) -> Tuple[Any, dict]: + document_data, + retry=None, + timeout=None, + ): batch = self._client.batch() batch.create(self, document_data) kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) @@ -200,19 +200,19 @@ def _prep_create( def create( self, - document_data: dict, - retry: retries.Retry = None, - timeout: float = None, - ) -> NoReturn: + document_data, + retry=None, + timeout=None, + ): raise NotImplementedError def _prep_set( self, - document_data: dict, - merge: bool = False, - retry: retries.Retry = None, - timeout: float = None, - ) -> Tuple[Any, dict]: + document_data, + merge=False, + retry=None, + timeout=None, + ): batch = self._client.batch() batch.set(self, document_data, merge=merge) kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) @@ -221,20 +221,20 @@ def _prep_set( def set( self, - document_data: dict, - merge: bool = False, - retry: retries.Retry = None, - timeout: float = None, - ) -> NoReturn: + document_data, + merge=False, + retry=None, + timeout=None, + ): raise NotImplementedError def _prep_update( self, - field_updates: dict, - option: _helpers.WriteOption = None, - retry: retries.Retry = None, - timeout: float = None, - ) -> Tuple[Any, dict]: + field_updates, + option=None, + retry=None, + timeout=None, + ): batch = self._client.batch() batch.update(self, field_updates, option=option) kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) @@ -243,19 +243,19 @@ def _prep_update( def update( self, - field_updates: dict, - option: _helpers.WriteOption = None, - retry: retries.Retry = None, - timeout: float = None, - ) -> NoReturn: + field_updates, + option=None, + retry=None, + timeout=None, + ): raise NotImplementedError def _prep_delete( self, - option: _helpers.WriteOption = None, - retry: retries.Retry = None, - timeout: float = None, - ) -> Tuple[dict, dict]: + option=None, + retry=None, + timeout=None, + ): """Shared setup for async/sync :meth:`delete`.""" write_pb = _helpers.pb_for_delete(self._document_path, option) request = { @@ -269,19 +269,19 @@ def _prep_delete( def delete( self, - option: _helpers.WriteOption = None, - retry: retries.Retry = None, - timeout: float = None, - ) -> NoReturn: + option=None, + retry=None, + timeout=None, + ): raise NotImplementedError def _prep_batch_get( self, - field_paths: Iterable[str] = None, + field_paths=None, transaction=None, - retry: retries.Retry = None, - timeout: float = None, - ) -> Tuple[dict, dict]: + retry=None, + timeout=None, + ): """Shared setup for async/sync :meth:`get`.""" if isinstance(field_paths, str): raise ValueError("'field_paths' must be a sequence of paths, not a string.") @@ -303,19 +303,19 @@ def _prep_batch_get( def get( self, - field_paths: Iterable[str] = None, + field_paths=None, transaction=None, - retry: retries.Retry = None, - timeout: float = None, - ) -> "DocumentSnapshot": + retry=None, + timeout=None, + ): raise NotImplementedError def _prep_collections( self, - page_size: int = None, - retry: retries.Retry = None, - timeout: float = None, - ) -> Tuple[dict, dict]: + page_size=None, + retry=None, + timeout=None, + ): """Shared setup for async/sync :meth:`collections`.""" request = {"parent": self._document_path, "page_size": page_size} kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) @@ -324,13 +324,13 @@ def _prep_collections( def collections( self, - page_size: int = None, - retry: retries.Retry = None, - timeout: float = None, - ) -> None: + page_size=None, + retry=None, + timeout=None, + ): raise NotImplementedError - def on_snapshot(self, callback) -> None: + def on_snapshot(self, callback): raise NotImplementedError @@ -362,9 +362,7 @@ class DocumentSnapshot(object): The time that this document was last updated. """ - def __init__( - self, reference, data, exists, read_time, create_time, update_time - ) -> None: + def __init__(self, reference, data, exists, read_time, create_time, update_time): self._reference = reference # We want immutable data, so callers can't modify this value # out from under us. @@ -423,7 +421,7 @@ def reference(self): """ return self._reference - def get(self, field_path: str) -> Any: + def get(self, field_path): """Get a value from the snapshot data. If the data is nested, for example: @@ -487,7 +485,7 @@ def get(self, field_path: str) -> Any: nested_data = field_path_module.get_nested_value(field_path, self._data) return copy.deepcopy(nested_data) - def to_dict(self) -> Union[Dict[str, Any], None]: + def to_dict(self): """Retrieve the data contained in this snapshot. A copy is returned since the data may contain mutable values, @@ -502,11 +500,11 @@ def to_dict(self) -> Union[Dict[str, Any], None]: return None return copy.deepcopy(self._data) - def _to_protobuf(self) -> Optional[Document]: + def _to_protobuf(self): return _helpers.document_snapshot_to_protobuf(self) -def _get_document_path(client, path: Tuple[str]) -> str: +def _get_document_path(client, path): """Convert a path tuple into a full path string. Of the form: @@ -527,7 +525,7 @@ def _get_document_path(client, path: Tuple[str]) -> str: return _helpers.DOCUMENT_PATH_DELIMITER.join(parts) -def _consume_single_get(response_iterator) -> firestore.BatchGetDocumentsResponse: +def _consume_single_get(response_iterator): """Consume a gRPC stream that should contain a single response. The stream will correspond to a ``BatchGetDocuments`` request made @@ -558,7 +556,7 @@ def _consume_single_get(response_iterator) -> firestore.BatchGetDocumentsRespons return all_responses[0] -def _first_write_result(write_results: list) -> write.WriteResult: +def _first_write_result(write_results): """Get first write result from list. For cases where ``len(write_results) > 1``, this assumes the writes diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 81a220ef6..87af46e9f 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -58,17 +58,17 @@ if TYPE_CHECKING: # pragma: NO COVER from google.cloud.firestore_v1.field_path import FieldPath -_BAD_DIR_STRING: str -_BAD_OP_NAN_NULL: str -_BAD_OP_STRING: str -_COMPARISON_OPERATORS: Dict[str, Any] -_EQ_OP: str -_INVALID_CURSOR_TRANSFORM: str -_INVALID_WHERE_TRANSFORM: str -_MISMATCH_CURSOR_W_ORDER_BY: str -_MISSING_ORDER_BY: str -_NO_ORDERS_FOR_CURSOR: str -_operator_enum: Any +# _BAD_DIR_STRING: str +# _BAD_OP_NAN_NULL: str +# _BAD_OP_STRING: str +# _COMPARISON_OPERATORS: Dict[str, Any] +# _EQ_OP: str +# _INVALID_CURSOR_TRANSFORM: str +# _INVALID_WHERE_TRANSFORM: str +# _MISMATCH_CURSOR_W_ORDER_BY: str +# _MISSING_ORDER_BY: str +# _NO_ORDERS_FOR_CURSOR: str +# _operator_enum: Any _EQ_OP = "==" @@ -297,7 +297,7 @@ def __init__( end_at=None, all_descendants=False, recursive=False, - ) -> None: + ): self._parent = parent self._projection = projection self._field_filters = field_filters @@ -336,7 +336,7 @@ def _client(self): """ return self._parent._client - def select(self: QueryType, field_paths: Iterable[str]) -> QueryType: + def select(self, field_paths): """Project documents matching query to a limited set of fields. See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for @@ -371,19 +371,19 @@ def select(self: QueryType, field_paths: Iterable[str]) -> QueryType: return self._copy(projection=new_projection) def _copy( - self: QueryType, + self, *, - projection: Optional[query.StructuredQuery.Projection] = _not_passed, - field_filters: Optional[Tuple[query.StructuredQuery.FieldFilter]] = _not_passed, - orders: Optional[Tuple[query.StructuredQuery.Order]] = _not_passed, - limit: Optional[int] = _not_passed, - limit_to_last: Optional[bool] = _not_passed, - offset: Optional[int] = _not_passed, - start_at: Optional[Tuple[dict, bool]] = _not_passed, - end_at: Optional[Tuple[dict, bool]] = _not_passed, - all_descendants: Optional[bool] = _not_passed, - recursive: Optional[bool] = _not_passed, - ) -> QueryType: + projection=_not_passed, + field_filters=_not_passed, + orders=_not_passed, + limit=_not_passed, + limit_to_last=_not_passed, + offset=_not_passed, + start_at=_not_passed, + end_at=_not_passed, + all_descendants=_not_passed, + recursive=_not_passed, + ): return self.__class__( self._parent, projection=self._evaluate_param(projection, self._projection), @@ -406,13 +406,13 @@ def _evaluate_param(self, value, fallback_value): return value if value is not _not_passed else fallback_value def where( - self: QueryType, - field_path: Optional[str] = None, - op_string: Optional[str] = None, + self, + field_path=None, + op_string=None, value=None, *, filter=None, - ) -> QueryType: + ): """Filter the query on a field. See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for @@ -502,16 +502,14 @@ def where( return self._copy(field_filters=new_filters) @staticmethod - def _make_order(field_path, direction) -> StructuredQuery.Order: + def _make_order(field_path, direction): """Helper for :meth:`order_by`.""" return query.StructuredQuery.Order( field=query.StructuredQuery.FieldReference(field_path=field_path), direction=_enum_from_direction(direction), ) - def order_by( - self: QueryType, field_path: str, direction: str = ASCENDING - ) -> QueryType: + def order_by(self, field_path, direction=ASCENDING): """Modify the query to add an order clause on a specific field. See :meth:`~google.cloud.firestore_v1.client.Client.field_path` for @@ -545,7 +543,7 @@ def order_by( new_orders = self._orders + (order_pb,) return self._copy(orders=new_orders) - def limit(self: QueryType, count: int) -> QueryType: + def limit(self, count): """Limit a query to return at most `count` matching results. If the current query already has a `limit` set, this will override it. @@ -564,7 +562,7 @@ def limit(self: QueryType, count: int) -> QueryType: """ return self._copy(limit=count, limit_to_last=False) - def limit_to_last(self: QueryType, count: int) -> QueryType: + def limit_to_last(self, count): """Limit a query to return the last `count` matching results. If the current query already has a `limit_to_last` set, this will override it. @@ -583,13 +581,13 @@ def limit_to_last(self: QueryType, count: int) -> QueryType: """ return self._copy(limit=count, limit_to_last=True) - def _resolve_chunk_size(self, num_loaded: int, chunk_size: int) -> int: + def _resolve_chunk_size(self, num_loaded, chunk_size): """Utility function for chunkify.""" if self._limit is not None and (num_loaded + chunk_size) > self._limit: return max(self._limit - num_loaded, 0) return chunk_size - def offset(self: QueryType, num_to_skip: int) -> QueryType: + def offset(self, num_to_skip): """Skip to an offset in a query. If the current query already has specified an offset, this will @@ -606,7 +604,7 @@ def offset(self: QueryType, num_to_skip: int) -> QueryType: """ return self._copy(offset=num_to_skip) - def _check_snapshot(self, document_snapshot) -> None: + def _check_snapshot(self, document_snapshot): """Validate local snapshots for non-collection-group queries. Raises: @@ -620,11 +618,11 @@ def _check_snapshot(self, document_snapshot) -> None: raise ValueError("Cannot use snapshot from another collection as a cursor.") def _cursor_helper( - self: QueryType, - document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple], - before: bool, - start: bool, - ) -> QueryType: + self, + document_fields_or_snapshot, + before, + start, + ): """Set values to be used for a ``start_at`` or ``end_at`` cursor. The values will later be used in a query protobuf. @@ -677,9 +675,9 @@ def _cursor_helper( return self._copy(**query_kwargs) def start_at( - self: QueryType, - document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple], - ) -> QueryType: + self, + document_fields_or_snapshot, + ): """Start query results at a particular document value. The result set will **include** the document specified by @@ -710,9 +708,9 @@ def start_at( return self._cursor_helper(document_fields_or_snapshot, before=True, start=True) def start_after( - self: QueryType, - document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple], - ) -> QueryType: + self, + document_fields_or_snapshot, + ): """Start query results after a particular document value. The result set will **exclude** the document specified by @@ -744,9 +742,9 @@ def start_after( ) def end_before( - self: QueryType, - document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple], - ) -> QueryType: + self, + document_fields_or_snapshot, + ): """End query results before a particular document value. The result set will **exclude** the document specified by @@ -778,9 +776,9 @@ def end_before( ) def end_at( - self: QueryType, - document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple], - ) -> QueryType: + self, + document_fields_or_snapshot, + ): """End query results at a particular document value. The result set will **include** the document specified by @@ -811,7 +809,7 @@ def end_at( document_fields_or_snapshot, before=False, start=False ) - def _filters_pb(self) -> Optional[StructuredQuery.Filter]: + def _filters_pb(self): """Convert all the filters into a single generic Filter protobuf. This may be a lone field filter or unary filter, may be a composite @@ -845,7 +843,7 @@ def _filters_pb(self) -> Optional[StructuredQuery.Filter]: return query.StructuredQuery.Filter(composite_filter=composite_filter) @staticmethod - def _normalize_projection(projection) -> StructuredQuery.Projection: + def _normalize_projection(projection): """Helper: convert field paths to message.""" if projection is not None: fields = list(projection.fields) @@ -856,7 +854,7 @@ def _normalize_projection(projection) -> StructuredQuery.Projection: return projection - def _normalize_orders(self) -> list: + def _normalize_orders(self): """Helper: adjust orders based on cursors, where clauses.""" orders = list(self._orders) _has_snapshot_cursor = False @@ -886,7 +884,7 @@ def _normalize_orders(self) -> list: return orders - def _normalize_cursor(self, cursor, orders) -> Optional[Tuple[Any, Any]]: + def _normalize_cursor(self, cursor, orders): """Helper: convert cursor to a list of values based on orders.""" if cursor is None: return None @@ -942,7 +940,7 @@ def _normalize_cursor(self, cursor, orders) -> Optional[Tuple[Any, Any]]: return document_fields, before - def _to_protobuf(self) -> StructuredQuery: + def _to_protobuf(self): """Convert the current query into the equivalent protobuf. Returns: @@ -972,35 +970,29 @@ def _to_protobuf(self) -> StructuredQuery: query_kwargs["limit"] = wrappers_pb2.Int32Value(value=self._limit) return query.StructuredQuery(**query_kwargs) - def count( - self, alias: str | None = None - ) -> Type["firestore_v1.base_aggregation.BaseAggregationQuery"]: + def count(self, alias=None): raise NotImplementedError - def sum( - self, field_ref: str | FieldPath, alias: str | None = None - ) -> Type["firestore_v1.base_aggregation.BaseAggregationQuery"]: + def sum(self, field_ref, alias=None): raise NotImplementedError - def avg( - self, field_ref: str | FieldPath, alias: str | None = None - ) -> Type["firestore_v1.base_aggregation.BaseAggregationQuery"]: + def avg(self, field_ref, alias=None): raise NotImplementedError def get( self, transaction=None, - retry: Optional[retries.Retry] = None, - timeout: Optional[float] = None, - ) -> Iterable[DocumentSnapshot]: + retry=None, + timeout=None, + ): raise NotImplementedError def _prep_stream( self, transaction=None, - retry: Optional[retries.Retry] = None, - timeout: Optional[float] = None, - ) -> Tuple[dict, str, dict]: + retry=None, + timeout=None, + ): """Shared setup for async / sync :meth:`stream`""" if self._limit_to_last: raise ValueError( @@ -1021,15 +1013,15 @@ def _prep_stream( def stream( self, transaction=None, - retry: Optional[retries.Retry] = None, - timeout: Optional[float] = None, - ) -> Generator[document.DocumentSnapshot, Any, None]: + retry=None, + timeout=None, + ): raise NotImplementedError - def on_snapshot(self, callback) -> NoReturn: + def on_snapshot(self, callback): raise NotImplementedError - def recursive(self: QueryType) -> QueryType: + def recursive(self): """Returns a copy of this query whose iterator will yield all matching documents as well as each of their descendent subcollections and documents. @@ -1069,7 +1061,7 @@ def recursive(self: QueryType) -> QueryType: return copied - def _comparator(self, doc1, doc2) -> int: + def _comparator(self, doc1, doc2): _orders = self._orders # Add implicit sorting by name, using the last specified direction. @@ -1116,7 +1108,7 @@ def _comparator(self, doc1, doc2) -> int: return 0 -def _enum_from_op_string(op_string: str) -> int: +def _enum_from_op_string(op_string): """Convert a string representation of a binary operator to an enum. These enums come from the protobuf message definition @@ -1141,7 +1133,7 @@ def _enum_from_op_string(op_string: str) -> int: raise ValueError(msg) -def _isnan(value) -> bool: +def _isnan(value): """Check if a value is NaN. This differs from ``math.isnan`` in that **any** input type is @@ -1159,7 +1151,7 @@ def _isnan(value) -> bool: return False -def _enum_from_direction(direction: str) -> int: +def _enum_from_direction(direction): """Convert a string representation of a direction to an enum. Args: @@ -1187,7 +1179,7 @@ def _enum_from_direction(direction: str) -> int: raise ValueError(msg) -def _filter_pb(field_or_unary) -> StructuredQuery.Filter: +def _filter_pb(field_or_unary): """Convert a specific protobuf filter to the generic filter type. Args: @@ -1211,7 +1203,7 @@ def _filter_pb(field_or_unary) -> StructuredQuery.Filter: raise ValueError("Unexpected filter type", type(field_or_unary), field_or_unary) -def _cursor_pb(cursor_pair: Tuple[list, bool]) -> Optional[Cursor]: +def _cursor_pb(cursor_pair): """Convert a cursor pair to a protobuf. If ``cursor_pair`` is :data:`None`, just returns :data:`None`. @@ -1232,9 +1224,7 @@ def _cursor_pb(cursor_pair: Tuple[list, bool]) -> Optional[Cursor]: return query.Cursor(values=value_pbs, before=before) -def _query_response_to_snapshot( - response_pb: RunQueryResponse, collection, expected_prefix: str -) -> Optional[document.DocumentSnapshot]: +def _query_response_to_snapshot(response_pb, collection, expected_prefix): """Parse a query response protobuf to a document snapshot. Args: @@ -1268,9 +1258,7 @@ def _query_response_to_snapshot( return snapshot -def _collection_group_query_response_to_snapshot( - response_pb: RunQueryResponse, collection -) -> Optional[document.DocumentSnapshot]: +def _collection_group_query_response_to_snapshot(response_pb, collection): """Parse a query response protobuf to a document snapshot. Args: @@ -1331,7 +1319,7 @@ def __init__( end_at=None, all_descendants=True, recursive=False, - ) -> None: + ): if not all_descendants: raise ValueError("all_descendants must be True for collection group query.") @@ -1368,9 +1356,9 @@ def _get_query_class(self): def _prep_get_partitions( self, partition_count, - retry: Optional[retries.Retry] = None, - timeout: Optional[float] = None, - ) -> Tuple[dict, dict]: + retry=None, + timeout=None, + ): self._validate_partition_query() parent_path, expected_prefix = self._parent._parent_info() klass = self._get_query_class() @@ -1393,13 +1381,13 @@ def _prep_get_partitions( def get_partitions( self, partition_count, - retry: Optional[retries.Retry] = None, - timeout: Optional[float] = None, - ) -> NoReturn: + retry=None, + timeout=None, + ): raise NotImplementedError @staticmethod - def _get_collection_reference_class() -> Type["BaseCollectionGroup"]: + def _get_collection_reference_class(): raise NotImplementedError diff --git a/google/cloud/firestore_v1/base_transaction.py b/google/cloud/firestore_v1/base_transaction.py index b4e5dd038..6cb943660 100644 --- a/google/cloud/firestore_v1/base_transaction.py +++ b/google/cloud/firestore_v1/base_transaction.py @@ -19,33 +19,33 @@ from google.cloud.firestore_v1 import types from typing import Any, Coroutine, NoReturn, Optional, Union -_CANT_BEGIN: str -_CANT_COMMIT: str -_CANT_RETRY_READ_ONLY: str -_CANT_ROLLBACK: str -_EXCEED_ATTEMPTS_TEMPLATE: str -_INITIAL_SLEEP: float -_MAX_SLEEP: float -_MISSING_ID_TEMPLATE: str -_MULTIPLIER: float -_WRITE_READ_ONLY: str +# _CANT_BEGIN: str +# _CANT_COMMIT: str +# _CANT_RETRY_READ_ONLY: str +# _CANT_ROLLBACK: str +# _EXCEED_ATTEMPTS_TEMPLATE: str +# _INITIAL_SLEEP: float +# _MAX_SLEEP: float +# _MISSING_ID_TEMPLATE: str +# _MULTIPLIER: float +# _WRITE_READ_ONLY: str MAX_ATTEMPTS = 5 """int: Default number of transaction attempts (with retries).""" -_CANT_BEGIN: str = "The transaction has already begun. Current transaction ID: {!r}." -_MISSING_ID_TEMPLATE: str = "The transaction has no transaction ID, so it cannot be {}." -_CANT_ROLLBACK: str = _MISSING_ID_TEMPLATE.format("rolled back") -_CANT_COMMIT: str = _MISSING_ID_TEMPLATE.format("committed") -_WRITE_READ_ONLY: str = "Cannot perform write operation in read-only transaction." -_INITIAL_SLEEP: float = 1.0 +_CANT_BEGIN = "The transaction has already begun. Current transaction ID: {!r}." +_MISSING_ID_TEMPLATE = "The transaction has no transaction ID, so it cannot be {}." +_CANT_ROLLBACK = _MISSING_ID_TEMPLATE.format("rolled back") +_CANT_COMMIT = _MISSING_ID_TEMPLATE.format("committed") +_WRITE_READ_ONLY = "Cannot perform write operation in read-only transaction." +_INITIAL_SLEEP = 1.0 """float: Initial "max" for sleep interval. To be used in :func:`_sleep`.""" -_MAX_SLEEP: float = 30.0 +_MAX_SLEEP = 30.0 """float: Eventual "max" sleep time. To be used in :func:`_sleep`.""" -_MULTIPLIER: float = 2.0 +_MULTIPLIER = 2.0 """float: Multiplier for exponential backoff. To be used in :func:`_sleep`.""" -_EXCEED_ATTEMPTS_TEMPLATE: str = "Failed to commit transaction in {:d} attempts." -_CANT_RETRY_READ_ONLY: str = "Only read-write transactions can be retried." +_EXCEED_ATTEMPTS_TEMPLATE = "Failed to commit transaction in {:d} attempts." +_CANT_RETRY_READ_ONLY = "Only read-write transactions can be retried." class BaseTransaction(object): @@ -60,17 +60,15 @@ class BaseTransaction(object): :data:`False`. """ - def __init__(self, max_attempts=MAX_ATTEMPTS, read_only=False) -> None: + def __init__(self, max_attempts=MAX_ATTEMPTS, read_only=False): self._max_attempts = max_attempts self._read_only = read_only self._id = None - def _add_write_pbs(self, write_pbs) -> NoReturn: + def _add_write_pbs(self, write_pbs): raise NotImplementedError - def _options_protobuf( - self, retry_id: Union[bytes, None] - ) -> Optional[types.common.TransactionOptions]: + def _options_protobuf(self, retry_id): """Convert the current object to protobuf. The ``retry_id`` value is used when retrying a transaction that @@ -125,7 +123,7 @@ def id(self): """ return self._id - def _clean_up(self) -> None: + def _clean_up(self): """Clean up the instance after :meth:`_rollback`` or :meth:`_commit``. This intended to occur on success or failure of the associated RPCs. @@ -133,29 +131,29 @@ def _clean_up(self) -> None: self._write_pbs = [] self._id = None - def _begin(self, retry_id=None) -> NoReturn: + def _begin(self, retry_id=None): raise NotImplementedError - def _rollback(self) -> NoReturn: + def _rollback(self): raise NotImplementedError - def _commit(self) -> Union[list, Coroutine[Any, Any, list]]: + def _commit(self): raise NotImplementedError def get_all( self, - references: list, - retry: retries.Retry = None, - timeout: float = None, - ) -> NoReturn: + references, + retry=None, + timeout=None, + ): raise NotImplementedError def get( self, ref_or_query, - retry: retries.Retry = None, - timeout: float = None, - ) -> NoReturn: + retry=None, + timeout=None, + ): raise NotImplementedError @@ -170,19 +168,19 @@ class _BaseTransactional(object): A callable that should be run (and retried) in a transaction. """ - def __init__(self, to_wrap) -> None: + def __init__(self, to_wrap): self.to_wrap = to_wrap self.current_id = None """Optional[bytes]: The current transaction ID.""" self.retry_id = None """Optional[bytes]: The ID of the first attempted transaction.""" - def _reset(self) -> None: + def _reset(self): """Unset the transaction IDs.""" self.current_id = None self.retry_id = None - def _pre_commit(self, transaction, *args, **kwargs) -> NoReturn: + def _pre_commit(self, transaction, *args, **kwargs): raise NotImplementedError def __call__(self, transaction, *args, **kwargs): diff --git a/google/cloud/firestore_v1/batch.py b/google/cloud/firestore_v1/batch.py index 5fa788041..a411eef24 100644 --- a/google/cloud/firestore_v1/batch.py +++ b/google/cloud/firestore_v1/batch.py @@ -34,12 +34,10 @@ class WriteBatch(BaseWriteBatch): The client that created this batch. """ - def __init__(self, client) -> None: + def __init__(self, client): super(WriteBatch, self).__init__(client=client) - def commit( - self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None - ) -> list: + def commit(self, retry=gapic_v1.method.DEFAULT, timeout=None): """Commit the changes accumulated in this batch. Args: diff --git a/google/cloud/firestore_v1/bulk_batch.py b/google/cloud/firestore_v1/bulk_batch.py index 7df48e586..b4c1978e0 100644 --- a/google/cloud/firestore_v1/bulk_batch.py +++ b/google/cloud/firestore_v1/bulk_batch.py @@ -42,12 +42,10 @@ class BulkWriteBatch(BaseBatch): The client that created this batch. """ - def __init__(self, client) -> None: + def __init__(self, client): super(BulkWriteBatch, self).__init__(client=client) - def commit( - self, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None - ) -> BatchWriteResponse: + def commit(self, retry=gapic_v1.method.DEFAULT, timeout=None): """Writes the changes accumulated in this batch. Write operations are not guaranteed to be applied in order and must not @@ -70,7 +68,7 @@ def commit( request, kwargs = self._prep_commit(retry, timeout) _api = self._client._firestore_api - save_response: BatchWriteResponse = _api.batch_write( + save_response = _api.batch_write( request=request, metadata=self._client._rpc_metadata, **kwargs, @@ -81,7 +79,7 @@ def commit( return save_response - def _prep_commit(self, retry: retries.Retry, timeout: float): + def _prep_commit(self, retry, timeout): request = { "database": self._client._database_string, "writes": self._write_pbs, diff --git a/google/cloud/firestore_v1/bulk_writer.py b/google/cloud/firestore_v1/bulk_writer.py index 9f7d0f624..ef04e50b0 100644 --- a/google/cloud/firestore_v1/bulk_writer.py +++ b/google/cloud/firestore_v1/bulk_writer.py @@ -118,9 +118,7 @@ def wrapper(self, *args, **kwargs): return wrapper @_with_send_mode - def _send_batch( - self, batch: BulkWriteBatch, operations: List["BulkWriterOperation"] - ): + def _send_batch(self, batch, operations): """Sends a batch without regard to rate limits, meaning limits must have already been checked. To that end, do not call this directly; instead, call `_send_until_queue_is_empty`. @@ -128,9 +126,9 @@ def _send_batch( Args: batch(:class:`~google.cloud.firestore_v1.base_batch.BulkWriteBatch`) """ - _len_batch: int = len(batch) + _len_batch = len(batch) self._in_flight_documents += _len_batch - response: BatchWriteResponse = self._send(batch) + response = self._send(batch) self._in_flight_documents -= _len_batch # Update bookkeeping totals @@ -141,20 +139,20 @@ def _send_batch( def _process_response( self, - batch: BulkWriteBatch, - response: BatchWriteResponse, - operations: List["BulkWriterOperation"], - ) -> None: + batch, + response, + operations, + ): """Invokes submitted callbacks for each batch and each operation within each batch. As this is called from `_send_batch()`, this is parallelized if we are in that mode. """ - batch_references: List[BaseDocumentReference] = list( + batch_references = list( batch._document_references.values(), ) self._batch_callback(batch, response, self) - status: status_pb2.Status + # status: status_pb2.Status for index, status in enumerate(response.status): if status.code == 0: self._success_callback( @@ -166,8 +164,8 @@ def _process_response( self, ) else: - operation: BulkWriterOperation = operations[index] - should_retry: bool = self._error_callback( + operation = operations[index] + should_retry = self._error_callback( # BulkWriteFailure BulkWriteFailure( operation=operation, @@ -183,9 +181,9 @@ def _process_response( def _retry_operation( self, - operation: "BulkWriterOperation", - ) -> concurrent.futures.Future: - delay: int = 0 + operation, + ): + delay = 0 if self._options.retry == BulkRetry.exponential: delay = operation.attempts**2 # pragma: NO COVER elif self._options.retry == BulkRetry.linear: @@ -204,7 +202,7 @@ def _retry_operation( OperationRetry(operation=operation, run_at=run_at), ) - def _send(self, batch: BulkWriteBatch) -> BatchWriteResponse: + def _send(self, batch): """Hook for overwriting the sending of batches. As this is only called from `_send_batch()`, this is parallelized if we are in that mode. """ @@ -254,12 +252,12 @@ class BulkWriter(AsyncBulkWriterMixin): The client that created this BulkWriter. """ - batch_size: int = 20 + batch_size = 20 def __init__( self, - client: "BaseClient" = None, - options: Optional["BulkWriterOptions"] = None, + client=None, + options=None, ): # Because `BulkWriter` instances are all synchronous/blocking on the # main thread (instead using other threads for asynchrony), it is @@ -273,11 +271,11 @@ def __init__( self._options = options or BulkWriterOptions() self._send_mode = self._options.mode - self._operations: List[BulkWriterOperation] + self._operations = [] # List of the `_document_path` attribute for each DocumentReference # contained in the current `self._operations`. This is reset every time # `self._operations` is reset. - self._operations_document_paths: List[BaseDocumentReference] + self._operations_document_paths = [] self._reset_operations() # List of all `BulkWriterOperation` objects that are waiting to be retried. @@ -285,62 +283,58 @@ def __init__( # the raw operation with the `datetime` of its next scheduled attempt. # `self._retries` must always remain sorted for efficient reads, so it is # required to only ever add elements via `bisect.insort`. - self._retries: collections.deque["OperationRetry"] = collections.deque([]) + self._retries = collections.deque([]) self._queued_batches = collections.deque([]) - self._is_open: bool = True + self._is_open = True # This list will go on to store the future returned from each submission # to the executor, for the purpose of awaiting all of those futures' # completions in the `flush` method. - self._pending_batch_futures: List[concurrent.futures.Future] = [] - - self._success_callback: Callable[ - [BaseDocumentReference, WriteResult, "BulkWriter"], None - ] = BulkWriter._default_on_success - self._batch_callback: Callable[ - [BulkWriteBatch, BatchWriteResponse, "BulkWriter"], None - ] = BulkWriter._default_on_batch - self._error_callback: Callable[ - [BulkWriteFailure, BulkWriter], bool - ] = BulkWriter._default_on_error - - self._in_flight_documents: int = 0 + self._pending_batch_futures = [] + + self._success_callback = BulkWriter._default_on_success + + self._batch_callback = BulkWriter._default_on_batch + + self._error_callback = BulkWriter._default_on_error + + self._in_flight_documents = 0 self._rate_limiter = RateLimiter( initial_tokens=self._options.initial_ops_per_second, global_max_tokens=self._options.max_ops_per_second, ) # Keep track of progress as batches and write operations are completed - self._total_batches_sent: int = 0 - self._total_write_operations: int = 0 + self._total_batches_sent = 0 + self._total_write_operations = 0 self._ensure_executor() @staticmethod def _default_on_batch( - batch: BulkWriteBatch, - response: BatchWriteResponse, - bulk_writer: "BulkWriter", - ) -> None: + batch, + response, + bulk_writer, + ): pass @staticmethod def _default_on_success( - reference: BaseDocumentReference, - result: WriteResult, - bulk_writer: "BulkWriter", - ) -> None: + reference, + result, + bulk_writer, + ): pass @staticmethod - def _default_on_error(error: "BulkWriteFailure", bulk_writer: "BulkWriter") -> bool: + def _default_on_error(error, bulk_writer): # Default number of retries for each operation is 15. This is a scary # number to combine with an exponential backoff, and as such, our default # backoff strategy is linear instead of exponential. return error.attempts < 15 - def _reset_operations(self) -> None: + def _reset_operations(self): self._operations = [] self._operations_document_paths = [] @@ -471,7 +465,7 @@ def _send_until_queue_is_empty(self): while self._queued_batches: # For FIFO order, add to the right of this deque (via `append`) and take # from the left (via `popleft`). - operations: List[BulkWriterOperation] = self._queued_batches.popleft() + operations = self._queued_batches.popleft() # Block until we are cleared for takeoff, which is fine because this # returns instantly unless the rate limiting logic determines that we @@ -480,7 +474,7 @@ def _send_until_queue_is_empty(self): # Handle some bookkeeping, and ultimately put these bits on the wire. batch = BulkWriteBatch(client=self._client) - op: BulkWriterOperation + # op: BulkWriterOperation for op in operations: op.add_to_batch(batch) @@ -502,20 +496,20 @@ def _schedule_ready_retries(self): ) for _ in range(take_until_index): - retry: OperationRetry = self._retries.popleft() + retry = self._retries.popleft() retry.retry(self) - def _request_send(self, batch_size: int) -> bool: + def _request_send(self, batch_size): # Set up this boolean to avoid repeatedly taking tokens if we're only # waiting on the `max_in_flight` limit. - have_received_tokens: bool = False + have_received_tokens = False while True: # To avoid bottlenecks on the server, an additional limit is that no # more write operations can be "in flight" (sent but still awaiting # response) at any given point than the maximum number of writes per # second. - under_threshold: bool = ( + under_threshold = ( self._in_flight_documents <= self._rate_limiter._maximum_tokens ) # Ask for tokens each pass through this loop until they are granted, @@ -536,10 +530,10 @@ def _request_send(self, batch_size: int) -> bool: def create( self, - reference: BaseDocumentReference, - document_data: Dict, - attempts: int = 0, - ) -> None: + reference, + document_data, + attempts=0, + ): """Adds a `create` pb to the in-progress batch. If the in-progress batch already contains a write operation involving @@ -575,10 +569,10 @@ def create( def delete( self, - reference: BaseDocumentReference, - option: Optional[_helpers.WriteOption] = None, - attempts: int = 0, - ) -> None: + reference, + option=None, + attempts=0, + ): """Adds a `delete` pb to the in-progress batch. If the in-progress batch already contains a write operation involving @@ -614,11 +608,11 @@ def delete( def set( self, - reference: BaseDocumentReference, - document_data: Dict, - merge: Union[bool, list] = False, - attempts: int = 0, - ) -> None: + reference, + document_data, + merge=False, + attempts=0, + ): """Adds a `set` pb to the in-progress batch. If the in-progress batch already contains a write operation involving @@ -658,11 +652,11 @@ def set( def update( self, - reference: BaseDocumentReference, - field_updates: dict, - option: Optional[_helpers.WriteOption] = None, - attempts: int = 0, - ) -> None: + reference, + field_updates, + option=None, + attempts=0, + ): """Adds an `update` pb to the in-progress batch. If the in-progress batch already contains a write operation involving @@ -706,21 +700,19 @@ def update( def on_write_result( self, - callback: Callable[[BaseDocumentReference, WriteResult, "BulkWriter"], None], - ) -> None: + callback, + ): """Sets a callback that will be invoked once for every successful operation.""" self._success_callback = callback or BulkWriter._default_on_success def on_batch_result( self, - callback: Callable[[BulkWriteBatch, BatchWriteResponse, "BulkWriter"], None], - ) -> None: + callback, + ): """Sets a callback that will be invoked once for every successful batch.""" self._batch_callback = callback or BulkWriter._default_on_batch - def on_write_error( - self, callback: Callable[["BulkWriteFailure", "BulkWriter"], bool] - ) -> None: + def on_write_error(self, callback): """Sets a callback that will be invoked once for every batch that contains an error.""" self._error_callback = callback or BulkWriter._default_on_error @@ -740,7 +732,7 @@ class BulkWriterOperation: similar writes to the same document. """ - def add_to_batch(self, batch: BulkWriteBatch): + def add_to_batch(self, batch): """Adds `self` to the supplied batch.""" assert isinstance(batch, BulkWriteBatch) if isinstance(self, BulkWriterCreateOperation): @@ -782,7 +774,7 @@ class BaseOperationRetry: Python 3.6 is dropped and `dataclasses` becomes universal. """ - def __lt__(self, other: "OperationRetry"): + def __lt__(self, other): """Allows use of `bisect` to maintain a sorted list of `OperationRetry` instances, which in turn allows us to cheaply grab all that are ready to run.""" @@ -792,7 +784,7 @@ def __lt__(self, other: "OperationRetry"): return self.run_at < other return NotImplemented # pragma: NO COVER - def retry(self, bulk_writer: BulkWriter) -> None: + def retry(self, bulk_writer): """Call this after waiting any necessary time to re-add the enclosed operation to the supplied BulkWriter's internal queue.""" if isinstance(self.operation, BulkWriterCreateOperation): @@ -832,21 +824,21 @@ def retry(self, bulk_writer: BulkWriter) -> None: @dataclass class BulkWriterOptions: - initial_ops_per_second: int = 500 - max_ops_per_second: int = 500 - mode: SendMode = SendMode.parallel - retry: BulkRetry = BulkRetry.linear + initial_ops_per_second = 500 + max_ops_per_second = 500 + mode = SendMode.parallel + retry = BulkRetry.linear @dataclass class BulkWriteFailure: - operation: BulkWriterOperation + # operation: BulkWriterOperation # https://grpc.github.io/grpc/core/md_doc_statuscodes.html - code: int - message: str + # code: int + # message: str @property - def attempts(self) -> int: + def attempts(self): return self.operation.attempts @@ -855,43 +847,43 @@ class OperationRetry(BaseOperationRetry): """Container for an additional attempt at an operation, scheduled for the future.""" - operation: BulkWriterOperation - run_at: datetime.datetime + # operation: BulkWriterOperation + # run_at: datetime.datetime @dataclass class BulkWriterCreateOperation(BulkWriterOperation): """Container for BulkWriter.create() operations.""" - reference: BaseDocumentReference - document_data: Dict - attempts: int = 0 + # reference: BaseDocumentReference + # document_data: Dict + attempts = 0 @dataclass class BulkWriterUpdateOperation(BulkWriterOperation): """Container for BulkWriter.update() operations.""" - reference: BaseDocumentReference - field_updates: Dict - option: Optional[_helpers.WriteOption] - attempts: int = 0 + # reference: BaseDocumentReference + # field_updates: Dict + # option: Optional[_helpers.WriteOption] + attempts = 0 @dataclass class BulkWriterSetOperation(BulkWriterOperation): """Container for BulkWriter.set() operations.""" - reference: BaseDocumentReference - document_data: Dict - merge: Union[bool, list] = False - attempts: int = 0 + # reference: BaseDocumentReference + # document_data: Dict + merge = False + attempts = 0 @dataclass class BulkWriterDeleteOperation(BulkWriterOperation): """Container for BulkWriter.delete() operations.""" - reference: BaseDocumentReference - option: Optional[_helpers.WriteOption] - attempts: int = 0 + # reference: BaseDocumentReference + # option: Optional[_helpers.WriteOption] + attempts = 0 diff --git a/google/cloud/firestore_v1/client.py b/google/cloud/firestore_v1/client.py index 05c135479..3fceb0f21 100644 --- a/google/cloud/firestore_v1/client.py +++ b/google/cloud/firestore_v1/client.py @@ -87,7 +87,7 @@ def __init__( database=None, client_info=_CLIENT_INFO, client_options=None, - ) -> None: + ): super(Client, self).__init__( project=project, credentials=credentials, @@ -119,7 +119,7 @@ def _target(self): """ return self._target_helper(firestore_client.FirestoreClient) - def collection(self, *collection_path: str) -> CollectionReference: + def collection(self, *collection_path): """Get a reference to a collection. For a top-level collection: @@ -150,7 +150,7 @@ def collection(self, *collection_path: str) -> CollectionReference: """ return CollectionReference(*_path_helper(collection_path), client=self) - def collection_group(self, collection_id: str) -> CollectionGroup: + def collection_group(self, collection_id): """ Creates and returns a new Query that includes all documents in the database that are contained in a collection or subcollection with the @@ -172,7 +172,7 @@ def collection_group(self, collection_id: str) -> CollectionGroup: """ return CollectionGroup(self._get_collection_reference(collection_id)) - def document(self, *document_path: str) -> DocumentReference: + def document(self, *document_path): """Get a reference to a document in a collection. For a top-level document: @@ -209,12 +209,12 @@ def document(self, *document_path: str) -> DocumentReference: def get_all( self, - references: list, - field_paths: Iterable[str] = None, - transaction: Transaction = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - ) -> Generator[DocumentSnapshot, Any, None]: + references, + field_paths=None, + transaction=None, + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """Retrieve a batch of documents. .. note:: @@ -268,9 +268,9 @@ def get_all( def collections( self, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - ) -> Generator[Any, Any, None]: + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """List top-level collections of the client's database. Args: @@ -296,11 +296,11 @@ def collections( def recursive_delete( self, - reference: Union[CollectionReference, DocumentReference], + reference, *, - bulk_writer: Optional["BulkWriter"] = None, - chunk_size: Optional[int] = 5000, - ) -> int: + bulk_writer=None, + chunk_size=5000, + ): """Deletes documents and their subcollections, regardless of collection name. @@ -333,30 +333,30 @@ def recursive_delete( def _recursive_delete( self, - reference: Union[CollectionReference, DocumentReference], - bulk_writer: "BulkWriter", + reference, + bulk_writer, *, - chunk_size: Optional[int] = 5000, - depth: Optional[int] = 0, - ) -> int: + chunk_size=5000, + depth=0, + ): """Recursion helper for `recursive_delete.""" - num_deleted: int = 0 + num_deleted = 0 if isinstance(reference, CollectionReference): - chunk: List[DocumentSnapshot] + # chunk: List[DocumentSnapshot] for chunk in ( reference.recursive() .select([FieldPath.document_id()]) ._chunkify(chunk_size) ): - doc_snap: DocumentSnapshot + # doc_snap: DocumentSnapshot for doc_snap in chunk: num_deleted += 1 bulk_writer.delete(doc_snap.reference) elif isinstance(reference, DocumentReference): - col_ref: CollectionReference + # col_ref: CollectionReference for col_ref in reference.collections(): num_deleted += self._recursive_delete( col_ref, @@ -377,7 +377,7 @@ def _recursive_delete( return num_deleted - def batch(self) -> WriteBatch: + def batch(self): """Get a batch instance from this client. Returns: @@ -387,7 +387,7 @@ def batch(self) -> WriteBatch: """ return WriteBatch(self) - def transaction(self, **kwargs) -> Transaction: + def transaction(self, **kwargs): """Get a transaction that uses this client. See :class:`~google.cloud.firestore_v1.transaction.Transaction` for diff --git a/google/cloud/firestore_v1/collection.py b/google/cloud/firestore_v1/collection.py index f6ba1833d..20c39efe9 100644 --- a/google/cloud/firestore_v1/collection.py +++ b/google/cloud/firestore_v1/collection.py @@ -57,10 +57,10 @@ class CollectionReference(BaseCollectionReference[query_mod.Query]): TypeError: If a keyword other than ``client`` is used. """ - def __init__(self, *path, **kwargs) -> None: + def __init__(self, *path, **kwargs): super(CollectionReference, self).__init__(*path, **kwargs) - def _query(self) -> query_mod.Query: + def _query(self): """Query factory. Returns: @@ -68,7 +68,7 @@ def _query(self) -> query_mod.Query: """ return query_mod.Query(self) - def _aggregation_query(self) -> aggregation.AggregationQuery: + def _aggregation_query(self): """AggregationQuery factory. Returns: @@ -78,11 +78,11 @@ def _aggregation_query(self) -> aggregation.AggregationQuery: def add( self, - document_data: dict, - document_id: Union[str, None] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: Union[float, None] = None, - ) -> Tuple[Any, Any]: + document_data, + document_id=None, + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """Create a document in the Firestore database with the provided data. Args: @@ -121,10 +121,10 @@ def add( def list_documents( self, - page_size: Union[int, None] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: Union[float, None] = None, - ) -> Generator[Any, Any, None]: + page_size=None, + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """List all subdocuments of the current collection. Args: @@ -151,15 +151,15 @@ def list_documents( ) return (_item_to_document_ref(self, i) for i in iterator) - def _chunkify(self, chunk_size: int): + def _chunkify(self, chunk_size): return self._query()._chunkify(chunk_size) def get( self, - transaction: Union[Transaction, None] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: Union[float, None] = None, - ) -> list: + transaction=None, + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """Read the documents in this collection. This sends a ``RunQuery`` RPC and returns a list of documents @@ -187,10 +187,10 @@ def get( def stream( self, - transaction: Union[Transaction, None] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: Union[float, None] = None, - ) -> Generator[document.DocumentSnapshot, Any, None]: + transaction=None, + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """Read the documents in this collection. This sends a ``RunQuery`` RPC and then returns an iterator which @@ -225,7 +225,7 @@ def stream( return query.stream(transaction=transaction, **kwargs) - def on_snapshot(self, callback: Callable) -> Watch: + def on_snapshot(self, callback): """Monitor the documents in this collection. This starts a watch on this collection using a background thread. The diff --git a/google/cloud/firestore_v1/document.py b/google/cloud/firestore_v1/document.py index 00d682d2b..66da7d2b3 100644 --- a/google/cloud/firestore_v1/document.py +++ b/google/cloud/firestore_v1/document.py @@ -60,15 +60,15 @@ class DocumentReference(BaseDocumentReference): TypeError: If a keyword other than ``client`` is used. """ - def __init__(self, *path, **kwargs) -> None: + def __init__(self, *path, **kwargs): super(DocumentReference, self).__init__(*path, **kwargs) def create( self, - document_data: dict, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - ) -> write.WriteResult: + document_data, + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """Create a document in the Firestore database. >>> document_data = {"a": 1, "b": {"c": "Two"}} @@ -101,11 +101,11 @@ def create( def set( self, - document_data: dict, - merge: bool = False, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - ) -> write.WriteResult: + document_data, + merge=False, + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """Create / replace / merge a document in the Firestore database. - To "upsert" a document (create if it doesn't exist, replace completely @@ -169,11 +169,11 @@ def set( def update( self, - field_updates: dict, - option: _helpers.WriteOption = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - ) -> write.WriteResult: + field_updates, + option=None, + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """Update an existing document in the Firestore database. By default, this method verifies that the document exists on the @@ -327,10 +327,10 @@ def update( def delete( self, - option: _helpers.WriteOption = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - ) -> Timestamp: + option=None, + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """Delete the current document in the Firestore database. Args: @@ -361,11 +361,11 @@ def delete( def get( self, - field_paths: Iterable[str] = None, + field_paths=None, transaction=None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - ) -> DocumentSnapshot: + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """Retrieve a snapshot of the current document. See :meth:`~google.cloud.firestore_v1.base_client.BaseClient.field_path` for @@ -431,10 +431,10 @@ def get( def collections( self, - page_size: int = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - ) -> Generator[Any, Any, None]: + page_size=None, + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """List subcollections of the current document. Args: @@ -463,7 +463,7 @@ def collections( for collection_id in iterator: yield self.collection(collection_id) - def on_snapshot(self, callback: Callable) -> Watch: + def on_snapshot(self, callback): """Watch this document. This starts a watch on this document using a background thread. The diff --git a/google/cloud/firestore_v1/field_path.py b/google/cloud/firestore_v1/field_path.py index 24683fb84..ea1cc7b91 100644 --- a/google/cloud/firestore_v1/field_path.py +++ b/google/cloud/firestore_v1/field_path.py @@ -43,7 +43,7 @@ TOKENS_REGEX = re.compile(TOKENS_PATTERN) -def _tokenize_field_path(path: str): +def _tokenize_field_path(path): """Lex a field path into tokens (including dots). Args: @@ -64,7 +64,7 @@ def _tokenize_field_path(path: str): raise ValueError("Path {} not consumed, residue: {}".format(path, path[pos:])) -def split_field_path(path: str): +def split_field_path(path): """Split a field path into valid elements (without dots). Args: @@ -99,7 +99,7 @@ def split_field_path(path: str): return elements -def parse_field_path(api_repr: str): +def parse_field_path(api_repr): """Parse a **field path** from into a list of nested field names. See :func:`field_path` for more on **field paths**. @@ -128,7 +128,7 @@ def parse_field_path(api_repr: str): return field_names -def render_field_path(field_names: Iterable[str]): +def render_field_path(field_names): """Create a **field path** from a list of nested field names. A **field path** is a ``.``-delimited concatenation of the field @@ -172,7 +172,7 @@ def render_field_path(field_names: Iterable[str]): get_field_path = render_field_path # backward-compatibility -def get_nested_value(field_path: str, data: dict): +def get_nested_value(field_path, data): """Get a (potentially nested) value from a dictionary. If the data is nested, for example: @@ -273,7 +273,7 @@ def __init__(self, *parts): self.parts = tuple(parts) @classmethod - def from_api_repr(cls, api_repr: str): + def from_api_repr(cls, api_repr): """Factory: create a FieldPath from the string formatted per the API. Args: @@ -290,7 +290,7 @@ def from_api_repr(cls, api_repr: str): return cls(*parse_field_path(api_repr)) @classmethod - def from_string(cls, path_string: str): + def from_string(cls, path_string): """Factory: create a FieldPath from a unicode string representation. This method splits on the character `.` and disallows the diff --git a/google/cloud/firestore_v1/order.py b/google/cloud/firestore_v1/order.py index 37052f9f5..18bb64259 100644 --- a/google/cloud/firestore_v1/order.py +++ b/google/cloud/firestore_v1/order.py @@ -32,7 +32,7 @@ class TypeOrder(Enum): OBJECT = 9 @staticmethod - def from_value(value) -> Any: + def from_value(value): v = value._pb.WhichOneof("value_type") lut = { @@ -60,7 +60,7 @@ class Order(object): """ @classmethod - def compare(cls, left, right) -> int: + def compare(cls, left, right): """ Main comparison function for all Firestore types. @return -1 is left < right, 0 if left == right, otherwise 1 @@ -102,14 +102,14 @@ def compare(cls, left, right) -> int: raise ValueError(f"Unknown ``value_type`` {value_type}") @staticmethod - def compare_blobs(left, right) -> int: + def compare_blobs(left, right): left_bytes = left.bytes_value right_bytes = right.bytes_value return Order._compare_to(left_bytes, right_bytes) @staticmethod - def compare_timestamps(left, right) -> Any: + def compare_timestamps(left, right): left = left._pb.timestamp_value right = right._pb.timestamp_value @@ -120,7 +120,7 @@ def compare_timestamps(left, right) -> Any: return Order._compare_to(left.nanos or 0, right.nanos or 0) @staticmethod - def compare_geo_points(left, right) -> Any: + def compare_geo_points(left, right): left_value = decode_value(left, None) right_value = decode_value(right, None) cmp = (left_value.latitude > right_value.latitude) - ( @@ -134,7 +134,7 @@ def compare_geo_points(left, right) -> Any: ) @staticmethod - def compare_resource_paths(left, right) -> int: + def compare_resource_paths(left, right): left = left.reference_value right = right.reference_value @@ -153,7 +153,7 @@ def compare_resource_paths(left, right) -> int: return (left_length > right_length) - (left_length < right_length) @staticmethod - def compare_arrays(left, right) -> int: + def compare_arrays(left, right): l_values = left.array_value.values r_values = right.array_value.values @@ -166,7 +166,7 @@ def compare_arrays(left, right) -> int: return Order._compare_to(len(l_values), len(r_values)) @staticmethod - def compare_objects(left, right) -> int: + def compare_objects(left, right): left_fields = left.map_value.fields right_fields = right.map_value.fields @@ -184,13 +184,13 @@ def compare_objects(left, right) -> int: return Order._compare_to(len(left_fields), len(right_fields)) @staticmethod - def compare_numbers(left, right) -> int: + def compare_numbers(left, right): left_value = decode_value(left, None) right_value = decode_value(right, None) return Order.compare_doubles(left_value, right_value) @staticmethod - def compare_doubles(left, right) -> int: + def compare_doubles(left, right): if math.isnan(left): if math.isnan(right): return 0 @@ -201,7 +201,7 @@ def compare_doubles(left, right) -> int: return Order._compare_to(left, right) @staticmethod - def _compare_to(left, right) -> int: + def _compare_to(left, right): # We can't just use cmp(left, right) because cmp doesn't exist # in Python 3, so this is an equivalent suggested by # https://docs.python.org/3.0/whatsnew/3.0.html#ordering-comparisons diff --git a/google/cloud/firestore_v1/query.py b/google/cloud/firestore_v1/query.py index d37964dce..556e3b40c 100644 --- a/google/cloud/firestore_v1/query.py +++ b/google/cloud/firestore_v1/query.py @@ -113,7 +113,7 @@ def __init__( end_at=None, all_descendants=False, recursive=False, - ) -> None: + ): super(Query, self).__init__( parent=parent, projection=projection, @@ -131,9 +131,9 @@ def __init__( def get( self, transaction=None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - ) -> List[DocumentSnapshot]: + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """Read the documents in the collection that match this query. This sends a ``RunQuery`` RPC and returns a list of documents @@ -174,18 +174,16 @@ def get( return list(result) - def _chunkify( - self, chunk_size: int - ) -> Generator[List[DocumentSnapshot], None, None]: - max_to_return: Optional[int] = self._limit - num_returned: int = 0 - original: Query = self._copy() - last_document: Optional[DocumentSnapshot] = None + def _chunkify(self, chunk_size): + max_to_return = self._limit + num_returned = 0 + original = self._copy() + last_document = None while True: # Optionally trim the `chunk_size` down to honor a previously # applied limits as set by `self.limit()` - _chunk_size: int = original._resolve_chunk_size(num_returned, chunk_size) + _chunk_size = original._resolve_chunk_size(num_returned, chunk_size) # Apply the optionally pruned limit and the cursor, if we are past # the first page. @@ -239,9 +237,7 @@ def _retry_query_after_exception(self, exc, retry, transaction): return False - def count( - self, alias: str | None = None - ) -> Type["firestore_v1.aggregation.AggregationQuery"]: + def count(self, alias=None): """ Adds a count over the query. @@ -251,9 +247,7 @@ def count( """ return aggregation.AggregationQuery(self).count(alias=alias) - def sum( - self, field_ref: str | FieldPath, alias: str | None = None - ) -> Type["firestore_v1.aggregation.AggregationQuery"]: + def sum(self, field_ref, alias=None): """ Adds a sum over the query. @@ -266,9 +260,7 @@ def sum( """ return aggregation.AggregationQuery(self).sum(field_ref, alias=alias) - def avg( - self, field_ref: str | FieldPath, alias: str | None = None - ) -> Type["firestore_v1.aggregation.AggregationQuery"]: + def avg(self, field_ref, alias=None): """ Adds an avg over the query. @@ -284,9 +276,9 @@ def avg( def stream( self, transaction=None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - ) -> Generator[document.DocumentSnapshot, Any, None]: + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """Read the documents in the collection that match this query. This sends a ``RunQuery`` RPC and then returns an iterator which @@ -355,7 +347,7 @@ def stream( last_snapshot = snapshot yield snapshot - def on_snapshot(self, callback: Callable) -> Watch: + def on_snapshot(self, callback): """Monitor the documents in this collection that match this query. This starts a watch on this query using a background thread. The @@ -387,9 +379,7 @@ def on_snapshot(docs, changes, read_time): return Watch.for_query(self, callback, document.DocumentSnapshot) @staticmethod - def _get_collection_reference_class() -> ( - Type["firestore_v1.collection.CollectionReference"] - ): + def _get_collection_reference_class(): from google.cloud.firestore_v1.collection import CollectionReference return CollectionReference @@ -420,7 +410,7 @@ def __init__( end_at=None, all_descendants=True, recursive=False, - ) -> None: + ): super(CollectionGroup, self).__init__( parent=parent, projection=projection, @@ -442,9 +432,9 @@ def _get_query_class(): def get_partitions( self, partition_count, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - ) -> Generator[QueryPartition, None, None]: + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """Partition a query for parallelization. Partitions a query by returning partition cursors that can be used to run the diff --git a/google/cloud/firestore_v1/rate_limiter.py b/google/cloud/firestore_v1/rate_limiter.py index 8ca98dbe8..1d41c3797 100644 --- a/google/cloud/firestore_v1/rate_limiter.py +++ b/google/cloud/firestore_v1/rate_limiter.py @@ -30,9 +30,9 @@ def utcnow(): return datetime.datetime.utcnow() -default_initial_tokens: int = 500 -default_phase_length: int = 60 * 5 # 5 minutes -microseconds_per_second: int = 1000000 +default_initial_tokens = 500 +default_phase_length = 60 * 5 # 5 minutes +microseconds_per_second = 1000000 class RateLimiter: @@ -70,28 +70,28 @@ class RateLimiter: def __init__( self, - initial_tokens: int = default_initial_tokens, - global_max_tokens: Optional[int] = None, - phase_length: int = default_phase_length, + initial_tokens=default_initial_tokens, + global_max_tokens=None, + phase_length=default_phase_length, ): # Tracks the volume of operations during a given ramp-up phase. - self._operations_this_phase: int = 0 + self._operations_this_phase = 0 # If provided, this enforces a cap on the maximum number of writes per # second we can ever attempt, regardless of how many 50% increases the # 5/5/5 rule would grant. self._global_max_tokens = global_max_tokens - self._start: Optional[datetime.datetime] = None - self._last_refill: Optional[datetime.datetime] = None + self._start = None + self._last_refill = None # Current number of available operations. Decrements with every # permitted request and refills over time. - self._available_tokens: int = initial_tokens + self._available_tokens = initial_tokens # Maximum size of the available operations. Can increase by 50% # every [phase_length] number of seconds. - self._maximum_tokens: int = self._available_tokens + self._maximum_tokens = self._available_tokens if self._global_max_tokens is not None: self._available_tokens = min( @@ -100,17 +100,17 @@ def __init__( self._maximum_tokens = min(self._maximum_tokens, self._global_max_tokens) # Number of seconds after which the [_maximum_tokens] can increase by 50%. - self._phase_length: int = phase_length + self._phase_length = phase_length # Tracks how many times the [_maximum_tokens] has increased by 50%. - self._phase: int = 0 + self._phase = 0 def _start_clock(self): utcnow = datetime.datetime.now(datetime.timezone.utc) self._start = self._start or utcnow self._last_refill = self._last_refill or utcnow - def take_tokens(self, num: Optional[int] = 1, allow_less: bool = False) -> int: + def take_tokens(self, num=1, allow_less=False): """Returns the number of available tokens, up to the amount requested.""" self._start_clock() self._check_phase() @@ -134,39 +134,37 @@ def _check_phase(self): This is a no-op unless a new [_phase_length] number of seconds since the start was crossed since it was last called. """ - age: datetime.timedelta = ( - datetime.datetime.now(datetime.timezone.utc) - self._start - ) + age = datetime.datetime.now(datetime.timezone.utc) - self._start # Uses integer division to calculate the expected phase. We start in # Phase 0, so until [_phase_length] seconds have passed, this will # not resolve to 1. - expected_phase: int = age.seconds // self._phase_length + expected_phase = age.seconds // self._phase_length # Short-circuit if we are still in the expected phase. if expected_phase == self._phase: return - operations_last_phase: int = self._operations_this_phase + operations_last_phase = self._operations_this_phase self._operations_this_phase = 0 - previous_phase: int = self._phase + previous_phase = self._phase self._phase = expected_phase # No-op if we did nothing for an entire phase if operations_last_phase and self._phase > previous_phase: self._increase_maximum_tokens() - def _increase_maximum_tokens(self) -> NoReturn: + def _increase_maximum_tokens(self): self._maximum_tokens = round(self._maximum_tokens * 1.5) if self._global_max_tokens is not None: self._maximum_tokens = min(self._maximum_tokens, self._global_max_tokens) - def _refill(self) -> NoReturn: + def _refill(self): """Replenishes any tokens that should have regenerated since the last operation.""" - now: datetime.datetime = datetime.datetime.now(datetime.timezone.utc) - time_since_last_refill: datetime.timedelta = now - self._last_refill + now = datetime.datetime.now(datetime.timezone.utc) + time_since_last_refill = now - self._last_refill if time_since_last_refill: self._last_refill = now @@ -179,10 +177,10 @@ def _refill(self) -> NoReturn: # If we have done something in the last 1s, then we know we should # allocate proportional tokens. else: - _percent_of_max: float = ( + _percent_of_max = ( time_since_last_refill.microseconds / microseconds_per_second ) - new_tokens: int = round(_percent_of_max * self._maximum_tokens) + new_tokens = round(_percent_of_max * self._maximum_tokens) # Add the number of provisioned tokens, capped at the maximum size. self._available_tokens = min( diff --git a/google/cloud/firestore_v1/transaction.py b/google/cloud/firestore_v1/transaction.py index 3c175a4ce..bbc362431 100644 --- a/google/cloud/firestore_v1/transaction.py +++ b/google/cloud/firestore_v1/transaction.py @@ -61,11 +61,11 @@ class Transaction(batch.WriteBatch, BaseTransaction): :data:`False`. """ - def __init__(self, client, max_attempts=MAX_ATTEMPTS, read_only=False) -> None: + def __init__(self, client, max_attempts=MAX_ATTEMPTS, read_only=False): super(Transaction, self).__init__(client) BaseTransaction.__init__(self, max_attempts, read_only) - def _add_write_pbs(self, write_pbs: list) -> None: + def _add_write_pbs(self, write_pbs): """Add `Write`` protobufs to this transaction. Args: @@ -80,7 +80,7 @@ def _add_write_pbs(self, write_pbs: list) -> None: super(Transaction, self)._add_write_pbs(write_pbs) - def _begin(self, retry_id: bytes = None) -> None: + def _begin(self, retry_id=None): """Begin the transaction. Args: @@ -103,7 +103,7 @@ def _begin(self, retry_id: bytes = None) -> None: ) self._id = transaction_response.transaction - def _rollback(self) -> None: + def _rollback(self): """Roll back the transaction. Raises: @@ -126,7 +126,7 @@ def _rollback(self) -> None: # clean up, even if rollback fails self._clean_up() - def _commit(self) -> list: + def _commit(self): """Transactionally commit the changes accumulated. Returns: @@ -148,10 +148,10 @@ def _commit(self) -> list: def get_all( self, - references: list, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - ) -> Generator[DocumentSnapshot, Any, None]: + references, + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """Retrieves multiple documents from Firestore. Args: @@ -172,9 +172,9 @@ def get_all( def get( self, ref_or_query, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - ) -> Generator[DocumentSnapshot, Any, None]: + retry=gapic_v1.method.DEFAULT, + timeout=None, + ): """Retrieve a document or a query result from the database. Args: @@ -210,10 +210,10 @@ class _Transactional(_BaseTransactional): A callable that should be run (and retried) in a transaction. """ - def __init__(self, to_wrap) -> None: + def __init__(self, to_wrap): super(_Transactional, self).__init__(to_wrap) - def _pre_commit(self, transaction: Transaction, *args, **kwargs) -> Any: + def _pre_commit(self, transaction, *args, **kwargs): """Begin transaction and call the wrapped callable. Args: @@ -241,7 +241,7 @@ def _pre_commit(self, transaction: Transaction, *args, **kwargs) -> Any: self.retry_id = self.current_id return self.to_wrap(transaction, *args, **kwargs) - def __call__(self, transaction: Transaction, *args, **kwargs): + def __call__(self, transaction, *args, **kwargs): """Execute the wrapped callable within a transaction. Args: @@ -291,7 +291,7 @@ def __call__(self, transaction: Transaction, *args, **kwargs): raise -def transactional(to_wrap: Callable) -> _Transactional: +def transactional(to_wrap): """Decorate a callable so that it runs in a transaction. Args: @@ -306,9 +306,7 @@ def transactional(to_wrap: Callable) -> _Transactional: return _Transactional(to_wrap) -def _commit_with_retry( - client, write_pbs: list, transaction_id: bytes -) -> CommitResponse: +def _commit_with_retry(client, write_pbs, transaction_id): """Call ``Commit`` on the GAPIC client with retry / sleep. Retries the ``Commit`` RPC on Unavailable. Usually this RPC-level @@ -351,9 +349,7 @@ def _commit_with_retry( current_sleep = _sleep(current_sleep) -def _sleep( - current_sleep: float, max_sleep: float = _MAX_SLEEP, multiplier: float = _MULTIPLIER -) -> float: +def _sleep(current_sleep, max_sleep=_MAX_SLEEP, multiplier=_MULTIPLIER): """Sleep and produce a new sleep time. .. _Exponential Backoff And Jitter: https://www.awsarchitectureblog.com/\ diff --git a/google/cloud/firestore_v1/transforms.py b/google/cloud/firestore_v1/transforms.py index f1361c951..83b644608 100644 --- a/google/cloud/firestore_v1/transforms.py +++ b/google/cloud/firestore_v1/transforms.py @@ -20,7 +20,7 @@ class Sentinel(object): __slots__ = ("description",) - def __init__(self, description) -> None: + def __init__(self, description): self.description = description def __repr__(self): @@ -44,7 +44,7 @@ class _ValueList(object): slots = ("_values",) - def __init__(self, values) -> None: + def __init__(self, values): if not isinstance(values, (list, tuple)): raise ValueError("'values' must be a list or tuple.") @@ -97,7 +97,7 @@ class _NumericValue(object): value (int | float): value held in the helper. """ - def __init__(self, value) -> None: + def __init__(self, value): if not isinstance(value, (int, float)): raise ValueError("Pass an integer / float value.") From 678ef229156416ec73f8abd7ea36a39775c6718b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 12 Dec 2023 16:35:35 -0800 Subject: [PATCH 3/4] fixed some types --- google/cloud/firestore_v1/__init__.py | 2 +- google/cloud/firestore_v1/bulk_writer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/google/cloud/firestore_v1/__init__.py b/google/cloud/firestore_v1/__init__.py index a7322fc24..f1c4303e6 100644 --- a/google/cloud/firestore_v1/__init__.py +++ b/google/cloud/firestore_v1/__init__.py @@ -21,7 +21,7 @@ from google.cloud.firestore_v1 import gapic_version as package_version -__version__ = package_version.__version__ +__version__: str = package_version.__version__ from google.cloud.firestore_v1 import types from google.cloud.firestore_v1._helpers import GeoPoint diff --git a/google/cloud/firestore_v1/bulk_writer.py b/google/cloud/firestore_v1/bulk_writer.py index ef04e50b0..db51b9962 100644 --- a/google/cloud/firestore_v1/bulk_writer.py +++ b/google/cloud/firestore_v1/bulk_writer.py @@ -83,7 +83,7 @@ class AsyncBulkWriterMixin: wrapped in a decorator which ensures that the `SendMode` is honored. """ - def _with_send_mode(fn): + def _with_send_mode(fn: Callable) -> Callable: # type: ignore """Decorates a method to ensure it is only called via the executor (IFF the SendMode value is SendMode.parallel!). From 6ffcdde33308246a1b3c49334ce4be2fa2b804a5 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 12 Dec 2023 16:35:45 -0800 Subject: [PATCH 4/4] updated mypy config to ignore google auth --- mypy.ini | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mypy.ini b/mypy.ini index 4505b4854..e16c38a1c 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,3 +1,6 @@ [mypy] -python_version = 3.6 +python_version = 3.7 namespace_packages = True + +[mypy-google.auth.*] +ignore_missing_imports = True