diff --git a/aiopg/pool.py b/aiopg/pool.py index 4d6015c6..8d45834f 100644 --- a/aiopg/pool.py +++ b/aiopg/pool.py @@ -72,7 +72,7 @@ class _PoolConnectionContextManager: __slots__ = ("_pool", "_conn") - def __init__(self, pool: "Pool", conn: Connection): + def __init__(self, pool: "Pool", conn: Connection) -> None: self._pool: Optional[Pool] = pool self._conn: Optional[Connection] = conn @@ -130,7 +130,7 @@ class _PoolCursorContextManager: __slots__ = ("_pool", "_conn", "_cursor") - def __init__(self, pool: "Pool", conn: Connection, cursor: Cursor): + def __init__(self, pool: "Pool", conn: Connection, cursor: Cursor) -> None: self._pool = pool self._conn = conn self._cursor = cursor diff --git a/aiopg/sa/connection.py b/aiopg/sa/connection.py index 8280dba7..7a8119fb 100644 --- a/aiopg/sa/connection.py +++ b/aiopg/sa/connection.py @@ -6,6 +6,7 @@ from sqlalchemy.sql.ddl import DDLElement from sqlalchemy.sql.dml import UpdateBase +from ..connection import Cursor from ..utils import _ContextManager, _IterableContextManager from . import exc from .result import ResultProxy @@ -43,7 +44,7 @@ class SAConnection: "_query_compile_kwargs", ) - def __init__(self, connection, engine): + def __init__(self, connection, engine) -> None: self._connection = connection self._transaction = None self._savepoint_seq = 0 @@ -52,7 +53,7 @@ def __init__(self, connection, engine): self._cursors = weakref.WeakSet() self._query_compile_kwargs = dict(self._QUERY_COMPILE_KWARGS) - def execute(self, query, *multiparams, **params): + def execute(self, query, *multiparams, **params) -> _IterableContextManager[ResultProxy]: """Executes a SQL query with optional parameters. query - a SQL query string or any sqlalchemy expression. @@ -92,18 +93,18 @@ def execute(self, query, *multiparams, **params): coro = self._execute(query, *multiparams, **params) return _IterableContextManager[ResultProxy](coro, _close_result_proxy) - async def _open_cursor(self): + async def _open_cursor(self) -> Cursor: if self._connection is None: raise exc.ResourceClosedError("This connection is closed.") cursor = await self._connection.cursor() self._cursors.add(cursor) return cursor - def _close_cursor(self, cursor): + def _close_cursor(self, cursor) -> None: self._cursors.remove(cursor) cursor.close() - async def _execute(self, query, *multiparams, **params): + async def _execute(self, query, *multiparams, **params) -> ResultProxy: cursor = await self._open_cursor() dp = _distill_params(multiparams, params) if len(dp) > 1: @@ -181,7 +182,7 @@ async def scalar(self, query, *multiparams, **params): return await res.scalar() @property - def closed(self): + def closed(self) -> bool: """The readonly property that returns True if connections is closed.""" return self.connection is None or self.connection.closed @@ -231,7 +232,7 @@ def begin(self, isolation_level=None, readonly=False, deferrable=False): coro, _commit_transaction_if_active, _rollback_transaction ) - async def _begin(self, isolation_level, readonly, deferrable): + async def _begin(self, isolation_level, readonly, deferrable) -> Transaction: if self._transaction is None: self._transaction = RootTransaction(self) await self._begin_impl(isolation_level, readonly, deferrable) @@ -377,11 +378,11 @@ async def commit_prepared(self, xid, *, is_prepared=True): await self._commit_impl() @property - def in_transaction(self): + def in_transaction(self) -> bool: """Return True if a transaction is in progress.""" return self._transaction is not None and self._transaction.is_active - async def close(self): + async def close(self) -> None: """Close this SAConnection. This results in a release of the underlying database @@ -401,7 +402,7 @@ async def close(self): await asyncio.shield(self._close()) - async def _close(self): + async def _close(self) -> None: if self._transaction is not None: with contextlib.suppress(Exception): await self._transaction.rollback() diff --git a/aiopg/sa/engine.py b/aiopg/sa/engine.py index b1ed35dd..ffa5940a 100644 --- a/aiopg/sa/engine.py +++ b/aiopg/sa/engine.py @@ -1,6 +1,10 @@ +from __future__ import annotations + import asyncio import json +from sqlalchemy.dialects.postgresql.base import PGDialect + import aiopg from ..connection import TIMEOUT @@ -41,7 +45,7 @@ def _exec_default(self, default): return default.arg -def get_dialect(json_serializer=json.dumps, json_deserializer=lambda x: x): +def get_dialect(json_serializer=json.dumps, json_deserializer=lambda x: x) -> PGDialect: dialect = PGDialect_psycopg2( json_serializer=json_serializer, json_deserializer=json_deserializer ) @@ -69,7 +73,7 @@ def create_engine( timeout=TIMEOUT, pool_recycle=-1, **kwargs -): +) -> _ContextManager[Engine]: """A coroutine for Engine creation. Returns Engine instance with embedded connection pool. @@ -98,7 +102,7 @@ async def _create_engine( timeout=TIMEOUT, pool_recycle=-1, **kwargs -): +) -> Engine: pool = await aiopg.create_pool( dsn, @@ -116,7 +120,7 @@ async def _create_engine( await pool.release(conn) -async def _close_engine(engine: "Engine") -> None: +async def _close_engine(engine: Engine) -> None: engine.close() await engine.wait_closed() @@ -136,19 +140,19 @@ class Engine: __slots__ = ("_dialect", "_pool", "_dsn", "_loop") - def __init__(self, dialect, pool, dsn): + def __init__(self, dialect, pool, dsn) -> None: self._dialect = dialect self._pool = pool self._dsn = dsn self._loop = get_running_loop() @property - def dialect(self): + def dialect(self) -> PGDialect: """An dialect for engine.""" return self._dialect @property - def name(self): + def name(self) -> str: """A name of the dialect.""" return self._dialect.name @@ -186,7 +190,7 @@ def freesize(self): def closed(self): return self._pool.closed - def close(self): + def close(self) -> None: """Close engine. Mark all engine connections to be closed on getting back to pool. @@ -194,7 +198,7 @@ def close(self): """ self._pool.close() - def terminate(self): + def terminate(self) -> None: """Terminate engine. Terminate engine pool with instantly closing all acquired @@ -206,12 +210,12 @@ async def wait_closed(self): """Wait for closing all engine's connections.""" await self._pool.wait_closed() - def acquire(self): + def acquire(self) -> _ContextManager[SAConnection]: """Get a connection from pool.""" coro = self._acquire() return _ContextManager[SAConnection](coro, _close_connection) - async def _acquire(self): + async def _acquire(self) -> SAConnection: raw = await self._pool.acquire() return SAConnection(raw, self) @@ -244,10 +248,10 @@ def __await__(self): conn = yield from self._acquire().__await__() return _ConnectionContextManager(conn, self._loop) - async def __aenter__(self): + async def __aenter__(self) -> Engine: return self - async def __aexit__(self, exc_type, exc_val, exc_tb): + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: self.close() await self.wait_closed() @@ -269,13 +273,13 @@ class _ConnectionContextManager: __slots__ = ("_conn", "_loop") - def __init__(self, conn: SAConnection, loop: asyncio.AbstractEventLoop): + def __init__(self, conn: SAConnection, loop: asyncio.AbstractEventLoop) -> None: self._conn = conn self._loop = loop - def __enter__(self): + def __enter__(self) -> SAConnection: return self._conn - def __exit__(self, *args): + def __exit__(self, *args) -> None: asyncio.ensure_future(self._conn.close(), loop=self._loop) self._conn = None diff --git a/aiopg/sa/result.py b/aiopg/sa/result.py index b81234c4..b4eb375f 100644 --- a/aiopg/sa/result.py +++ b/aiopg/sa/result.py @@ -1,6 +1,8 @@ import weakref from collections.abc import Mapping, Sequence +from typing import Tuple, Dict, List, Union +from sqlalchemy.dialects.postgresql.base import PGDialect from sqlalchemy.sql import expression, sqltypes from . import exc @@ -17,7 +19,7 @@ class RowProxy(Mapping): __slots__ = ("_result_proxy", "_row", "_processors", "_keymap") - def __init__(self, result_proxy, row, processors, keymap): + def __init__(self, result_proxy, row, processors, keymap) -> None: """RowProxy objects are constructed by ResultProxy objects.""" self._result_proxy = result_proxy self._row = row @@ -27,7 +29,7 @@ def __init__(self, result_proxy, row, processors, keymap): def __iter__(self): return iter(self._result_proxy.keys) - def __len__(self): + def __len__(self) -> int: return len(self._row) def __getitem__(self, key): @@ -64,12 +66,12 @@ def __getattr__(self, name): except KeyError as e: raise AttributeError(e.args[0]) - def __contains__(self, key): + def __contains__(self, key) -> bool: return self._result_proxy._has_key(self._row, key) __hash__ = None - def __eq__(self, other): + def __eq__(self, other) -> bool: if isinstance(other, RowProxy): return self.as_tuple() == other.as_tuple() elif isinstance(other, Sequence): @@ -77,13 +79,13 @@ def __eq__(self, other): else: return NotImplemented - def __ne__(self, other): + def __ne__(self, other) -> bool: return not self == other - def as_tuple(self): + def as_tuple(self) -> Tuple: return tuple(self[k] for k in self) - def __repr__(self): + def __repr__(self) -> str: return repr(self.as_tuple()) @@ -91,7 +93,7 @@ class ResultMetaData: """Handle cursor.description, applying additional info from an execution context.""" - def __init__(self, result_proxy, cursor_description): + def __init__(self, result_proxy, cursor_description) -> None: self._processors = processors = [] map_type, map_column_name = self.result_map(result_proxy._result_map) @@ -171,7 +173,7 @@ def __init__(self, result_proxy, cursor_description): # high precedence keymap. keymap.update(primary_keymap) - def result_map(self, data_map): + def result_map(self, data_map) -> Tuple[Dict, Dict]: data_map = data_map or {} map_type = {} map_column_name = {} @@ -220,7 +222,7 @@ def _key_fallback(self, key, raiseerr=True): map[key] = result return result - def _has_key(self, row, key): + def _has_key(self, row, key) -> bool: if key in self._keymap: return True else: @@ -247,7 +249,7 @@ class ResultProxy: the originating SQL statement that produced this result set. """ - def __init__(self, connection, cursor, dialect, result_map=None): + def __init__(self, connection, cursor, dialect, result_map=None) -> None: self._dialect = dialect self._result_map = result_map self._cursor = cursor @@ -258,7 +260,7 @@ def __init__(self, connection, cursor, dialect, result_map=None): self._init_metadata() @property - def dialect(self): + def dialect(self) -> PGDialect: """SQLAlchemy dialect.""" return self._dialect @@ -266,7 +268,7 @@ def dialect(self): def cursor(self): return self._cursor - def keys(self): + def keys(self) -> tuple: """Return the current set of string keys for rows.""" if self._metadata: return tuple(self._metadata.keys) @@ -274,7 +276,7 @@ def keys(self): return () @property - def rowcount(self): + def rowcount(self) -> int: """Return the 'rowcount' for this result. The 'rowcount' reports the number of rows *matched* @@ -313,7 +315,7 @@ def _init_metadata(self): self.close() @property - def returns_rows(self): + def returns_rows(self) -> bool: """True if this ResultProxy returns rows. I.e. if it is legal to call the methods .fetchone(), @@ -322,13 +324,13 @@ def returns_rows(self): return self._metadata is not None @property - def closed(self): + def closed(self) -> bool: if self._cursor is None: return True return bool(self._cursor.closed) - def close(self): + def close(self) -> None: """Close this ResultProxy. Closes the underlying DBAPI cursor corresponding to the execution. @@ -380,7 +382,7 @@ def _process_rows(self, rows): processors = metadata._processors return [process_row(metadata, row, processors, keymap) for row in rows] - async def fetchall(self): + async def fetchall(self) -> List[RowProxy]: """Fetch all rows, just like DB-API cursor.fetchall().""" try: rows = await self.cursor.fetchall() @@ -391,7 +393,7 @@ async def fetchall(self): self.close() return res - async def fetchone(self): + async def fetchone(self) -> Union[RowProxy, None]: """Fetch one row, just like DB-API cursor.fetchone(). If a row is present, the cursor remains open after this is called. @@ -408,7 +410,7 @@ async def fetchone(self): self.close() return None - async def fetchmany(self, size=None): + async def fetchmany(self, size=None) -> List[RowProxy]: """Fetch many rows, just like DB-API cursor.fetchmany(size=cursor.arraysize). @@ -428,7 +430,7 @@ async def fetchmany(self, size=None): self.close() return res - async def first(self): + async def first(self) -> Union[RowProxy, None]: """Fetch the first row and then close the result set unconditionally. Returns None if no row is present. diff --git a/aiopg/sa/transaction.py b/aiopg/sa/transaction.py index 099cd303..dd2a8a4c 100644 --- a/aiopg/sa/transaction.py +++ b/aiopg/sa/transaction.py @@ -26,13 +26,13 @@ class Transaction: __slots__ = ("_connection", "_parent", "_is_active") - def __init__(self, connection, parent): + def __init__(self, connection, parent) -> None: self._connection = connection self._parent = parent or self self._is_active = True @property - def is_active(self): + def is_active(self) -> bool: """Return ``True`` if a transaction is active.""" return self._is_active @@ -41,7 +41,7 @@ def connection(self): """Return transaction's connection (SAConnection instance).""" return self._connection - async def close(self): + async def close(self) -> None: """Close this transaction. If this transaction is the base transaction in a begin/commit @@ -58,17 +58,17 @@ async def close(self): else: self._is_active = False - async def rollback(self): + async def rollback(self) -> None: """Roll back this transaction.""" if not self._parent._is_active: return await self._do_rollback() self._is_active = False - async def _do_rollback(self): + async def _do_rollback(self) -> None: await self._parent.rollback() - async def commit(self): + async def commit(self) -> None: """Commit this transaction.""" if not self._parent._is_active: @@ -76,13 +76,13 @@ async def commit(self): await self._do_commit() self._is_active = False - async def _do_commit(self): + async def _do_commit(self) -> None: pass - async def __aenter__(self): + async def __aenter__(self) -> "Transaction": return self - async def __aexit__(self, exc_type, exc_val, exc_tb): + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: if exc_type: await self.rollback() elif self._is_active: @@ -92,7 +92,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): class RootTransaction(Transaction): __slots__ = () - def __init__(self, connection): + def __init__(self, connection) -> None: super().__init__(connection, None) async def _do_rollback(self): @@ -113,7 +113,7 @@ class NestedTransaction(Transaction): __slots__ = ("_savepoint",) - def __init__(self, connection, parent): + def __init__(self, connection, parent) -> None: super().__init__(connection, parent) self._savepoint = None @@ -144,7 +144,7 @@ class TwoPhaseTransaction(Transaction): __slots__ = ("_is_prepared", "_xid") - def __init__(self, connection, xid): + def __init__(self, connection, xid) -> None: super().__init__(connection, None) self._is_prepared = False self._xid = xid