From 6211f03d984de5561330b3de5d68c767e85de677 Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Sun, 9 Nov 2025 19:34:23 +0000 Subject: [PATCH 1/2] feat: add PostgreSQL store as async-only implementation - Add PostgreSQL store using asyncpg for native async operations - Configure as async-only (no sync codegen, following DynamoDB/Memcached pattern) - Add comprehensive documentation in docs/stores.md - Include DuckDB in documentation table as well - Add postgresql extra dependency to pyproject.toml - Exclude PostgreSQL from sync codegen in build_sync_library.py The PostgreSQL store provides: - JSONB storage for flexible key-value data - TTL support via expiration timestamps - Single table design with collections as column values - Async-only implementation using asyncpg Co-authored-by: William Easton --- docs/stores.md | 51 ++ key-value/key-value-aio/pyproject.toml | 3 +- .../aio/stores/postgresql/__init__.py | 9 + .../key_value/aio/stores/postgresql/store.py | 551 ++++++++++++++++++ .../tests/stores/postgresql/__init__.py | 1 + .../stores/postgresql/test_postgresql.py | 168 ++++++ scripts/build_sync_library.py | 2 + uv.lock | 55 +- 8 files changed, 835 insertions(+), 5 deletions(-) create mode 100644 key-value/key-value-aio/src/key_value/aio/stores/postgresql/__init__.py create mode 100644 key-value/key-value-aio/src/key_value/aio/stores/postgresql/store.py create mode 100644 key-value/key-value-aio/tests/stores/postgresql/__init__.py create mode 100644 key-value/key-value-aio/tests/stores/postgresql/test_postgresql.py diff --git a/docs/stores.md b/docs/stores.md index ee6250a7..7bd3fafc 100644 --- a/docs/stores.md +++ b/docs/stores.md @@ -34,6 +34,7 @@ Local stores are stored in memory or on disk, local to the application. | Memory | N/A | ✅ | ✅ | Fast in-memory storage for development and caching | | Disk | Stable | ☑️ | ✅ | Persistent file-based storage in a single file | | Disk (Per-Collection) | Stable | ☑️ | ✅ | Persistent storage with separate files per collection | +| DuckDB | Unstable | ☑️ | ✅ | In-process SQL OLAP database with native JSON storage | | FileTree (test) | Unstable | ☑️ | ✅ | Directory-based storage with JSON files for visual inspection | | Null (test) | N/A | ✅ | ✅ | No-op store for testing without side effects | | RocksDB | Unstable | ☑️ | ✅ | High-performance embedded database | @@ -400,6 +401,7 @@ Distributed stores provide network-based storage for multi-node applications. | Elasticsearch | Unstable | ✅ | ✅ | Full-text search with key-value capabilities | | Memcached | Unstable | ✅ | ✖️ | High-performance distributed memory cache | | MongoDB | Unstable | ✅ | ✅ | Document database used as key-value store | +| PostgreSQL | Unstable | ✅ | ✖️ | PostgreSQL database with JSONB storage | | Redis | Stable | ✅ | ✅ | Popular in-memory data structure store | | Valkey | Stable | ✅ | ✅ | Open-source Redis fork | @@ -569,6 +571,55 @@ pip install py-key-value-aio[mongodb] --- +### PostgreSQLStore + +PostgreSQL database with JSONB storage for flexible key-value data. + +**Note:** PostgreSQL is async-only. This store uses `asyncpg` which provides native async/await operations. + +```python +from key_value.aio.stores.postgresql import PostgreSQLStore + +# Using connection URL +store = PostgreSQLStore(url="postgresql://localhost:5432/mydb") + +# Using connection parameters +store = PostgreSQLStore( + host="localhost", + port=5432, + database="mydb", + user="myuser", + password="mypass" +) + +async with store: + await store.put(key="user_1", value={"name": "Alice"}, collection="users") + user = await store.get(key="user_1", collection="users") +``` + +**Installation:** + +```bash +pip install py-key-value-aio[postgresql] +``` + +**Use Cases:** + +- Applications already using PostgreSQL +- Need for SQL querying on stored data +- ACID transaction requirements +- Complex data relationships + +**Characteristics:** + +- JSONB storage for efficient querying +- TTL support via expiration timestamps +- Single table design (collections as column values) +- Async-only (uses asyncpg) +- Stable storage format: **Unstable** + +--- + ### MemcachedStore High-performance distributed memory caching system. diff --git a/key-value/key-value-aio/pyproject.toml b/key-value/key-value-aio/pyproject.toml index 8f41827c..ee70d1a9 100644 --- a/key-value/key-value-aio/pyproject.toml +++ b/key-value/key-value-aio/pyproject.toml @@ -50,6 +50,7 @@ rocksdb = [ "rocksdict>=0.3.2 ; python_version < '3.12'" ] duckdb = ["duckdb>=1.1.1", "pytz>=2025.2"] +postgresql = ["asyncpg>=0.30.0"] wrappers-encryption = ["cryptography>=45.0.0"] [tool.pytest.ini_options] @@ -69,7 +70,7 @@ env_files = [".env"] [dependency-groups] dev = [ - "py-key-value-aio[memory,disk,filetree,redis,elasticsearch,memcached,mongodb,vault,dynamodb,rocksdb,duckdb]", + "py-key-value-aio[memory,disk,filetree,redis,elasticsearch,memcached,mongodb,vault,dynamodb,rocksdb,duckdb,postgresql]", "py-key-value-aio[valkey]; platform_system != 'Windows'", "py-key-value-aio[keyring]", "py-key-value-aio[pydantic]", diff --git a/key-value/key-value-aio/src/key_value/aio/stores/postgresql/__init__.py b/key-value/key-value-aio/src/key_value/aio/stores/postgresql/__init__.py new file mode 100644 index 00000000..0af3ae48 --- /dev/null +++ b/key-value/key-value-aio/src/key_value/aio/stores/postgresql/__init__.py @@ -0,0 +1,9 @@ +"""PostgreSQL store for py-key-value-aio.""" + +try: + from key_value.aio.stores.postgresql.store import PostgreSQLStore, PostgreSQLV1CollectionSanitizationStrategy +except ImportError as e: + msg = 'PostgreSQLStore requires the "postgresql" extra. Install via: pip install "py-key-value-aio[postgresql]"' + raise ImportError(msg) from e + +__all__ = ["PostgreSQLStore", "PostgreSQLV1CollectionSanitizationStrategy"] diff --git a/key-value/key-value-aio/src/key_value/aio/stores/postgresql/store.py b/key-value/key-value-aio/src/key_value/aio/stores/postgresql/store.py new file mode 100644 index 00000000..ad6620c6 --- /dev/null +++ b/key-value/key-value-aio/src/key_value/aio/stores/postgresql/store.py @@ -0,0 +1,551 @@ +"""PostgreSQL-based key-value store using asyncpg. + +Note: SQL queries in this module use f-strings for table names, which triggers S608 warnings. +This is safe because table names are validated in __init__ to be alphanumeric plus underscores. +""" + +# ruff: noqa: S608 + +from collections.abc import AsyncIterator, Sequence +from contextlib import asynccontextmanager +from datetime import datetime +from typing import Any, overload + +from key_value.shared.utils.managed_entry import ManagedEntry +from key_value.shared.utils.sanitization import HybridSanitizationStrategy, SanitizationStrategy +from key_value.shared.utils.sanitize import ALPHANUMERIC_CHARACTERS +from typing_extensions import Self, override + +from key_value.aio.stores.base import BaseContextManagerStore, BaseDestroyCollectionStore, BaseEnumerateCollectionsStore, BaseStore + +try: + import asyncpg +except ImportError as e: + msg = "PostgreSQLStore requires py-key-value-aio[postgresql]" + raise ImportError(msg) from e + + +DEFAULT_HOST = "localhost" +DEFAULT_PORT = 5432 +DEFAULT_DATABASE = "postgres" +DEFAULT_TABLE = "kv_store" + +DEFAULT_PAGE_SIZE = 10000 +PAGE_LIMIT = 10000 + +# PostgreSQL table name length limit is 63 characters +# Use 200 for consistency with MongoDB +MAX_COLLECTION_LENGTH = 200 +POSTGRES_MAX_IDENTIFIER_LEN = 63 +COLLECTION_ALLOWED_CHARACTERS = ALPHANUMERIC_CHARACTERS + "_" + + +class PostgreSQLV1CollectionSanitizationStrategy(HybridSanitizationStrategy): + def __init__(self) -> None: + super().__init__( + replacement_character="_", + max_length=MAX_COLLECTION_LENGTH, + allowed_characters=COLLECTION_ALLOWED_CHARACTERS, + ) + + +class PostgreSQLStore(BaseEnumerateCollectionsStore, BaseDestroyCollectionStore, BaseContextManagerStore, BaseStore): + """PostgreSQL-based key-value store using asyncpg. + + This store uses a single table with columns for collection, key, value (JSONB), and metadata. + Collections are stored as a column value rather than separate tables. + + By default, collections are not sanitized. This means that there are character and length restrictions on + collection names that may cause errors when trying to get and put entries. + + To avoid issues, you may want to consider leveraging the `PostgreSQLV1CollectionSanitizationStrategy` strategy. + + Example: + Basic usage with default connection: + >>> store = PostgreSQLStore() + >>> async with store: + ... await store.put("user_1", {"name": "Alice"}, collection="users") + ... user = await store.get("user_1", collection="users") + + Using a connection URL: + >>> store = PostgreSQLStore(url="postgresql://user:pass@localhost/mydb") + >>> async with store: + ... await store.put("key", {"data": "value"}) + + Using custom connection parameters: + >>> store = PostgreSQLStore( + ... host="db.example.com", + ... port=5432, + ... database="myapp", + ... user="myuser", + ... password="mypass" + ... ) + """ + + _pool: asyncpg.Pool | None # type: ignore[type-arg] + _owns_pool: bool + _url: str | None + _host: str + _port: int + _database: str + _user: str | None + _password: str | None + _table_name: str + + @overload + def __init__( + self, + *, + pool: asyncpg.Pool, # type: ignore[type-arg] + table_name: str | None = None, + default_collection: str | None = None, + collection_sanitization_strategy: SanitizationStrategy | None = None, + ) -> None: + """Initialize the PostgreSQL store with an existing connection pool. + + Args: + pool: An existing asyncpg connection pool to use. + table_name: The name of the table to use for storage (default: kv_store). + default_collection: The default collection to use if no collection is provided. + collection_sanitization_strategy: The sanitization strategy to use for collections. + """ + + @overload + def __init__( + self, + *, + url: str, + table_name: str | None = None, + default_collection: str | None = None, + collection_sanitization_strategy: SanitizationStrategy | None = None, + ) -> None: + """Initialize the PostgreSQL store with a connection URL. + + Args: + url: PostgreSQL connection URL (e.g., postgresql://user:pass@localhost/dbname). + table_name: The name of the table to use for storage (default: kv_store). + default_collection: The default collection to use if no collection is provided. + collection_sanitization_strategy: The sanitization strategy to use for collections. + """ + + @overload + def __init__( + self, + *, + host: str = DEFAULT_HOST, + port: int = DEFAULT_PORT, + database: str = DEFAULT_DATABASE, + user: str | None = None, + password: str | None = None, + table_name: str | None = None, + default_collection: str | None = None, + collection_sanitization_strategy: SanitizationStrategy | None = None, + ) -> None: + """Initialize the PostgreSQL store with connection parameters. + + Args: + host: PostgreSQL server host (default: localhost). + port: PostgreSQL server port (default: 5432). + database: Database name (default: postgres). + user: Database user (default: current user). + password: Database password (default: None). + table_name: The name of the table to use for storage (default: kv_store). + default_collection: The default collection to use if no collection is provided. + collection_sanitization_strategy: The sanitization strategy to use for collections. + """ + + def __init__( + self, + *, + pool: asyncpg.Pool | None = None, # type: ignore[type-arg] + url: str | None = None, + host: str = DEFAULT_HOST, + port: int = DEFAULT_PORT, + database: str = DEFAULT_DATABASE, + user: str | None = None, + password: str | None = None, + table_name: str | None = None, + default_collection: str | None = None, + collection_sanitization_strategy: SanitizationStrategy | None = None, + ) -> None: + """Initialize the PostgreSQL store.""" + self._pool = pool + self._owns_pool = pool is None # Only own the pool if we create it + self._url = url + self._host = host + self._port = port + self._database = database + self._user = user + self._password = password + + # Validate and sanitize table name to prevent SQL injection and invalid identifiers + table_name = table_name or DEFAULT_TABLE + if not table_name.replace("_", "").isalnum(): + msg = f"Table name must be alphanumeric (with underscores): {table_name}" + raise ValueError(msg) + if table_name[0].isdigit(): + msg = f"Table name must not start with a digit: {table_name}" + raise ValueError(msg) + # PostgreSQL identifier limit is 63 bytes + if len(table_name) > POSTGRES_MAX_IDENTIFIER_LEN: + msg = f"Table name too long (>{POSTGRES_MAX_IDENTIFIER_LEN}): {table_name}" + raise ValueError(msg) + self._table_name = table_name + + super().__init__( + default_collection=default_collection, + collection_sanitization_strategy=collection_sanitization_strategy, + ) + + def _ensure_pool_initialized(self) -> asyncpg.Pool: # type: ignore[type-arg] + """Ensure the connection pool is initialized. + + Returns: + The initialized connection pool. + + Raises: + RuntimeError: If the pool is not initialized. + """ + if self._pool is None: + msg = "Pool is not initialized. Use async with or call __aenter__() first." + raise RuntimeError(msg) + return self._pool + + @asynccontextmanager + async def _acquire_connection(self) -> AsyncIterator[asyncpg.Connection]: # type: ignore[type-arg] + """Acquire a connection from the pool. + + Yields: + A connection from the pool. + + Raises: + RuntimeError: If the pool is not initialized. + """ + pool = self._ensure_pool_initialized() + async with pool.acquire() as conn: # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + yield conn + + @override + async def __aenter__(self) -> Self: + if self._pool is None: + if self._url: + self._pool = await asyncpg.create_pool(self._url) # pyright: ignore[reportUnknownMemberType] + else: + self._pool = await asyncpg.create_pool( # pyright: ignore[reportUnknownMemberType] + host=self._host, + port=self._port, + database=self._database, + user=self._user, + password=self._password, + ) + self._owns_pool = True + + await super().__aenter__() + return self + + @override + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: # pyright: ignore[reportAny] + await super().__aexit__(exc_type, exc_val, exc_tb) + if self._pool is not None and self._owns_pool: + await self._pool.close() + + @override + async def _setup_collection(self, *, collection: str) -> None: + """Set up the database table and indexes if they don't exist. + + Args: + collection: The collection name (used for validation, but all collections share the same table). + """ + _ = self._sanitize_collection(collection=collection) + + # Create the main table if it doesn't exist + table_sql = ( + f"CREATE TABLE IF NOT EXISTS {self._table_name} (" + "collection VARCHAR(255) NOT NULL, " + "key VARCHAR(255) NOT NULL, " + "value JSONB NOT NULL, " + "ttl DOUBLE PRECISION, " + "created_at TIMESTAMPTZ, " + "expires_at TIMESTAMPTZ, " + "PRIMARY KEY (collection, key))" + ) + + # Create index on expires_at for efficient TTL queries + # Ensure index name <= 63 chars (PostgreSQL identifier limit) + index_name = f"idx_{self._table_name}_expires_at" + if len(index_name) > POSTGRES_MAX_IDENTIFIER_LEN: + import hashlib + + index_name = "idx_" + hashlib.sha256(self._table_name.encode()).hexdigest()[:16] + "_exp" + + index_sql = f"CREATE INDEX IF NOT EXISTS {index_name} ON {self._table_name}(expires_at) WHERE expires_at IS NOT NULL" + + async with self._acquire_connection() as conn: + await conn.execute(table_sql) # pyright: ignore[reportUnknownMemberType] + await conn.execute(index_sql) # pyright: ignore[reportUnknownMemberType] + + @override + async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | None: + """Retrieve a managed entry by key from the specified collection. + + Args: + key: The key to retrieve. + collection: The collection to retrieve from. + + Returns: + The managed entry if found and not expired, None otherwise. + """ + sanitized_collection = self._sanitize_collection(collection=collection) + + async with self._acquire_connection() as conn: + row = await conn.fetchrow( # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + f"SELECT value, ttl, created_at, expires_at FROM {self._table_name} WHERE collection = $1 AND key = $2", + sanitized_collection, + key, + ) + + if row is None: + return None + + # Parse the managed entry + managed_entry = ManagedEntry( + value=row["value"], # pyright: ignore[reportUnknownArgumentType] + created_at=row["created_at"], # pyright: ignore[reportUnknownArgumentType] + expires_at=row["expires_at"], # pyright: ignore[reportUnknownArgumentType] + ) + + # Check if expired and delete if so + if managed_entry.is_expired: + await conn.execute( # pyright: ignore[reportUnknownMemberType] + f"DELETE FROM {self._table_name} WHERE collection = $1 AND key = $2", + sanitized_collection, + key, + ) + return None + + return managed_entry + + @override + async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> list[ManagedEntry | None]: + """Retrieve multiple managed entries by key from the specified collection. + + Args: + collection: The collection to retrieve from. + keys: The keys to retrieve. + + Returns: + A list of managed entries in the same order as keys, with None for missing/expired entries. + """ + if not keys: + return [] + + sanitized_collection = self._sanitize_collection(collection=collection) + + async with self._acquire_connection() as conn: + # Use ANY to query for multiple keys + rows = await conn.fetch( # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + f"SELECT key, value, ttl, created_at, expires_at FROM {self._table_name} WHERE collection = $1 AND key = ANY($2::text[])", + sanitized_collection, + list(keys), + ) + + # Build a map of key -> managed entry + entries_by_key: dict[str, ManagedEntry | None] = dict.fromkeys(keys) + expired_keys: list[str] = [] + + for row in rows: # pyright: ignore[reportUnknownVariableType] + managed_entry = ManagedEntry( + value=row["value"], # pyright: ignore[reportUnknownArgumentType] + created_at=row["created_at"], # pyright: ignore[reportUnknownArgumentType] + expires_at=row["expires_at"], # pyright: ignore[reportUnknownArgumentType] + ) + + if managed_entry.is_expired: + expired_keys.append(row["key"]) # pyright: ignore[reportUnknownArgumentType] + entries_by_key[row["key"]] = None + else: + entries_by_key[row["key"]] = managed_entry + + # Delete expired entries in batch + if expired_keys: + await conn.execute( # pyright: ignore[reportUnknownMemberType] + f"DELETE FROM {self._table_name} WHERE collection = $1 AND key = ANY($2::text[])", + sanitized_collection, + expired_keys, + ) + + return [entries_by_key[key] for key in keys] + + @override + async def _put_managed_entry( + self, + *, + key: str, + collection: str, + managed_entry: ManagedEntry, + ) -> None: + """Store a managed entry by key in the specified collection. + + Args: + key: The key to store. + collection: The collection to store in. + managed_entry: The managed entry to store. + """ + sanitized_collection = self._sanitize_collection(collection=collection) + + async with self._acquire_connection() as conn: + upsert_sql = ( + f"INSERT INTO {self._table_name} " + "(collection, key, value, ttl, created_at, expires_at) " + "VALUES ($1, $2, $3, $4, $5, $6) " + "ON CONFLICT (collection, key) " + "DO UPDATE SET value = EXCLUDED.value, ttl = EXCLUDED.ttl, expires_at = EXCLUDED.expires_at" + ) + await conn.execute( # pyright: ignore[reportUnknownMemberType] + upsert_sql, + sanitized_collection, + key, + managed_entry.value, + managed_entry.ttl, + managed_entry.created_at, + managed_entry.expires_at, + ) + + @override + async def _put_managed_entries( + self, + *, + collection: str, + keys: Sequence[str], + managed_entries: Sequence[ManagedEntry], + ttl: float | None, + created_at: datetime, + expires_at: datetime | None, + ) -> None: + """Store multiple managed entries by key in the specified collection. + + Args: + collection: The collection to store in. + keys: The keys to store. + managed_entries: The managed entries to store. + ttl: The TTL in seconds (None for no expiration). + created_at: The creation timestamp for all entries. + expires_at: The expiration timestamp for all entries (None if no TTL). + """ + if not keys: + return + + sanitized_collection = self._sanitize_collection(collection=collection) + + # Prepare data for batch insert using method-level ttl/created_at/expires_at + values = [ + (sanitized_collection, key, entry.value, ttl, created_at, expires_at) for key, entry in zip(keys, managed_entries, strict=True) + ] + + async with self._acquire_connection() as conn: + # Use executemany for batch insert + batch_upsert_sql = ( + f"INSERT INTO {self._table_name} " + "(collection, key, value, ttl, created_at, expires_at) " + "VALUES ($1, $2, $3, $4, $5, $6) " + "ON CONFLICT (collection, key) " + "DO UPDATE SET value = EXCLUDED.value, ttl = EXCLUDED.ttl, expires_at = EXCLUDED.expires_at" + ) + await conn.executemany( # pyright: ignore[reportUnknownMemberType] + batch_upsert_sql, + values, + ) + + @override + async def _delete_managed_entry(self, *, key: str, collection: str) -> bool: + """Delete a managed entry by key from the specified collection. + + Args: + key: The key to delete. + collection: The collection to delete from. + + Returns: + True if the entry was deleted, False if it didn't exist. + """ + sanitized_collection = self._sanitize_collection(collection=collection) + + async with self._acquire_connection() as conn: + result = await conn.execute( # pyright: ignore[reportUnknownMemberType] + f"DELETE FROM {self._table_name} WHERE collection = $1 AND key = $2", + sanitized_collection, + key, + ) + # PostgreSQL execute returns a string like "DELETE N" where N is the number of rows deleted + return result.split()[-1] != "0" + + @override + async def _delete_managed_entries(self, *, keys: Sequence[str], collection: str) -> int: + """Delete multiple managed entries by key from the specified collection. + + Args: + keys: The keys to delete. + collection: The collection to delete from. + + Returns: + The number of entries that were deleted. + """ + if not keys: + return 0 + + sanitized_collection = self._sanitize_collection(collection=collection) + + async with self._acquire_connection() as conn: + result = await conn.execute( # pyright: ignore[reportUnknownMemberType] + f"DELETE FROM {self._table_name} WHERE collection = $1 AND key = ANY($2::text[])", + sanitized_collection, + list(keys), + ) + # PostgreSQL execute returns a string like "DELETE N" where N is the number of rows deleted + return int(result.split()[-1]) + + @override + async def _get_collection_names(self, *, limit: int | None = None) -> list[str]: + """List all collection names. + + Args: + limit: Maximum number of collection names to return. + + Returns: + A list of collection names. + """ + if limit is None or limit <= 0: + limit = DEFAULT_PAGE_SIZE + limit = min(limit, PAGE_LIMIT) + + async with self._acquire_connection() as conn: + rows = await conn.fetch( # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + f"SELECT DISTINCT collection FROM {self._table_name} ORDER BY collection LIMIT $1", + limit, + ) + + return [row["collection"] for row in rows] # pyright: ignore[reportUnknownVariableType] + + @override + async def _delete_collection(self, *, collection: str) -> bool: + """Delete all entries in a collection. + + Args: + collection: The collection to delete. + + Returns: + True if any entries were deleted, False otherwise. + """ + sanitized_collection = self._sanitize_collection(collection=collection) + + async with self._acquire_connection() as conn: + result = await conn.execute( # pyright: ignore[reportUnknownMemberType] + f"DELETE FROM {self._table_name} WHERE collection = $1", + sanitized_collection, + ) + # Return True if any rows were deleted + return result.split()[-1] != "0" + + @override + async def _close(self) -> None: + """Close the connection pool.""" + # Connection pool is closed in __aexit__ diff --git a/key-value/key-value-aio/tests/stores/postgresql/__init__.py b/key-value/key-value-aio/tests/stores/postgresql/__init__.py new file mode 100644 index 00000000..3f8faa3f --- /dev/null +++ b/key-value/key-value-aio/tests/stores/postgresql/__init__.py @@ -0,0 +1 @@ +"""Tests for PostgreSQL store.""" diff --git a/key-value/key-value-aio/tests/stores/postgresql/test_postgresql.py b/key-value/key-value-aio/tests/stores/postgresql/test_postgresql.py new file mode 100644 index 00000000..592173ad --- /dev/null +++ b/key-value/key-value-aio/tests/stores/postgresql/test_postgresql.py @@ -0,0 +1,168 @@ +"""Tests for PostgreSQL store.""" + +import contextlib +from collections.abc import AsyncGenerator + +import pytest +from typing_extensions import override + +from key_value.aio.stores.base import BaseStore +from key_value.aio.stores.postgresql import PostgreSQLStore, PostgreSQLV1CollectionSanitizationStrategy +from tests.conftest import docker_container, should_skip_docker_tests +from tests.stores.base import BaseStoreTests, ContextManagerStoreTestMixin + +try: + import asyncpg +except ImportError: + asyncpg = None # type: ignore[assignment] + +# PostgreSQL test configuration +POSTGRESQL_HOST = "localhost" +POSTGRESQL_HOST_PORT = 5432 +POSTGRESQL_USER = "postgres" +POSTGRESQL_PASSWORD = "test" # noqa: S105 +POSTGRESQL_TEST_DB = "kv_store_test" + +WAIT_FOR_POSTGRESQL_TIMEOUT = 30 + +POSTGRESQL_VERSIONS_TO_TEST = [ + "12", # Older supported version + "17", # Latest stable version +] + + +async def ping_postgresql() -> bool: + """Check if PostgreSQL is available and responsive.""" + if asyncpg is None: + return False + + try: + conn = await asyncpg.connect( # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + host=POSTGRESQL_HOST, + port=POSTGRESQL_HOST_PORT, + user=POSTGRESQL_USER, + password=POSTGRESQL_PASSWORD, + database="postgres", + ) + await conn.close() # pyright: ignore[reportUnknownMemberType] + except Exception: + return False + else: + return True + + +class PostgreSQLFailedToStartError(Exception): + """Raised when PostgreSQL fails to start in tests.""" + + +@pytest.mark.skipif(should_skip_docker_tests(), reason="Docker is not available") +class TestPostgreSQLStore(ContextManagerStoreTestMixin, BaseStoreTests): + """Test suite for PostgreSQL store.""" + + @pytest.fixture(autouse=True, scope="session", params=POSTGRESQL_VERSIONS_TO_TEST) + async def setup_postgresql(self, request: pytest.FixtureRequest) -> AsyncGenerator[None, None]: + """Set up PostgreSQL container for testing.""" + version = request.param + + with docker_container( + f"postgresql-test-{version}", + f"postgres:{version}-alpine", + {str(POSTGRESQL_HOST_PORT): POSTGRESQL_HOST_PORT}, + environment={ + "POSTGRES_PASSWORD": POSTGRESQL_PASSWORD, + "POSTGRES_DB": POSTGRESQL_TEST_DB, + }, + ): + # Import here to avoid issues when asyncpg is not installed + from key_value.shared.stores.wait import async_wait_for_true + + if not await async_wait_for_true(bool_fn=ping_postgresql, tries=WAIT_FOR_POSTGRESQL_TIMEOUT, wait_time=1): + msg = f"PostgreSQL {version} failed to start" + raise PostgreSQLFailedToStartError(msg) + + yield + + @override + @pytest.fixture + async def store(self, setup_postgresql: None) -> PostgreSQLStore: + """Create a PostgreSQL store for testing.""" + store = PostgreSQLStore( + host=POSTGRESQL_HOST, + port=POSTGRESQL_HOST_PORT, + database=POSTGRESQL_TEST_DB, + user=POSTGRESQL_USER, + password=POSTGRESQL_PASSWORD, + ) + + # Clean up the database before each test + async with store: + if store._pool is not None: # pyright: ignore[reportPrivateUsage] + async with store._pool.acquire() as conn: # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportUnknownVariableType] + # Drop and recreate the kv_store table + with contextlib.suppress(Exception): + await conn.execute("DROP TABLE IF EXISTS kv_store") # pyright: ignore[reportUnknownMemberType] + + return store + + @pytest.fixture + async def postgresql_store(self, store: PostgreSQLStore) -> PostgreSQLStore: + """Provide the PostgreSQL store fixture.""" + return store + + @pytest.fixture + async def sanitizing_store(self, setup_postgresql: None) -> PostgreSQLStore: + """Create a PostgreSQL store with collection sanitization enabled.""" + store = PostgreSQLStore( + host=POSTGRESQL_HOST, + port=POSTGRESQL_HOST_PORT, + database=POSTGRESQL_TEST_DB, + user=POSTGRESQL_USER, + password=POSTGRESQL_PASSWORD, + table_name="kv_store_sanitizing", + collection_sanitization_strategy=PostgreSQLV1CollectionSanitizationStrategy(), + ) + + # Clean up the database before each test + async with store: + if store._pool is not None: # pyright: ignore[reportPrivateUsage] + async with store._pool.acquire() as conn: # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportUnknownVariableType] + # Drop and recreate the kv_store_sanitizing table + with contextlib.suppress(Exception): + await conn.execute("DROP TABLE IF EXISTS kv_store_sanitizing") # pyright: ignore[reportUnknownMemberType] + + return store + + @pytest.mark.skip(reason="Distributed Caches are unbounded") + @override + async def test_not_unbounded(self, store: BaseStore): ... + + @override + async def test_long_collection_name(self, store: PostgreSQLStore, sanitizing_store: PostgreSQLStore): # pyright: ignore[reportIncompatibleMethodOverride] + """Test that long collection names fail without sanitization but work with it.""" + with pytest.raises(Exception): # noqa: B017, PT011 + await store.put(collection="test_collection" * 100, key="test_key", value={"test": "test"}) + + await sanitizing_store.put(collection="test_collection" * 100, key="test_key", value={"test": "test"}) + assert await sanitizing_store.get(collection="test_collection" * 100, key="test_key") == {"test": "test"} + + @override + async def test_special_characters_in_collection_name(self, store: PostgreSQLStore, sanitizing_store: PostgreSQLStore): # pyright: ignore[reportIncompatibleMethodOverride] + """Test that special characters in collection names fail without sanitization but work with it.""" + # Without sanitization, special characters should work (PostgreSQL allows them in column values) + # but may cause issues with certain characters + await store.put(collection="test_collection", key="test_key", value={"test": "test"}) + assert await store.get(collection="test_collection", key="test_key") == {"test": "test"} + + # With sanitization, special characters should work + await sanitizing_store.put(collection="test_collection!@#$%^&*()", key="test_key", value={"test": "test"}) + assert await sanitizing_store.get(collection="test_collection!@#$%^&*()", key="test_key") == {"test": "test"} + + async def test_postgresql_collection_name_sanitization(self, sanitizing_store: PostgreSQLStore): + """Test that the V1 sanitization strategy produces expected collection names.""" + await sanitizing_store.put(collection="test_collection!@#$%^&*()", key="test_key", value={"test": "test"}) + assert await sanitizing_store.get(collection="test_collection!@#$%^&*()", key="test_key") == {"test": "test"} + + collections = await sanitizing_store.collections() + # The sanitized collection name should only contain alphanumeric characters and underscores + assert len(collections) == 1 + assert all(c.isalnum() or c in "_-" for c in collections[0]) diff --git a/scripts/build_sync_library.py b/scripts/build_sync_library.py index e1f6d9a4..b838dd2f 100644 --- a/scripts/build_sync_library.py +++ b/scripts/build_sync_library.py @@ -58,6 +58,8 @@ "key-value/key-value-aio/tests/stores/memcached", "key-value/key-value-aio/src/key_value/aio/stores/filetree", "key-value/key-value-aio/tests/stores/filetree", + "key-value/key-value-aio/src/key_value/aio/stores/postgresql", + "key-value/key-value-aio/tests/stores/postgresql", "key-value/key-value-aio/src/key_value/aio/wrappers/timeout", "key-value/key-value-aio/tests/wrappers/timeout", ] diff --git a/uv.lock b/uv.lock index 5eff3617..d267d549 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.12' and sys_platform != 'win32'", @@ -288,6 +288,49 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fe/ba/e2081de779ca30d473f21f5b30e0e737c438205440784c7dfc81efc2b029/async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c", size = 6233, upload-time = "2024-11-06T16:41:37.9Z" }, ] +[[package]] +name = "asyncpg" +version = "0.30.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "async-timeout", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2f/4c/7c991e080e106d854809030d8584e15b2e996e26f16aee6d757e387bc17d/asyncpg-0.30.0.tar.gz", hash = "sha256:c551e9928ab6707602f44811817f82ba3c446e018bfe1d3abecc8ba5f3eac851", size = 957746, upload-time = "2024-10-20T00:30:41.127Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bb/07/1650a8c30e3a5c625478fa8aafd89a8dd7d85999bf7169b16f54973ebf2c/asyncpg-0.30.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bfb4dd5ae0699bad2b233672c8fc5ccbd9ad24b89afded02341786887e37927e", size = 673143, upload-time = "2024-10-20T00:29:08.846Z" }, + { url = "https://files.pythonhosted.org/packages/a0/9a/568ff9b590d0954553c56806766914c149609b828c426c5118d4869111d3/asyncpg-0.30.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:dc1f62c792752a49f88b7e6f774c26077091b44caceb1983509edc18a2222ec0", size = 645035, upload-time = "2024-10-20T00:29:12.02Z" }, + { url = "https://files.pythonhosted.org/packages/de/11/6f2fa6c902f341ca10403743701ea952bca896fc5b07cc1f4705d2bb0593/asyncpg-0.30.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3152fef2e265c9c24eec4ee3d22b4f4d2703d30614b0b6753e9ed4115c8a146f", size = 2912384, upload-time = "2024-10-20T00:29:13.644Z" }, + { url = "https://files.pythonhosted.org/packages/83/83/44bd393919c504ffe4a82d0aed8ea0e55eb1571a1dea6a4922b723f0a03b/asyncpg-0.30.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7255812ac85099a0e1ffb81b10dc477b9973345793776b128a23e60148dd1af", size = 2947526, upload-time = "2024-10-20T00:29:15.871Z" }, + { url = "https://files.pythonhosted.org/packages/08/85/e23dd3a2b55536eb0ded80c457b0693352262dc70426ef4d4a6fc994fa51/asyncpg-0.30.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:578445f09f45d1ad7abddbff2a3c7f7c291738fdae0abffbeb737d3fc3ab8b75", size = 2895390, upload-time = "2024-10-20T00:29:19.346Z" }, + { url = "https://files.pythonhosted.org/packages/9b/26/fa96c8f4877d47dc6c1864fef5500b446522365da3d3d0ee89a5cce71a3f/asyncpg-0.30.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:c42f6bb65a277ce4d93f3fba46b91a265631c8df7250592dd4f11f8b0152150f", size = 3015630, upload-time = "2024-10-20T00:29:21.186Z" }, + { url = "https://files.pythonhosted.org/packages/34/00/814514eb9287614188a5179a8b6e588a3611ca47d41937af0f3a844b1b4b/asyncpg-0.30.0-cp310-cp310-win32.whl", hash = "sha256:aa403147d3e07a267ada2ae34dfc9324e67ccc4cdca35261c8c22792ba2b10cf", size = 568760, upload-time = "2024-10-20T00:29:22.769Z" }, + { url = "https://files.pythonhosted.org/packages/f0/28/869a7a279400f8b06dd237266fdd7220bc5f7c975348fea5d1e6909588e9/asyncpg-0.30.0-cp310-cp310-win_amd64.whl", hash = "sha256:fb622c94db4e13137c4c7f98834185049cc50ee01d8f657ef898b6407c7b9c50", size = 625764, upload-time = "2024-10-20T00:29:25.882Z" }, + { url = "https://files.pythonhosted.org/packages/4c/0e/f5d708add0d0b97446c402db7e8dd4c4183c13edaabe8a8500b411e7b495/asyncpg-0.30.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5e0511ad3dec5f6b4f7a9e063591d407eee66b88c14e2ea636f187da1dcfff6a", size = 674506, upload-time = "2024-10-20T00:29:27.988Z" }, + { url = "https://files.pythonhosted.org/packages/6a/a0/67ec9a75cb24a1d99f97b8437c8d56da40e6f6bd23b04e2f4ea5d5ad82ac/asyncpg-0.30.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:915aeb9f79316b43c3207363af12d0e6fd10776641a7de8a01212afd95bdf0ed", size = 645922, upload-time = "2024-10-20T00:29:29.391Z" }, + { url = "https://files.pythonhosted.org/packages/5c/d9/a7584f24174bd86ff1053b14bb841f9e714380c672f61c906eb01d8ec433/asyncpg-0.30.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c198a00cce9506fcd0bf219a799f38ac7a237745e1d27f0e1f66d3707c84a5a", size = 3079565, upload-time = "2024-10-20T00:29:30.832Z" }, + { url = "https://files.pythonhosted.org/packages/a0/d7/a4c0f9660e333114bdb04d1a9ac70db690dd4ae003f34f691139a5cbdae3/asyncpg-0.30.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3326e6d7381799e9735ca2ec9fd7be4d5fef5dcbc3cb555d8a463d8460607956", size = 3109962, upload-time = "2024-10-20T00:29:33.114Z" }, + { url = "https://files.pythonhosted.org/packages/3c/21/199fd16b5a981b1575923cbb5d9cf916fdc936b377e0423099f209e7e73d/asyncpg-0.30.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:51da377487e249e35bd0859661f6ee2b81db11ad1f4fc036194bc9cb2ead5056", size = 3064791, upload-time = "2024-10-20T00:29:34.677Z" }, + { url = "https://files.pythonhosted.org/packages/77/52/0004809b3427534a0c9139c08c87b515f1c77a8376a50ae29f001e53962f/asyncpg-0.30.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:bc6d84136f9c4d24d358f3b02be4b6ba358abd09f80737d1ac7c444f36108454", size = 3188696, upload-time = "2024-10-20T00:29:36.389Z" }, + { url = "https://files.pythonhosted.org/packages/52/cb/fbad941cd466117be58b774a3f1cc9ecc659af625f028b163b1e646a55fe/asyncpg-0.30.0-cp311-cp311-win32.whl", hash = "sha256:574156480df14f64c2d76450a3f3aaaf26105869cad3865041156b38459e935d", size = 567358, upload-time = "2024-10-20T00:29:37.915Z" }, + { url = "https://files.pythonhosted.org/packages/3c/0a/0a32307cf166d50e1ad120d9b81a33a948a1a5463ebfa5a96cc5606c0863/asyncpg-0.30.0-cp311-cp311-win_amd64.whl", hash = "sha256:3356637f0bd830407b5597317b3cb3571387ae52ddc3bca6233682be88bbbc1f", size = 629375, upload-time = "2024-10-20T00:29:39.987Z" }, + { url = "https://files.pythonhosted.org/packages/4b/64/9d3e887bb7b01535fdbc45fbd5f0a8447539833b97ee69ecdbb7a79d0cb4/asyncpg-0.30.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c902a60b52e506d38d7e80e0dd5399f657220f24635fee368117b8b5fce1142e", size = 673162, upload-time = "2024-10-20T00:29:41.88Z" }, + { url = "https://files.pythonhosted.org/packages/6e/eb/8b236663f06984f212a087b3e849731f917ab80f84450e943900e8ca4052/asyncpg-0.30.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:aca1548e43bbb9f0f627a04666fedaca23db0a31a84136ad1f868cb15deb6e3a", size = 637025, upload-time = "2024-10-20T00:29:43.352Z" }, + { url = "https://files.pythonhosted.org/packages/cc/57/2dc240bb263d58786cfaa60920779af6e8d32da63ab9ffc09f8312bd7a14/asyncpg-0.30.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c2a2ef565400234a633da0eafdce27e843836256d40705d83ab7ec42074efb3", size = 3496243, upload-time = "2024-10-20T00:29:44.922Z" }, + { url = "https://files.pythonhosted.org/packages/f4/40/0ae9d061d278b10713ea9021ef6b703ec44698fe32178715a501ac696c6b/asyncpg-0.30.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1292b84ee06ac8a2ad8e51c7475aa309245874b61333d97411aab835c4a2f737", size = 3575059, upload-time = "2024-10-20T00:29:46.891Z" }, + { url = "https://files.pythonhosted.org/packages/c3/75/d6b895a35a2c6506952247640178e5f768eeb28b2e20299b6a6f1d743ba0/asyncpg-0.30.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0f5712350388d0cd0615caec629ad53c81e506b1abaaf8d14c93f54b35e3595a", size = 3473596, upload-time = "2024-10-20T00:29:49.201Z" }, + { url = "https://files.pythonhosted.org/packages/c8/e7/3693392d3e168ab0aebb2d361431375bd22ffc7b4a586a0fc060d519fae7/asyncpg-0.30.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:db9891e2d76e6f425746c5d2da01921e9a16b5a71a1c905b13f30e12a257c4af", size = 3641632, upload-time = "2024-10-20T00:29:50.768Z" }, + { url = "https://files.pythonhosted.org/packages/32/ea/15670cea95745bba3f0352341db55f506a820b21c619ee66b7d12ea7867d/asyncpg-0.30.0-cp312-cp312-win32.whl", hash = "sha256:68d71a1be3d83d0570049cd1654a9bdfe506e794ecc98ad0873304a9f35e411e", size = 560186, upload-time = "2024-10-20T00:29:52.394Z" }, + { url = "https://files.pythonhosted.org/packages/7e/6b/fe1fad5cee79ca5f5c27aed7bd95baee529c1bf8a387435c8ba4fe53d5c1/asyncpg-0.30.0-cp312-cp312-win_amd64.whl", hash = "sha256:9a0292c6af5c500523949155ec17b7fe01a00ace33b68a476d6b5059f9630305", size = 621064, upload-time = "2024-10-20T00:29:53.757Z" }, + { url = "https://files.pythonhosted.org/packages/3a/22/e20602e1218dc07692acf70d5b902be820168d6282e69ef0d3cb920dc36f/asyncpg-0.30.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:05b185ebb8083c8568ea8a40e896d5f7af4b8554b64d7719c0eaa1eb5a5c3a70", size = 670373, upload-time = "2024-10-20T00:29:55.165Z" }, + { url = "https://files.pythonhosted.org/packages/3d/b3/0cf269a9d647852a95c06eb00b815d0b95a4eb4b55aa2d6ba680971733b9/asyncpg-0.30.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c47806b1a8cbb0a0db896f4cd34d89942effe353a5035c62734ab13b9f938da3", size = 634745, upload-time = "2024-10-20T00:29:57.14Z" }, + { url = "https://files.pythonhosted.org/packages/8e/6d/a4f31bf358ce8491d2a31bfe0d7bcf25269e80481e49de4d8616c4295a34/asyncpg-0.30.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b6fde867a74e8c76c71e2f64f80c64c0f3163e687f1763cfaf21633ec24ec33", size = 3512103, upload-time = "2024-10-20T00:29:58.499Z" }, + { url = "https://files.pythonhosted.org/packages/96/19/139227a6e67f407b9c386cb594d9628c6c78c9024f26df87c912fabd4368/asyncpg-0.30.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46973045b567972128a27d40001124fbc821c87a6cade040cfcd4fa8a30bcdc4", size = 3592471, upload-time = "2024-10-20T00:30:00.354Z" }, + { url = "https://files.pythonhosted.org/packages/67/e4/ab3ca38f628f53f0fd28d3ff20edff1c975dd1cb22482e0061916b4b9a74/asyncpg-0.30.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9110df111cabc2ed81aad2f35394a00cadf4f2e0635603db6ebbd0fc896f46a4", size = 3496253, upload-time = "2024-10-20T00:30:02.794Z" }, + { url = "https://files.pythonhosted.org/packages/ef/5f/0bf65511d4eeac3a1f41c54034a492515a707c6edbc642174ae79034d3ba/asyncpg-0.30.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:04ff0785ae7eed6cc138e73fc67b8e51d54ee7a3ce9b63666ce55a0bf095f7ba", size = 3662720, upload-time = "2024-10-20T00:30:04.501Z" }, + { url = "https://files.pythonhosted.org/packages/e7/31/1513d5a6412b98052c3ed9158d783b1e09d0910f51fbe0e05f56cc370bc4/asyncpg-0.30.0-cp313-cp313-win32.whl", hash = "sha256:ae374585f51c2b444510cdf3595b97ece4f233fde739aa14b50e0d64e8a7a590", size = 560404, upload-time = "2024-10-20T00:30:06.537Z" }, + { url = "https://files.pythonhosted.org/packages/c8/a4/cec76b3389c4c5ff66301cd100fe88c318563ec8a520e0b2e792b5b84972/asyncpg-0.30.0-cp313-cp313-win_amd64.whl", hash = "sha256:f59b430b8e27557c3fb9869222559f7417ced18688375825f8f12302c34e915e", size = 621623, upload-time = "2024-10-20T00:30:09.024Z" }, +] + [[package]] name = "attrs" version = "25.4.0" @@ -1806,6 +1849,9 @@ memory = [ mongodb = [ { name = "pymongo" }, ] +postgresql = [ + { name = "asyncpg" }, +] pydantic = [ { name = "pydantic" }, ] @@ -1829,7 +1875,7 @@ wrappers-encryption = [ [package.dev-dependencies] dev = [ { name = "py-key-value", extra = ["dev"] }, - { name = "py-key-value-aio", extra = ["disk", "duckdb", "dynamodb", "elasticsearch", "filetree", "keyring", "memcached", "memory", "mongodb", "pydantic", "redis", "rocksdb", "vault", "wrappers-encryption"] }, + { name = "py-key-value-aio", extra = ["disk", "duckdb", "dynamodb", "elasticsearch", "filetree", "keyring", "memcached", "memory", "mongodb", "postgresql", "pydantic", "redis", "rocksdb", "vault", "wrappers-encryption"] }, { name = "py-key-value-aio", extra = ["valkey"], marker = "sys_platform != 'win32'" }, ] @@ -1840,6 +1886,7 @@ requires-dist = [ { name = "aiohttp", marker = "extra == 'elasticsearch'", specifier = ">=3.12" }, { name = "aiomcache", marker = "extra == 'memcached'", specifier = ">=0.8.0" }, { name = "anyio", marker = "extra == 'filetree'", specifier = ">=4.4.0" }, + { name = "asyncpg", marker = "extra == 'postgresql'", specifier = ">=0.30.0" }, { name = "beartype", specifier = ">=0.20.0" }, { name = "cachetools", marker = "extra == 'memory'", specifier = ">=5.0.0" }, { name = "cryptography", marker = "extra == 'wrappers-encryption'", specifier = ">=45.0.0" }, @@ -1862,13 +1909,13 @@ requires-dist = [ { name = "types-hvac", marker = "extra == 'vault'", specifier = ">=2.3.0" }, { name = "valkey-glide", marker = "extra == 'valkey'", specifier = ">=2.1.0" }, ] -provides-extras = ["memory", "disk", "filetree", "redis", "mongodb", "valkey", "vault", "memcached", "elasticsearch", "dynamodb", "keyring", "keyring-linux", "pydantic", "rocksdb", "duckdb", "wrappers-encryption"] +provides-extras = ["memory", "disk", "filetree", "redis", "mongodb", "valkey", "vault", "memcached", "elasticsearch", "dynamodb", "keyring", "keyring-linux", "pydantic", "rocksdb", "duckdb", "postgresql", "wrappers-encryption"] [package.metadata.requires-dev] dev = [ { name = "py-key-value", extras = ["dev"], editable = "." }, { name = "py-key-value-aio", extras = ["keyring"] }, - { name = "py-key-value-aio", extras = ["memory", "disk", "filetree", "redis", "elasticsearch", "memcached", "mongodb", "vault", "dynamodb", "rocksdb", "duckdb"] }, + { name = "py-key-value-aio", extras = ["memory", "disk", "filetree", "redis", "elasticsearch", "memcached", "mongodb", "vault", "dynamodb", "rocksdb", "duckdb", "postgresql"] }, { name = "py-key-value-aio", extras = ["pydantic"] }, { name = "py-key-value-aio", extras = ["valkey"], marker = "sys_platform != 'win32'" }, { name = "py-key-value-aio", extras = ["wrappers-encryption"] }, From 38a059dea14c387061997e693d08e9d6bb74068c Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Mon, 10 Nov 2025 14:53:04 +0000 Subject: [PATCH 2/2] refactor: optimize PostgreSQL store to use one-time table setup - Move table/index creation from _setup_collection to _setup (called once) - Remove collection sanitization since collection names are column values - Remove PostgreSQLV1CollectionSanitizationStrategy class and exports - Update tests to verify collection names work without restrictions - Update docstrings to clarify collections are stored as values This simplifies the implementation since all collections share a single table, eliminating unnecessary per-collection setup overhead. Co-authored-by: William Easton --- .../aio/stores/postgresql/__init__.py | 4 +- .../key_value/aio/stores/postgresql/store.py | 79 ++++--------------- .../stores/postgresql/test_postgresql.py | 72 ++++------------- 3 files changed, 35 insertions(+), 120 deletions(-) diff --git a/key-value/key-value-aio/src/key_value/aio/stores/postgresql/__init__.py b/key-value/key-value-aio/src/key_value/aio/stores/postgresql/__init__.py index 0af3ae48..da09c87a 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/postgresql/__init__.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/postgresql/__init__.py @@ -1,9 +1,9 @@ """PostgreSQL store for py-key-value-aio.""" try: - from key_value.aio.stores.postgresql.store import PostgreSQLStore, PostgreSQLV1CollectionSanitizationStrategy + from key_value.aio.stores.postgresql.store import PostgreSQLStore except ImportError as e: msg = 'PostgreSQLStore requires the "postgresql" extra. Install via: pip install "py-key-value-aio[postgresql]"' raise ImportError(msg) from e -__all__ = ["PostgreSQLStore", "PostgreSQLV1CollectionSanitizationStrategy"] +__all__ = ["PostgreSQLStore"] diff --git a/key-value/key-value-aio/src/key_value/aio/stores/postgresql/store.py b/key-value/key-value-aio/src/key_value/aio/stores/postgresql/store.py index ad6620c6..279d4ae4 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/postgresql/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/postgresql/store.py @@ -12,8 +12,6 @@ from typing import Any, overload from key_value.shared.utils.managed_entry import ManagedEntry -from key_value.shared.utils.sanitization import HybridSanitizationStrategy, SanitizationStrategy -from key_value.shared.utils.sanitize import ALPHANUMERIC_CHARACTERS from typing_extensions import Self, override from key_value.aio.stores.base import BaseContextManagerStore, BaseDestroyCollectionStore, BaseEnumerateCollectionsStore, BaseStore @@ -34,31 +32,15 @@ PAGE_LIMIT = 10000 # PostgreSQL table name length limit is 63 characters -# Use 200 for consistency with MongoDB -MAX_COLLECTION_LENGTH = 200 POSTGRES_MAX_IDENTIFIER_LEN = 63 -COLLECTION_ALLOWED_CHARACTERS = ALPHANUMERIC_CHARACTERS + "_" - - -class PostgreSQLV1CollectionSanitizationStrategy(HybridSanitizationStrategy): - def __init__(self) -> None: - super().__init__( - replacement_character="_", - max_length=MAX_COLLECTION_LENGTH, - allowed_characters=COLLECTION_ALLOWED_CHARACTERS, - ) class PostgreSQLStore(BaseEnumerateCollectionsStore, BaseDestroyCollectionStore, BaseContextManagerStore, BaseStore): """PostgreSQL-based key-value store using asyncpg. - This store uses a single table with columns for collection, key, value (JSONB), and metadata. - Collections are stored as a column value rather than separate tables. - - By default, collections are not sanitized. This means that there are character and length restrictions on - collection names that may cause errors when trying to get and put entries. - - To avoid issues, you may want to consider leveraging the `PostgreSQLV1CollectionSanitizationStrategy` strategy. + This store uses a single shared table with columns for collection, key, value (JSONB), and metadata. + Collections are stored as values in the collection column, not as separate tables or SQL identifiers, + so there are no character restrictions on collection names. Example: Basic usage with default connection: @@ -99,7 +81,6 @@ def __init__( pool: asyncpg.Pool, # type: ignore[type-arg] table_name: str | None = None, default_collection: str | None = None, - collection_sanitization_strategy: SanitizationStrategy | None = None, ) -> None: """Initialize the PostgreSQL store with an existing connection pool. @@ -107,7 +88,6 @@ def __init__( pool: An existing asyncpg connection pool to use. table_name: The name of the table to use for storage (default: kv_store). default_collection: The default collection to use if no collection is provided. - collection_sanitization_strategy: The sanitization strategy to use for collections. """ @overload @@ -117,7 +97,6 @@ def __init__( url: str, table_name: str | None = None, default_collection: str | None = None, - collection_sanitization_strategy: SanitizationStrategy | None = None, ) -> None: """Initialize the PostgreSQL store with a connection URL. @@ -125,7 +104,6 @@ def __init__( url: PostgreSQL connection URL (e.g., postgresql://user:pass@localhost/dbname). table_name: The name of the table to use for storage (default: kv_store). default_collection: The default collection to use if no collection is provided. - collection_sanitization_strategy: The sanitization strategy to use for collections. """ @overload @@ -139,7 +117,6 @@ def __init__( password: str | None = None, table_name: str | None = None, default_collection: str | None = None, - collection_sanitization_strategy: SanitizationStrategy | None = None, ) -> None: """Initialize the PostgreSQL store with connection parameters. @@ -151,7 +128,6 @@ def __init__( password: Database password (default: None). table_name: The name of the table to use for storage (default: kv_store). default_collection: The default collection to use if no collection is provided. - collection_sanitization_strategy: The sanitization strategy to use for collections. """ def __init__( @@ -166,7 +142,6 @@ def __init__( password: str | None = None, table_name: str | None = None, default_collection: str | None = None, - collection_sanitization_strategy: SanitizationStrategy | None = None, ) -> None: """Initialize the PostgreSQL store.""" self._pool = pool @@ -178,7 +153,7 @@ def __init__( self._user = user self._password = password - # Validate and sanitize table name to prevent SQL injection and invalid identifiers + # Validate table name to prevent SQL injection and invalid identifiers table_name = table_name or DEFAULT_TABLE if not table_name.replace("_", "").isalnum(): msg = f"Table name must be alphanumeric (with underscores): {table_name}" @@ -192,10 +167,7 @@ def __init__( raise ValueError(msg) self._table_name = table_name - super().__init__( - default_collection=default_collection, - collection_sanitization_strategy=collection_sanitization_strategy, - ) + super().__init__(default_collection=default_collection) def _ensure_pool_initialized(self) -> asyncpg.Pool: # type: ignore[type-arg] """Ensure the connection pool is initialized. @@ -250,14 +222,12 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: # await self._pool.close() @override - async def _setup_collection(self, *, collection: str) -> None: + async def _setup(self) -> None: """Set up the database table and indexes if they don't exist. - Args: - collection: The collection name (used for validation, but all collections share the same table). + This is called once when the store is first used. Since all collections share the same table, + we only need to set up the schema once. """ - _ = self._sanitize_collection(collection=collection) - # Create the main table if it doesn't exist table_sql = ( f"CREATE TABLE IF NOT EXISTS {self._table_name} (" @@ -295,12 +265,10 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry Returns: The managed entry if found and not expired, None otherwise. """ - sanitized_collection = self._sanitize_collection(collection=collection) - async with self._acquire_connection() as conn: row = await conn.fetchrow( # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] f"SELECT value, ttl, created_at, expires_at FROM {self._table_name} WHERE collection = $1 AND key = $2", - sanitized_collection, + collection, key, ) @@ -318,7 +286,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry if managed_entry.is_expired: await conn.execute( # pyright: ignore[reportUnknownMemberType] f"DELETE FROM {self._table_name} WHERE collection = $1 AND key = $2", - sanitized_collection, + collection, key, ) return None @@ -339,13 +307,11 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> if not keys: return [] - sanitized_collection = self._sanitize_collection(collection=collection) - async with self._acquire_connection() as conn: # Use ANY to query for multiple keys rows = await conn.fetch( # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] f"SELECT key, value, ttl, created_at, expires_at FROM {self._table_name} WHERE collection = $1 AND key = ANY($2::text[])", - sanitized_collection, + collection, list(keys), ) @@ -370,7 +336,7 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> if expired_keys: await conn.execute( # pyright: ignore[reportUnknownMemberType] f"DELETE FROM {self._table_name} WHERE collection = $1 AND key = ANY($2::text[])", - sanitized_collection, + collection, expired_keys, ) @@ -391,7 +357,6 @@ async def _put_managed_entry( collection: The collection to store in. managed_entry: The managed entry to store. """ - sanitized_collection = self._sanitize_collection(collection=collection) async with self._acquire_connection() as conn: upsert_sql = ( @@ -403,7 +368,7 @@ async def _put_managed_entry( ) await conn.execute( # pyright: ignore[reportUnknownMemberType] upsert_sql, - sanitized_collection, + collection, key, managed_entry.value, managed_entry.ttl, @@ -435,12 +400,8 @@ async def _put_managed_entries( if not keys: return - sanitized_collection = self._sanitize_collection(collection=collection) - # Prepare data for batch insert using method-level ttl/created_at/expires_at - values = [ - (sanitized_collection, key, entry.value, ttl, created_at, expires_at) for key, entry in zip(keys, managed_entries, strict=True) - ] + values = [(collection, key, entry.value, ttl, created_at, expires_at) for key, entry in zip(keys, managed_entries, strict=True)] async with self._acquire_connection() as conn: # Use executemany for batch insert @@ -467,12 +428,10 @@ async def _delete_managed_entry(self, *, key: str, collection: str) -> bool: Returns: True if the entry was deleted, False if it didn't exist. """ - sanitized_collection = self._sanitize_collection(collection=collection) - async with self._acquire_connection() as conn: result = await conn.execute( # pyright: ignore[reportUnknownMemberType] f"DELETE FROM {self._table_name} WHERE collection = $1 AND key = $2", - sanitized_collection, + collection, key, ) # PostgreSQL execute returns a string like "DELETE N" where N is the number of rows deleted @@ -492,12 +451,10 @@ async def _delete_managed_entries(self, *, keys: Sequence[str], collection: str) if not keys: return 0 - sanitized_collection = self._sanitize_collection(collection=collection) - async with self._acquire_connection() as conn: result = await conn.execute( # pyright: ignore[reportUnknownMemberType] f"DELETE FROM {self._table_name} WHERE collection = $1 AND key = ANY($2::text[])", - sanitized_collection, + collection, list(keys), ) # PostgreSQL execute returns a string like "DELETE N" where N is the number of rows deleted @@ -535,12 +492,10 @@ async def _delete_collection(self, *, collection: str) -> bool: Returns: True if any entries were deleted, False otherwise. """ - sanitized_collection = self._sanitize_collection(collection=collection) - async with self._acquire_connection() as conn: result = await conn.execute( # pyright: ignore[reportUnknownMemberType] f"DELETE FROM {self._table_name} WHERE collection = $1", - sanitized_collection, + collection, ) # Return True if any rows were deleted return result.split()[-1] != "0" diff --git a/key-value/key-value-aio/tests/stores/postgresql/test_postgresql.py b/key-value/key-value-aio/tests/stores/postgresql/test_postgresql.py index 592173ad..b93ea844 100644 --- a/key-value/key-value-aio/tests/stores/postgresql/test_postgresql.py +++ b/key-value/key-value-aio/tests/stores/postgresql/test_postgresql.py @@ -7,7 +7,7 @@ from typing_extensions import override from key_value.aio.stores.base import BaseStore -from key_value.aio.stores.postgresql import PostgreSQLStore, PostgreSQLV1CollectionSanitizationStrategy +from key_value.aio.stores.postgresql import PostgreSQLStore from tests.conftest import docker_container, should_skip_docker_tests from tests.stores.base import BaseStoreTests, ContextManagerStoreTestMixin @@ -104,65 +104,25 @@ async def store(self, setup_postgresql: None) -> PostgreSQLStore: return store - @pytest.fixture - async def postgresql_store(self, store: PostgreSQLStore) -> PostgreSQLStore: - """Provide the PostgreSQL store fixture.""" - return store - - @pytest.fixture - async def sanitizing_store(self, setup_postgresql: None) -> PostgreSQLStore: - """Create a PostgreSQL store with collection sanitization enabled.""" - store = PostgreSQLStore( - host=POSTGRESQL_HOST, - port=POSTGRESQL_HOST_PORT, - database=POSTGRESQL_TEST_DB, - user=POSTGRESQL_USER, - password=POSTGRESQL_PASSWORD, - table_name="kv_store_sanitizing", - collection_sanitization_strategy=PostgreSQLV1CollectionSanitizationStrategy(), - ) - - # Clean up the database before each test - async with store: - if store._pool is not None: # pyright: ignore[reportPrivateUsage] - async with store._pool.acquire() as conn: # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportUnknownVariableType] - # Drop and recreate the kv_store_sanitizing table - with contextlib.suppress(Exception): - await conn.execute("DROP TABLE IF EXISTS kv_store_sanitizing") # pyright: ignore[reportUnknownMemberType] - - return store - @pytest.mark.skip(reason="Distributed Caches are unbounded") @override async def test_not_unbounded(self, store: BaseStore): ... @override - async def test_long_collection_name(self, store: PostgreSQLStore, sanitizing_store: PostgreSQLStore): # pyright: ignore[reportIncompatibleMethodOverride] - """Test that long collection names fail without sanitization but work with it.""" - with pytest.raises(Exception): # noqa: B017, PT011 - await store.put(collection="test_collection" * 100, key="test_key", value={"test": "test"}) - - await sanitizing_store.put(collection="test_collection" * 100, key="test_key", value={"test": "test"}) - assert await sanitizing_store.get(collection="test_collection" * 100, key="test_key") == {"test": "test"} + async def test_long_collection_name(self, store: PostgreSQLStore): # pyright: ignore[reportIncompatibleMethodOverride] + """Test that long collection names work since they're just column values.""" + # Long collection names should work fine since they're stored as column values, not SQL identifiers + long_collection = "test_collection" * 100 + await store.put(collection=long_collection, key="test_key", value={"test": "test"}) + assert await store.get(collection=long_collection, key="test_key") == {"test": "test"} @override - async def test_special_characters_in_collection_name(self, store: PostgreSQLStore, sanitizing_store: PostgreSQLStore): # pyright: ignore[reportIncompatibleMethodOverride] - """Test that special characters in collection names fail without sanitization but work with it.""" - # Without sanitization, special characters should work (PostgreSQL allows them in column values) - # but may cause issues with certain characters - await store.put(collection="test_collection", key="test_key", value={"test": "test"}) - assert await store.get(collection="test_collection", key="test_key") == {"test": "test"} - - # With sanitization, special characters should work - await sanitizing_store.put(collection="test_collection!@#$%^&*()", key="test_key", value={"test": "test"}) - assert await sanitizing_store.get(collection="test_collection!@#$%^&*()", key="test_key") == {"test": "test"} - - async def test_postgresql_collection_name_sanitization(self, sanitizing_store: PostgreSQLStore): - """Test that the V1 sanitization strategy produces expected collection names.""" - await sanitizing_store.put(collection="test_collection!@#$%^&*()", key="test_key", value={"test": "test"}) - assert await sanitizing_store.get(collection="test_collection!@#$%^&*()", key="test_key") == {"test": "test"} - - collections = await sanitizing_store.collections() - # The sanitized collection name should only contain alphanumeric characters and underscores - assert len(collections) == 1 - assert all(c.isalnum() or c in "_-" for c in collections[0]) + async def test_special_characters_in_collection_name(self, store: PostgreSQLStore): # pyright: ignore[reportIncompatibleMethodOverride] + """Test that special characters in collection names work since they're just column values.""" + # Special characters should work fine since collection names are stored as column values + await store.put(collection="test_collection!@#$%^&*()", key="test_key", value={"test": "test"}) + assert await store.get(collection="test_collection!@#$%^&*()", key="test_key") == {"test": "test"} + + # Verify the collection name is stored as-is + collections = await store.collections() + assert "test_collection!@#$%^&*()" in collections