diff --git a/stdlib/@tests/test_cases/sqlite3/check_aggregations.py b/stdlib/@tests/test_cases/sqlite3/check_aggregations.py new file mode 100644 index 000000000000..8eebe71a39e0 --- /dev/null +++ b/stdlib/@tests/test_cases/sqlite3/check_aggregations.py @@ -0,0 +1,115 @@ +import sqlite3 +import sys + + +class WindowSumInt: + def __init__(self) -> None: + self.count = 0 + + def step(self, param: int) -> None: + self.count += param + + def value(self) -> int: + return self.count + + def inverse(self, param: int) -> None: + self.count -= param + + def finalize(self) -> int: + return self.count + + +con = sqlite3.connect(":memory:") +cur = con.execute("CREATE TABLE test(x, y)") +values = [("a", 4), ("b", 5), ("c", 3), ("d", 8), ("e", 1)] +cur.executemany("INSERT INTO test VALUES(?, ?)", values) + +if sys.version_info >= (3, 11): + con.create_window_function("sumint", 1, WindowSumInt) + +con.create_aggregate("sumint", 1, WindowSumInt) +cur.execute( + """ + SELECT x, sumint(y) OVER ( + ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING + ) AS sum_y + FROM test ORDER BY x +""" +) +con.close() + + +def _create_window_function() -> WindowSumInt: + return WindowSumInt() + + +# A callable should work as well. +if sys.version_info >= (3, 11): + con.create_window_function("sumint", 1, _create_window_function) + con.create_aggregate("sumint", 1, _create_window_function) + +# With num_args set to 1, the callable should not be called with more than one. + + +class WindowSumIntMultiArgs: + def __init__(self) -> None: + self.count = 0 + + def step(self, *args: int) -> None: + self.count += sum(args) + + def value(self) -> int: + return self.count + + def inverse(self, *args: int) -> None: + self.count -= sum(args) + + def finalize(self) -> int: + return self.count + + +if sys.version_info >= (3, 11): + con.create_window_function("sumint", 1, WindowSumIntMultiArgs) + con.create_window_function("sumint", 2, WindowSumIntMultiArgs) + +con.create_aggregate("sumint", 1, WindowSumIntMultiArgs) +con.create_aggregate("sumint", 2, WindowSumIntMultiArgs) + + +class WindowSumIntMismatchedArgs: + def __init__(self) -> None: + self.count = 0 + + def step(self, *args: str) -> None: + self.count += 34 + + def value(self) -> int: + return self.count + + def inverse(self, *args: int) -> None: + self.count -= 34 + + def finalize(self) -> str: + return str(self.count) + + +# Since the types for `inverse`, `step`, `finalize`, and `value` are not compatible, the following should fail. +if sys.version_info >= (3, 11): + con.create_window_function("sumint", 1, WindowSumIntMismatchedArgs) # type: ignore + con.create_window_function("sumint", 2, WindowSumIntMismatchedArgs) # type: ignore + + +class AggMismatchedArgs: + def __init__(self) -> None: + self.count = 0 + + def step(self, *args: str) -> None: + self.count += 34 + + def finalize(self) -> int: + return self.count + + +# Since the types for `step` and `finalize` are not compatible, the following should fail. +con.create_aggregate("sumint", 1, AggMismatchedArgs) # type: ignore +con.create_aggregate("sumint", 2, AggMismatchedArgs) # type: ignore diff --git a/stdlib/@tests/test_cases/check_sqlite3.py b/stdlib/@tests/test_cases/sqlite3/check_connection.py similarity index 100% rename from stdlib/@tests/test_cases/check_sqlite3.py rename to stdlib/@tests/test_cases/sqlite3/check_connection.py diff --git a/stdlib/sqlite3/dbapi2.pyi b/stdlib/sqlite3/dbapi2.pyi index 3cb4b93e88fe..0fd7740804dd 100644 --- a/stdlib/sqlite3/dbapi2.pyi +++ b/stdlib/sqlite3/dbapi2.pyi @@ -11,6 +11,7 @@ _T = TypeVar("_T") _ConnectionT = TypeVar("_ConnectionT", bound=Connection) _CursorT = TypeVar("_CursorT", bound=Cursor) _SqliteData: TypeAlias = str | ReadableBuffer | int | float | None +_SQLType = TypeVar("_SQLType", bound=_SqliteData) # Data that is passed through adapters can be of any type accepted by an adapter. _AdaptedInputData: TypeAlias = _SqliteData | Any # The Mapping must really be a dict, but making it invariant is too annoying. @@ -312,27 +313,25 @@ else: def register_adapter(type: type[_T], caster: _Adapter[_T], /) -> None: ... def register_converter(name: str, converter: _Converter, /) -> None: ... -class _AggregateProtocol(Protocol): - def step(self, value: int, /) -> object: ... - def finalize(self) -> int: ... +class _SingleParamAggregateProtocol(Protocol[_SQLType]): + def step(self, param: _SQLType, /) -> object: ... + def finalize(self) -> _SQLType: ... -class _SingleParamWindowAggregateClass(Protocol): - def step(self, param: Any, /) -> object: ... - def inverse(self, param: Any, /) -> object: ... - def value(self) -> _SqliteData: ... - def finalize(self) -> _SqliteData: ... +class _AnyParamAggregateProtocol(Protocol[_SQLType]): + def step(self, *args: _SQLType) -> object: ... + def finalize(self) -> _SQLType: ... -class _AnyParamWindowAggregateClass(Protocol): - def step(self, *args: Any) -> object: ... - def inverse(self, *args: Any) -> object: ... - def value(self) -> _SqliteData: ... - def finalize(self) -> _SqliteData: ... +class _SingleParamWindowAggregateClass(Protocol[_SQLType]): + def step(self, param: _SQLType, /) -> object: ... + def inverse(self, param: _SQLType, /) -> object: ... + def value(self) -> _SQLType: ... + def finalize(self) -> _SQLType: ... -class _WindowAggregateClass(Protocol): - step: Callable[..., object] - inverse: Callable[..., object] - def value(self) -> _SqliteData: ... - def finalize(self) -> _SqliteData: ... +class _AnyParamWindowAggregateClass(Protocol[_SQLType]): + def step(self, *args: _SQLType) -> object: ... + def inverse(self, *args: _SQLType) -> object: ... + def value(self) -> _SQLType: ... + def finalize(self) -> _SQLType: ... class Connection: @property @@ -398,22 +397,30 @@ class Connection: def blobopen(self, table: str, column: str, row: int, /, *, readonly: bool = False, name: str = "main") -> Blob: ... def commit(self) -> None: ... - def create_aggregate(self, name: str, n_arg: int, aggregate_class: Callable[[], _AggregateProtocol]) -> None: ... + @overload + def create_aggregate( + self, name: str, n_arg: Literal[1], aggregate_class: Callable[[], _SingleParamAggregateProtocol[_SQLType]] + ) -> None: ... + @overload + def create_aggregate( + self, name: str, n_arg: int, aggregate_class: Callable[[], _AnyParamAggregateProtocol[_SQLType]] + ) -> None: ... + if sys.version_info >= (3, 11): # num_params determines how many params will be passed to the aggregate class. We provide an overload # for the case where num_params = 1, which is expected to be the common case. @overload def create_window_function( - self, name: str, num_params: Literal[1], aggregate_class: Callable[[], _SingleParamWindowAggregateClass] | None, / + self, + name: str, + num_params: Literal[1], + aggregate_class: Callable[[], _SingleParamWindowAggregateClass[_SQLType]] | None, + /, ) -> None: ... # And for num_params = -1, which means the aggregate must accept any number of parameters. @overload def create_window_function( - self, name: str, num_params: Literal[-1], aggregate_class: Callable[[], _AnyParamWindowAggregateClass] | None, / - ) -> None: ... - @overload - def create_window_function( - self, name: str, num_params: int, aggregate_class: Callable[[], _WindowAggregateClass] | None, / + self, name: str, num_params: int, aggregate_class: Callable[[], _AnyParamWindowAggregateClass[_SQLType]] | None, / ) -> None: ... def create_collation(self, name: str, callback: Callable[[str, str], int | SupportsIndex] | None, /) -> None: ...