Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 98 additions & 18 deletions key-value/key-value-aio/src/key_value/aio/stores/mongodb/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from datetime import datetime
from typing import Any, overload

from key_value.shared.utils.managed_entry import ManagedEntry
from key_value.shared.errors import DeserializationError
from key_value.shared.utils.managed_entry import ManagedEntry, verify_dict
from key_value.shared.utils.sanitize import ALPHANUMERIC_CHARACTERS, sanitize_string
from key_value.shared.utils.time_to_live import timezone
from typing_extensions import Self, override

from key_value.aio.stores.base import BaseContextManagerStore, BaseDestroyCollectionStore, BaseEnumerateCollectionsStore, BaseStore
Expand Down Expand Up @@ -37,44 +39,100 @@ def document_to_managed_entry(document: dict[str, Any]) -> ManagedEntry:
"""Convert a MongoDB document back to a ManagedEntry.

This function deserializes a MongoDB document (created by `managed_entry_to_document`) back to a
ManagedEntry object, parsing the stringified value field and preserving all metadata.
ManagedEntry object. It supports both native BSON storage (dict in value.object field) and legacy
JSON string storage (string in value.string field) for migration support.

Args:
document: The MongoDB document to convert.

Returns:
A ManagedEntry object reconstructed from the document.
"""
return ManagedEntry.from_dict(data=document, stringified_value=True)
if not (value_field := document.get("value")):
msg = "Value field not found"
raise DeserializationError(msg)

if not isinstance(value_field, dict):
msg = "Expected `value` field to be an object"
raise DeserializationError(msg)

def managed_entry_to_document(key: str, managed_entry: ManagedEntry) -> dict[str, Any]:
value_holder: dict[str, Any] = verify_dict(obj=value_field)

data: dict[str, Any] = {}

# The Value field is an object with two possible fields: `object` and `string`
# - `object`: The value is a native BSON dict
# - `string`: The value is a JSON string
# Mongo stores datetimes without timezones as UTC so we mark them as UTC

if created_at_datetime := document.get("created_at"):
if not isinstance(created_at_datetime, datetime):
msg = "Expected `created_at` field to be a datetime"
raise DeserializationError(msg)
data["created_at"] = created_at_datetime.replace(tzinfo=timezone.utc)

if expires_at_datetime := document.get("expires_at"):
if not isinstance(expires_at_datetime, datetime):
msg = "Expected `expires_at` field to be a datetime"
raise DeserializationError(msg)
data["expires_at"] = expires_at_datetime.replace(tzinfo=timezone.utc)

if value_object := value_holder.get("object"):
return ManagedEntry.from_dict(data={"value": value_object, **data})

if value_string := value_holder.get("string"):
return ManagedEntry.from_dict(data={"value": value_string, **data}, stringified_value=True)

msg = "Expected `value` field to be an object with `object` or `string` subfield"
raise DeserializationError(msg)


def managed_entry_to_document(key: str, managed_entry: ManagedEntry, *, native_storage: bool = True) -> dict[str, Any]:
"""Convert a ManagedEntry to a MongoDB document for storage.

This function serializes a ManagedEntry to a MongoDB document format, including the key and all
metadata (TTL, creation, and expiration timestamps). The value is stringified to ensure proper
storage in MongoDB. The serialization is designed to preserve all entry information for round-trip
conversion back to a ManagedEntry.
metadata (TTL, creation, and expiration timestamps). The value storage format depends on the
native_storage parameter.

Args:
key: The key associated with this entry.
managed_entry: The ManagedEntry to serialize.
native_storage: If True (default), store value as native BSON dict in value.object field.
If False, store as JSON string in value.string field for backward compatibility.

Returns:
A MongoDB document dict containing the key, value, and all metadata.
"""
return {
"key": key,
**managed_entry.to_dict(include_metadata=True, include_expiration=True, include_creation=True, stringify_value=True),
}
document: dict[str, Any] = {"key": key, "value": {}}

# We convert to JSON even if we don't need to, this ensures that the value we were provided
# can be serialized to JSON which helps ensure compatibility across stores. For example,
# Mongo can natively handle datetime objects which other stores cannot, if we don't convert to JSON,
# then using py-key-value with Mongo will return different values than if we used another store.
json_str = managed_entry.value_as_json

# Store in appropriate field based on mode
if native_storage:
document["value"]["object"] = managed_entry.value_as_dict
else:
document["value"]["string"] = json_str

# Add metadata fields
if managed_entry.created_at:
document["created_at"] = managed_entry.created_at
if managed_entry.expires_at:
document["expires_at"] = managed_entry.expires_at

return document


class MongoDBStore(BaseEnumerateCollectionsStore, BaseDestroyCollectionStore, BaseContextManagerStore, BaseStore):
"""MongoDB-based key-value store using Motor (async MongoDB driver)."""
"""MongoDB-based key-value store using pymongo."""

_client: AsyncMongoClient[dict[str, Any]]
_db: AsyncDatabase[dict[str, Any]]
_collections_by_name: dict[str, AsyncCollection[dict[str, Any]]]
_native_storage: bool

@overload
def __init__(
Expand All @@ -83,6 +141,7 @@ def __init__(
client: AsyncMongoClient[dict[str, Any]],
db_name: str | None = None,
coll_name: str | None = None,
native_storage: bool = True,
default_collection: str | None = None,
) -> None:
"""Initialize the MongoDB store.
Expand All @@ -91,19 +150,27 @@ def __init__(
client: The MongoDB client to use.
db_name: The name of the MongoDB database.
coll_name: The name of the MongoDB collection.
native_storage: Whether to use native BSON storage (True, default) or JSON string storage (False).
default_collection: The default collection to use if no collection is provided.
"""

@overload
def __init__(
self, *, url: str, db_name: str | None = None, coll_name: str | None = None, default_collection: str | None = None
self,
*,
url: str,
db_name: str | None = None,
coll_name: str | None = None,
native_storage: bool = True,
default_collection: str | None = None,
) -> None:
"""Initialize the MongoDB store.

Args:
url: The url of the MongoDB cluster.
db_name: The name of the MongoDB database.
coll_name: The name of the MongoDB collection.
native_storage: Whether to use native BSON storage (True, default) or JSON string storage (False).
default_collection: The default collection to use if no collection is provided.
"""

Expand All @@ -114,9 +181,21 @@ def __init__(
url: str | None = None,
db_name: str | None = None,
coll_name: str | None = None,
native_storage: bool = True,
default_collection: str | None = None,
) -> None:
"""Initialize the MongoDB store."""
"""Initialize the MongoDB store.

Args:
client: The MongoDB client to use (mutually exclusive with url).
url: The url of the MongoDB cluster (mutually exclusive with client).
db_name: The name of the MongoDB database.
coll_name: The name of the MongoDB collection.
native_storage: Whether to use native BSON storage (True, default) or JSON string storage (False).
Native storage stores values as BSON dicts for better query support.
Legacy mode stores values as JSON strings for backward compatibility.
default_collection: The default collection to use if no collection is provided.
"""

if client:
self._client = client
Expand All @@ -131,6 +210,7 @@ def __init__(

self._db = self._client[db_name]
self._collections_by_name = {}
self._native_storage = native_storage

super().__init__(default_collection=default_collection)

Expand Down Expand Up @@ -158,7 +238,7 @@ def _sanitize_collection_name(self, collection: str) -> str:
Returns:
A sanitized collection name that meets MongoDB requirements.
"""
return sanitize_string(value=collection, max_length=MAX_COLLECTION_LENGTH, allowed_characters=ALPHANUMERIC_CHARACTERS)
return sanitize_string(value=collection, max_length=MAX_COLLECTION_LENGTH, allowed_characters=COLLECTION_ALLOWED_CHARACTERS)

@override
async def _setup_collection(self, *, collection: str) -> None:
Expand Down Expand Up @@ -187,7 +267,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
sanitized_collection = self._sanitize_collection_name(collection=collection)

if doc := await self._collections_by_name[sanitized_collection].find_one(filter={"key": key}):
return ManagedEntry.from_dict(data=doc, stringified_value=True)
return document_to_managed_entry(document=doc)

return None

Expand Down Expand Up @@ -217,7 +297,7 @@ async def _put_managed_entry(
collection: str,
managed_entry: ManagedEntry,
) -> None:
mongo_doc: dict[str, Any] = managed_entry_to_document(key=key, managed_entry=managed_entry)
mongo_doc: dict[str, Any] = managed_entry_to_document(key=key, managed_entry=managed_entry, native_storage=self._native_storage)

sanitized_collection = self._sanitize_collection_name(collection=collection)

Expand Down Expand Up @@ -248,7 +328,7 @@ async def _put_managed_entries(

operations: list[UpdateOne] = []
for key, managed_entry in zip(keys, managed_entries, strict=True):
mongo_doc: dict[str, Any] = managed_entry_to_document(key=key, managed_entry=managed_entry)
mongo_doc: dict[str, Any] = managed_entry_to_document(key=key, managed_entry=managed_entry, native_storage=self._native_storage)

operations.append(
UpdateOne(
Expand Down
1 change: 1 addition & 0 deletions key-value/key-value-aio/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def docker_container(
finally:
docker_stop(name, raise_on_error=False)
docker_rm(name, raise_on_error=False)
docker_wait_container_gone(name=name, max_tries=10, wait_time=1.0)

logger.info(f"Container {name} stopped and removed")
return
Expand Down
Loading
Loading