Skip to content

Commit 4f1fb49

Browse files
committed
fix: work in progress
1 parent 7da086e commit 4f1fb49

File tree

2 files changed

+60
-59
lines changed

2 files changed

+60
-59
lines changed

pandas-stubs/core/frame.pyi

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@ from builtins import (
22
bool as _bool,
33
str as _str,
44
)
5-
from collections import defaultdict
5+
from collections import (
6+
OrderedDict,
7+
defaultdict,
8+
)
69
from collections.abc import (
710
Callable,
811
Hashable,
@@ -398,7 +401,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
398401
@overload
399402
def to_dict(
400403
self,
401-
orient=...,
404+
orient: str = ...,
402405
*,
403406
into: type[defaultdict],
404407
index: Literal[True] = ...,
@@ -424,33 +427,33 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
424427
self,
425428
orient: Literal["index"],
426429
*,
427-
into: _T_MUTABLE_MAPPING | type[_T_MUTABLE_MAPPING],
430+
into: defaultdict,
428431
index: Literal[True] = ...,
429-
) -> MutableMapping[Hashable, _T_MUTABLE_MAPPING]: ...
432+
) -> defaultdict[Hashable, dict[Hashable, Any]]: ...
430433
@overload
431434
def to_dict(
432435
self,
433436
orient: Literal["index"],
434437
*,
435-
into: type[dict] = ...,
438+
into: OrderedDict | type[OrderedDict],
436439
index: Literal[True] = ...,
437-
) -> dict[Hashable, dict[Hashable, Any]]: ...
440+
) -> OrderedDict[Hashable, dict[Hashable, Any]]: ...
438441
@overload
439442
def to_dict(
440443
self,
441-
orient: Literal["dict", "list", "series"],
444+
orient: Literal["index"],
442445
*,
443446
into: _T_MUTABLE_MAPPING | type[_T_MUTABLE_MAPPING],
444447
index: Literal[True] = ...,
445448
) -> _T_MUTABLE_MAPPING: ...
446449
@overload
447450
def to_dict(
448451
self,
449-
orient: Literal["split", "tight"],
452+
orient: Literal["index"],
450453
*,
451-
into: _T_MUTABLE_MAPPING | type[_T_MUTABLE_MAPPING],
452-
index: bool = ...,
453-
) -> _T_MUTABLE_MAPPING: ...
454+
into: type[dict] = ...,
455+
index: Literal[True] = ...,
456+
) -> dict[Hashable, dict[Hashable, Any]]: ...
454457
@overload
455458
def to_dict(
456459
self,
@@ -460,14 +463,6 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
460463
index: Literal[True] = ...,
461464
) -> _T_MUTABLE_MAPPING: ...
462465
@overload
463-
def to_dict(
464-
self,
465-
orient: Literal["split", "tight"] = ...,
466-
*,
467-
into: _T_MUTABLE_MAPPING | type[_T_MUTABLE_MAPPING],
468-
index: bool = ...,
469-
) -> _T_MUTABLE_MAPPING: ...
470-
@overload
471466
def to_dict(
472467
self,
473468
orient: Literal["dict", "list", "series"] = ...,
@@ -478,11 +473,19 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
478473
@overload
479474
def to_dict(
480475
self,
481-
orient: Literal["split", "tight"] = ...,
476+
orient: Literal["split", "tight"],
477+
*,
478+
into: MutableMapping | type[MutableMapping],
479+
index: bool = ...,
480+
) -> MutableMapping[str, list]: ...
481+
@overload
482+
def to_dict(
483+
self,
484+
orient: Literal["split", "tight"],
482485
*,
483486
into: type[dict] = ...,
484487
index: bool = ...,
485-
) -> dict[Hashable, Any]: ...
488+
) -> dict[str, list]: ...
486489
def to_gbq(
487490
self,
488491
destination_table: str,

tests/test_frame.py

Lines changed: 36 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from __future__ import annotations
22

3-
from collections import defaultdict
3+
from collections import (
4+
OrderedDict,
5+
defaultdict,
6+
)
47
from collections.abc import (
58
Callable,
69
Hashable,
@@ -3641,8 +3644,13 @@ def test_to_records() -> None:
36413644

36423645
def test_to_dict_simple() -> None:
36433646
check(assert_type(DF.to_dict(), dict[Hashable, Any]), dict)
3644-
check(assert_type(DF.to_dict("split"), dict[Hashable, Any]), dict)
36453647
check(assert_type(DF.to_dict("records"), list[dict[Hashable, Any]]), list)
3648+
check(assert_type(DF.to_dict("index"), dict[Hashable, dict[Hashable, Any]]), dict)
3649+
check(assert_type(DF.to_dict("dict"), dict[Hashable, Any]), dict)
3650+
check(assert_type(DF.to_dict("list"), dict[Hashable, Any]), dict)
3651+
check(assert_type(DF.to_dict("series"), dict[Hashable, Any]), dict)
3652+
check(assert_type(DF.to_dict("split"), dict[str, list]), dict, str)
3653+
check(assert_type(DF.to_dict("tight"), dict[str, list]), dict, str)
36463654

36473655
if TYPE_CHECKING_INVALID_USAGE:
36483656

@@ -3661,69 +3669,59 @@ def test(mapping: Mapping) -> None: # pyright: ignore[reportUnusedFunction]
36613669
assert_type(DF.to_dict("tight", into=defaultdict), Never)
36623670

36633671

3664-
def test_to_dict_into_defaultdict_any() -> None:
3665-
"""Test DataFrame.to_dict with `into=defaultdict[Any, list]`"""
3672+
def test_to_dict_into_defaultdict() -> None:
3673+
"""Test DataFrame.to_dict with `into` is an instance of defaultdict[Any, list]"""
36663674

36673675
data = pd.DataFrame({("str", "rts"): [[1, 2, 4], [2, 3], [3]]})
3668-
target: defaultdict[Hashable, list[Any]] = defaultdict(list)
3676+
target: defaultdict[Any, list] = defaultdict(list)
36693677

36703678
check(
3671-
assert_type(data.to_dict(into=target), defaultdict[Hashable, list[Any]]),
3679+
assert_type(data.to_dict(into=target), defaultdict[Any, list]),
36723680
defaultdict,
3681+
tuple,
36733682
)
36743683
check(
3675-
assert_type(
3684+
assert_type( # type: ignore[assert-type]
36763685
data.to_dict("index", into=target),
3677-
MutableMapping[Hashable, defaultdict[Hashable, list[Any]]],
3686+
defaultdict[Hashable, dict[Hashable, Any]],
36783687
),
36793688
defaultdict,
36803689
)
36813690
check(
3682-
assert_type(
3683-
data.to_dict("tight", into=target), defaultdict[Hashable, list[Any]]
3684-
),
3691+
assert_type(data.to_dict("tight", into=target), MutableMapping[str, list]),
36853692
defaultdict,
3693+
str,
36863694
)
36873695
check(
3688-
assert_type(
3689-
data.to_dict("records", into=target), list[defaultdict[Hashable, list[Any]]]
3690-
),
3696+
assert_type(data.to_dict("records", into=target), list[defaultdict[Any, list]]),
36913697
list,
3698+
defaultdict,
36923699
)
36933700

36943701

3695-
def test_to_dict_into_defaultdict_typed() -> None:
3696-
"""Test DataFrame.to_dict with `into=defaultdict[tuple[str, str], list[int]]`"""
3702+
def test_to_dict_into_ordered_dict() -> None:
3703+
"""Test DataFrame.to_dict with `into=OrderedDict`"""
36973704

36983705
data = pd.DataFrame({("str", "rts"): [[1, 2, 4], [2, 3], [3]]})
3699-
target: defaultdict[tuple[str, str], list[int]] = defaultdict(list)
3700-
target[("str", "rts")].append(1)
3706+
target = OrderedDict
37013707

3702-
check(
3703-
assert_type(data.to_dict(into=target), defaultdict[tuple[str, str], list[int]]),
3704-
defaultdict,
3705-
tuple,
3706-
)
3708+
check(assert_type(data.to_dict(into=target), OrderedDict), OrderedDict, tuple) # type: ignore[assert-type]
37073709
check(
37083710
assert_type(
37093711
data.to_dict("index", into=target),
3710-
MutableMapping[Hashable, defaultdict[tuple[str, str], list[int]]],
3712+
OrderedDict[Hashable, dict[Hashable, Any]],
37113713
),
3712-
defaultdict,
3714+
OrderedDict,
37133715
)
37143716
check(
3715-
assert_type(
3716-
data.to_dict("tight", into=target), defaultdict[tuple[str, str], list[int]]
3717-
),
3718-
defaultdict,
3717+
assert_type(data.to_dict("tight", into=target), MutableMapping[str, list]),
3718+
OrderedDict,
3719+
str,
37193720
)
37203721
check(
3721-
assert_type(
3722-
data.to_dict("records", into=target),
3723-
list[defaultdict[tuple[str, str], list[int]]],
3724-
),
3722+
assert_type(data.to_dict("records", into=target), list[OrderedDict]), # type: ignore[assert-type]
37253723
list,
3726-
defaultdict,
3724+
OrderedDict,
37273725
)
37283726

37293727

@@ -4177,16 +4175,16 @@ def test_to_dict_index() -> None:
41774175
dict,
41784176
)
41794177
check(
4180-
assert_type(df.to_dict(orient="split", index=True), dict[Hashable, Any]), dict
4178+
assert_type(df.to_dict(orient="split", index=True), dict[str, list]), dict, str
41814179
)
41824180
check(
4183-
assert_type(df.to_dict(orient="tight", index=True), dict[Hashable, Any]), dict
4181+
assert_type(df.to_dict(orient="tight", index=True), dict[str, list]), dict, str
41844182
)
41854183
check(
4186-
assert_type(df.to_dict(orient="tight", index=False), dict[Hashable, Any]), dict
4184+
assert_type(df.to_dict(orient="tight", index=False), dict[str, list]), dict, str
41874185
)
41884186
check(
4189-
assert_type(df.to_dict(orient="split", index=False), dict[Hashable, Any]), dict
4187+
assert_type(df.to_dict(orient="split", index=False), dict[str, list]), dict, str
41904188
)
41914189
if TYPE_CHECKING_INVALID_USAGE:
41924190
check(assert_type(df.to_dict(orient="records", index=False), list[dict[Hashable, Any]]), list) # type: ignore[assert-type, call-overload] # pyright: ignore[reportArgumentType,reportAssertTypeFailure,reportCallIssue]

0 commit comments

Comments
 (0)