Skip to content

TYP: reshape.merge #53780

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 22, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 51 additions & 41 deletions pandas/core/reshape/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""
from __future__ import annotations

import copy as cp
import datetime
from functools import partial
import string
Expand All @@ -13,6 +12,7 @@
Literal,
Sequence,
cast,
final,
)
import uuid
import warnings
Expand Down Expand Up @@ -655,8 +655,8 @@ class _MergeOperation:
indicator: str | bool
validate: str | None
join_names: list[Hashable]
right_join_keys: list[AnyArrayLike]
left_join_keys: list[AnyArrayLike]
right_join_keys: list[ArrayLike]
left_join_keys: list[ArrayLike]

def __init__(
self,
Expand Down Expand Up @@ -743,6 +743,7 @@ def __init__(
if validate is not None:
self._validate(validate)

@final
def _reindex_and_concat(
self,
join_index: Index,
Expand Down Expand Up @@ -821,12 +822,14 @@ def get_result(self, copy: bool | None = True) -> DataFrame:

return result.__finalize__(self, method="merge")

@final
def _maybe_drop_cross_column(
self, result: DataFrame, cross_col: str | None
) -> None:
if cross_col is not None:
del result[cross_col]

@final
@cache_readonly
def _indicator_name(self) -> str | None:
if isinstance(self.indicator, str):
Expand All @@ -838,6 +841,7 @@ def _indicator_name(self) -> str | None:
"indicator option can only accept boolean or string arguments"
)

@final
def _indicator_pre_merge(
self, left: DataFrame, right: DataFrame
) -> tuple[DataFrame, DataFrame]:
Expand Down Expand Up @@ -865,6 +869,7 @@ def _indicator_pre_merge(

return left, right

@final
def _indicator_post_merge(self, result: DataFrame) -> DataFrame:
result["_left_indicator"] = result["_left_indicator"].fillna(0)
result["_right_indicator"] = result["_right_indicator"].fillna(0)
Expand All @@ -880,6 +885,7 @@ def _indicator_post_merge(self, result: DataFrame) -> DataFrame:
result = result.drop(labels=["_left_indicator", "_right_indicator"], axis=1)
return result

@final
def _maybe_restore_index_levels(self, result: DataFrame) -> None:
"""
Restore index levels specified as `on` parameters
Expand Down Expand Up @@ -923,11 +929,12 @@ def _maybe_restore_index_levels(self, result: DataFrame) -> None:
if names_to_restore:
result.set_index(names_to_restore, inplace=True)

@final
def _maybe_add_join_keys(
self,
result: DataFrame,
left_indexer: np.ndarray | None,
right_indexer: np.ndarray | None,
left_indexer: npt.NDArray[np.intp] | None,
right_indexer: npt.NDArray[np.intp] | None,
) -> None:
left_has_missing = None
right_has_missing = None
Expand Down Expand Up @@ -1032,6 +1039,7 @@ def _get_join_indexers(self) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]
self.left_join_keys, self.right_join_keys, sort=self.sort, how=self.how
)

@final
def _get_join_info(
self,
) -> tuple[Index, npt.NDArray[np.intp] | None, npt.NDArray[np.intp] | None]:
Expand Down Expand Up @@ -1093,6 +1101,7 @@ def _get_join_info(
join_index = default_index(0).set_names(join_index.name)
return join_index, left_indexer, right_indexer

@final
def _create_join_index(
self,
index: Index,
Expand Down Expand Up @@ -1129,7 +1138,7 @@ def _create_join_index(

def _get_merge_keys(
self,
) -> tuple[list[AnyArrayLike], list[AnyArrayLike], list[Hashable]]:
) -> tuple[list[ArrayLike], list[ArrayLike], list[Hashable]]:
"""
Note: has side effects (copy/delete key columns)

Expand All @@ -1145,8 +1154,8 @@ def _get_merge_keys(
"""
# left_keys, right_keys entries can actually be anything listlike
# with a 'dtype' attr
left_keys: list[AnyArrayLike] = []
right_keys: list[AnyArrayLike] = []
left_keys: list[ArrayLike] = []
right_keys: list[ArrayLike] = []
join_names: list[Hashable] = []
right_drop: list[Hashable] = []
left_drop: list[Hashable] = []
Expand All @@ -1169,11 +1178,13 @@ def _get_merge_keys(
# ugh, spaghetti re #733
if _any(self.left_on) and _any(self.right_on):
for lk, rk in zip(self.left_on, self.right_on):
lk = extract_array(lk, extract_numpy=True)
rk = extract_array(rk, extract_numpy=True)
if is_lkey(lk):
lk = cast(AnyArrayLike, lk)
lk = cast(ArrayLike, lk)
left_keys.append(lk)
if is_rkey(rk):
rk = cast(AnyArrayLike, rk)
rk = cast(ArrayLike, rk)
right_keys.append(rk)
join_names.append(None) # what to do?
else:
Expand All @@ -1185,7 +1196,7 @@ def _get_merge_keys(
join_names.append(rk)
else:
# work-around for merge_asof(right_index=True)
right_keys.append(right.index)
right_keys.append(right.index._values)
join_names.append(right.index.name)
else:
if not is_rkey(rk):
Expand All @@ -1196,7 +1207,7 @@ def _get_merge_keys(
right_keys.append(right._get_label_or_level_values(rk))
else:
# work-around for merge_asof(right_index=True)
right_keys.append(right.index)
right_keys.append(right.index._values)
if lk is not None and lk == rk: # FIXME: what about other NAs?
# avoid key upcast in corner case (length-0)
lk = cast(Hashable, lk)
Expand All @@ -1205,7 +1216,7 @@ def _get_merge_keys(
else:
left_drop.append(lk)
else:
rk = cast(AnyArrayLike, rk)
rk = cast(ArrayLike, rk)
right_keys.append(rk)
if lk is not None:
# Then we're either Hashable or a wrong-length arraylike,
Expand All @@ -1215,12 +1226,13 @@ def _get_merge_keys(
join_names.append(lk)
else:
# work-around for merge_asof(left_index=True)
left_keys.append(left.index)
left_keys.append(left.index._values)
join_names.append(left.index.name)
elif _any(self.left_on):
for k in self.left_on:
if is_lkey(k):
k = cast(AnyArrayLike, k)
k = extract_array(k, extract_numpy=True)
k = cast(ArrayLike, k)
left_keys.append(k)
join_names.append(None)
else:
Expand All @@ -1240,8 +1252,9 @@ def _get_merge_keys(
right_keys = [self.right.index._values]
elif _any(self.right_on):
for k in self.right_on:
k = extract_array(k, extract_numpy=True)
if is_rkey(k):
k = cast(AnyArrayLike, k)
k = cast(ArrayLike, k)
right_keys.append(k)
join_names.append(None)
else:
Expand All @@ -1268,6 +1281,7 @@ def _get_merge_keys(

return left_keys, right_keys, join_names

@final
def _maybe_coerce_merge_keys(self) -> None:
# we have valid merges but we may have to further
# coerce these if they are originally incompatible types
Expand Down Expand Up @@ -1432,6 +1446,7 @@ def _maybe_coerce_merge_keys(self) -> None:
self.right = self.right.copy()
self.right[name] = self.right[name].astype(typ)

@final
def _create_cross_configuration(
self, left: DataFrame, right: DataFrame
) -> tuple[DataFrame, DataFrame, JoinHow, str]:
Expand Down Expand Up @@ -1610,11 +1625,10 @@ def _validate(self, validate: str) -> None:


def get_join_indexers(
left_keys: list[AnyArrayLike],
right_keys: list[AnyArrayLike],
left_keys: list[ArrayLike],
right_keys: list[ArrayLike],
sort: bool = False,
how: MergeHow | Literal["asof"] = "inner",
**kwargs,
) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]:
"""

Expand Down Expand Up @@ -1667,7 +1681,7 @@ def get_join_indexers(

lkey, rkey, count = _factorize_keys(lkey, rkey, sort=sort, how=how)
# preserve left frame order if how == 'left' and sort == False
kwargs = cp.copy(kwargs)
kwargs = {}
if how in ("left", "right"):
kwargs["sort"] = sort
join_func = {
Expand Down Expand Up @@ -1812,8 +1826,8 @@ def get_result(self, copy: bool | None = True) -> DataFrame:
self.left._info_axis, self.right._info_axis, self.suffixes
)

left_join_indexer: np.ndarray | None
right_join_indexer: np.ndarray | None
left_join_indexer: npt.NDArray[np.intp] | None
right_join_indexer: npt.NDArray[np.intp] | None

if self.fill_method == "ffill":
if left_indexer is None:
Expand Down Expand Up @@ -1984,7 +1998,7 @@ def _validate_left_right_on(self, left_on, right_on):

def _get_merge_keys(
self,
) -> tuple[list[AnyArrayLike], list[AnyArrayLike], list[Hashable]]:
) -> tuple[list[ArrayLike], list[ArrayLike], list[Hashable]]:
# note this function has side effects
(left_join_keys, right_join_keys, join_names) = super()._get_merge_keys()

Expand Down Expand Up @@ -2016,8 +2030,7 @@ def _get_merge_keys(
# validate tolerance; datetime.timedelta or Timedelta if we have a DTI
if self.tolerance is not None:
if self.left_index:
# Actually more specifically an Index
lt = cast(AnyArrayLike, self.left.index)
lt = self.left.index._values
else:
lt = left_join_keys[-1]

Expand All @@ -2026,19 +2039,19 @@ def _get_merge_keys(
f"with type {repr(lt.dtype)}"
)

if needs_i8_conversion(getattr(lt, "dtype", None)):
if needs_i8_conversion(lt.dtype):
if not isinstance(self.tolerance, datetime.timedelta):
raise MergeError(msg)
if self.tolerance < Timedelta(0):
raise MergeError("tolerance must be positive")

elif is_integer_dtype(lt):
elif is_integer_dtype(lt.dtype):
if not is_integer(self.tolerance):
raise MergeError(msg)
if self.tolerance < 0:
raise MergeError("tolerance must be positive")

elif is_float_dtype(lt):
elif is_float_dtype(lt.dtype):
if not is_number(self.tolerance):
raise MergeError(msg)
# error: Unsupported operand types for > ("int" and "Number")
Expand All @@ -2061,10 +2074,10 @@ def _get_merge_keys(
def _get_join_indexers(self) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]:
"""return the join indexers"""

def flip(xs: list[AnyArrayLike]) -> np.ndarray:
def flip(xs: list[ArrayLike]) -> np.ndarray:
"""unlike np.transpose, this returns an array of tuples"""

def injection(obj: AnyArrayLike):
def injection(obj: ArrayLike):
if not isinstance(obj.dtype, ExtensionDtype):
# ndarray
return obj
Expand Down Expand Up @@ -2212,11 +2225,11 @@ def injection(obj: AnyArrayLike):


def _get_multiindex_indexer(
join_keys: list[AnyArrayLike], index: MultiIndex, sort: bool
join_keys: list[ArrayLike], index: MultiIndex, sort: bool
) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]:
# left & right join labels and num. of levels at each location
mapped = (
_factorize_keys(index.levels[n], join_keys[n], sort=sort)
_factorize_keys(index.levels[n]._values, join_keys[n], sort=sort)
for n in range(index.nlevels)
)
zipped = zip(*mapped)
Expand Down Expand Up @@ -2249,7 +2262,7 @@ def _get_multiindex_indexer(


def _get_single_indexer(
join_key: AnyArrayLike, index: Index, sort: bool = False
join_key: ArrayLike, index: Index, sort: bool = False
) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]:
left_key, right_key, count = _factorize_keys(join_key, index._values, sort=sort)

Expand Down Expand Up @@ -2294,7 +2307,7 @@ def _get_no_sort_one_missing_indexer(


def _left_join_on_index(
left_ax: Index, right_ax: Index, join_keys: list[AnyArrayLike], sort: bool = False
left_ax: Index, right_ax: Index, join_keys: list[ArrayLike], sort: bool = False
) -> tuple[Index, npt.NDArray[np.intp] | None, npt.NDArray[np.intp]]:
if isinstance(right_ax, MultiIndex):
left_indexer, right_indexer = _get_multiindex_indexer(
Expand All @@ -2315,8 +2328,8 @@ def _left_join_on_index(


def _factorize_keys(
lk: AnyArrayLike,
rk: AnyArrayLike,
lk: ArrayLike,
rk: ArrayLike,
sort: bool = True,
how: MergeHow | Literal["asof"] = "inner",
) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp], int]:
Expand All @@ -2327,9 +2340,9 @@ def _factorize_keys(

Parameters
----------
lk : ndarray, ExtensionArray, Index, or Series
lk : ndarray, ExtensionArray
Left key.
rk : ndarray, ExtensionArray, Index, or Series
rk : ndarray, ExtensionArray
Right key.
sort : bool, defaults to True
If True, the encoding is done such that the unique elements in the
Expand Down Expand Up @@ -2370,9 +2383,6 @@ def _factorize_keys(
>>> pd.core.reshape.merge._factorize_keys(lk, rk, sort=False)
(array([0, 1, 2]), array([0, 1]), 3)
"""
# Some pre-processing for non-ndarray lk / rk
lk = extract_array(lk, extract_numpy=True, extract_range=True)
rk = extract_array(rk, extract_numpy=True, extract_range=True)
# TODO: if either is a RangeIndex, we can likely factorize more efficiently?

if (
Expand Down