Skip to content

Use generic defaults for Counter value #12344

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions stdlib/@tests/test_cases/collections/check_counter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from __future__ import annotations

from collections import Counter
from decimal import Decimal
from typing import Union
from typing_extensions import assert_type

# Initialize a Counter for strings with integer values
word_counts: Counter[str] = Counter()
word_counts["foo"] += 3
word_counts["bar"] += 2
assert_type(word_counts, "Counter[str, int]")

# Initialize a Counter for strings with float values
floating_point_counts: Counter[str, float] = Counter()
floating_point_counts["foo"] += 3.0
floating_point_counts["bar"] += 5.0

# Initialize a Counter for strings with Decimal values
decimal_counts: Counter[str, Decimal] = Counter()
decimal_counts["foo"] += Decimal("3.0")
decimal_counts["bar"] += Decimal("5.0")
# Each key defualts to an int.
assert_type(decimal_counts["test"], Union[Decimal, int])
assert_type(decimal_counts.get("test"), Union[Decimal, int, None])
assert_type(decimal_counts.pop("test"), Union[Decimal, int])

# Using kwargs for `__init__`
word_counts = Counter(foo=3, bar=2)
assert_type(word_counts, "Counter[str, int]")
floating_point_counts = Counter(foo=3.0, bar=5.0)
assert_type(floating_point_counts, "Counter[str, float]")
decimal_counts = Counter(foo=Decimal("3.0"), bar=Decimal("5.0"))
assert_type(decimal_counts, "Counter[str, Decimal]")

# Counter combining integers and floats
mixed_type_counter = Counter({"foo": 3, "bar": 2.5})
mixed_type_counter["baz"] += 1.5
mixed_type_counter # pyright: ignore[reportUnusedExpression] # mypy: `"Counter[str, float]"`; pyright: `Counter[str, int | float]`

# Check ORing and ANDing Counters with different value types
# MyPy and Pyright infer the types differently for these, so we can't use assert_type.

_ = mixed_type_counter or decimal_counts
_ # pyright: ignore[reportUnusedExpression] # mypy: `"Counter[str, float] | Counter[str, Decimal]"`; pyright: `Counter[str, int | float] | Counter[str, Decimal]`

_ = decimal_counts or mixed_type_counter
_ # pyright: ignore[reportUnusedExpression] # mypy: `"Counter[str, Decimal] | Counter[str, float]"`; pyright: `Counter[str, Decimal] | Counter[str, int | float]`

_ = mixed_type_counter and decimal_counts
_ # pyright: ignore[reportUnusedExpression] # mypy: `"Counter[str, float] | Counter[str, Decimal]"`; pyright: `Counter[str, int | float] | Counter[str, Decimal]`

_ = decimal_counts and mixed_type_counter
_ # pyright: ignore[reportUnusedExpression] # mypy: `"Counter[str, Decimal] | Counter[str, float]"`; pyright: `Counter[str, Decimal] | Counter[str, int | float]`

# We shouldn't be able to add Counters with incompatible value types
_ = mixed_type_counter + decimal_counts # type: ignore
mixed_type_counter += decimal_counts # type: ignore

# Adding Counters with compatible types
_wc = word_counts + Counter({"foo": 2, "baz": 1})
word_counts += Counter({"foo": 2, "baz": 1})

# Combining Counters of different key types
integer_key_counter = Counter({1: 2, 2: 3})
combined_word_and_integer_keys = word_counts + integer_key_counter
assert_type(combined_word_and_integer_keys, "Counter[str | int, int]")
54 changes: 28 additions & 26 deletions stdlib/collections/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import sys
from _collections_abc import dict_items, dict_keys, dict_values
from _typeshed import SupportsItems, SupportsKeysAndGetItem, SupportsRichComparison, SupportsRichComparisonT
from typing import Any, Generic, NoReturn, SupportsIndex, TypeVar, final, overload
from typing_extensions import Self
from typing import Any, Generic, NoReturn, SupportsIndex, final, overload
from typing_extensions import Self, TypeVar

if sys.version_info >= (3, 9):
from types import GenericAlias
Expand All @@ -28,6 +28,8 @@ __all__ = ["ChainMap", "Counter", "OrderedDict", "UserDict", "UserList", "UserSt

_S = TypeVar("_S")
_T = TypeVar("_T")
_V = TypeVar("_V")
_V_I = TypeVar("_V_I", default=int)
_T1 = TypeVar("_T1")
_T2 = TypeVar("_T2")
_KT = TypeVar("_KT")
Expand Down Expand Up @@ -273,24 +275,24 @@ class deque(MutableSequence[_T]):
if sys.version_info >= (3, 9):
def __class_getitem__(cls, item: Any, /) -> GenericAlias: ...

class Counter(dict[_T, int], Generic[_T]):
class Counter(dict[_T, _V_I | int]):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might prefer Literal[0] here, curious what we think.

@overload
def __init__(self, iterable: None = None, /) -> None: ...
@overload
def __init__(self: Counter[str], iterable: None = None, /, **kwargs: int) -> None: ...
def __init__(self: Counter[str, Any], iterable: None = None, /, **kwargs: _V_I) -> None: ...
@overload
def __init__(self, mapping: SupportsKeysAndGetItem[_T, int], /) -> None: ...
def __init__(self, mapping: SupportsKeysAndGetItem[_T, _V_I], /) -> None: ...
@overload
def __init__(self, iterable: Iterable[_T], /) -> None: ...
def copy(self) -> Self: ...
def elements(self) -> Iterator[_T]: ...
def most_common(self, n: int | None = None) -> list[tuple[_T, int]]: ...
def most_common(self, n: int | None = None) -> list[tuple[_T, _V_I]]: ...
@classmethod
def fromkeys(cls, iterable: Any, v: int | None = None) -> NoReturn: ... # type: ignore[override]
def fromkeys(cls, iterable: Any, v: _V_I | None = None) -> NoReturn: ... # type: ignore[override]
@overload
def subtract(self, iterable: None = None, /) -> None: ...
@overload
def subtract(self, mapping: Mapping[_T, int], /) -> None: ...
def subtract(self, mapping: Mapping[_T, _V_I], /) -> None: ...
@overload
def subtract(self, iterable: Iterable[_T], /) -> None: ...
# Unlike dict.update(), use Mapping instead of SupportsKeysAndGetItem for the first overload
Expand All @@ -300,34 +302,34 @@ class Counter(dict[_T, int], Generic[_T]):
# (if it were `Iterable[_T] | Iterable[tuple[_T, int]]`,
# the tuples would be added as keys, breaking type safety)
@overload # type: ignore[override]
def update(self, m: Mapping[_T, int], /, **kwargs: int) -> None: ...
def update(self, m: Mapping[_T, int], /, **kwargs: _V_I) -> None: ...
@overload
def update(self, iterable: Iterable[_T], /, **kwargs: int) -> None: ...
def update(self, iterable: Iterable[_T], /, **kwargs: _V_I) -> None: ...
@overload
def update(self, iterable: None = None, /, **kwargs: int) -> None: ...
def __missing__(self, key: _T) -> int: ...
def update(self, iterable: None = None, /, **kwargs: _V_I) -> None: ...
def __missing__(self, key: _T) -> _V_I | int: ...
def __delitem__(self, elem: object) -> None: ...
if sys.version_info >= (3, 10):
def __eq__(self, other: object) -> bool: ...
def __ne__(self, other: object) -> bool: ...

def __add__(self, other: Counter[_S]) -> Counter[_T | _S]: ...
def __sub__(self, other: Counter[_T]) -> Counter[_T]: ...
def __and__(self, other: Counter[_T]) -> Counter[_T]: ...
def __or__(self, other: Counter[_S]) -> Counter[_T | _S]: ... # type: ignore[override]
def __pos__(self) -> Counter[_T]: ...
def __neg__(self) -> Counter[_T]: ...
def __add__(self, other: Counter[_S, _V_I]) -> Counter[_T | _S, _V_I]: ...
def __sub__(self, other: Counter[_T, _V_I]) -> Counter[_T, _V_I]: ...
def __and__(self, other: Counter[_T, _V_I]) -> Counter[_T, _V_I]: ...
def __or__(self, other: Counter[_S, _V]) -> Counter[_T | _S, _V_I | _V]: ... # type: ignore[override]
def __pos__(self) -> Counter[_T, _V_I]: ...
def __neg__(self) -> Counter[_T, _V_I]: ...
# several type: ignores because __iadd__ is supposedly incompatible with __add__, etc.
def __iadd__(self, other: SupportsItems[_T, int]) -> Self: ... # type: ignore[misc]
def __isub__(self, other: SupportsItems[_T, int]) -> Self: ...
def __iand__(self, other: SupportsItems[_T, int]) -> Self: ...
def __ior__(self, other: SupportsItems[_T, int]) -> Self: ... # type: ignore[override,misc]
def __iadd__(self, other: SupportsItems[_T, _V_I]) -> Self: ... # type: ignore[misc]
def __isub__(self, other: SupportsItems[_T, _V_I]) -> Self: ... # type: ignore[misc]
def __iand__(self, other: SupportsItems[_T, _V_I]) -> Self: ... # type: ignore[misc]
def __ior__(self, other: SupportsItems[_T, _V_I]) -> Self: ... # type: ignore[override,misc]
if sys.version_info >= (3, 10):
def total(self) -> int: ...
def __le__(self, other: Counter[Any]) -> bool: ...
def __lt__(self, other: Counter[Any]) -> bool: ...
def __ge__(self, other: Counter[Any]) -> bool: ...
def __gt__(self, other: Counter[Any]) -> bool: ...
def __le__(self, other: Counter[Any, _V_I]) -> bool: ...
def __lt__(self, other: Counter[Any, _V_I]) -> bool: ...
def __ge__(self, other: Counter[Any, _V_I]) -> bool: ...
def __gt__(self, other: Counter[Any, _V_I]) -> bool: ...

# The pure-Python implementations of the "views" classes
# These are exposed at runtime in `collections/__init__.py`
Expand Down