44This is safe because table names are validated in __init__ to be alphanumeric (plus underscores).
55"""
66
7- from collections .abc import Sequence
7+ from collections .abc import AsyncIterator , Sequence
8+ from contextlib import asynccontextmanager
89from datetime import datetime
910from typing import Any , overload
1011
@@ -62,7 +63,7 @@ class PostgreSQLStore(BaseEnumerateCollectionsStore, BaseDestroyCollectionStore,
6263 ... )
6364 """
6465
65- _pool : asyncpg .Pool [ asyncpg . Record ] | None
66+ _pool : asyncpg .Pool | None # type: ignore[type-arg]
6667 _url : str | None
6768 _host : str
6869 _port : int
@@ -75,7 +76,7 @@ class PostgreSQLStore(BaseEnumerateCollectionsStore, BaseDestroyCollectionStore,
7576 def __init__ (
7677 self ,
7778 * ,
78- pool : asyncpg .Pool [ asyncpg . Record ],
79+ pool : asyncpg .Pool , # type: ignore[type-arg]
7980 table_name : str | None = None ,
8081 default_collection : str | None = None ,
8182 ) -> None :
@@ -130,7 +131,7 @@ def __init__(
130131 def __init__ (
131132 self ,
132133 * ,
133- pool : asyncpg .Pool [ asyncpg . Record ] | None = None ,
134+ pool : asyncpg .Pool | None = None , # type: ignore[type-arg]
134135 url : str | None = None ,
135136 host : str = DEFAULT_HOST ,
136137 port : int = DEFAULT_PORT ,
@@ -158,13 +159,41 @@ def __init__(
158159
159160 super ().__init__ (default_collection = default_collection )
160161
162+ def _ensure_pool_initialized (self ) -> asyncpg .Pool : # type: ignore[type-arg]
163+ """Ensure the connection pool is initialized.
164+
165+ Returns:
166+ The initialized connection pool.
167+
168+ Raises:
169+ RuntimeError: If the pool is not initialized.
170+ """
171+ if self ._pool is None :
172+ msg = "Pool is not initialized. Use async with or call __aenter__() first."
173+ raise RuntimeError (msg )
174+ return self ._pool
175+
176+ @asynccontextmanager
177+ async def _acquire_connection (self ) -> AsyncIterator [asyncpg .Connection ]: # type: ignore[type-arg]
178+ """Acquire a connection from the pool.
179+
180+ Yields:
181+ A connection from the pool.
182+
183+ Raises:
184+ RuntimeError: If the pool is not initialized.
185+ """
186+ pool = self ._ensure_pool_initialized ()
187+ async with pool .acquire () as conn : # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
188+ yield conn
189+
161190 @override
162191 async def __aenter__ (self ) -> Self :
163192 if self ._pool is None :
164193 if self ._url :
165- self ._pool = await asyncpg .create_pool (self ._url )
194+ self ._pool = await asyncpg .create_pool (self ._url ) # pyright: ignore[reportUnknownMemberType]
166195 else :
167- self ._pool = await asyncpg .create_pool (
196+ self ._pool = await asyncpg .create_pool ( # pyright: ignore[reportUnknownMemberType]
168197 host = self ._host ,
169198 port = self ._port ,
170199 database = self ._database ,
@@ -201,10 +230,6 @@ async def _setup_collection(self, *, collection: str) -> None:
201230 """
202231 _ = self ._sanitize_collection_name (collection = collection )
203232
204- if self ._pool is None :
205- msg = "Pool is not initialized. Use async with or call __aenter__() first."
206- raise RuntimeError (msg )
207-
208233 # Create the main table if it doesn't exist
209234 create_table_sql = f""" # noqa: S608
210235 CREATE TABLE IF NOT EXISTS { self ._table_name } (
@@ -225,9 +250,9 @@ async def _setup_collection(self, *, collection: str) -> None:
225250 WHERE expires_at IS NOT NULL
226251 """
227252
228- async with self ._pool . acquire () as conn :
229- await conn .execute (create_table_sql )
230- await conn .execute (create_index_sql )
253+ async with self ._acquire_connection () as conn :
254+ await conn .execute (create_table_sql ) # pyright: ignore[reportUnknownMemberType]
255+ await conn .execute (create_index_sql ) # pyright: ignore[reportUnknownMemberType]
231256
232257 @override
233258 async def _get_managed_entry (self , * , key : str , collection : str ) -> ManagedEntry | None :
@@ -242,12 +267,8 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
242267 """
243268 sanitized_collection = self ._sanitize_collection_name (collection = collection )
244269
245- if self ._pool is None :
246- msg = "Pool is not initialized. Use async with or call __aenter__() first."
247- raise RuntimeError (msg )
248-
249- async with self ._pool .acquire () as conn :
250- row = await conn .fetchrow (
270+ async with self ._acquire_connection () as conn :
271+ row = await conn .fetchrow ( # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
251272 f"SELECT value, ttl, created_at, expires_at FROM { self ._table_name } WHERE collection = $1 AND key = $2" , # noqa: S608
252273 sanitized_collection ,
253274 key ,
@@ -258,15 +279,15 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
258279
259280 # Parse the managed entry
260281 managed_entry = ManagedEntry (
261- value = row ["value" ],
262- ttl = row ["ttl" ],
263- created_at = row ["created_at" ],
264- expires_at = row ["expires_at" ],
282+ value = row ["value" ], # pyright: ignore[reportUnknownArgumentType]
283+ ttl = row ["ttl" ], # pyright: ignore[reportUnknownArgumentType]
284+ created_at = row ["created_at" ], # pyright: ignore[reportUnknownArgumentType]
285+ expires_at = row ["expires_at" ], # pyright: ignore[reportUnknownArgumentType]
265286 )
266287
267288 # Check if expired and delete if so
268289 if managed_entry .is_expired :
269- await conn .execute (
290+ await conn .execute ( # pyright: ignore[reportUnknownMemberType]
270291 f"DELETE FROM { self ._table_name } WHERE collection = $1 AND key = $2" , # noqa: S608
271292 sanitized_collection ,
272293 key ,
@@ -291,13 +312,9 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) ->
291312
292313 sanitized_collection = self ._sanitize_collection_name (collection = collection )
293314
294- if self ._pool is None :
295- msg = "Pool is not initialized. Use async with or call __aenter__() first."
296- raise RuntimeError (msg )
297-
298- async with self ._pool .acquire () as conn :
315+ async with self ._acquire_connection () as conn :
299316 # Use ANY to query for multiple keys
300- rows = await conn .fetch (
317+ rows = await conn .fetch ( # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
301318 f"SELECT key, value, ttl, created_at, expires_at FROM { self ._table_name } WHERE collection = $1 AND key = ANY($2::text[])" , # noqa: S608
302319 sanitized_collection ,
303320 list (keys ),
@@ -307,23 +324,23 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) ->
307324 entries_by_key : dict [str , ManagedEntry | None ] = dict .fromkeys (keys )
308325 expired_keys : list [str ] = []
309326
310- for row in rows :
327+ for row in rows : # pyright: ignore[reportUnknownVariableType]
311328 managed_entry = ManagedEntry (
312- value = row ["value" ],
313- ttl = row ["ttl" ],
314- created_at = row ["created_at" ],
315- expires_at = row ["expires_at" ],
329+ value = row ["value" ], # pyright: ignore[reportUnknownArgumentType]
330+ ttl = row ["ttl" ], # pyright: ignore[reportUnknownArgumentType]
331+ created_at = row ["created_at" ], # pyright: ignore[reportUnknownArgumentType]
332+ expires_at = row ["expires_at" ], # pyright: ignore[reportUnknownArgumentType]
316333 )
317334
318335 if managed_entry .is_expired :
319- expired_keys .append (row ["key" ])
336+ expired_keys .append (row ["key" ]) # pyright: ignore[reportUnknownArgumentType]
320337 entries_by_key [row ["key" ]] = None
321338 else :
322339 entries_by_key [row ["key" ]] = managed_entry
323340
324341 # Delete expired entries in batch
325342 if expired_keys :
326- await conn .execute (
343+ await conn .execute ( # pyright: ignore[reportUnknownMemberType]
327344 f"DELETE FROM { self ._table_name } WHERE collection = $1 AND key = ANY($2::text[])" , # noqa: S608
328345 sanitized_collection ,
329346 expired_keys ,
@@ -348,12 +365,8 @@ async def _put_managed_entry(
348365 """
349366 sanitized_collection = self ._sanitize_collection_name (collection = collection )
350367
351- if self ._pool is None :
352- msg = "Pool is not initialized. Use async with or call __aenter__() first."
353- raise RuntimeError (msg )
354-
355- async with self ._pool .acquire () as conn :
356- await conn .execute (
368+ async with self ._acquire_connection () as conn :
369+ await conn .execute ( # pyright: ignore[reportUnknownMemberType]
357370 f"""
358371 INSERT INTO { self ._table_name } (collection, key, value, ttl, created_at, expires_at)
359372 VALUES ($1, $2, $3, $4, $5, $6)
@@ -398,19 +411,15 @@ async def _put_managed_entries(
398411
399412 sanitized_collection = self ._sanitize_collection_name (collection = collection )
400413
401- if self ._pool is None :
402- msg = "Pool is not initialized. Use async with or call __aenter__() first."
403- raise RuntimeError (msg )
404-
405414 # Prepare data for batch insert
406415 values = [
407416 (sanitized_collection , key , entry .value , entry .ttl , entry .created_at , entry .expires_at )
408417 for key , entry in zip (keys , managed_entries , strict = True )
409418 ]
410419
411- async with self ._pool . acquire () as conn :
420+ async with self ._acquire_connection () as conn :
412421 # Use executemany for batch insert
413- await conn .executemany (
422+ await conn .executemany ( # pyright: ignore[reportUnknownMemberType]
414423 f"""
415424 INSERT INTO { self ._table_name } (collection, key, value, ttl, created_at, expires_at)
416425 VALUES ($1, $2, $3, $4, $5, $6)
@@ -437,12 +446,8 @@ async def _delete_managed_entry(self, *, key: str, collection: str) -> bool:
437446 """
438447 sanitized_collection = self ._sanitize_collection_name (collection = collection )
439448
440- if self ._pool is None :
441- msg = "Pool is not initialized. Use async with or call __aenter__() first."
442- raise RuntimeError (msg )
443-
444- async with self ._pool .acquire () as conn :
445- result = await conn .execute (
449+ async with self ._acquire_connection () as conn :
450+ result = await conn .execute ( # pyright: ignore[reportUnknownMemberType]
446451 f"DELETE FROM { self ._table_name } WHERE collection = $1 AND key = $2" , # noqa: S608
447452 sanitized_collection ,
448453 key ,
@@ -466,12 +471,8 @@ async def _delete_managed_entries(self, *, keys: Sequence[str], collection: str)
466471
467472 sanitized_collection = self ._sanitize_collection_name (collection = collection )
468473
469- if self ._pool is None :
470- msg = "Pool is not initialized. Use async with or call __aenter__() first."
471- raise RuntimeError (msg )
472-
473- async with self ._pool .acquire () as conn :
474- result = await conn .execute (
474+ async with self ._acquire_connection () as conn :
475+ result = await conn .execute ( # pyright: ignore[reportUnknownMemberType]
475476 f"DELETE FROM { self ._table_name } WHERE collection = $1 AND key = ANY($2::text[])" , # noqa: S608
476477 sanitized_collection ,
477478 list (keys ),
@@ -491,17 +492,13 @@ async def _get_collection_names(self, *, limit: int | None = None) -> list[str]:
491492 """
492493 limit = min (limit or DEFAULT_PAGE_SIZE , PAGE_LIMIT )
493494
494- if self ._pool is None :
495- msg = "Pool is not initialized. Use async with or call __aenter__() first."
496- raise RuntimeError (msg )
497-
498- async with self ._pool .acquire () as conn :
499- rows = await conn .fetch (
495+ async with self ._acquire_connection () as conn :
496+ rows = await conn .fetch ( # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
500497 f"SELECT DISTINCT collection FROM { self ._table_name } ORDER BY collection LIMIT $1" , # noqa: S608
501498 limit ,
502499 )
503500
504- return [row ["collection" ] for row in rows ]
501+ return [row ["collection" ] for row in rows ] # pyright: ignore[reportUnknownVariableType]
505502
506503 @override
507504 async def _delete_collection (self , * , collection : str ) -> bool :
@@ -515,12 +512,8 @@ async def _delete_collection(self, *, collection: str) -> bool:
515512 """
516513 sanitized_collection = self ._sanitize_collection_name (collection = collection )
517514
518- if self ._pool is None :
519- msg = "Pool is not initialized. Use async with or call __aenter__() first."
520- raise RuntimeError (msg )
521-
522- async with self ._pool .acquire () as conn :
523- result = await conn .execute (
515+ async with self ._acquire_connection () as conn :
516+ result = await conn .execute ( # pyright: ignore[reportUnknownMemberType]
524517 f"DELETE FROM { self ._table_name } WHERE collection = $1" , # noqa: S608
525518 sanitized_collection ,
526519 )
0 commit comments