Skip to content

Make Counter generic over the value #11632

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 22 additions & 21 deletions stdlib/collections/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import sys
from _collections_abc import dict_items, dict_keys, dict_values
from _typeshed import SupportsItems, SupportsKeysAndGetItem, SupportsRichComparison, SupportsRichComparisonT
from types import GenericAlias
from typing import Any, ClassVar, Generic, NoReturn, SupportsIndex, TypeVar, final, overload
from typing import Any, ClassVar, NoReturn, SupportsIndex, TypeVar, final, overload
from typing_extensions import Self

if sys.version_info >= (3, 10):
Expand Down Expand Up @@ -31,6 +31,7 @@ _KT = TypeVar("_KT")
_VT = TypeVar("_VT")
_KT_co = TypeVar("_KT_co", covariant=True)
_VT_co = TypeVar("_VT_co", covariant=True)
_C = TypeVar("_C", default=int)

# namedtuple is special-cased in the type checker; the initializer is ignored.
def namedtuple(
Expand Down Expand Up @@ -268,55 +269,55 @@ class deque(MutableSequence[_T]):
def __eq__(self, value: object, /) -> bool: ...
def __class_getitem__(cls, item: Any, /) -> GenericAlias: ...

class Counter(dict[_T, int], Generic[_T]):
class Counter(dict[_T, _C]):
@overload
def __init__(self, iterable: None = None, /) -> None: ...
@overload
def __init__(self: Counter[str], iterable: None = None, /, **kwargs: int) -> None: ...
def __init__(self: Counter[str], iterable: None = None, /, **kwargs: _C) -> None: ...
@overload
def __init__(self, mapping: SupportsKeysAndGetItem[_T, int], /) -> None: ...
def __init__(self, mapping: SupportsKeysAndGetItem[_T, _C], /) -> None: ...
@overload
def __init__(self, iterable: Iterable[_T], /) -> None: ...
def copy(self) -> Self: ...
def elements(self) -> Iterator[_T]: ...
def most_common(self, n: int | None = None) -> list[tuple[_T, int]]: ...
@classmethod
def fromkeys(cls, iterable: Any, v: int | None = None) -> NoReturn: ... # type: ignore[override]
def fromkeys(cls, iterable: Any, v: _C | None = None) -> NoReturn: ... # type: ignore[override]
@overload
def subtract(self, iterable: None = None, /) -> None: ...
@overload
def subtract(self, mapping: Mapping[_T, int], /) -> None: ...
def subtract(self, mapping: Mapping[_T, _C], /) -> None: ...
@overload
def subtract(self, iterable: Iterable[_T], /) -> None: ...
# Unlike dict.update(), use Mapping instead of SupportsKeysAndGetItem for the first overload
# (source code does an `isinstance(other, Mapping)` check)
#
# The second overload is also deliberately different to dict.update()
# (if it were `Iterable[_T] | Iterable[tuple[_T, int]]`,
# (if it were `Iterable[_T] | Iterable[tuple[_T, _C]]`,
# the tuples would be added as keys, breaking type safety)
@overload # type: ignore[override]
def update(self, m: Mapping[_T, int], /, **kwargs: int) -> None: ...
def update(self, m: Mapping[_T, _C], /, **kwargs: _C) -> None: ...
@overload
def update(self, iterable: Iterable[_T], /, **kwargs: int) -> None: ...
def update(self, iterable: Iterable[_T], /, **kwargs: _C) -> None: ...
@overload
def update(self, iterable: None = None, /, **kwargs: int) -> None: ...
def __missing__(self, key: _T) -> int: ...
def update(self, iterable: None = None, /, **kwargs: _C) -> None: ...
def __missing__(self, key: _T) -> _C: ...
def __delitem__(self, elem: object) -> None: ...
if sys.version_info >= (3, 10):
def __eq__(self, other: object) -> bool: ...
def __ne__(self, other: object) -> bool: ...

def __add__(self, other: Counter[_S]) -> Counter[_T | _S]: ...
def __sub__(self, other: Counter[_T]) -> Counter[_T]: ...
def __and__(self, other: Counter[_T]) -> Counter[_T]: ...
def __or__(self, other: Counter[_S]) -> Counter[_T | _S]: ... # type: ignore[override]
def __pos__(self) -> Counter[_T]: ...
def __neg__(self) -> Counter[_T]: ...
def __add__(self, other: Counter[_S, _C]) -> Counter[_T | _S, _C]: ...
def __sub__(self, other: Counter[_T, _C]) -> Counter[_T, _C]: ...
def __and__(self, other: Counter[_T, _C]) -> Counter[_T, _C]: ...
def __or__(self, other: Counter[_S, _C]) -> Counter[_T | _S, _C]: ... # type: ignore[override]
def __pos__(self) -> Counter[_T, _C]: ...
def __neg__(self) -> Counter[_T, _C]: ...
# several type: ignores because __iadd__ is supposedly incompatible with __add__, etc.
def __iadd__(self, other: SupportsItems[_T, int]) -> Self: ... # type: ignore[misc]
def __isub__(self, other: SupportsItems[_T, int]) -> Self: ...
def __iand__(self, other: SupportsItems[_T, int]) -> Self: ...
def __ior__(self, other: SupportsItems[_T, int]) -> Self: ... # type: ignore[override,misc]
def __iadd__(self, other: SupportsItems[_T, _C]) -> Self: ... # type: ignore[misc]
def __isub__(self, other: SupportsItems[_T, _C]) -> Self: ...
def __iand__(self, other: SupportsItems[_T, _C]) -> Self: ...
def __ior__(self, other: SupportsItems[_T, _C]) -> Self: ... # type: ignore[override,misc]
if sys.version_info >= (3, 10):
def total(self) -> int: ...
def __le__(self, other: Counter[Any]) -> bool: ...
Expand Down
38 changes: 38 additions & 0 deletions test_cases/stdlib/collections/check_counter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import annotations

from collections import Counter
from typing import Any, cast
from typing_extensions import assert_type


class Foo: ...


# Test the constructor
# mypy derives Never for the first type argument while, pyright derives Unknown
assert_type(Counter(), "Counter[Any, int]")
assert_type(Counter(foo=42.2), "Counter[str, float]")
assert_type(Counter({42: "bar"}), "Counter[int, str]")
assert_type(Counter([1, 2, 3]), "Counter[int, int]")

int_c: Counter[str] = Counter()
assert_type(int_c, "Counter[str, int]")
assert_type(int_c["a"], int)
int_c["a"] = 1
int_c["a"] += 3
int_c["a"] += 3.5 # type: ignore

float_c = Counter(foo=42.2)
assert_type(float_c, "Counter[str, float]")
assert_type(float_c["a"], float)
float_c["a"] = 1.0
float_c["a"] += 3.0
float_c["a"] += 42
float_c["a"] += "42" # type: ignore

custom_c = cast("Counter[str, Foo]", Counter())
assert_type(custom_c, "Counter[str, Foo]")
assert_type(custom_c["a"], Foo)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At runtime this is actually an int though. I wonder if we need to make all these methods return _C | int.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line should probably not be accepted, as Counter() is a Counter[unknown, int], which is incompatible with Counter[..., Foo]. I'm not sure the test makes much sense.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this sort of problem apply to any Counter with a non-int value type, though? This seems like a fundamental problem with this PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At runtime this is actually an int though. I wonder if we need to make all these methods return _C | int.

I wonder whether type checkers support __missing__, in which case, this should happen automatically when we add it to the stubs. But returning _C | int makes some sense to me for getter methods.

custom_c["a"] = Foo()
custom_c["a"] += Foo() # type: ignore
custom_c["a"] += 42 # type: ignore