Skip to content

Commit

Permalink
TYP: Correct type annotation for to_dict. (#55130)
Browse files Browse the repository at this point in the history
* Correct type annotation for to_dict.

The `into` argument of DataFrame.to_dict and Series.to_dict can be
either a class or instance of a class of dict; this is covariant -
subclasses of dict can also be used. The argument was annotated as
`type[dict]` though, so type checkers marked passing initialized objects
(required for collections.defaultdict) as an incorrect argument type.

Fix by annotating `into` to take either a subclass of dict or an
initialized instance of a subclass of dict.

* Use generic MutableMapping type for to_dict method.

Unfortunately a generic type annotation with a default triggers an
existing mypy limitation (python/mypy#3737).
The current workaround is to use overloads and then not annotate the
implementation containing the default parameter; this still enables mypy
to deduce correct return types.

Two overloads are added for Series.to_dict, even though they could be
combined using a Union type, as at least two overloads are required for
a single method.

* Fix formatting

* return annotation for non-overload

* no keyword should return dict

* swap overload order to work for dict subclasses that are passed as keywords

* fix tests

---------

Co-authored-by: Torsten Wörtwein <twoertwein@gmail.com>
  • Loading branch information
jsspencer and twoertwein authored Oct 2, 2023
1 parent 9282d9f commit a0a6e04
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 23 deletions.
2 changes: 2 additions & 0 deletions pandas/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Hashable,
Iterator,
Mapping,
MutableMapping,
Sequence,
)
from datetime import (
Expand Down Expand Up @@ -103,6 +104,7 @@
TypeGuard: Any = None

HashableT = TypeVar("HashableT", bound=Hashable)
MutableMappingT = TypeVar("MutableMappingT", bound=MutableMapping)

# array-like

Expand Down
41 changes: 34 additions & 7 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@
Level,
MergeHow,
MergeValidate,
MutableMappingT,
NaAction,
NaPosition,
NsmallestNlargestKeep,
Expand Down Expand Up @@ -1927,6 +1928,27 @@ def _create_data_for_split_and_tight_to_dict(
def to_dict(
self,
orient: Literal["dict", "list", "series", "split", "tight", "index"] = ...,
*,
into: type[MutableMappingT] | MutableMappingT,
index: bool = ...,
) -> MutableMappingT:
...

@overload
def to_dict(
self,
orient: Literal["records"],
*,
into: type[MutableMappingT] | MutableMappingT,
index: bool = ...,
) -> list[MutableMappingT]:
...

@overload
def to_dict(
self,
orient: Literal["dict", "list", "series", "split", "tight", "index"] = ...,
*,
into: type[dict] = ...,
index: bool = ...,
) -> dict:
Expand All @@ -1936,11 +1958,14 @@ def to_dict(
def to_dict(
self,
orient: Literal["records"],
*,
into: type[dict] = ...,
index: bool = ...,
) -> list[dict]:
...

# error: Incompatible default for argument "into" (default has type "type
# [dict[Any, Any]]", argument has type "type[MutableMappingT] | MutableMappingT")
@deprecate_nonkeyword_arguments(
version="3.0", allowed_args=["self", "orient"], name="to_dict"
)
Expand All @@ -1949,9 +1974,10 @@ def to_dict(
orient: Literal[
"dict", "list", "series", "split", "tight", "records", "index"
] = "dict",
into: type[dict] = dict,
into: type[MutableMappingT]
| MutableMappingT = dict, # type: ignore[assignment]
index: bool = True,
) -> dict | list[dict]:
) -> MutableMappingT | list[MutableMappingT]:
"""
Convert the DataFrame to a dictionary.
Expand Down Expand Up @@ -1979,7 +2005,7 @@ def to_dict(
'tight' as an allowed value for the ``orient`` argument
into : class, default dict
The collections.abc.Mapping subclass used for all Mappings
The collections.abc.MutableMapping subclass used for all Mappings
in the return value. Can be the actual class or an empty
instance of the mapping type you want. If you want a
collections.defaultdict, you must pass it initialized.
Expand All @@ -1993,9 +2019,10 @@ def to_dict(
Returns
-------
dict, list or collections.abc.Mapping
Return a collections.abc.Mapping object representing the DataFrame.
The resulting transformation depends on the `orient` parameter.
dict, list or collections.abc.MutableMapping
Return a collections.abc.MutableMapping object representing the
DataFrame. The resulting transformation depends on the `orient`
parameter.
See Also
--------
Expand Down Expand Up @@ -2054,7 +2081,7 @@ def to_dict(
"""
from pandas.core.methods.to_dict import to_dict

return to_dict(self, orient, into, index)
return to_dict(self, orient, into=into, index=index)

@deprecate_nonkeyword_arguments(
version="3.0", allowed_args=["self", "destination_table"], name="to_gbq"
Expand Down
62 changes: 56 additions & 6 deletions pandas/core/methods/to_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import (
TYPE_CHECKING,
Literal,
overload,
)
import warnings

Expand All @@ -16,17 +17,66 @@
from pandas.core import common as com

if TYPE_CHECKING:
from pandas._typing import MutableMappingT

from pandas import DataFrame


@overload
def to_dict(
df: DataFrame,
orient: Literal["dict", "list", "series", "split", "tight", "index"] = ...,
*,
into: type[MutableMappingT] | MutableMappingT,
index: bool = ...,
) -> MutableMappingT:
...


@overload
def to_dict(
df: DataFrame,
orient: Literal["records"],
*,
into: type[MutableMappingT] | MutableMappingT,
index: bool = ...,
) -> list[MutableMappingT]:
...


@overload
def to_dict(
df: DataFrame,
orient: Literal["dict", "list", "series", "split", "tight", "index"] = ...,
*,
into: type[dict] = ...,
index: bool = ...,
) -> dict:
...


@overload
def to_dict(
df: DataFrame,
orient: Literal["records"],
*,
into: type[dict] = ...,
index: bool = ...,
) -> list[dict]:
...


# error: Incompatible default for argument "into" (default has type "type[dict
# [Any, Any]]", argument has type "type[MutableMappingT] | MutableMappingT")
def to_dict(
df: DataFrame,
orient: Literal[
"dict", "list", "series", "split", "tight", "records", "index"
] = "dict",
into: type[dict] = dict,
*,
into: type[MutableMappingT] | MutableMappingT = dict, # type: ignore[assignment]
index: bool = True,
) -> dict | list[dict]:
) -> MutableMappingT | list[MutableMappingT]:
"""
Convert the DataFrame to a dictionary.
Expand Down Expand Up @@ -54,7 +104,7 @@ def to_dict(
'tight' as an allowed value for the ``orient`` argument
into : class, default dict
The collections.abc.Mapping subclass used for all Mappings
The collections.abc.MutableMapping subclass used for all Mappings
in the return value. Can be the actual class or an empty
instance of the mapping type you want. If you want a
collections.defaultdict, you must pass it initialized.
Expand All @@ -69,8 +119,8 @@ def to_dict(
Returns
-------
dict, list or collections.abc.Mapping
Return a collections.abc.Mapping object representing the DataFrame.
The resulting transformation depends on the `orient` parameter.
Return a collections.abc.MutableMapping object representing the
DataFrame. The resulting transformation depends on the `orient` parameter.
"""
if not df.columns.is_unique:
warnings.warn(
Expand Down Expand Up @@ -103,7 +153,7 @@ def to_dict(
are_all_object_dtype_cols = len(box_native_indices) == len(df.dtypes)

if orient == "dict":
return into_c((k, v.to_dict(into)) for k, v in df.items())
return into_c((k, v.to_dict(into=into)) for k, v in df.items())

elif orient == "list":
object_dtype_indices_as_set: set[int] = set(box_native_indices)
Expand Down
37 changes: 29 additions & 8 deletions pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from pandas.util._decorators import (
Appender,
Substitution,
deprecate_nonkeyword_arguments,
doc,
)
from pandas.util._exceptions import find_stack_level
Expand Down Expand Up @@ -167,6 +168,7 @@
IndexKeyFunc,
IndexLabel,
Level,
MutableMappingT,
NaPosition,
NumpySorter,
NumpyValueArrayLike,
Expand Down Expand Up @@ -1922,21 +1924,40 @@ def keys(self) -> Index:
"""
return self.index

def to_dict(self, into: type[dict] = dict) -> dict:
@overload
def to_dict(
self, *, into: type[MutableMappingT] | MutableMappingT
) -> MutableMappingT:
...

@overload
def to_dict(self, *, into: type[dict] = ...) -> dict:
...

# error: Incompatible default for argument "into" (default has type "type[
# dict[Any, Any]]", argument has type "type[MutableMappingT] | MutableMappingT")
@deprecate_nonkeyword_arguments(
version="3.0", allowed_args=["self"], name="to_dict"
)
def to_dict(
self,
into: type[MutableMappingT]
| MutableMappingT = dict, # type: ignore[assignment]
) -> MutableMappingT:
"""
Convert Series to {label -> value} dict or dict-like object.
Parameters
----------
into : class, default dict
The collections.abc.Mapping subclass to use as the return
object. Can be the actual class or an empty
instance of the mapping type you want. If you want a
collections.defaultdict, you must pass it initialized.
The collections.abc.MutableMapping subclass to use as the return
object. Can be the actual class or an empty instance of the mapping
type you want. If you want a collections.defaultdict, you must
pass it initialized.
Returns
-------
collections.abc.Mapping
collections.abc.MutableMapping
Key-value representation of Series.
Examples
Expand All @@ -1945,10 +1966,10 @@ def to_dict(self, into: type[dict] = dict) -> dict:
>>> s.to_dict()
{0: 1, 1: 2, 2: 3, 3: 4}
>>> from collections import OrderedDict, defaultdict
>>> s.to_dict(OrderedDict)
>>> s.to_dict(into=OrderedDict)
OrderedDict([(0, 1), (1, 2), (2, 3), (3, 4)])
>>> dd = defaultdict(list)
>>> s.to_dict(dd)
>>> s.to_dict(into=dd)
defaultdict(<class 'list'>, {0: 1, 1: 2, 2: 3, 3: 4})
"""
# GH16122
Expand Down
4 changes: 2 additions & 2 deletions pandas/tests/series/methods/test_to_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ class TestSeriesToDict:
)
def test_to_dict(self, mapping, datetime_series):
# GH#16122
result = Series(datetime_series.to_dict(mapping), name="ts")
result = Series(datetime_series.to_dict(into=mapping), name="ts")
expected = datetime_series.copy()
expected.index = expected.index._with_freq(None)
tm.assert_series_equal(result, expected)

from_method = Series(datetime_series.to_dict(collections.Counter))
from_method = Series(datetime_series.to_dict(into=collections.Counter))
from_constructor = Series(collections.Counter(datetime_series.items()))
tm.assert_series_equal(from_method, from_constructor)

Expand Down

0 comments on commit a0a6e04

Please sign in to comment.