Skip to content

Commit a10278f

Browse files
Refactor PostgreSQL store to use helper methods for pool management
- Add _ensure_pool_initialized() helper to centralize pool initialization checks - Add _acquire_connection() async context manager for acquiring connections - Remove 7 duplicate pool initialization checks across all methods - Fix linting issues in test file (S105, TRY300) - Add pyright type ignores for asyncpg's incomplete type stubs This refactoring improves code maintainability by following DRY principles and centralizing pool management logic. Co-authored-by: William Easton <strawgate@users.noreply.github.com>
1 parent 4249b1b commit a10278f

File tree

2 files changed

+74
-80
lines changed

2 files changed

+74
-80
lines changed

key-value/key-value-aio/src/key_value/aio/stores/postgresql/store.py

Lines changed: 67 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
This is safe because table names are validated in __init__ to be alphanumeric (plus underscores).
55
"""
66

7-
from collections.abc import Sequence
7+
from collections.abc import AsyncIterator, Sequence
8+
from contextlib import asynccontextmanager
89
from datetime import datetime
910
from typing import Any, overload
1011

@@ -62,7 +63,7 @@ class PostgreSQLStore(BaseEnumerateCollectionsStore, BaseDestroyCollectionStore,
6263
... )
6364
"""
6465

65-
_pool: asyncpg.Pool[asyncpg.Record] | None
66+
_pool: asyncpg.Pool | None # type: ignore[type-arg]
6667
_url: str | None
6768
_host: str
6869
_port: int
@@ -75,7 +76,7 @@ class PostgreSQLStore(BaseEnumerateCollectionsStore, BaseDestroyCollectionStore,
7576
def __init__(
7677
self,
7778
*,
78-
pool: asyncpg.Pool[asyncpg.Record],
79+
pool: asyncpg.Pool, # type: ignore[type-arg]
7980
table_name: str | None = None,
8081
default_collection: str | None = None,
8182
) -> None:
@@ -130,7 +131,7 @@ def __init__(
130131
def __init__(
131132
self,
132133
*,
133-
pool: asyncpg.Pool[asyncpg.Record] | None = None,
134+
pool: asyncpg.Pool | None = None, # type: ignore[type-arg]
134135
url: str | None = None,
135136
host: str = DEFAULT_HOST,
136137
port: int = DEFAULT_PORT,
@@ -158,13 +159,41 @@ def __init__(
158159

159160
super().__init__(default_collection=default_collection)
160161

162+
def _ensure_pool_initialized(self) -> asyncpg.Pool: # type: ignore[type-arg]
163+
"""Ensure the connection pool is initialized.
164+
165+
Returns:
166+
The initialized connection pool.
167+
168+
Raises:
169+
RuntimeError: If the pool is not initialized.
170+
"""
171+
if self._pool is None:
172+
msg = "Pool is not initialized. Use async with or call __aenter__() first."
173+
raise RuntimeError(msg)
174+
return self._pool
175+
176+
@asynccontextmanager
177+
async def _acquire_connection(self) -> AsyncIterator[asyncpg.Connection]: # type: ignore[type-arg]
178+
"""Acquire a connection from the pool.
179+
180+
Yields:
181+
A connection from the pool.
182+
183+
Raises:
184+
RuntimeError: If the pool is not initialized.
185+
"""
186+
pool = self._ensure_pool_initialized()
187+
async with pool.acquire() as conn: # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
188+
yield conn
189+
161190
@override
162191
async def __aenter__(self) -> Self:
163192
if self._pool is None:
164193
if self._url:
165-
self._pool = await asyncpg.create_pool(self._url)
194+
self._pool = await asyncpg.create_pool(self._url) # pyright: ignore[reportUnknownMemberType]
166195
else:
167-
self._pool = await asyncpg.create_pool(
196+
self._pool = await asyncpg.create_pool( # pyright: ignore[reportUnknownMemberType]
168197
host=self._host,
169198
port=self._port,
170199
database=self._database,
@@ -201,10 +230,6 @@ async def _setup_collection(self, *, collection: str) -> None:
201230
"""
202231
_ = self._sanitize_collection_name(collection=collection)
203232

204-
if self._pool is None:
205-
msg = "Pool is not initialized. Use async with or call __aenter__() first."
206-
raise RuntimeError(msg)
207-
208233
# Create the main table if it doesn't exist
209234
create_table_sql = f""" # noqa: S608
210235
CREATE TABLE IF NOT EXISTS {self._table_name} (
@@ -225,9 +250,9 @@ async def _setup_collection(self, *, collection: str) -> None:
225250
WHERE expires_at IS NOT NULL
226251
"""
227252

228-
async with self._pool.acquire() as conn:
229-
await conn.execute(create_table_sql)
230-
await conn.execute(create_index_sql)
253+
async with self._acquire_connection() as conn:
254+
await conn.execute(create_table_sql) # pyright: ignore[reportUnknownMemberType]
255+
await conn.execute(create_index_sql) # pyright: ignore[reportUnknownMemberType]
231256

232257
@override
233258
async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | None:
@@ -242,12 +267,8 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
242267
"""
243268
sanitized_collection = self._sanitize_collection_name(collection=collection)
244269

245-
if self._pool is None:
246-
msg = "Pool is not initialized. Use async with or call __aenter__() first."
247-
raise RuntimeError(msg)
248-
249-
async with self._pool.acquire() as conn:
250-
row = await conn.fetchrow(
270+
async with self._acquire_connection() as conn:
271+
row = await conn.fetchrow( # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
251272
f"SELECT value, ttl, created_at, expires_at FROM {self._table_name} WHERE collection = $1 AND key = $2", # noqa: S608
252273
sanitized_collection,
253274
key,
@@ -258,15 +279,15 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
258279

259280
# Parse the managed entry
260281
managed_entry = ManagedEntry(
261-
value=row["value"],
262-
ttl=row["ttl"],
263-
created_at=row["created_at"],
264-
expires_at=row["expires_at"],
282+
value=row["value"], # pyright: ignore[reportUnknownArgumentType]
283+
ttl=row["ttl"], # pyright: ignore[reportUnknownArgumentType]
284+
created_at=row["created_at"], # pyright: ignore[reportUnknownArgumentType]
285+
expires_at=row["expires_at"], # pyright: ignore[reportUnknownArgumentType]
265286
)
266287

267288
# Check if expired and delete if so
268289
if managed_entry.is_expired:
269-
await conn.execute(
290+
await conn.execute( # pyright: ignore[reportUnknownMemberType]
270291
f"DELETE FROM {self._table_name} WHERE collection = $1 AND key = $2", # noqa: S608
271292
sanitized_collection,
272293
key,
@@ -291,13 +312,9 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) ->
291312

292313
sanitized_collection = self._sanitize_collection_name(collection=collection)
293314

294-
if self._pool is None:
295-
msg = "Pool is not initialized. Use async with or call __aenter__() first."
296-
raise RuntimeError(msg)
297-
298-
async with self._pool.acquire() as conn:
315+
async with self._acquire_connection() as conn:
299316
# Use ANY to query for multiple keys
300-
rows = await conn.fetch(
317+
rows = await conn.fetch( # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
301318
f"SELECT key, value, ttl, created_at, expires_at FROM {self._table_name} WHERE collection = $1 AND key = ANY($2::text[])", # noqa: S608
302319
sanitized_collection,
303320
list(keys),
@@ -307,23 +324,23 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) ->
307324
entries_by_key: dict[str, ManagedEntry | None] = dict.fromkeys(keys)
308325
expired_keys: list[str] = []
309326

310-
for row in rows:
327+
for row in rows: # pyright: ignore[reportUnknownVariableType]
311328
managed_entry = ManagedEntry(
312-
value=row["value"],
313-
ttl=row["ttl"],
314-
created_at=row["created_at"],
315-
expires_at=row["expires_at"],
329+
value=row["value"], # pyright: ignore[reportUnknownArgumentType]
330+
ttl=row["ttl"], # pyright: ignore[reportUnknownArgumentType]
331+
created_at=row["created_at"], # pyright: ignore[reportUnknownArgumentType]
332+
expires_at=row["expires_at"], # pyright: ignore[reportUnknownArgumentType]
316333
)
317334

318335
if managed_entry.is_expired:
319-
expired_keys.append(row["key"])
336+
expired_keys.append(row["key"]) # pyright: ignore[reportUnknownArgumentType]
320337
entries_by_key[row["key"]] = None
321338
else:
322339
entries_by_key[row["key"]] = managed_entry
323340

324341
# Delete expired entries in batch
325342
if expired_keys:
326-
await conn.execute(
343+
await conn.execute( # pyright: ignore[reportUnknownMemberType]
327344
f"DELETE FROM {self._table_name} WHERE collection = $1 AND key = ANY($2::text[])", # noqa: S608
328345
sanitized_collection,
329346
expired_keys,
@@ -348,12 +365,8 @@ async def _put_managed_entry(
348365
"""
349366
sanitized_collection = self._sanitize_collection_name(collection=collection)
350367

351-
if self._pool is None:
352-
msg = "Pool is not initialized. Use async with or call __aenter__() first."
353-
raise RuntimeError(msg)
354-
355-
async with self._pool.acquire() as conn:
356-
await conn.execute(
368+
async with self._acquire_connection() as conn:
369+
await conn.execute( # pyright: ignore[reportUnknownMemberType]
357370
f"""
358371
INSERT INTO {self._table_name} (collection, key, value, ttl, created_at, expires_at)
359372
VALUES ($1, $2, $3, $4, $5, $6)
@@ -398,19 +411,15 @@ async def _put_managed_entries(
398411

399412
sanitized_collection = self._sanitize_collection_name(collection=collection)
400413

401-
if self._pool is None:
402-
msg = "Pool is not initialized. Use async with or call __aenter__() first."
403-
raise RuntimeError(msg)
404-
405414
# Prepare data for batch insert
406415
values = [
407416
(sanitized_collection, key, entry.value, entry.ttl, entry.created_at, entry.expires_at)
408417
for key, entry in zip(keys, managed_entries, strict=True)
409418
]
410419

411-
async with self._pool.acquire() as conn:
420+
async with self._acquire_connection() as conn:
412421
# Use executemany for batch insert
413-
await conn.executemany(
422+
await conn.executemany( # pyright: ignore[reportUnknownMemberType]
414423
f"""
415424
INSERT INTO {self._table_name} (collection, key, value, ttl, created_at, expires_at)
416425
VALUES ($1, $2, $3, $4, $5, $6)
@@ -437,12 +446,8 @@ async def _delete_managed_entry(self, *, key: str, collection: str) -> bool:
437446
"""
438447
sanitized_collection = self._sanitize_collection_name(collection=collection)
439448

440-
if self._pool is None:
441-
msg = "Pool is not initialized. Use async with or call __aenter__() first."
442-
raise RuntimeError(msg)
443-
444-
async with self._pool.acquire() as conn:
445-
result = await conn.execute(
449+
async with self._acquire_connection() as conn:
450+
result = await conn.execute( # pyright: ignore[reportUnknownMemberType]
446451
f"DELETE FROM {self._table_name} WHERE collection = $1 AND key = $2", # noqa: S608
447452
sanitized_collection,
448453
key,
@@ -466,12 +471,8 @@ async def _delete_managed_entries(self, *, keys: Sequence[str], collection: str)
466471

467472
sanitized_collection = self._sanitize_collection_name(collection=collection)
468473

469-
if self._pool is None:
470-
msg = "Pool is not initialized. Use async with or call __aenter__() first."
471-
raise RuntimeError(msg)
472-
473-
async with self._pool.acquire() as conn:
474-
result = await conn.execute(
474+
async with self._acquire_connection() as conn:
475+
result = await conn.execute( # pyright: ignore[reportUnknownMemberType]
475476
f"DELETE FROM {self._table_name} WHERE collection = $1 AND key = ANY($2::text[])", # noqa: S608
476477
sanitized_collection,
477478
list(keys),
@@ -491,17 +492,13 @@ async def _get_collection_names(self, *, limit: int | None = None) -> list[str]:
491492
"""
492493
limit = min(limit or DEFAULT_PAGE_SIZE, PAGE_LIMIT)
493494

494-
if self._pool is None:
495-
msg = "Pool is not initialized. Use async with or call __aenter__() first."
496-
raise RuntimeError(msg)
497-
498-
async with self._pool.acquire() as conn:
499-
rows = await conn.fetch(
495+
async with self._acquire_connection() as conn:
496+
rows = await conn.fetch( # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
500497
f"SELECT DISTINCT collection FROM {self._table_name} ORDER BY collection LIMIT $1", # noqa: S608
501498
limit,
502499
)
503500

504-
return [row["collection"] for row in rows]
501+
return [row["collection"] for row in rows] # pyright: ignore[reportUnknownVariableType]
505502

506503
@override
507504
async def _delete_collection(self, *, collection: str) -> bool:
@@ -515,12 +512,8 @@ async def _delete_collection(self, *, collection: str) -> bool:
515512
"""
516513
sanitized_collection = self._sanitize_collection_name(collection=collection)
517514

518-
if self._pool is None:
519-
msg = "Pool is not initialized. Use async with or call __aenter__() first."
520-
raise RuntimeError(msg)
521-
522-
async with self._pool.acquire() as conn:
523-
result = await conn.execute(
515+
async with self._acquire_connection() as conn:
516+
result = await conn.execute( # pyright: ignore[reportUnknownMemberType]
524517
f"DELETE FROM {self._table_name} WHERE collection = $1", # noqa: S608
525518
sanitized_collection,
526519
)

key-value/key-value-aio/tests/stores/postgresql/test_postgresql.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
POSTGRESQL_HOST = "localhost"
2121
POSTGRESQL_HOST_PORT = 5432
2222
POSTGRESQL_USER = "postgres"
23-
POSTGRESQL_PASSWORD = "test"
23+
POSTGRESQL_PASSWORD = "test" # noqa: S105
2424
POSTGRESQL_TEST_DB = "kv_store_test"
2525

2626
WAIT_FOR_POSTGRESQL_TIMEOUT = 30
@@ -37,17 +37,18 @@ async def ping_postgresql() -> bool:
3737
return False
3838

3939
try:
40-
conn = await asyncpg.connect(
40+
conn = await asyncpg.connect( # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
4141
host=POSTGRESQL_HOST,
4242
port=POSTGRESQL_HOST_PORT,
4343
user=POSTGRESQL_USER,
4444
password=POSTGRESQL_PASSWORD,
4545
database="postgres",
4646
)
47-
await conn.close()
48-
return True
47+
await conn.close() # pyright: ignore[reportUnknownMemberType]
4948
except Exception:
5049
return False
50+
else:
51+
return True
5152

5253

5354
class PostgreSQLFailedToStartError(Exception):
@@ -96,10 +97,10 @@ async def store(self, setup_postgresql: None) -> PostgreSQLStore:
9697
# Clean up the database before each test
9798
async with store:
9899
if store._pool is not None: # pyright: ignore[reportPrivateUsage]
99-
async with store._pool.acquire() as conn: # pyright: ignore[reportPrivateUsage]
100+
async with store._pool.acquire() as conn: # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportUnknownVariableType]
100101
# Drop and recreate the kv_store table
101102
with contextlib.suppress(Exception):
102-
await conn.execute("DROP TABLE IF EXISTS kv_store")
103+
await conn.execute("DROP TABLE IF EXISTS kv_store") # pyright: ignore[reportUnknownMemberType]
103104

104105
return store
105106

0 commit comments

Comments
 (0)