diff --git a/pandas/_typing.py b/pandas/_typing.py index 6059bced4a7d4..b54b8d6adb50e 100644 --- a/pandas/_typing.py +++ b/pandas/_typing.py @@ -125,6 +125,9 @@ # Series is passed into a function, a Series is always returned and if a DataFrame is # passed in, a DataFrame is always returned. NDFrameT = TypeVar("NDFrameT", bound="NDFrame") +# same as NDFrameT, needed when binding two pairs of parameters to potentially +# separate NDFrame-subclasses (see NDFrame.align) +NDFrameTb = TypeVar("NDFrameTb", bound="NDFrame") NumpyIndexT = TypeVar("NumpyIndexT", np.ndarray, "Index") diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 3a48c3b88e071..70019030da182 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -225,6 +225,7 @@ Level, MergeHow, NaPosition, + NDFrameT, PythonFuncType, QuantileInterpolation, ReadBuffer, @@ -4997,7 +4998,7 @@ def _reindex_multi( @doc(NDFrame.align, **_shared_doc_kwargs) def align( self, - other: DataFrame, + other: NDFrameT, join: AlignJoin = "outer", axis: Axis | None = None, level: Level = None, @@ -5007,7 +5008,7 @@ def align( limit: int | None = None, fill_axis: Axis = 0, broadcast_axis: Axis | None = None, - ) -> DataFrame: + ) -> tuple[DataFrame, NDFrameT]: return super().align( other, join=join, @@ -7771,9 +7772,7 @@ def to_series(right): ) left, right = left.align( - # error: Argument 1 to "align" of "DataFrame" has incompatible - # type "Series"; expected "DataFrame" - right, # type: ignore[arg-type] + right, join="outer", axis=axis, level=level, diff --git a/pandas/core/generic.py b/pandas/core/generic.py index a4dfb085c766f..95ac522833b35 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -67,6 +67,7 @@ Manager, NaPosition, NDFrameT, + NDFrameTb, RandomState, Renamer, Scalar, @@ -198,7 +199,6 @@ from pandas.core.indexers.objects import BaseIndexer from pandas.core.resample import Resampler - # goal is to be able to define the docs close to function, while still being # able to share _shared_docs = {**_shared_docs} @@ -9297,7 +9297,7 @@ def compare( @doc(**_shared_doc_kwargs) def align( self: NDFrameT, - other: NDFrameT, + other: NDFrameTb, join: AlignJoin = "outer", axis: Axis | None = None, level: Level = None, @@ -9307,7 +9307,7 @@ def align( limit: int | None = None, fill_axis: Axis = 0, broadcast_axis: Axis | None = None, - ) -> NDFrameT: + ) -> tuple[NDFrameT, NDFrameTb]: """ Align two objects on their axes with the specified join method. @@ -9428,8 +9428,10 @@ def align( df = cons( {c: self for c in other.columns}, **other._construct_axes_dict() ) - return df._align_frame( - other, + # error: Incompatible return value type (got "Tuple[DataFrame, + # DataFrame]", expected "Tuple[NDFrameT, NDFrameTb]") + return df._align_frame( # type: ignore[return-value] + other, # type: ignore[arg-type] join=join, axis=axis, level=level, @@ -9446,7 +9448,9 @@ def align( df = cons( {c: other for c in self.columns}, **self._construct_axes_dict() ) - return self._align_frame( + # error: Incompatible return value type (got "Tuple[NDFrameT, + # DataFrame]", expected "Tuple[NDFrameT, NDFrameTb]") + return self._align_frame( # type: ignore[return-value] df, join=join, axis=axis, @@ -9461,7 +9465,9 @@ def align( if axis is not None: axis = self._get_axis_number(axis) if isinstance(other, ABCDataFrame): - return self._align_frame( + # error: Incompatible return value type (got "Tuple[NDFrameT, DataFrame]", + # expected "Tuple[NDFrameT, NDFrameTb]") + return self._align_frame( # type: ignore[return-value] other, join=join, axis=axis, @@ -9473,7 +9479,9 @@ def align( fill_axis=fill_axis, ) elif isinstance(other, ABCSeries): - return self._align_series( + # error: Incompatible return value type (got "Tuple[NDFrameT, Series]", + # expected "Tuple[NDFrameT, NDFrameTb]") + return self._align_series( # type: ignore[return-value] other, join=join, axis=axis, @@ -9489,8 +9497,8 @@ def align( @final def _align_frame( - self, - other, + self: NDFrameT, + other: DataFrame, join: AlignJoin = "outer", axis: Axis | None = None, level=None, @@ -9499,7 +9507,7 @@ def _align_frame( method=None, limit=None, fill_axis: Axis = 0, - ): + ) -> tuple[NDFrameT, DataFrame]: # defaults join_index, join_columns = None, None ilidx, iridx = None, None @@ -9553,8 +9561,8 @@ def _align_frame( @final def _align_series( - self, - other, + self: NDFrameT, + other: Series, join: AlignJoin = "outer", axis: Axis | None = None, level=None, @@ -9563,7 +9571,7 @@ def _align_series( method=None, limit=None, fill_axis: Axis = 0, - ): + ) -> tuple[NDFrameT, Series]: is_series = isinstance(self, ABCSeries) if copy and using_copy_on_write(): copy = False @@ -12798,8 +12806,8 @@ def _doc_params(cls): def _align_as_utc( - left: NDFrameT, right: NDFrameT, join_index: Index | None -) -> tuple[NDFrameT, NDFrameT]: + left: NDFrameT, right: NDFrameTb, join_index: Index | None +) -> tuple[NDFrameT, NDFrameTb]: """ If we are aligning timezone-aware DatetimeIndexes and the timezones do not match, convert both to UTC. diff --git a/pandas/core/series.py b/pandas/core/series.py index 38a5d94db207c..1dac028d7b54a 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -166,6 +166,7 @@ IndexLabel, Level, NaPosition, + NDFrameT, NumpySorter, NumpyValueArrayLike, QuantileInterpolation, @@ -4571,7 +4572,7 @@ def _needs_reindex_multi(self, axes, method, level) -> bool: ) def align( self, - other: Series, + other: NDFrameT, join: AlignJoin = "outer", axis: Axis | None = None, level: Level = None, @@ -4581,7 +4582,7 @@ def align( limit: int | None = None, fill_axis: Axis = 0, broadcast_axis: Axis | None = None, - ) -> Series: + ) -> tuple[Series, NDFrameT]: return super().align( other, join=join,