diff --git a/key-value/key-value-aio/src/key_value/aio/stores/mongodb/store.py b/key-value/key-value-aio/src/key_value/aio/stores/mongodb/store.py index 3cfde94a..6db89992 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/mongodb/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/mongodb/store.py @@ -2,8 +2,10 @@ from datetime import datetime from typing import Any, overload -from key_value.shared.utils.managed_entry import ManagedEntry +from key_value.shared.errors import DeserializationError +from key_value.shared.utils.managed_entry import ManagedEntry, verify_dict from key_value.shared.utils.sanitize import ALPHANUMERIC_CHARACTERS, sanitize_string +from key_value.shared.utils.time_to_live import timezone from typing_extensions import Self, override from key_value.aio.stores.base import BaseContextManagerStore, BaseDestroyCollectionStore, BaseEnumerateCollectionsStore, BaseStore @@ -37,7 +39,8 @@ def document_to_managed_entry(document: dict[str, Any]) -> ManagedEntry: """Convert a MongoDB document back to a ManagedEntry. This function deserializes a MongoDB document (created by `managed_entry_to_document`) back to a - ManagedEntry object, parsing the stringified value field and preserving all metadata. + ManagedEntry object. It supports both native BSON storage (dict in value.object field) and legacy + JSON string storage (string in value.string field) for migration support. Args: document: The MongoDB document to convert. @@ -45,36 +48,91 @@ def document_to_managed_entry(document: dict[str, Any]) -> ManagedEntry: Returns: A ManagedEntry object reconstructed from the document. """ - return ManagedEntry.from_dict(data=document, stringified_value=True) + if not (value_field := document.get("value")): + msg = "Value field not found" + raise DeserializationError(msg) + if not isinstance(value_field, dict): + msg = "Expected `value` field to be an object" + raise DeserializationError(msg) -def managed_entry_to_document(key: str, managed_entry: ManagedEntry) -> dict[str, Any]: + value_holder: dict[str, Any] = verify_dict(obj=value_field) + + data: dict[str, Any] = {} + + # The Value field is an object with two possible fields: `object` and `string` + # - `object`: The value is a native BSON dict + # - `string`: The value is a JSON string + # Mongo stores datetimes without timezones as UTC so we mark them as UTC + + if created_at_datetime := document.get("created_at"): + if not isinstance(created_at_datetime, datetime): + msg = "Expected `created_at` field to be a datetime" + raise DeserializationError(msg) + data["created_at"] = created_at_datetime.replace(tzinfo=timezone.utc) + + if expires_at_datetime := document.get("expires_at"): + if not isinstance(expires_at_datetime, datetime): + msg = "Expected `expires_at` field to be a datetime" + raise DeserializationError(msg) + data["expires_at"] = expires_at_datetime.replace(tzinfo=timezone.utc) + + if value_object := value_holder.get("object"): + return ManagedEntry.from_dict(data={"value": value_object, **data}) + + if value_string := value_holder.get("string"): + return ManagedEntry.from_dict(data={"value": value_string, **data}, stringified_value=True) + + msg = "Expected `value` field to be an object with `object` or `string` subfield" + raise DeserializationError(msg) + + +def managed_entry_to_document(key: str, managed_entry: ManagedEntry, *, native_storage: bool = True) -> dict[str, Any]: """Convert a ManagedEntry to a MongoDB document for storage. This function serializes a ManagedEntry to a MongoDB document format, including the key and all - metadata (TTL, creation, and expiration timestamps). The value is stringified to ensure proper - storage in MongoDB. The serialization is designed to preserve all entry information for round-trip - conversion back to a ManagedEntry. + metadata (TTL, creation, and expiration timestamps). The value storage format depends on the + native_storage parameter. Args: key: The key associated with this entry. managed_entry: The ManagedEntry to serialize. + native_storage: If True (default), store value as native BSON dict in value.object field. + If False, store as JSON string in value.string field for backward compatibility. Returns: A MongoDB document dict containing the key, value, and all metadata. """ - return { - "key": key, - **managed_entry.to_dict(include_metadata=True, include_expiration=True, include_creation=True, stringify_value=True), - } + document: dict[str, Any] = {"key": key, "value": {}} + + # We convert to JSON even if we don't need to, this ensures that the value we were provided + # can be serialized to JSON which helps ensure compatibility across stores. For example, + # Mongo can natively handle datetime objects which other stores cannot, if we don't convert to JSON, + # then using py-key-value with Mongo will return different values than if we used another store. + json_str = managed_entry.value_as_json + + # Store in appropriate field based on mode + if native_storage: + document["value"]["object"] = managed_entry.value_as_dict + else: + document["value"]["string"] = json_str + + # Add metadata fields + if managed_entry.created_at: + document["created_at"] = managed_entry.created_at + if managed_entry.expires_at: + document["expires_at"] = managed_entry.expires_at + + return document class MongoDBStore(BaseEnumerateCollectionsStore, BaseDestroyCollectionStore, BaseContextManagerStore, BaseStore): - """MongoDB-based key-value store using Motor (async MongoDB driver).""" + """MongoDB-based key-value store using pymongo.""" _client: AsyncMongoClient[dict[str, Any]] _db: AsyncDatabase[dict[str, Any]] _collections_by_name: dict[str, AsyncCollection[dict[str, Any]]] + _native_storage: bool @overload def __init__( @@ -83,6 +141,7 @@ def __init__( client: AsyncMongoClient[dict[str, Any]], db_name: str | None = None, coll_name: str | None = None, + native_storage: bool = True, default_collection: str | None = None, ) -> None: """Initialize the MongoDB store. @@ -91,12 +150,19 @@ def __init__( client: The MongoDB client to use. db_name: The name of the MongoDB database. coll_name: The name of the MongoDB collection. + native_storage: Whether to use native BSON storage (True, default) or JSON string storage (False). default_collection: The default collection to use if no collection is provided. """ @overload def __init__( - self, *, url: str, db_name: str | None = None, coll_name: str | None = None, default_collection: str | None = None + self, + *, + url: str, + db_name: str | None = None, + coll_name: str | None = None, + native_storage: bool = True, + default_collection: str | None = None, ) -> None: """Initialize the MongoDB store. @@ -104,6 +170,7 @@ def __init__( url: The url of the MongoDB cluster. db_name: The name of the MongoDB database. coll_name: The name of the MongoDB collection. + native_storage: Whether to use native BSON storage (True, default) or JSON string storage (False). default_collection: The default collection to use if no collection is provided. """ @@ -114,9 +181,21 @@ def __init__( url: str | None = None, db_name: str | None = None, coll_name: str | None = None, + native_storage: bool = True, default_collection: str | None = None, ) -> None: - """Initialize the MongoDB store.""" + """Initialize the MongoDB store. + + Args: + client: The MongoDB client to use (mutually exclusive with url). + url: The url of the MongoDB cluster (mutually exclusive with client). + db_name: The name of the MongoDB database. + coll_name: The name of the MongoDB collection. + native_storage: Whether to use native BSON storage (True, default) or JSON string storage (False). + Native storage stores values as BSON dicts for better query support. + Legacy mode stores values as JSON strings for backward compatibility. + default_collection: The default collection to use if no collection is provided. + """ if client: self._client = client @@ -131,6 +210,7 @@ def __init__( self._db = self._client[db_name] self._collections_by_name = {} + self._native_storage = native_storage super().__init__(default_collection=default_collection) @@ -158,7 +238,7 @@ def _sanitize_collection_name(self, collection: str) -> str: Returns: A sanitized collection name that meets MongoDB requirements. """ - return sanitize_string(value=collection, max_length=MAX_COLLECTION_LENGTH, allowed_characters=ALPHANUMERIC_CHARACTERS) + return sanitize_string(value=collection, max_length=MAX_COLLECTION_LENGTH, allowed_characters=COLLECTION_ALLOWED_CHARACTERS) @override async def _setup_collection(self, *, collection: str) -> None: @@ -187,7 +267,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry sanitized_collection = self._sanitize_collection_name(collection=collection) if doc := await self._collections_by_name[sanitized_collection].find_one(filter={"key": key}): - return ManagedEntry.from_dict(data=doc, stringified_value=True) + return document_to_managed_entry(document=doc) return None @@ -217,7 +297,7 @@ async def _put_managed_entry( collection: str, managed_entry: ManagedEntry, ) -> None: - mongo_doc: dict[str, Any] = managed_entry_to_document(key=key, managed_entry=managed_entry) + mongo_doc: dict[str, Any] = managed_entry_to_document(key=key, managed_entry=managed_entry, native_storage=self._native_storage) sanitized_collection = self._sanitize_collection_name(collection=collection) @@ -248,7 +328,7 @@ async def _put_managed_entries( operations: list[UpdateOne] = [] for key, managed_entry in zip(keys, managed_entries, strict=True): - mongo_doc: dict[str, Any] = managed_entry_to_document(key=key, managed_entry=managed_entry) + mongo_doc: dict[str, Any] = managed_entry_to_document(key=key, managed_entry=managed_entry, native_storage=self._native_storage) operations.append( UpdateOne( diff --git a/key-value/key-value-aio/tests/conftest.py b/key-value/key-value-aio/tests/conftest.py index 5be4d78c..bf9abb00 100644 --- a/key-value/key-value-aio/tests/conftest.py +++ b/key-value/key-value-aio/tests/conftest.py @@ -168,6 +168,7 @@ def docker_container( finally: docker_stop(name, raise_on_error=False) docker_rm(name, raise_on_error=False) + docker_wait_container_gone(name=name, max_tries=10, wait_time=1.0) logger.info(f"Container {name} stopped and removed") return diff --git a/key-value/key-value-aio/tests/stores/mongodb/test_mongodb.py b/key-value/key-value-aio/tests/stores/mongodb/test_mongodb.py index d47503a1..a638a152 100644 --- a/key-value/key-value-aio/tests/stores/mongodb/test_mongodb.py +++ b/key-value/key-value-aio/tests/stores/mongodb/test_mongodb.py @@ -4,7 +4,8 @@ from typing import Any import pytest -from dirty_equals import IsFloat +from bson import ObjectId +from dirty_equals import IsDatetime, IsFloat, IsInstance from inline_snapshot import snapshot from key_value.shared.stores.wait import async_wait_for_true from key_value.shared.utils.managed_entry import ManagedEntry @@ -44,19 +45,19 @@ class MongoDBFailedToStartError(Exception): pass -def test_managed_entry_document_conversion(): +def test_managed_entry_document_conversion_native_mode(): created_at = datetime(year=2025, month=1, day=1, hour=0, minute=0, second=0, tzinfo=timezone.utc) expires_at = created_at + timedelta(seconds=10) managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at) - document = managed_entry_to_document(key="test", managed_entry=managed_entry) + document = managed_entry_to_document(key="test", managed_entry=managed_entry, native_storage=True) assert document == snapshot( { "key": "test", - "value": '{"test": "test"}', - "created_at": "2025-01-01T00:00:00+00:00", - "expires_at": "2025-01-01T00:00:10+00:00", + "value": {"object": {"test": "test"}}, + "created_at": datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc), + "expires_at": datetime(2025, 1, 1, 0, 0, 10, tzinfo=timezone.utc), } ) @@ -68,8 +69,38 @@ def test_managed_entry_document_conversion(): assert round_trip_managed_entry.expires_at == expires_at -@pytest.mark.skipif(should_skip_docker_tests(), reason="Docker is not available") -class TestMongoDBStore(ContextManagerStoreTestMixin, BaseStoreTests): +def test_managed_entry_document_conversion_legacy_mode(): + created_at = datetime(year=2025, month=1, day=1, hour=0, minute=0, second=0, tzinfo=timezone.utc) + expires_at = created_at + timedelta(seconds=10) + + managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at) + document = managed_entry_to_document(key="test", managed_entry=managed_entry, native_storage=False) + + assert document == snapshot( + { + "key": "test", + "value": {"string": '{"test": "test"}'}, + "created_at": datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc), + "expires_at": datetime(2025, 1, 1, 0, 0, 10, tzinfo=timezone.utc), + } + ) + + round_trip_managed_entry = document_to_managed_entry(document=document) + + assert round_trip_managed_entry.value == managed_entry.value + assert round_trip_managed_entry.created_at == created_at + assert round_trip_managed_entry.ttl == IsFloat(lt=0) + assert round_trip_managed_entry.expires_at == expires_at + + +async def clean_mongodb_database(store: MongoDBStore) -> None: + with contextlib.suppress(Exception): + _ = await store._client.drop_database(name_or_database=store._db.name) # pyright: ignore[reportPrivateUsage] + + +class BaseMongoDBStoreTests(ContextManagerStoreTestMixin, BaseStoreTests): + """Base class for MongoDB store tests.""" + @pytest.fixture(autouse=True, scope="session", params=MONGODB_VERSIONS_TO_TEST) async def setup_mongodb(self, request: pytest.FixtureRequest) -> AsyncGenerator[None, None]: version = request.param @@ -81,28 +112,112 @@ async def setup_mongodb(self, request: pytest.FixtureRequest) -> AsyncGenerator[ yield + @pytest.mark.skip(reason="Distributed Caches are unbounded") + @override + async def test_not_unbounded(self, store: BaseStore): ... + + async def test_mongodb_collection_name_sanitization(self, store: MongoDBStore): + """Tests that a special characters in the collection name will not raise an error.""" + await store.put(collection="test_collection!@#$%^&*()", key="test_key", value={"test": "test"}) + assert await store.get(collection="test_collection!@#$%^&*()", key="test_key") == {"test": "test"} + + collections = await store.collections() + assert collections == snapshot(["test_collection_-daf4a2ec"]) + + +@pytest.mark.skipif(should_skip_docker_tests(), reason="Docker is not available") +class TestMongoDBStoreNativeMode(BaseMongoDBStoreTests): + """Test MongoDBStore with native_storage=True (default).""" + @override @pytest.fixture async def store(self, setup_mongodb: None) -> MongoDBStore: - store = MongoDBStore(url=f"mongodb://{MONGODB_HOST}:{MONGODB_HOST_PORT}", db_name=MONGODB_TEST_DB) - # Ensure a clean db by dropping our default test collection if it exists - with contextlib.suppress(Exception): - _ = await store._client.drop_database(name_or_database=MONGODB_TEST_DB) # pyright: ignore[reportPrivateUsage] + store = MongoDBStore(url=f"mongodb://{MONGODB_HOST}:{MONGODB_HOST_PORT}", db_name=f"{MONGODB_TEST_DB}-native", native_storage=True) - return store + await clean_mongodb_database(store=store) - @pytest.fixture - async def mongodb_store(self, store: MongoDBStore) -> MongoDBStore: return store - @pytest.mark.skip(reason="Distributed Caches are unbounded") + async def test_value_stored_as_bson_dict(self, store: MongoDBStore): + """Verify values are stored as BSON dicts, not JSON strings.""" + await store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}) + + # Get the raw MongoDB document + await store._setup_collection(collection="test") # pyright: ignore[reportPrivateUsage] + sanitized_collection = store._sanitize_collection_name(collection="test") # pyright: ignore[reportPrivateUsage] + collection = store._collections_by_name[sanitized_collection] # pyright: ignore[reportPrivateUsage] + doc = await collection.find_one({"key": "test_key"}) + + assert doc == snapshot( + { + "_id": IsInstance(expected_type=ObjectId), + "key": "test_key", + "created_at": IsDatetime(), + "value": {"object": {"name": "Alice", "age": 30}}, + } + ) + + async def test_migration_from_legacy_mode(self, store: MongoDBStore): + """Verify native mode can read legacy JSON string data.""" + await store._setup_collection(collection="test") # pyright: ignore[reportPrivateUsage] + sanitized_collection = store._sanitize_collection_name(collection="test") # pyright: ignore[reportPrivateUsage] + collection = store._collections_by_name[sanitized_collection] # pyright: ignore[reportPrivateUsage] + + await collection.insert_one( + { + "key": "legacy_key", + "value": {"string": '{"legacy": "data"}'}, + } + ) + + result = await store.get(collection="test", key="legacy_key") + assert result == {"legacy": "data"} + + +@pytest.mark.skipif(should_skip_docker_tests(), reason="Docker is not available") +class TestMongoDBStoreNonNativeMode(BaseMongoDBStoreTests): + """Test MongoDBStore with native_storage=False (legacy mode) for backward compatibility.""" + @override - async def test_not_unbounded(self, store: BaseStore): ... + @pytest.fixture + async def store(self, setup_mongodb: None) -> MongoDBStore: + store = MongoDBStore(url=f"mongodb://{MONGODB_HOST}:{MONGODB_HOST_PORT}", db_name=MONGODB_TEST_DB, native_storage=False) - async def test_mongodb_collection_name_sanitization(self, mongodb_store: MongoDBStore): - """Tests that a special characters in the collection name will not raise an error.""" - await mongodb_store.put(collection="test_collection!@#$%^&*()", key="test_key", value={"test": "test"}) - assert await mongodb_store.get(collection="test_collection!@#$%^&*()", key="test_key") == {"test": "test"} + await clean_mongodb_database(store=store) - collections = await mongodb_store.collections() - assert collections == snapshot(["test_collection_-daf4a2ec"]) + return store + + async def test_value_stored_as_json(self, store: MongoDBStore): + """Verify values are stored as JSON strings.""" + await store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}) + + # Get the raw MongoDB document + await store._setup_collection(collection="test") # pyright: ignore[reportPrivateUsage] + sanitized_collection = store._sanitize_collection_name(collection="test") # pyright: ignore[reportPrivateUsage] + collection = store._collections_by_name[sanitized_collection] # pyright: ignore[reportPrivateUsage] + doc = await collection.find_one({"key": "test_key"}) + + assert doc == snapshot( + { + "_id": IsInstance(expected_type=ObjectId), + "key": "test_key", + "created_at": IsDatetime(), + "value": {"string": '{"age": 30, "name": "Alice"}'}, + } + ) + + async def test_migration_from_native_mode(self, store: MongoDBStore): + """Verify non-native mode can read native mode data.""" + await store._setup_collection(collection="test") # pyright: ignore[reportPrivateUsage] + sanitized_collection = store._sanitize_collection_name(collection="test") # pyright: ignore[reportPrivateUsage] + collection = store._collections_by_name[sanitized_collection] # pyright: ignore[reportPrivateUsage] + + await collection.insert_one( + { + "key": "legacy_key", + "value": {"object": {"name": "Alice", "age": 30}}, + } + ) + + result = await store.get(collection="test", key="legacy_key") + assert result == {"name": "Alice", "age": 30} diff --git a/key-value/key-value-shared/src/key_value/shared/utils/managed_entry.py b/key-value/key-value-shared/src/key_value/shared/utils/managed_entry.py index 30be70e5..60cafb60 100644 --- a/key-value/key-value-shared/src/key_value/shared/utils/managed_entry.py +++ b/key-value/key-value-shared/src/key_value/shared/utils/managed_entry.py @@ -79,7 +79,7 @@ def to_json( ) @classmethod - def from_dict( + def from_dict( # noqa: PLR0912 cls, data: dict[str, Any], includes_metadata: bool = True, ttl: SupportsFloat | None = None, stringified_value: bool = False ) -> Self: if not includes_metadata: @@ -87,8 +87,26 @@ def from_dict( value=data, ) - created_at: datetime | None = try_parse_datetime_str(value=data.get("created_at")) - expires_at: datetime | None = try_parse_datetime_str(value=data.get("expires_at")) + created_at: datetime | None = None + expires_at: datetime | None = None + + if created_at_value := data.get("created_at"): + if isinstance(created_at_value, str): + created_at = try_parse_datetime_str(value=created_at_value) + elif isinstance(created_at_value, datetime): + created_at = created_at_value + else: + msg = "Expected `created_at` field to be a string or datetime" + raise DeserializationError(msg) + + if expires_at_value := data.get("expires_at"): + if isinstance(expires_at_value, str): + expires_at = try_parse_datetime_str(value=expires_at_value) + elif isinstance(expires_at_value, datetime): + expires_at = expires_at_value + else: + msg = "Expected `expires_at` field to be a string or datetime" + raise DeserializationError(msg) if not (raw_value := data.get("value")): msg = "Value is None" diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/mongodb/store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/mongodb/store.py index fafe9bf2..753ffb46 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/mongodb/store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/mongodb/store.py @@ -5,8 +5,10 @@ from datetime import datetime from typing import Any, overload -from key_value.shared.utils.managed_entry import ManagedEntry +from key_value.shared.errors import DeserializationError +from key_value.shared.utils.managed_entry import ManagedEntry, verify_dict from key_value.shared.utils.sanitize import ALPHANUMERIC_CHARACTERS, sanitize_string +from key_value.shared.utils.time_to_live import timezone from typing_extensions import Self, override from key_value.sync.code_gen.stores.base import ( @@ -44,7 +46,8 @@ def document_to_managed_entry(document: dict[str, Any]) -> ManagedEntry: """Convert a MongoDB document back to a ManagedEntry. This function deserializes a MongoDB document (created by `managed_entry_to_document`) back to a - ManagedEntry object, parsing the stringified value field and preserving all metadata. + ManagedEntry object. It supports both native BSON storage (dict in value.object field) and legacy + JSON string storage (string in value.string field) for migration support. Args: document: The MongoDB document to convert. @@ -52,36 +55,91 @@ def document_to_managed_entry(document: dict[str, Any]) -> ManagedEntry: Returns: A ManagedEntry object reconstructed from the document. """ - return ManagedEntry.from_dict(data=document, stringified_value=True) + if not (value_field := document.get("value")): + msg = "Value field not found" + raise DeserializationError(msg) + if not isinstance(value_field, dict): + msg = "Expected `value` field to be an object" + raise DeserializationError(msg) -def managed_entry_to_document(key: str, managed_entry: ManagedEntry) -> dict[str, Any]: + value_holder: dict[str, Any] = verify_dict(obj=value_field) + + data: dict[str, Any] = {} + + # The Value field is an object with two possible fields: `object` and `string` + # - `object`: The value is a native BSON dict + # - `string`: The value is a JSON string + # Mongo stores datetimes without timezones as UTC so we mark them as UTC + + if created_at_datetime := document.get("created_at"): + if not isinstance(created_at_datetime, datetime): + msg = "Expected `created_at` field to be a datetime" + raise DeserializationError(msg) + data["created_at"] = created_at_datetime.replace(tzinfo=timezone.utc) + + if expires_at_datetime := document.get("expires_at"): + if not isinstance(expires_at_datetime, datetime): + msg = "Expected `expires_at` field to be a datetime" + raise DeserializationError(msg) + data["expires_at"] = expires_at_datetime.replace(tzinfo=timezone.utc) + + if value_object := value_holder.get("object"): + return ManagedEntry.from_dict(data={"value": value_object, **data}) + + if value_string := value_holder.get("string"): + return ManagedEntry.from_dict(data={"value": value_string, **data}, stringified_value=True) + + msg = "Expected `value` field to be an object with `object` or `string` subfield" + raise DeserializationError(msg) + + +def managed_entry_to_document(key: str, managed_entry: ManagedEntry, *, native_storage: bool = True) -> dict[str, Any]: """Convert a ManagedEntry to a MongoDB document for storage. This function serializes a ManagedEntry to a MongoDB document format, including the key and all - metadata (TTL, creation, and expiration timestamps). The value is stringified to ensure proper - storage in MongoDB. The serialization is designed to preserve all entry information for round-trip - conversion back to a ManagedEntry. + metadata (TTL, creation, and expiration timestamps). The value storage format depends on the + native_storage parameter. Args: key: The key associated with this entry. managed_entry: The ManagedEntry to serialize. + native_storage: If True (default), store value as native BSON dict in value.object field. + If False, store as JSON string in value.string field for backward compatibility. Returns: A MongoDB document dict containing the key, value, and all metadata. """ - return { - "key": key, - **managed_entry.to_dict(include_metadata=True, include_expiration=True, include_creation=True, stringify_value=True), - } + document: dict[str, Any] = {"key": key, "value": {}} + + # We convert to JSON even if we don't need to, this ensures that the value we were provided + # can be serialized to JSON which helps ensure compatibility across stores. For example, + # Mongo can natively handle datetime objects which other stores cannot, if we don't convert to JSON, + # then using py-key-value with Mongo will return different values than if we used another store. + json_str = managed_entry.value_as_json + + # Store in appropriate field based on mode + if native_storage: + document["value"]["object"] = managed_entry.value_as_dict + else: + document["value"]["string"] = json_str + + # Add metadata fields + if managed_entry.created_at: + document["created_at"] = managed_entry.created_at + if managed_entry.expires_at: + document["expires_at"] = managed_entry.expires_at + + return document class MongoDBStore(BaseEnumerateCollectionsStore, BaseDestroyCollectionStore, BaseContextManagerStore, BaseStore): - """MongoDB-based key-value store using Motor (sync MongoDB driver).""" + """MongoDB-based key-value store using pymongo.""" _client: MongoClient[dict[str, Any]] _db: Database[dict[str, Any]] _collections_by_name: dict[str, Collection[dict[str, Any]]] + _native_storage: bool @overload def __init__( @@ -90,6 +148,7 @@ def __init__( client: MongoClient[dict[str, Any]], db_name: str | None = None, coll_name: str | None = None, + native_storage: bool = True, default_collection: str | None = None, ) -> None: """Initialize the MongoDB store. @@ -98,12 +157,19 @@ def __init__( client: The MongoDB client to use. db_name: The name of the MongoDB database. coll_name: The name of the MongoDB collection. + native_storage: Whether to use native BSON storage (True, default) or JSON string storage (False). default_collection: The default collection to use if no collection is provided. """ @overload def __init__( - self, *, url: str, db_name: str | None = None, coll_name: str | None = None, default_collection: str | None = None + self, + *, + url: str, + db_name: str | None = None, + coll_name: str | None = None, + native_storage: bool = True, + default_collection: str | None = None, ) -> None: """Initialize the MongoDB store. @@ -111,6 +177,7 @@ def __init__( url: The url of the MongoDB cluster. db_name: The name of the MongoDB database. coll_name: The name of the MongoDB collection. + native_storage: Whether to use native BSON storage (True, default) or JSON string storage (False). default_collection: The default collection to use if no collection is provided. """ @@ -121,9 +188,21 @@ def __init__( url: str | None = None, db_name: str | None = None, coll_name: str | None = None, + native_storage: bool = True, default_collection: str | None = None, ) -> None: - """Initialize the MongoDB store.""" + """Initialize the MongoDB store. + + Args: + client: The MongoDB client to use (mutually exclusive with url). + url: The url of the MongoDB cluster (mutually exclusive with client). + db_name: The name of the MongoDB database. + coll_name: The name of the MongoDB collection. + native_storage: Whether to use native BSON storage (True, default) or JSON string storage (False). + Native storage stores values as BSON dicts for better query support. + Legacy mode stores values as JSON strings for backward compatibility. + default_collection: The default collection to use if no collection is provided. + """ if client: self._client = client @@ -138,6 +217,7 @@ def __init__( self._db = self._client[db_name] self._collections_by_name = {} + self._native_storage = native_storage super().__init__(default_collection=default_collection) @@ -165,7 +245,7 @@ def _sanitize_collection_name(self, collection: str) -> str: Returns: A sanitized collection name that meets MongoDB requirements. """ - return sanitize_string(value=collection, max_length=MAX_COLLECTION_LENGTH, allowed_characters=ALPHANUMERIC_CHARACTERS) + return sanitize_string(value=collection, max_length=MAX_COLLECTION_LENGTH, allowed_characters=COLLECTION_ALLOWED_CHARACTERS) @override def _setup_collection(self, *, collection: str) -> None: @@ -194,7 +274,7 @@ def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | Non sanitized_collection = self._sanitize_collection_name(collection=collection) if doc := self._collections_by_name[sanitized_collection].find_one(filter={"key": key}): - return ManagedEntry.from_dict(data=doc, stringified_value=True) + return document_to_managed_entry(document=doc) return None @@ -218,7 +298,7 @@ def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> list[ @override def _put_managed_entry(self, *, key: str, collection: str, managed_entry: ManagedEntry) -> None: - mongo_doc: dict[str, Any] = managed_entry_to_document(key=key, managed_entry=managed_entry) + mongo_doc: dict[str, Any] = managed_entry_to_document(key=key, managed_entry=managed_entry, native_storage=self._native_storage) sanitized_collection = self._sanitize_collection_name(collection=collection) @@ -245,7 +325,7 @@ def _put_managed_entries( operations: list[UpdateOne] = [] for key, managed_entry in zip(keys, managed_entries, strict=True): - mongo_doc: dict[str, Any] = managed_entry_to_document(key=key, managed_entry=managed_entry) + mongo_doc: dict[str, Any] = managed_entry_to_document(key=key, managed_entry=managed_entry, native_storage=self._native_storage) operations.append(UpdateOne(filter={"key": key}, update={"$set": mongo_doc}, upsert=True)) diff --git a/key-value/key-value-sync/tests/code_gen/conftest.py b/key-value/key-value-sync/tests/code_gen/conftest.py index 9df1f24e..3d7b2051 100644 --- a/key-value/key-value-sync/tests/code_gen/conftest.py +++ b/key-value/key-value-sync/tests/code_gen/conftest.py @@ -170,6 +170,7 @@ def docker_container( finally: docker_stop(name, raise_on_error=False) docker_rm(name, raise_on_error=False) + docker_wait_container_gone(name=name, max_tries=10, wait_time=1.0) logger.info(f"Container {name} stopped and removed") return diff --git a/key-value/key-value-sync/tests/code_gen/stores/mongodb/test_mongodb.py b/key-value/key-value-sync/tests/code_gen/stores/mongodb/test_mongodb.py index 086634fb..a9a3a9a3 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/mongodb/test_mongodb.py +++ b/key-value/key-value-sync/tests/code_gen/stores/mongodb/test_mongodb.py @@ -7,7 +7,8 @@ from typing import Any import pytest -from dirty_equals import IsFloat +from bson import ObjectId +from dirty_equals import IsDatetime, IsFloat, IsInstance from inline_snapshot import snapshot from key_value.shared.stores.wait import wait_for_true from key_value.shared.utils.managed_entry import ManagedEntry @@ -45,15 +46,20 @@ class MongoDBFailedToStartError(Exception): pass -def test_managed_entry_document_conversion(): +def test_managed_entry_document_conversion_native_mode(): created_at = datetime(year=2025, month=1, day=1, hour=0, minute=0, second=0, tzinfo=timezone.utc) expires_at = created_at + timedelta(seconds=10) managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at) - document = managed_entry_to_document(key="test", managed_entry=managed_entry) + document = managed_entry_to_document(key="test", managed_entry=managed_entry, native_storage=True) assert document == snapshot( - {"key": "test", "value": '{"test": "test"}', "created_at": "2025-01-01T00:00:00+00:00", "expires_at": "2025-01-01T00:00:10+00:00"} + { + "key": "test", + "value": {"object": {"test": "test"}}, + "created_at": datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc), + "expires_at": datetime(2025, 1, 1, 0, 0, 10, tzinfo=timezone.utc), + } ) round_trip_managed_entry = document_to_managed_entry(document=document) @@ -64,8 +70,38 @@ def test_managed_entry_document_conversion(): assert round_trip_managed_entry.expires_at == expires_at -@pytest.mark.skipif(should_skip_docker_tests(), reason="Docker is not available") -class TestMongoDBStore(ContextManagerStoreTestMixin, BaseStoreTests): +def test_managed_entry_document_conversion_legacy_mode(): + created_at = datetime(year=2025, month=1, day=1, hour=0, minute=0, second=0, tzinfo=timezone.utc) + expires_at = created_at + timedelta(seconds=10) + + managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at) + document = managed_entry_to_document(key="test", managed_entry=managed_entry, native_storage=False) + + assert document == snapshot( + { + "key": "test", + "value": {"string": '{"test": "test"}'}, + "created_at": datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc), + "expires_at": datetime(2025, 1, 1, 0, 0, 10, tzinfo=timezone.utc), + } + ) + + round_trip_managed_entry = document_to_managed_entry(document=document) + + assert round_trip_managed_entry.value == managed_entry.value + assert round_trip_managed_entry.created_at == created_at + assert round_trip_managed_entry.ttl == IsFloat(lt=0) + assert round_trip_managed_entry.expires_at == expires_at + + +def clean_mongodb_database(store: MongoDBStore) -> None: + with contextlib.suppress(Exception): + _ = store._client.drop_database(name_or_database=store._db.name) # pyright: ignore[reportPrivateUsage] + + +class BaseMongoDBStoreTests(ContextManagerStoreTestMixin, BaseStoreTests): + """Base class for MongoDB store tests.""" + @pytest.fixture(autouse=True, scope="session", params=MONGODB_VERSIONS_TO_TEST) def setup_mongodb(self, request: pytest.FixtureRequest) -> Generator[None, None, None]: version = request.param @@ -77,28 +113,102 @@ def setup_mongodb(self, request: pytest.FixtureRequest) -> Generator[None, None, yield + @pytest.mark.skip(reason="Distributed Caches are unbounded") + @override + def test_not_unbounded(self, store: BaseStore): ... + + def test_mongodb_collection_name_sanitization(self, store: MongoDBStore): + """Tests that a special characters in the collection name will not raise an error.""" + store.put(collection="test_collection!@#$%^&*()", key="test_key", value={"test": "test"}) + assert store.get(collection="test_collection!@#$%^&*()", key="test_key") == {"test": "test"} + + collections = store.collections() + assert collections == snapshot(["test_collection_-daf4a2ec"]) + + +@pytest.mark.skipif(should_skip_docker_tests(), reason="Docker is not available") +class TestMongoDBStoreNativeMode(BaseMongoDBStoreTests): + """Test MongoDBStore with native_storage=True (default).""" + @override @pytest.fixture def store(self, setup_mongodb: None) -> MongoDBStore: - store = MongoDBStore(url=f"mongodb://{MONGODB_HOST}:{MONGODB_HOST_PORT}", db_name=MONGODB_TEST_DB) - # Ensure a clean db by dropping our default test collection if it exists - with contextlib.suppress(Exception): - _ = store._client.drop_database(name_or_database=MONGODB_TEST_DB) # pyright: ignore[reportPrivateUsage] + store = MongoDBStore(url=f"mongodb://{MONGODB_HOST}:{MONGODB_HOST_PORT}", db_name=f"{MONGODB_TEST_DB}-native", native_storage=True) - return store + clean_mongodb_database(store=store) - @pytest.fixture - def mongodb_store(self, store: MongoDBStore) -> MongoDBStore: return store - @pytest.mark.skip(reason="Distributed Caches are unbounded") + def test_value_stored_as_bson_dict(self, store: MongoDBStore): + """Verify values are stored as BSON dicts, not JSON strings.""" + store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}) + + # Get the raw MongoDB document + store._setup_collection(collection="test") # pyright: ignore[reportPrivateUsage] + sanitized_collection = store._sanitize_collection_name(collection="test") # pyright: ignore[reportPrivateUsage] + collection = store._collections_by_name[sanitized_collection] # pyright: ignore[reportPrivateUsage] + doc = collection.find_one({"key": "test_key"}) + + assert doc == snapshot( + { + "_id": IsInstance(expected_type=ObjectId), + "key": "test_key", + "created_at": IsDatetime(), + "value": {"object": {"name": "Alice", "age": 30}}, + } + ) + + def test_migration_from_legacy_mode(self, store: MongoDBStore): + """Verify native mode can read legacy JSON string data.""" + store._setup_collection(collection="test") # pyright: ignore[reportPrivateUsage] + sanitized_collection = store._sanitize_collection_name(collection="test") # pyright: ignore[reportPrivateUsage] + collection = store._collections_by_name[sanitized_collection] # pyright: ignore[reportPrivateUsage] + + collection.insert_one({"key": "legacy_key", "value": {"string": '{"legacy": "data"}'}}) + + result = store.get(collection="test", key="legacy_key") + assert result == {"legacy": "data"} + + +@pytest.mark.skipif(should_skip_docker_tests(), reason="Docker is not available") +class TestMongoDBStoreNonNativeMode(BaseMongoDBStoreTests): + """Test MongoDBStore with native_storage=False (legacy mode) for backward compatibility.""" + @override - def test_not_unbounded(self, store: BaseStore): ... + @pytest.fixture + def store(self, setup_mongodb: None) -> MongoDBStore: + store = MongoDBStore(url=f"mongodb://{MONGODB_HOST}:{MONGODB_HOST_PORT}", db_name=MONGODB_TEST_DB, native_storage=False) - def test_mongodb_collection_name_sanitization(self, mongodb_store: MongoDBStore): - """Tests that a special characters in the collection name will not raise an error.""" - mongodb_store.put(collection="test_collection!@#$%^&*()", key="test_key", value={"test": "test"}) - assert mongodb_store.get(collection="test_collection!@#$%^&*()", key="test_key") == {"test": "test"} + clean_mongodb_database(store=store) - collections = mongodb_store.collections() - assert collections == snapshot(["test_collection_-daf4a2ec"]) + return store + + def test_value_stored_as_json(self, store: MongoDBStore): + """Verify values are stored as JSON strings.""" + store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}) + + # Get the raw MongoDB document + store._setup_collection(collection="test") # pyright: ignore[reportPrivateUsage] + sanitized_collection = store._sanitize_collection_name(collection="test") # pyright: ignore[reportPrivateUsage] + collection = store._collections_by_name[sanitized_collection] # pyright: ignore[reportPrivateUsage] + doc = collection.find_one({"key": "test_key"}) + + assert doc == snapshot( + { + "_id": IsInstance(expected_type=ObjectId), + "key": "test_key", + "created_at": IsDatetime(), + "value": {"string": '{"age": 30, "name": "Alice"}'}, + } + ) + + def test_migration_from_native_mode(self, store: MongoDBStore): + """Verify non-native mode can read native mode data.""" + store._setup_collection(collection="test") # pyright: ignore[reportPrivateUsage] + sanitized_collection = store._sanitize_collection_name(collection="test") # pyright: ignore[reportPrivateUsage] + collection = store._collections_by_name[sanitized_collection] # pyright: ignore[reportPrivateUsage] + + collection.insert_one({"key": "legacy_key", "value": {"object": {"name": "Alice", "age": 30}}}) + + result = store.get(collection="test", key="legacy_key") + assert result == {"name": "Alice", "age": 30}