Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [ '3.6', '3.7', '3.8', '3.9', '3.10' ]
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ To install:

`pip install mock-firestore`

Python 3.6+ is required for it to work.
Python 3.8+ is supported.

## Usage

Expand Down
12 changes: 6 additions & 6 deletions mockfirestore/collection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import warnings
from typing import Any, List, Optional, Iterable, Dict, Tuple, Sequence, Union
from google.cloud.firestore_v1.base_query import FieldFilter
from google.cloud.firestore_v1.base_query import FieldFilter, And, Or

from mockfirestore import AlreadyExists
from mockfirestore._helpers import (
Expand Down Expand Up @@ -60,15 +60,15 @@ def where(
field: Optional[str] = None,
op: Optional[str] = None,
value: Optional[Any] = None,
filter: Optional[FieldFilter] = None,
filter: Union[FieldFilter, And, Or, None] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not critical
May we use OptionalUnion[FieldFilter, And, Or] ?
As None is not a type. and using Union[type, None] is outdated

Copy link
Contributor Author

@CeeEffEff CeeEffEff Jan 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional is just shorthand for Union[x, None].
In this case it wouldn't make things much shorter as you end up with Optional[Union[x]].
It's not deprecated either to use Union

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It wouldn't
Thats why it isnt critical.
But even python's creators decided that it is bad design :)

) -> Query:
if filter is not None:
field, op, value = filter.field_path, filter.op_string, filter.value
if field is None or op is None or value is None:
if filter is None and (field is None or op is None or value is None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

filter is None and not any([field, op, value])?
Or we need i direct check is None ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes as we cannot rely on the truthiness of value.
We are explicitly checking for if the params have not been provided.

raise ValueError(
"field, op, and value must be provided (or a FieldFilter instance)"
)
query = Query(self, field_filters=[(field, op, value)])
if filter is None:
filter = FieldFilter(field, op, value)
query = Query(self, field_filter=filter)
return query

def order_by(self, key: str, direction: Optional[str] = None) -> Query:
Expand Down
52 changes: 31 additions & 21 deletions mockfirestore/query.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from itertools import islice, tee
from typing import Iterator, Any, Optional, List, Callable, Union
from google.cloud.firestore_v1.base_query import FieldFilter
from typing import Iterator, Any, Optional, Callable, Union
from google.cloud.firestore_v1.base_query import FieldFilter, And, Or
from mockfirestore.document import DocumentSnapshot
from mockfirestore._helpers import T

Expand All @@ -21,7 +21,7 @@ def __init__(
self,
parent: "CollectionReference",
projection=None,
field_filters=(),
field_filter: Union[FieldFilter, And, Or, None] = None,
orders=(),
limit=None,
offset=None,
Expand All @@ -31,28 +31,42 @@ def __init__(
) -> None:
self.parent = parent
self.projection = projection
self._field_filters = []
self._field_filter = field_filter
self.orders = list(orders)
self._limit = limit
self._offset = offset
self._start_at = start_at
self._end_at = end_at
self.all_descendants = all_descendants

if field_filters:
for field_filter in field_filters:
self._add_field_filter(*field_filter)
def _query_filter(self, doc_snapshots, field_filter: Union[FieldFilter, And, Or, None]):
if field_filter is None:
return doc_snapshots

def stream(self, transaction=None) -> Iterator[DocumentSnapshot]:
doc_snapshots = self.parent.stream()

for field, compare, value in self._field_filters:
doc_snapshots = [
if isinstance(field_filter, FieldFilter):
compare = self._compare_func(field_filter.op_string)
return [
doc_snapshot
for doc_snapshot in doc_snapshots
if compare(doc_snapshot._get_by_field_path(field), value)
if compare(doc_snapshot._get_by_field_path(field_filter.field_path), field_filter.value)
]

if isinstance(field_filter, And):
for and_filter in field_filter.filters:
doc_snapshots = self._query_filter(doc_snapshots, and_filter)
return doc_snapshots

if isinstance(field_filter, Or):
# Collect results for each filter in the OR condition
or_results = set()
for or_filter in field_filter.filters:
or_results.update(self._query_filter(doc_snapshots, or_filter))
return [doc_snapshot for doc_snapshot in doc_snapshots if doc_snapshot in or_results]

def stream(self, transaction=None) -> Iterator[DocumentSnapshot]:
doc_snapshots = self.parent.stream()
doc_snapshots = self._query_filter(doc_snapshots, self._field_filter)

if self.orders:
for key, direction in self.orders:
doc_snapshots = sorted(
Expand Down Expand Up @@ -87,24 +101,20 @@ def get(self) -> Iterator[DocumentSnapshot]:
)
return self.stream()

def _add_field_filter(self, field: str, op: str, value: Any):
compare = self._compare_func(op)
self._field_filters.append((field, compare, value))

def where(
self,
field: Optional[str] = None,
op: Optional[str] = None,
value: Optional[Any] = None,
filter: Optional[FieldFilter] = None,
) -> "Query":
if filter is not None:
field, op, value = filter.field_path, filter.op_string, filter.value
if field is None or op is None or value is None:
if filter is None and (field is None or op is None or value is None):
raise ValueError(
"field, op, and value must be provided (or a FieldFilter instance)"
)
self._add_field_filter(field, op, value)
if filter is None:
filter = FieldFilter(field, op, value)
self._field_filter = filter
return self

def order_by(self, key: str, direction: Optional[str] = "ASCENDING") -> "Query":
Expand Down
28 changes: 28 additions & 0 deletions tests/test_collection_reference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from unittest import TestCase

from google.cloud.firestore_v1.base_query import And, FieldFilter, Or

from mockfirestore import MockFirestore, DocumentReference, DocumentSnapshot, AlreadyExists


Expand Down Expand Up @@ -81,6 +83,32 @@ def test_collection_whereEquals(self):
docs = list(fs.collection('foo').where('valid', '==', True).stream())
self.assertEqual({'valid': True}, docs[0].to_dict())

def test_collection_whereEquals_And_Nested_Or(self):
fs = MockFirestore()
fs._data = {'foo': {
'first': {'valid': True, 'name': 'A wonderful test'},
'second': {'valid': False, 'name': 'A test'},
'third': {'valid': True, 'name': 'A different test'},
'fourth': {'valid': False, 'name': 'Another test'},
'fifth': {'valid': True, 'name': 'A wow test'},
}}
filter = And((
FieldFilter('valid', '==', True),
Or((
FieldFilter('name', '==', 'A wonderful test'),
FieldFilter('name', '==', 'A wow test'),
))
))
docs = list(fs.collection('foo').where(filter=filter).stream())
self.assertEqual(len(docs), 2)
names = set()
for doc in docs:
doc_dict = doc.to_dict()
self.assertTrue(doc_dict['valid'])
names.add(doc_dict['name'])
self.assertIn('A wonderful test', names)
self.assertIn('A wow test', names)

def test_collection_whereNotEquals(self):
fs = MockFirestore()
fs._data = {'foo': {
Expand Down
Loading