Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Full support for dotted key names in distinct #250

Merged
merged 3 commits into from
Mar 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions astrapy/idiomatic/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,22 @@ def distinct(

Args:
key: the name of the field whose value is inspected across documents.
Keys can use dot-notation to descend to deeper document levels.
Example of acceptable `key` values:
"field"
"field.subfield"
"field.3"
"field.3.subfield"
if lists are encountered and no numeric index is specified,
all items in the list are visited.
Keys can use dot-notation to descend to deeper document levels.
Example of acceptable `key` values:
"field"
"field.subfield"
"field.3"
"field.3.subfield"
if lists are encountered and no numeric index is specified,
all items in the list are visited.
filter: a predicate expressed as a dictionary according to the
Data API filter syntax. Examples are:
{}
Expand Down Expand Up @@ -1737,6 +1753,14 @@ async def distinct(

Args:
key: the name of the field whose value is inspected across documents.
Keys can use dot-notation to descend to deeper document levels.
Example of acceptable `key` values:
"field"
"field.subfield"
"field.3"
"field.3.subfield"
if lists are encountered and no numeric index is specified,
all items in the list are visited.
filter: a predicate expressed as a dictionary according to the
Data API filter syntax. Examples are:
{}
Expand Down
153 changes: 142 additions & 11 deletions astrapy/idiomatic/cursors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,24 @@

from __future__ import annotations

import hashlib
import json
from collections.abc import Iterator, AsyncIterator
from typing import (
Any,
Callable,
Dict,
Generic,
Iterable,
List,
Optional,
Tuple,
TypeVar,
Union,
TYPE_CHECKING,
)

from astrapy.utils import _normalize_payload_value
from astrapy.idiomatic.types import (
DocumentType,
ProjectionType,
Expand All @@ -38,10 +44,88 @@

BC = TypeVar("BC", bound="BaseCursor")
T = TypeVar("T")
IndexPairType = Tuple[str, Optional[int]]

FIND_PREFETCH = 20


def _maybe_valid_list_index(key_block: str) -> Optional[int]:
# '0', '1' is good. '00', '01', '-30' are not.
try:
kb_index = int(key_block)
if kb_index >= 0 and key_block == str(kb_index):
return kb_index
else:
return None
except ValueError:
return None


def _create_document_key_extractor(
key: str,
) -> Callable[[Dict[str, Any]], Iterable[Any]]:

key_blocks0: List[IndexPairType] = [
(kb_str, _maybe_valid_list_index(kb_str)) for kb_str in key.split(".")
]
if key_blocks0 == []:
raise ValueError("Field path specification cannot be empty")
if any(kb[0] == "" for kb in key_blocks0):
raise ValueError("Field path components cannot be empty")

def _extract_with_key_blocks(
key_blocks: List[IndexPairType], value: Any
) -> Iterable[Any]:
if key_blocks == []:
if isinstance(value, list):
for item in value:
yield item
else:
yield value
return
else:
# go deeper as requested
rest_key_blocks = key_blocks[1:]
key_block = key_blocks[0]
k_str, k_int = key_block
if isinstance(value, dict):
if k_str in value:
new_value = value[k_str]
for item in _extract_with_key_blocks(rest_key_blocks, new_value):
yield item
return
elif isinstance(value, list):
if k_int is not None and len(value) > k_int:
new_value = value[k_int]
for item in _extract_with_key_blocks(rest_key_blocks, new_value):
yield item
else:
for item in value:
for item in _extract_with_key_blocks(key_blocks, item):
yield item
return
else:
# keyblocks are deeper than the document. Nothing to extract.
return

def _item_extractor(document: Dict[str, Any]) -> Iterable[Any]:
return _extract_with_key_blocks(key_blocks=key_blocks0, value=document)

return _item_extractor


def _reduce_distinct_key_to_safe(distinct_key: str) -> str:
"""
In light of the twofold interpretation of "0" as index and dict key
in selection (for distinct), and the auto-unroll of lists, it is not
safe to go beyond the first level. See this example:
document = {'x': [{'y': 'Y', '0': 'ZERO'}]}
key = "x.0"
With full key as projection, we would lose the `"y": "Y"` part (mistakenly).
"""
return distinct_key.split(".")[0]


class BaseCursor:
"""
Represents a generic Cursor over query results, regardless of whether
Expand Down Expand Up @@ -119,6 +203,7 @@ def _ensure_not_started(self) -> None:
def _copy(
self: BC,
*,
projection: Optional[ProjectionType] = None,
limit: Optional[int] = None,
skip: Optional[int] = None,
started: Optional[bool] = None,
Expand All @@ -127,7 +212,7 @@ def _copy(
new_cursor = self.__class__(
collection=self._collection,
filter=self._filter,
projection=self._projection,
projection=projection or self._projection,
)
# Cursor treated as mutable within this function scope:
new_cursor._limit = limit if limit is not None else self._limit
Expand Down Expand Up @@ -363,6 +448,17 @@ def distinct(self, key: str) -> List[Any]:
Invoking this method has no effect on the cursor state, i.e.
the position of the cursor is unchanged.

Args:
key: the name of the field whose value is inspected across documents.
Keys can use dot-notation to descend to deeper document levels.
Example of acceptable `key` values:
"field"
"field.subfield"
"field.3"
"field.3.subfield"
if lists are encountered and no numeric index is specified,
all items in the list are visited.

Note:
this operation works at client-side by scrolling through all
documents matching the cursor parameters (such as `filter`).
Expand All @@ -371,9 +467,23 @@ def distinct(self, key: str) -> List[Any]:
network traffic and possibly billing.
"""

return list(
{document[key] for document in self._copy(started=False) if key in document}
)
_item_hashes = set()
distinct_items = []

_extractor = _create_document_key_extractor(key)
_key = _reduce_distinct_key_to_safe(key)

d_cursor = self._copy(projection={_key: True}, started=False)
for document in d_cursor:
for item in _extractor(document):
_normalized_item = _normalize_payload_value(path=[], value=item)
_normalized_json = json.dumps(_normalized_item, separators=(",", ":"))
_item_hash = hashlib.md5(_normalized_json.encode()).hexdigest()
if _item_hash not in _item_hashes:
_item_hashes.add(_item_hash)
distinct_items.append(item)

return distinct_items


class AsyncCursor(BaseCursor):
Expand Down Expand Up @@ -507,6 +617,17 @@ async def distinct(self, key: str) -> List[Any]:
Invoking this method has no effect on the cursor state, i.e.
the position of the cursor is unchanged.

Args:
key: the name of the field whose value is inspected across documents.
Keys can use dot-notation to descend to deeper document levels.
Example of acceptable `key` values:
"field"
"field.subfield"
"field.3"
"field.3.subfield"
if lists are encountered and no numeric index is specified,
all items in the list are visited.

Note:
this operation works at client-side by scrolling through all
documents matching the cursor parameters (such as `filter`).
Expand All @@ -515,13 +636,23 @@ async def distinct(self, key: str) -> List[Any]:
network traffic and possibly billing.
"""

return list(
{
document[key]
async for document in self._copy(started=False)
if key in document
}
)
_item_hashes = set()
distinct_items = []

_extractor = _create_document_key_extractor(key)
_key = _reduce_distinct_key_to_safe(key)

d_cursor = self._copy(projection={_key: True}, started=False)
async for document in d_cursor:
for item in _extractor(document):
_normalized_item = _normalize_payload_value(path=[], value=item)
_normalized_json = json.dumps(_normalized_item, separators=(",", ":"))
_item_hash = hashlib.md5(_normalized_json.encode()).hexdigest()
if _item_hash not in _item_hashes:
_item_hashes.add(_item_hash)
distinct_items.append(item)

return distinct_items


class CommandCursor(Generic[T]):
Expand Down
72 changes: 71 additions & 1 deletion tests/idiomatic/integration/test_dml_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List
import datetime

from typing import Any, Dict, List

import pytest

Expand Down Expand Up @@ -438,6 +440,74 @@ async def _alist(acursor: AsyncCursor) -> List[DocumentType]:
cursor7.rewind()
cursor7["wrong"]

@pytest.mark.describe("test of distinct with non-hashable items, async")
async def test_collection_distinct_nonhashable_async(
self,
async_empty_collection: AsyncCollection,
) -> None:
acol = async_empty_collection
documents: List[Dict[str, Any]] = [
{},
{"f": 1},
{"f": "a"},
{"f": {"subf": 99}},
{"f": {"subf": 99, "another": {"subsubf": [True, False]}}},
{"f": [10, 11]},
{"f": [11, 10]},
{"f": [10]},
{"f": datetime.datetime(2000, 1, 1, 12, 00, 00)},
{"f": None},
]
await acol.insert_many(documents * 2)

d_items = await acol.distinct("f")
expected = [
1,
"a",
{"subf": 99},
{"subf": 99, "another": {"subsubf": [True, False]}},
10,
11,
datetime.datetime(2000, 1, 1, 12, 0),
None,
]
assert len(d_items) == len(expected)
for doc in documents:
if "f" in doc:
if isinstance(doc["f"], list):
for item in doc["f"]:
assert item in d_items
else:
assert doc["f"] in d_items

@pytest.mark.describe("test of usage of projection in distinct, async")
async def test_collection_projections_distinct_async(
self,
async_empty_collection: AsyncCollection,
) -> None:
acol = async_empty_collection
await acol.insert_one({"x": [{"y": "Y", "0": "ZERO"}]})

assert await acol.distinct("x.y") == ["Y"]
# the one below shows that if index-in-list, then browse-whole-list is off
assert await acol.distinct("x.0") == [{"y": "Y", "0": "ZERO"}]
assert await acol.distinct("x.0.y") == ["Y"]
assert await acol.distinct("x.0.0") == ["ZERO"]

@pytest.mark.describe("test of unacceptable paths for distinct, async")
async def test_collection_wrong_paths_distinc_async(
self,
async_empty_collection: AsyncCollection,
) -> None:
with pytest.raises(ValueError):
await async_empty_collection.distinct("root.1..subf")
with pytest.raises(ValueError):
await async_empty_collection.distinct("root..1.subf")
with pytest.raises(ValueError):
await async_empty_collection.distinct("root..subf.subsubf")
with pytest.raises(ValueError):
await async_empty_collection.distinct("root.subf..subsubf")

@pytest.mark.describe("test of collection insert_many, async")
async def test_collection_insert_many_async(
self,
Expand Down
Loading
Loading