diff --git a/stdlib/@tests/test_cases/collections/check_counter.py b/stdlib/@tests/test_cases/collections/check_counter.py new file mode 100644 index 000000000000..cd3f93852833 --- /dev/null +++ b/stdlib/@tests/test_cases/collections/check_counter.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from collections import Counter +from decimal import Decimal +from typing import Union +from typing_extensions import assert_type + +# Initialize a Counter for strings with integer values +word_counts: Counter[str] = Counter() +word_counts["foo"] += 3 +word_counts["bar"] += 2 +assert_type(word_counts, "Counter[str, int]") + +# Initialize a Counter for strings with float values +floating_point_counts: Counter[str, float] = Counter() +floating_point_counts["foo"] += 3.0 +floating_point_counts["bar"] += 5.0 + +# Initialize a Counter for strings with Decimal values +decimal_counts: Counter[str, Decimal] = Counter() +decimal_counts["foo"] += Decimal("3.0") +decimal_counts["bar"] += Decimal("5.0") +# Each key defualts to an int. +assert_type(decimal_counts["test"], Union[Decimal, int]) +assert_type(decimal_counts.get("test"), Union[Decimal, int, None]) +assert_type(decimal_counts.pop("test"), Union[Decimal, int]) + +# Using kwargs for `__init__` +word_counts = Counter(foo=3, bar=2) +assert_type(word_counts, "Counter[str, int]") +floating_point_counts = Counter(foo=3.0, bar=5.0) +assert_type(floating_point_counts, "Counter[str, float]") +decimal_counts = Counter(foo=Decimal("3.0"), bar=Decimal("5.0")) +assert_type(decimal_counts, "Counter[str, Decimal]") + +# Counter combining integers and floats +mixed_type_counter = Counter({"foo": 3, "bar": 2.5}) +mixed_type_counter["baz"] += 1.5 +mixed_type_counter # pyright: ignore[reportUnusedExpression] # mypy: `"Counter[str, float]"`; pyright: `Counter[str, int | float]` + +# Check ORing and ANDing Counters with different value types +# MyPy and Pyright infer the types differently for these, so we can't use assert_type. + +_ = mixed_type_counter or decimal_counts +_ # pyright: ignore[reportUnusedExpression] # mypy: `"Counter[str, float] | Counter[str, Decimal]"`; pyright: `Counter[str, int | float] | Counter[str, Decimal]` + +_ = decimal_counts or mixed_type_counter +_ # pyright: ignore[reportUnusedExpression] # mypy: `"Counter[str, Decimal] | Counter[str, float]"`; pyright: `Counter[str, Decimal] | Counter[str, int | float]` + +_ = mixed_type_counter and decimal_counts +_ # pyright: ignore[reportUnusedExpression] # mypy: `"Counter[str, float] | Counter[str, Decimal]"`; pyright: `Counter[str, int | float] | Counter[str, Decimal]` + +_ = decimal_counts and mixed_type_counter +_ # pyright: ignore[reportUnusedExpression] # mypy: `"Counter[str, Decimal] | Counter[str, float]"`; pyright: `Counter[str, Decimal] | Counter[str, int | float]` + +# We shouldn't be able to add Counters with incompatible value types +_ = mixed_type_counter + decimal_counts # type: ignore +mixed_type_counter += decimal_counts # type: ignore + +# Adding Counters with compatible types +_wc = word_counts + Counter({"foo": 2, "baz": 1}) +word_counts += Counter({"foo": 2, "baz": 1}) + +# Combining Counters of different key types +integer_key_counter = Counter({1: 2, 2: 3}) +combined_word_and_integer_keys = word_counts + integer_key_counter +assert_type(combined_word_and_integer_keys, "Counter[str | int, int]") diff --git a/stdlib/collections/__init__.pyi b/stdlib/collections/__init__.pyi index 71e3c564dd57..e48fcfcd3969 100644 --- a/stdlib/collections/__init__.pyi +++ b/stdlib/collections/__init__.pyi @@ -1,8 +1,8 @@ import sys from _collections_abc import dict_items, dict_keys, dict_values from _typeshed import SupportsItems, SupportsKeysAndGetItem, SupportsRichComparison, SupportsRichComparisonT -from typing import Any, Generic, NoReturn, SupportsIndex, TypeVar, final, overload -from typing_extensions import Self +from typing import Any, Generic, NoReturn, SupportsIndex, final, overload +from typing_extensions import Self, TypeVar if sys.version_info >= (3, 9): from types import GenericAlias @@ -28,6 +28,8 @@ __all__ = ["ChainMap", "Counter", "OrderedDict", "UserDict", "UserList", "UserSt _S = TypeVar("_S") _T = TypeVar("_T") +_V = TypeVar("_V") +_V_I = TypeVar("_V_I", default=int) _T1 = TypeVar("_T1") _T2 = TypeVar("_T2") _KT = TypeVar("_KT") @@ -273,24 +275,24 @@ class deque(MutableSequence[_T]): if sys.version_info >= (3, 9): def __class_getitem__(cls, item: Any, /) -> GenericAlias: ... -class Counter(dict[_T, int], Generic[_T]): +class Counter(dict[_T, _V_I | int]): @overload def __init__(self, iterable: None = None, /) -> None: ... @overload - def __init__(self: Counter[str], iterable: None = None, /, **kwargs: int) -> None: ... + def __init__(self: Counter[str, Any], iterable: None = None, /, **kwargs: _V_I) -> None: ... @overload - def __init__(self, mapping: SupportsKeysAndGetItem[_T, int], /) -> None: ... + def __init__(self, mapping: SupportsKeysAndGetItem[_T, _V_I], /) -> None: ... @overload def __init__(self, iterable: Iterable[_T], /) -> None: ... def copy(self) -> Self: ... def elements(self) -> Iterator[_T]: ... - def most_common(self, n: int | None = None) -> list[tuple[_T, int]]: ... + def most_common(self, n: int | None = None) -> list[tuple[_T, _V_I]]: ... @classmethod - def fromkeys(cls, iterable: Any, v: int | None = None) -> NoReturn: ... # type: ignore[override] + def fromkeys(cls, iterable: Any, v: _V_I | None = None) -> NoReturn: ... # type: ignore[override] @overload def subtract(self, iterable: None = None, /) -> None: ... @overload - def subtract(self, mapping: Mapping[_T, int], /) -> None: ... + def subtract(self, mapping: Mapping[_T, _V_I], /) -> None: ... @overload def subtract(self, iterable: Iterable[_T], /) -> None: ... # Unlike dict.update(), use Mapping instead of SupportsKeysAndGetItem for the first overload @@ -300,34 +302,34 @@ class Counter(dict[_T, int], Generic[_T]): # (if it were `Iterable[_T] | Iterable[tuple[_T, int]]`, # the tuples would be added as keys, breaking type safety) @overload # type: ignore[override] - def update(self, m: Mapping[_T, int], /, **kwargs: int) -> None: ... + def update(self, m: Mapping[_T, int], /, **kwargs: _V_I) -> None: ... @overload - def update(self, iterable: Iterable[_T], /, **kwargs: int) -> None: ... + def update(self, iterable: Iterable[_T], /, **kwargs: _V_I) -> None: ... @overload - def update(self, iterable: None = None, /, **kwargs: int) -> None: ... - def __missing__(self, key: _T) -> int: ... + def update(self, iterable: None = None, /, **kwargs: _V_I) -> None: ... + def __missing__(self, key: _T) -> _V_I | int: ... def __delitem__(self, elem: object) -> None: ... if sys.version_info >= (3, 10): def __eq__(self, other: object) -> bool: ... def __ne__(self, other: object) -> bool: ... - def __add__(self, other: Counter[_S]) -> Counter[_T | _S]: ... - def __sub__(self, other: Counter[_T]) -> Counter[_T]: ... - def __and__(self, other: Counter[_T]) -> Counter[_T]: ... - def __or__(self, other: Counter[_S]) -> Counter[_T | _S]: ... # type: ignore[override] - def __pos__(self) -> Counter[_T]: ... - def __neg__(self) -> Counter[_T]: ... + def __add__(self, other: Counter[_S, _V_I]) -> Counter[_T | _S, _V_I]: ... + def __sub__(self, other: Counter[_T, _V_I]) -> Counter[_T, _V_I]: ... + def __and__(self, other: Counter[_T, _V_I]) -> Counter[_T, _V_I]: ... + def __or__(self, other: Counter[_S, _V]) -> Counter[_T | _S, _V_I | _V]: ... # type: ignore[override] + def __pos__(self) -> Counter[_T, _V_I]: ... + def __neg__(self) -> Counter[_T, _V_I]: ... # several type: ignores because __iadd__ is supposedly incompatible with __add__, etc. - def __iadd__(self, other: SupportsItems[_T, int]) -> Self: ... # type: ignore[misc] - def __isub__(self, other: SupportsItems[_T, int]) -> Self: ... - def __iand__(self, other: SupportsItems[_T, int]) -> Self: ... - def __ior__(self, other: SupportsItems[_T, int]) -> Self: ... # type: ignore[override,misc] + def __iadd__(self, other: SupportsItems[_T, _V_I]) -> Self: ... # type: ignore[misc] + def __isub__(self, other: SupportsItems[_T, _V_I]) -> Self: ... # type: ignore[misc] + def __iand__(self, other: SupportsItems[_T, _V_I]) -> Self: ... # type: ignore[misc] + def __ior__(self, other: SupportsItems[_T, _V_I]) -> Self: ... # type: ignore[override,misc] if sys.version_info >= (3, 10): def total(self) -> int: ... - def __le__(self, other: Counter[Any]) -> bool: ... - def __lt__(self, other: Counter[Any]) -> bool: ... - def __ge__(self, other: Counter[Any]) -> bool: ... - def __gt__(self, other: Counter[Any]) -> bool: ... + def __le__(self, other: Counter[Any, _V_I]) -> bool: ... + def __lt__(self, other: Counter[Any, _V_I]) -> bool: ... + def __ge__(self, other: Counter[Any, _V_I]) -> bool: ... + def __gt__(self, other: Counter[Any, _V_I]) -> bool: ... # The pure-Python implementations of the "views" classes # These are exposed at runtime in `collections/__init__.py`