From a544b75320e97424d2d927605316383c755cdac0 Mon Sep 17 00:00:00 2001 From: Sebastian Rittau Date: Wed, 15 Mar 2023 08:57:54 +0100 Subject: [PATCH] [SQLAlchemy] Annotate row classes (#9568) Co-authored-by: Avasam --- .../SQLAlchemy/@tests/stubtest_allowlist.txt | 6 ++ stubs/SQLAlchemy/sqlalchemy/cresultproxy.pyi | 22 ++++-- stubs/SQLAlchemy/sqlalchemy/engine/row.pyi | 74 ++++++++++--------- 3 files changed, 62 insertions(+), 40 deletions(-) diff --git a/stubs/SQLAlchemy/@tests/stubtest_allowlist.txt b/stubs/SQLAlchemy/@tests/stubtest_allowlist.txt index 21c09df701a7..4340d77424fa 100644 --- a/stubs/SQLAlchemy/@tests/stubtest_allowlist.txt +++ b/stubs/SQLAlchemy/@tests/stubtest_allowlist.txt @@ -59,6 +59,12 @@ sqlalchemy.testing.provision.stop_test_class_outside_fixtures sqlalchemy.testing.provision.temp_table_keyword_args sqlalchemy.testing.provision.update_db_opts +# potentially replaced at runtime +sqlalchemy.engine.Row.count +sqlalchemy.engine.Row.index +sqlalchemy.engine.row.Row.count +sqlalchemy.engine.row.Row.index + # KeyError/AttributeError on import due to dynamic initialization from a different module sqlalchemy.testing.fixtures sqlalchemy.testing.pickleable diff --git a/stubs/SQLAlchemy/sqlalchemy/cresultproxy.pyi b/stubs/SQLAlchemy/sqlalchemy/cresultproxy.pyi index ecb5b7062293..4bf9b9ea1931 100644 --- a/stubs/SQLAlchemy/sqlalchemy/cresultproxy.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/cresultproxy.pyi @@ -1,11 +1,23 @@ -from typing import Any +from _typeshed import Incomplete +from collections.abc import Callable, Iterable, Iterator +from typing import Any, overload class BaseRow: - def __init__(self, parent, processors, keymap, key_style, data) -> None: ... - def __reduce__(self): ... - def __iter__(self): ... + def __init__( + self, + __parent, + __processors: Iterable[Callable[[Any], Any]] | None, + __keymap: dict[Incomplete, Incomplete], + __key_style: int, + __row: Iterable[Any], + ) -> None: ... + def __reduce__(self) -> tuple[Incomplete, tuple[Incomplete, Incomplete]]: ... + def __iter__(self) -> Iterator[Any]: ... def __len__(self) -> int: ... def __hash__(self) -> int: ... - __getitem__: Any + @overload + def __getitem__(self, __key: str | int) -> tuple[Any, ...]: ... + @overload + def __getitem__(self, __key: slice) -> tuple[tuple[Any, ...]]: ... def safe_rowproxy_reconstructor(__cls, __state): ... diff --git a/stubs/SQLAlchemy/sqlalchemy/engine/row.pyi b/stubs/SQLAlchemy/sqlalchemy/engine/row.pyi index b5098c78b5c7..330af3e9ed34 100644 --- a/stubs/SQLAlchemy/sqlalchemy/engine/row.pyi +++ b/stubs/SQLAlchemy/sqlalchemy/engine/row.pyi @@ -1,9 +1,10 @@ -import abc -from collections.abc import ItemsView, KeysView, Mapping, Sequence, ValuesView -from typing import Any +from collections.abc import ItemsView, Iterator, KeysView, Mapping, Sequence, ValuesView +from typing import Any, Generic, TypeVar from ..cresultproxy import BaseRow as BaseRow +_VT_co = TypeVar("_VT_co", covariant=True) + MD_INDEX: int def rowproxy_reconstructor(cls, state): ... @@ -13,45 +14,48 @@ KEY_OBJECTS_ONLY: int KEY_OBJECTS_BUT_WARN: int KEY_OBJECTS_NO_WARN: int -class Row(BaseRow, Sequence[Any], metaclass=abc.ABCMeta): +class Row(BaseRow, Sequence[Any]): + # The count and index methods are inherited from Sequence. + # If the result set contains columns with the same names, these + # fields contains their respective values, instead. We don't reflect + # this in the stubs. + __hash__ = BaseRow.__hash__ # type: ignore[assignment] + def __lt__(self, other: Row | tuple[Any, ...]) -> bool: ... + def __le__(self, other: Row | tuple[Any, ...]) -> bool: ... + def __ge__(self, other: Row | tuple[Any, ...]) -> bool: ... + def __gt__(self, other: Row | tuple[Any, ...]) -> bool: ... + def __eq__(self, other: object) -> bool: ... + def __ne__(self, other: object) -> bool: ... + def keys(self) -> list[str]: ... + # The following methods are public, but have a leading underscore + # to prevent conflicts with column names. @property - def count(self): ... + def _mapping(self) -> RowMapping: ... @property - def index(self): ... - def __contains__(self, key): ... - __hash__ = BaseRow.__hash__ - def __lt__(self, other): ... - def __le__(self, other): ... - def __ge__(self, other): ... - def __gt__(self, other): ... - def __eq__(self, other): ... - def __ne__(self, other): ... - def keys(self): ... - -class LegacyRow(Row, metaclass=abc.ABCMeta): - def __contains__(self, key): ... - def has_key(self, key): ... - def items(self): ... - def iterkeys(self): ... - def itervalues(self): ... - def values(self): ... + def _fields(self) -> tuple[str, ...]: ... + def _asdict(self) -> dict[str, Any]: ... + +class LegacyRow(Row): + def has_key(self, key: str) -> bool: ... + def items(self) -> list[tuple[str, Any]]: ... + def iterkeys(self) -> Iterator[str]: ... + def itervalues(self) -> Iterator[Any]: ... + def values(self) -> list[Any]: ... BaseRowProxy = BaseRow RowProxy = Row -class ROMappingView(KeysView[Any], ValuesView[Any], ItemsView[Any, Any]): - def __init__(self, mapping, items) -> None: ... +class ROMappingView(KeysView[str], ValuesView[_VT_co], ItemsView[str, _VT_co], Generic[_VT_co]): # type: ignore[misc] + def __init__(self, mapping: RowMapping, items: list[_VT_co]) -> None: ... def __len__(self) -> int: ... - def __iter__(self): ... - def __contains__(self, item): ... - def __eq__(self, other): ... - def __ne__(self, other): ... + def __iter__(self) -> Iterator[_VT_co]: ... # type: ignore[override] + def __eq__(self, other: ROMappingView[_VT_co]) -> bool: ... # type: ignore[override] + def __ne__(self, other: ROMappingView[_VT_co]) -> bool: ... # type: ignore[override] -class RowMapping(BaseRow, Mapping[Any, Any]): +class RowMapping(BaseRow, Mapping[str, Row]): __getitem__: Any - def __iter__(self): ... + def __iter__(self) -> Iterator[str]: ... def __len__(self) -> int: ... - def __contains__(self, key): ... - def items(self): ... - def keys(self): ... - def values(self): ... + def items(self) -> ROMappingView[tuple[str, Any]]: ... # type: ignore[override] + def keys(self) -> list[str]: ... # type: ignore[override] + def values(self) -> ROMappingView[Any]: ... # type: ignore[override]