diff --git a/stdlib/@tests/test_cases/check_dataclasses.py b/stdlib/@tests/test_cases/check_dataclasses.py index 917331e4c084..1594400e44fe 100644 --- a/stdlib/@tests/test_cases/check_dataclasses.py +++ b/stdlib/@tests/test_cases/check_dataclasses.py @@ -91,6 +91,16 @@ def check_other_isdataclass_overloads(x: type, y: object) -> None: dc.replace(y) +class _D: ... + + +custom_dc = dc.dataclass(_D, init=True) +assert_type(custom_dc, type[_D]) + +custom_dc_2 = dc.dataclass(None, init=True)(_D) +assert_type(custom_dc_2, type[_D]) + + # Regression test for #11653 D = dc.make_dataclass( "D", [("a", Union[int, None]), "y", ("z", Annotated[FrozenSet[bytes], "metadata"], dc.field(default=frozenset({b"foo"})))] diff --git a/stdlib/dataclasses.pyi b/stdlib/dataclasses.pyi index bba76c1af1b4..d69bec7f68d7 100644 --- a/stdlib/dataclasses.pyi +++ b/stdlib/dataclasses.pyi @@ -71,14 +71,28 @@ def asdict(obj: DataclassInstance, *, dict_factory: Callable[[list[tuple[str, An def astuple(obj: DataclassInstance) -> tuple[Any, ...]: ... @overload def astuple(obj: DataclassInstance, *, tuple_factory: Callable[[list[Any]], _T]) -> _T: ... -@overload -def dataclass(cls: None, /) -> Callable[[type[_T]], type[_T]]: ... -@overload -def dataclass(cls: type[_T], /) -> type[_T]: ... if sys.version_info >= (3, 11): @overload def dataclass( + cls: type[_T], + /, + *, + init: bool = True, + repr: bool = True, + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + frozen: bool = False, + match_args: bool = True, + kw_only: bool = False, + slots: bool = False, + weakref_slot: bool = False, + ) -> type[_T]: ... + @overload + def dataclass( + cls: None = None, + /, *, init: bool = True, repr: bool = True, @@ -95,6 +109,23 @@ if sys.version_info >= (3, 11): elif sys.version_info >= (3, 10): @overload def dataclass( + cls: type[_T], + /, + *, + init: bool = True, + repr: bool = True, + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + frozen: bool = False, + match_args: bool = True, + kw_only: bool = False, + slots: bool = False, + ) -> type[_T]: ... + @overload + def dataclass( + cls: None = None, + /, *, init: bool = True, repr: bool = True, @@ -110,6 +141,20 @@ elif sys.version_info >= (3, 10): else: @overload def dataclass( + cls: type[_T], + /, + *, + init: bool = True, + repr: bool = True, + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + frozen: bool = False, + ) -> type[_T]: ... + @overload + def dataclass( + cls: None = None, + /, *, init: bool = True, repr: bool = True,