From e1659c4bdfd3c4ef3ebbf78cd8482d47d27d736c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Noord?= <13665637+DanielNoord@users.noreply.github.com> Date: Wed, 30 Oct 2024 20:08:04 +0100 Subject: [PATCH] Complete typing of `connresource` --- asyncpg/connection.py | 2 +- asyncpg/connresource.py | 38 +++++++++++++++++++++++++++++++++---- asyncpg/exceptions/_base.py | 12 ++++++++++-- pyproject.toml | 1 - 4 files changed, 45 insertions(+), 8 deletions(-) diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 3a86466c..040e460f 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -1482,7 +1482,7 @@ async def set_builtin_type_codec(self, typename, *, # Statement cache is no longer valid due to codec changes. self._drop_local_statement_cache() - def is_closed(self): + def is_closed(self) -> bool: """Return ``True`` if the connection is closed, ``False`` otherwise. :return bool: ``True`` if the connection is closed, ``False`` diff --git a/asyncpg/connresource.py b/asyncpg/connresource.py index 3b0c1d3c..eb1e74e2 100644 --- a/asyncpg/connresource.py +++ b/asyncpg/connresource.py @@ -5,17 +5,47 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations import functools +from typing import TYPE_CHECKING, Generic, Protocol, TypeVar +from typing_extensions import ParamSpec from . import exceptions +if TYPE_CHECKING: + from . import connection -def guarded(meth): +_ConnectionResourceT = TypeVar( + "_ConnectionResourceT", bound="ConnectionResource", contravariant=True +) +_P = ParamSpec("_P") +_R = TypeVar("_R", covariant=True) + + +class _ConnectionResourceMethod( + Protocol, + Generic[_ConnectionResourceT, _R, _P], +): + # This indicates that the Protocol is a function and not a lambda + __name__: str + + # Type signature of a method on an instance of _ConnectionResourceT + def __call__( + _, self: _ConnectionResourceT, *args: _P.args, **kwds: _P.kwargs + ) -> _R: + ... + + +def guarded( + meth: _ConnectionResourceMethod[_ConnectionResourceT, _R, _P] +) -> _ConnectionResourceMethod[_ConnectionResourceT, _R, _P]: """A decorator to add a sanity check to ConnectionResource methods.""" @functools.wraps(meth) - def _check(self, *args, **kwargs): + def _check( + self: _ConnectionResourceT, *args: _P.args, **kwargs: _P.kwargs + ) -> _R: self._check_conn_validity(meth.__name__) return meth(self, *args, **kwargs) @@ -25,11 +55,11 @@ def _check(self, *args, **kwargs): class ConnectionResource: __slots__ = ('_connection', '_con_release_ctr') - def __init__(self, connection): + def __init__(self, connection: connection.Connection) -> None: self._connection = connection self._con_release_ctr = connection._pool_release_ctr - def _check_conn_validity(self, meth_name): + def _check_conn_validity(self, meth_name: str) -> None: con_release_ctr = self._connection._pool_release_ctr if con_release_ctr != self._con_release_ctr: raise exceptions.InterfaceError( diff --git a/asyncpg/exceptions/_base.py b/asyncpg/exceptions/_base.py index 00e9699a..2a16ee4f 100644 --- a/asyncpg/exceptions/_base.py +++ b/asyncpg/exceptions/_base.py @@ -4,6 +4,8 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations +from typing import Optional import asyncpg import sys @@ -208,11 +210,17 @@ def __str__(self): class InterfaceError(InterfaceMessage, Exception): """An error caused by improper use of asyncpg API.""" - def __init__(self, msg, *, detail=None, hint=None): + def __init__( + self, + msg: str, + *, + detail: Optional[str] = None, + hint: Optional[str] = None, + ) -> None: InterfaceMessage.__init__(self, detail=detail, hint=hint) Exception.__init__(self, msg) - def with_msg(self, msg): + def with_msg(self, msg: str) -> InterfaceError: return type(self)( msg, detail=self.detail, diff --git a/pyproject.toml b/pyproject.toml index dabb7d8b..c9d99af3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -136,7 +136,6 @@ module = [ "asyncpg.cluster", "asyncpg.connect_utils", "asyncpg.connection", - "asyncpg.connresource", "asyncpg.cursor", "asyncpg.exceptions", "asyncpg.exceptions.*",