Skip to content

Fix SQLite Aggregation Protocols #12192

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 13 commits into from
115 changes: 115 additions & 0 deletions stdlib/@tests/test_cases/sqlite3/check_aggregations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import sqlite3
import sys


class WindowSumInt:
def __init__(self) -> None:
self.count = 0

def step(self, param: int) -> None:
self.count += param

def value(self) -> int:
return self.count

def inverse(self, param: int) -> None:
self.count -= param

def finalize(self) -> int:
return self.count


con = sqlite3.connect(":memory:")
cur = con.execute("CREATE TABLE test(x, y)")
values = [("a", 4), ("b", 5), ("c", 3), ("d", 8), ("e", 1)]
cur.executemany("INSERT INTO test VALUES(?, ?)", values)

if sys.version_info >= (3, 11):
con.create_window_function("sumint", 1, WindowSumInt)

con.create_aggregate("sumint", 1, WindowSumInt)
cur.execute(
"""
SELECT x, sumint(y) OVER (
ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING
) AS sum_y
FROM test ORDER BY x
"""
)
con.close()


def _create_window_function() -> WindowSumInt:
return WindowSumInt()


# A callable should work as well.
if sys.version_info >= (3, 11):
con.create_window_function("sumint", 1, _create_window_function)
con.create_aggregate("sumint", 1, _create_window_function)

# With num_args set to 1, the callable should not be called with more than one.


class WindowSumIntMultiArgs:
def __init__(self) -> None:
self.count = 0

def step(self, *args: int) -> None:
self.count += sum(args)

def value(self) -> int:
return self.count

def inverse(self, *args: int) -> None:
self.count -= sum(args)

def finalize(self) -> int:
return self.count


if sys.version_info >= (3, 11):
con.create_window_function("sumint", 1, WindowSumIntMultiArgs)
con.create_window_function("sumint", 2, WindowSumIntMultiArgs)

con.create_aggregate("sumint", 1, WindowSumIntMultiArgs)
con.create_aggregate("sumint", 2, WindowSumIntMultiArgs)


class WindowSumIntMismatchedArgs:
def __init__(self) -> None:
self.count = 0

def step(self, *args: str) -> None:
self.count += 34

def value(self) -> int:
return self.count

def inverse(self, *args: int) -> None:
self.count -= 34

def finalize(self) -> str:
return str(self.count)


# Since the types for `inverse`, `step`, `finalize`, and `value` are not compatible, the following should fail.
if sys.version_info >= (3, 11):
con.create_window_function("sumint", 1, WindowSumIntMismatchedArgs) # type: ignore
con.create_window_function("sumint", 2, WindowSumIntMismatchedArgs) # type: ignore


class AggMismatchedArgs:
def __init__(self) -> None:
self.count = 0

def step(self, *args: str) -> None:
self.count += 34

def finalize(self) -> int:
return self.count


# Since the types for `step` and `finalize` are not compatible, the following should fail.
con.create_aggregate("sumint", 1, AggMismatchedArgs) # type: ignore
con.create_aggregate("sumint", 2, AggMismatchedArgs) # type: ignore
57 changes: 32 additions & 25 deletions stdlib/sqlite3/dbapi2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ _T = TypeVar("_T")
_ConnectionT = TypeVar("_ConnectionT", bound=Connection)
_CursorT = TypeVar("_CursorT", bound=Cursor)
_SqliteData: TypeAlias = str | ReadableBuffer | int | float | None
_SQLType = TypeVar("_SQLType", bound=_SqliteData)
# Data that is passed through adapters can be of any type accepted by an adapter.
_AdaptedInputData: TypeAlias = _SqliteData | Any
# The Mapping must really be a dict, but making it invariant is too annoying.
Expand Down Expand Up @@ -312,27 +313,25 @@ else:
def register_adapter(type: type[_T], caster: _Adapter[_T], /) -> None: ...
def register_converter(name: str, converter: _Converter, /) -> None: ...

class _AggregateProtocol(Protocol):
def step(self, value: int, /) -> object: ...
def finalize(self) -> int: ...
class _SingleParamAggregateProtocol(Protocol[_SQLType]):
def step(self, param: _SQLType, /) -> object: ...
def finalize(self) -> _SQLType: ...

class _SingleParamWindowAggregateClass(Protocol):
def step(self, param: Any, /) -> object: ...
def inverse(self, param: Any, /) -> object: ...
def value(self) -> _SqliteData: ...
def finalize(self) -> _SqliteData: ...
class _AnyParamAggregateProtocol(Protocol[_SQLType]):
def step(self, *args: _SQLType) -> object: ...
def finalize(self) -> _SQLType: ...

class _AnyParamWindowAggregateClass(Protocol):
def step(self, *args: Any) -> object: ...
def inverse(self, *args: Any) -> object: ...
def value(self) -> _SqliteData: ...
def finalize(self) -> _SqliteData: ...
class _SingleParamWindowAggregateClass(Protocol[_SQLType]):
def step(self, param: _SQLType, /) -> object: ...
def inverse(self, param: _SQLType, /) -> object: ...
def value(self) -> _SQLType: ...
def finalize(self) -> _SQLType: ...

class _WindowAggregateClass(Protocol):
step: Callable[..., object]
Copy link
Contributor Author

@max-muoto max-muoto Jun 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From testing things out, it doesn't seem this protocol really works as intended in either MyPy or Pyright. Unless you actually were annotating a lambda perhaps. Due to this, I went ahead and removed it.

Some examples of how it might not work as you would expect:

Pyright Playground

MyPy Playground

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been part of the initial commit in #7625, while the other protocols already used a function. Maybe @JelleZijlstra remembers why we used an attribute instead of a function here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really remember what I was thinking when I wrote that code, but the annotations proposed in this PR mean that protocol implementations must take *args. I am not familiar with how these things are used, but I'd expect concrete implementations to only accept a fixed number of parameters. Maybe that's why I chose to use Callable[....

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried messing around with TypeVarTuples to do that, but had some issues there as well, there might not be a great way, but I'll see if I can figure something out.

inverse: Callable[..., object]
def value(self) -> _SqliteData: ...
def finalize(self) -> _SqliteData: ...
class _AnyParamWindowAggregateClass(Protocol[_SQLType]):
def step(self, *args: _SQLType) -> object: ...
def inverse(self, *args: _SQLType) -> object: ...
def value(self) -> _SQLType: ...
def finalize(self) -> _SQLType: ...

class Connection:
@property
Expand Down Expand Up @@ -398,22 +397,30 @@ class Connection:
def blobopen(self, table: str, column: str, row: int, /, *, readonly: bool = False, name: str = "main") -> Blob: ...

def commit(self) -> None: ...
def create_aggregate(self, name: str, n_arg: int, aggregate_class: Callable[[], _AggregateProtocol]) -> None: ...
@overload
def create_aggregate(
self, name: str, n_arg: Literal[1], aggregate_class: Callable[[], _SingleParamAggregateProtocol[_SQLType]]
) -> None: ...
@overload
def create_aggregate(
self, name: str, n_arg: int, aggregate_class: Callable[[], _AnyParamAggregateProtocol[_SQLType]]
) -> None: ...

if sys.version_info >= (3, 11):
# num_params determines how many params will be passed to the aggregate class. We provide an overload
# for the case where num_params = 1, which is expected to be the common case.
@overload
def create_window_function(
self, name: str, num_params: Literal[1], aggregate_class: Callable[[], _SingleParamWindowAggregateClass] | None, /
self,
name: str,
num_params: Literal[1],
aggregate_class: Callable[[], _SingleParamWindowAggregateClass[_SQLType]] | None,
/,
) -> None: ...
# And for num_params = -1, which means the aggregate must accept any number of parameters.
@overload
def create_window_function(
self, name: str, num_params: Literal[-1], aggregate_class: Callable[[], _AnyParamWindowAggregateClass] | None, /
) -> None: ...
@overload
def create_window_function(
self, name: str, num_params: int, aggregate_class: Callable[[], _WindowAggregateClass] | None, /
self, name: str, num_params: int, aggregate_class: Callable[[], _AnyParamWindowAggregateClass[_SQLType]] | None, /
) -> None: ...

def create_collation(self, name: str, callback: Callable[[str, str], int | SupportsIndex] | None, /) -> None: ...
Expand Down