diff --git a/pandas/tests/frame/test_analytics.py b/pandas/tests/frame/test_analytics.py index 04a35ac2ec897..35e57091f701a 100644 --- a/pandas/tests/frame/test_analytics.py +++ b/pandas/tests/frame/test_analytics.py @@ -459,6 +459,26 @@ def test_corrwith_mixed_dtypes(self): expected = pd.Series(data=corrs, index=['a', 'b']) tm.assert_series_equal(result, expected) + def test_corrwith_index_intersection(self): + df1 = pd.DataFrame(np.random.random(size=(10, 2)), + columns=["a", "b"]) + df2 = pd.DataFrame(np.random.random(size=(10, 3)), + columns=["a", "b", "c"]) + + result = df1.corrwith(df2, drop=True).index.sort_values() + expected = df1.columns.intersection(df2.columns).sort_values() + tm.assert_index_equal(result, expected) + + def test_corrwith_index_union(self): + df1 = pd.DataFrame(np.random.random(size=(10, 2)), + columns=["a", "b"]) + df2 = pd.DataFrame(np.random.random(size=(10, 3)), + columns=["a", "b", "c"]) + + result = df1.corrwith(df2, drop=False).index.sort_values() + expected = df1.columns.union(df2.columns).sort_values() + tm.assert_index_equal(result, expected) + def test_corrwith_dup_cols(self): # GH 21925 df1 = pd.DataFrame(np.vstack([np.arange(10)] * 3).T)