diff --git a/python/pyspark/pandas/tests/test_expanding.py b/python/pyspark/pandas/tests/test_expanding.py index 9ea8e08bb0128..aeb0e9f297bce 100644 --- a/python/pyspark/pandas/tests/test_expanding.py +++ b/python/pyspark/pandas/tests/test_expanding.py @@ -26,37 +26,37 @@ class ExpandingTest(PandasOnSparkTestCase, TestUtils): - def _test_expanding_func(self, f): + def _test_expanding_func(self, ps_func, pd_func=None): + if not pd_func: + pd_func = ps_func + if isinstance(pd_func, str): + pd_func = self.convert_str_to_lambda(pd_func) + if isinstance(ps_func, str): + ps_func = self.convert_str_to_lambda(ps_func) pser = pd.Series([1, 2, 3, 7, 9, 8], index=np.random.rand(6), name="a") psser = ps.from_pandas(pser) - self.assert_eq( - getattr(psser.expanding(2), f)(), getattr(pser.expanding(2), f)(), almost=True - ) - self.assert_eq( - getattr(psser.expanding(2), f)().sum(), - getattr(pser.expanding(2), f)().sum(), - almost=True, - ) + self.assert_eq(ps_func(psser.expanding(2)), pd_func(pser.expanding(2)), almost=True) + self.assert_eq(ps_func(psser.expanding(2)), pd_func(pser.expanding(2)), almost=True) # Multiindex pser = pd.Series( [1, 2, 3], index=pd.MultiIndex.from_tuples([("a", "x"), ("a", "y"), ("b", "z")]) ) psser = ps.from_pandas(pser) - self.assert_eq(getattr(psser.expanding(2), f)(), getattr(pser.expanding(2), f)()) + self.assert_eq(ps_func(psser.expanding(2)), pd_func(pser.expanding(2))) pdf = pd.DataFrame( {"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]}, index=np.random.rand(4) ) psdf = ps.from_pandas(pdf) - self.assert_eq(getattr(psdf.expanding(2), f)(), getattr(pdf.expanding(2), f)()) - self.assert_eq(getattr(psdf.expanding(2), f)().sum(), getattr(pdf.expanding(2), f)().sum()) + self.assert_eq(ps_func(psdf.expanding(2)), pd_func(pdf.expanding(2))) + self.assert_eq(ps_func(psdf.expanding(2)).sum(), pd_func(pdf.expanding(2)).sum()) # Multiindex column columns = pd.MultiIndex.from_tuples([("a", "x"), ("a", "y")]) pdf.columns = columns psdf.columns = columns - self.assert_eq(getattr(psdf.expanding(2), f)(), getattr(pdf.expanding(2), f)()) + self.assert_eq(ps_func(psdf.expanding(2)), pd_func(pdf.expanding(2))) def test_expanding_error(self): with self.assertRaisesRegex(ValueError, "min_periods must be >= 0"): @@ -97,16 +97,22 @@ def test_expanding_skew(self): def test_expanding_kurt(self): self._test_expanding_func("kurt") - def _test_groupby_expanding_func(self, f): + def _test_groupby_expanding_func(self, ps_func, pd_func=None): + if not pd_func: + pd_func = ps_func + if isinstance(pd_func, str): + pd_func = self.convert_str_to_lambda(pd_func) + if isinstance(ps_func, str): + ps_func = self.convert_str_to_lambda(ps_func) pser = pd.Series([1, 2, 3, 2], index=np.random.rand(4), name="a") psser = ps.from_pandas(pser) self.assert_eq( - getattr(psser.groupby(psser).expanding(2), f)().sort_index(), - getattr(pser.groupby(pser).expanding(2), f)().sort_index(), + ps_func(psser.groupby(psser).expanding(2)).sort_index(), + pd_func(pser.groupby(pser).expanding(2)).sort_index(), ) self.assert_eq( - getattr(psser.groupby(psser).expanding(2), f)().sum(), - getattr(pser.groupby(pser).expanding(2), f)().sum(), + ps_func(psser.groupby(psser).expanding(2)).sum(), + pd_func(pser.groupby(pser).expanding(2)).sum(), ) # Multiindex @@ -117,8 +123,8 @@ def _test_groupby_expanding_func(self, f): ) psser = ps.from_pandas(pser) self.assert_eq( - getattr(psser.groupby(psser).expanding(2), f)().sort_index(), - getattr(pser.groupby(pser).expanding(2), f)().sort_index(), + ps_func(psser.groupby(psser).expanding(2)).sort_index(), + pd_func(pser.groupby(pser).expanding(2)).sort_index(), ) pdf = pd.DataFrame({"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]}) @@ -127,42 +133,42 @@ def _test_groupby_expanding_func(self, f): # The behavior of GroupBy.expanding is changed from pandas 1.3. if LooseVersion(pd.__version__) >= LooseVersion("1.3"): self.assert_eq( - getattr(psdf.groupby(psdf.a).expanding(2), f)().sort_index(), - getattr(pdf.groupby(pdf.a).expanding(2), f)().sort_index(), + ps_func(psdf.groupby(psdf.a).expanding(2)).sort_index(), + pd_func(pdf.groupby(pdf.a).expanding(2)).sort_index(), ) self.assert_eq( - getattr(psdf.groupby(psdf.a).expanding(2), f)().sum(), - getattr(pdf.groupby(pdf.a).expanding(2), f)().sum(), + ps_func(psdf.groupby(psdf.a).expanding(2)).sum(), + pd_func(pdf.groupby(pdf.a).expanding(2)).sum(), ) self.assert_eq( - getattr(psdf.groupby(psdf.a + 1).expanding(2), f)().sort_index(), - getattr(pdf.groupby(pdf.a + 1).expanding(2), f)().sort_index(), + ps_func(psdf.groupby(psdf.a + 1).expanding(2)).sort_index(), + pd_func(pdf.groupby(pdf.a + 1).expanding(2)).sort_index(), ) else: self.assert_eq( - getattr(psdf.groupby(psdf.a).expanding(2), f)().sort_index(), - getattr(pdf.groupby(pdf.a).expanding(2), f)().drop("a", axis=1).sort_index(), + ps_func(psdf.groupby(psdf.a).expanding(2)).sort_index(), + pd_func(pdf.groupby(pdf.a).expanding(2)).drop("a", axis=1).sort_index(), ) self.assert_eq( - getattr(psdf.groupby(psdf.a).expanding(2), f)().sum(), - getattr(pdf.groupby(pdf.a).expanding(2), f)().sum().drop("a"), + ps_func(psdf.groupby(psdf.a).expanding(2)).sum(), + pd_func(pdf.groupby(pdf.a).expanding(2)).sum().drop("a"), ) self.assert_eq( - getattr(psdf.groupby(psdf.a + 1).expanding(2), f)().sort_index(), - getattr(pdf.groupby(pdf.a + 1).expanding(2), f)().drop("a", axis=1).sort_index(), + ps_func(psdf.groupby(psdf.a + 1).expanding(2)).sort_index(), + pd_func(pdf.groupby(pdf.a + 1).expanding(2)).drop("a", axis=1).sort_index(), ) self.assert_eq( - getattr(psdf.b.groupby(psdf.a).expanding(2), f)().sort_index(), - getattr(pdf.b.groupby(pdf.a).expanding(2), f)().sort_index(), + ps_func(psdf.b.groupby(psdf.a).expanding(2)).sort_index(), + pd_func(pdf.b.groupby(pdf.a).expanding(2)).sort_index(), ) self.assert_eq( - getattr(psdf.groupby(psdf.a)["b"].expanding(2), f)().sort_index(), - getattr(pdf.groupby(pdf.a)["b"].expanding(2), f)().sort_index(), + ps_func(psdf.groupby(psdf.a)["b"].expanding(2)).sort_index(), + pd_func(pdf.groupby(pdf.a)["b"].expanding(2)).sort_index(), ) self.assert_eq( - getattr(psdf.groupby(psdf.a)[["b"]].expanding(2), f)().sort_index(), - getattr(pdf.groupby(pdf.a)[["b"]].expanding(2), f)().sort_index(), + ps_func(psdf.groupby(psdf.a)[["b"]].expanding(2)).sort_index(), + pd_func(pdf.groupby(pdf.a)[["b"]].expanding(2)).sort_index(), ) # Multiindex column @@ -173,25 +179,23 @@ def _test_groupby_expanding_func(self, f): # The behavior of GroupBy.expanding is changed from pandas 1.3. if LooseVersion(pd.__version__) >= LooseVersion("1.3"): self.assert_eq( - getattr(psdf.groupby(("a", "x")).expanding(2), f)().sort_index(), - getattr(pdf.groupby(("a", "x")).expanding(2), f)().sort_index(), + ps_func(psdf.groupby(("a", "x")).expanding(2)).sort_index(), + pd_func(pdf.groupby(("a", "x")).expanding(2)).sort_index(), ) self.assert_eq( - getattr(psdf.groupby([("a", "x"), ("a", "y")]).expanding(2), f)().sort_index(), - getattr(pdf.groupby([("a", "x"), ("a", "y")]).expanding(2), f)().sort_index(), + ps_func(psdf.groupby([("a", "x"), ("a", "y")]).expanding(2)).sort_index(), + pd_func(pdf.groupby([("a", "x"), ("a", "y")]).expanding(2)).sort_index(), ) else: self.assert_eq( - getattr(psdf.groupby(("a", "x")).expanding(2), f)().sort_index(), - getattr(pdf.groupby(("a", "x")).expanding(2), f)() - .drop(("a", "x"), axis=1) - .sort_index(), + ps_func(psdf.groupby(("a", "x")).expanding(2)).sort_index(), + pd_func(pdf.groupby(("a", "x")).expanding(2)).drop(("a", "x"), axis=1).sort_index(), ) self.assert_eq( - getattr(psdf.groupby([("a", "x"), ("a", "y")]).expanding(2), f)().sort_index(), - getattr(pdf.groupby([("a", "x"), ("a", "y")]).expanding(2), f)() + ps_func(psdf.groupby([("a", "x"), ("a", "y")]).expanding(2)).sort_index(), + pd_func(pdf.groupby([("a", "x"), ("a", "y")]).expanding(2)) .drop([("a", "x"), ("a", "y")], axis=1) .sort_index(), ) diff --git a/python/pyspark/pandas/tests/test_rolling.py b/python/pyspark/pandas/tests/test_rolling.py index bf793765655b0..3f92eba79ce99 100644 --- a/python/pyspark/pandas/tests/test_rolling.py +++ b/python/pyspark/pandas/tests/test_rolling.py @@ -36,11 +36,17 @@ def test_rolling_error(self): ): Rolling(1, 2) - def _test_rolling_func(self, f): + def _test_rolling_func(self, ps_func, pd_func=None): + if not pd_func: + pd_func = ps_func + if isinstance(pd_func, str): + pd_func = self.convert_str_to_lambda(pd_func) + if isinstance(ps_func, str): + ps_func = self.convert_str_to_lambda(ps_func) pser = pd.Series([1, 2, 3, 7, 9, 8], index=np.random.rand(6), name="a") psser = ps.from_pandas(pser) - self.assert_eq(getattr(psser.rolling(2), f)(), getattr(pser.rolling(2), f)()) - self.assert_eq(getattr(psser.rolling(2), f)().sum(), getattr(pser.rolling(2), f)().sum()) + self.assert_eq(ps_func(psser.rolling(2)), pd_func(pser.rolling(2))) + self.assert_eq(ps_func(psser.rolling(2)).sum(), pd_func(pser.rolling(2)).sum()) # Multiindex pser = pd.Series( @@ -49,20 +55,20 @@ def _test_rolling_func(self, f): name="a", ) psser = ps.from_pandas(pser) - self.assert_eq(getattr(psser.rolling(2), f)(), getattr(pser.rolling(2), f)()) + self.assert_eq(ps_func(psser.rolling(2)), pd_func(pser.rolling(2))) pdf = pd.DataFrame( {"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]}, index=np.random.rand(4) ) psdf = ps.from_pandas(pdf) - self.assert_eq(getattr(psdf.rolling(2), f)(), getattr(pdf.rolling(2), f)()) - self.assert_eq(getattr(psdf.rolling(2), f)().sum(), getattr(pdf.rolling(2), f)().sum()) + self.assert_eq(ps_func(psdf.rolling(2)), pd_func(pdf.rolling(2))) + self.assert_eq(ps_func(psdf.rolling(2)).sum(), pd_func(pdf.rolling(2)).sum()) # Multiindex column columns = pd.MultiIndex.from_tuples([("a", "x"), ("a", "y")]) pdf.columns = columns psdf.columns = columns - self.assert_eq(getattr(psdf.rolling(2), f)(), getattr(pdf.rolling(2), f)()) + self.assert_eq(ps_func(psdf.rolling(2)), pd_func(pdf.rolling(2))) def test_rolling_min(self): self._test_rolling_func("min") @@ -91,16 +97,22 @@ def test_rolling_skew(self): def test_rolling_kurt(self): self._test_rolling_func("kurt") - def _test_groupby_rolling_func(self, f): + def _test_groupby_rolling_func(self, ps_func, pd_func=None): + if not pd_func: + pd_func = ps_func + if isinstance(pd_func, str): + pd_func = self.convert_str_to_lambda(pd_func) + if isinstance(ps_func, str): + ps_func = self.convert_str_to_lambda(ps_func) pser = pd.Series([1, 2, 3, 2], index=np.random.rand(4), name="a") psser = ps.from_pandas(pser) self.assert_eq( - getattr(psser.groupby(psser).rolling(2), f)().sort_index(), - getattr(pser.groupby(pser).rolling(2), f)().sort_index(), + ps_func(psser.groupby(psser).rolling(2)).sort_index(), + pd_func(pser.groupby(pser).rolling(2)).sort_index(), ) self.assert_eq( - getattr(psser.groupby(psser).rolling(2), f)().sum(), - getattr(pser.groupby(pser).rolling(2), f)().sum(), + ps_func(psser.groupby(psser).rolling(2)).sum(), + pd_func(pser.groupby(pser).rolling(2)).sum(), ) # Multiindex @@ -111,8 +123,8 @@ def _test_groupby_rolling_func(self, f): ) psser = ps.from_pandas(pser) self.assert_eq( - getattr(psser.groupby(psser).rolling(2), f)().sort_index(), - getattr(pser.groupby(pser).rolling(2), f)().sort_index(), + ps_func(psser.groupby(psser).rolling(2)).sort_index(), + pd_func(pser.groupby(pser).rolling(2)).sort_index(), ) pdf = pd.DataFrame({"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]}) @@ -121,42 +133,42 @@ def _test_groupby_rolling_func(self, f): # The behavior of GroupBy.rolling is changed from pandas 1.3. if LooseVersion(pd.__version__) >= LooseVersion("1.3"): self.assert_eq( - getattr(psdf.groupby(psdf.a).rolling(2), f)().sort_index(), - getattr(pdf.groupby(pdf.a).rolling(2), f)().sort_index(), + ps_func(psdf.groupby(psdf.a).rolling(2)).sort_index(), + pd_func(pdf.groupby(pdf.a).rolling(2)).sort_index(), ) self.assert_eq( - getattr(psdf.groupby(psdf.a).rolling(2), f)().sum(), - getattr(pdf.groupby(pdf.a).rolling(2), f)().sum(), + ps_func(psdf.groupby(psdf.a).rolling(2)).sum(), + pd_func(pdf.groupby(pdf.a).rolling(2)).sum(), ) self.assert_eq( - getattr(psdf.groupby(psdf.a + 1).rolling(2), f)().sort_index(), - getattr(pdf.groupby(pdf.a + 1).rolling(2), f)().sort_index(), + ps_func(psdf.groupby(psdf.a + 1).rolling(2)).sort_index(), + pd_func(pdf.groupby(pdf.a + 1).rolling(2)).sort_index(), ) else: self.assert_eq( - getattr(psdf.groupby(psdf.a).rolling(2), f)().sort_index(), - getattr(pdf.groupby(pdf.a).rolling(2), f)().drop("a", axis=1).sort_index(), + ps_func(psdf.groupby(psdf.a).rolling(2)).sort_index(), + pd_func(pdf.groupby(pdf.a).rolling(2)).drop("a", axis=1).sort_index(), ) self.assert_eq( - getattr(psdf.groupby(psdf.a).rolling(2), f)().sum(), - getattr(pdf.groupby(pdf.a).rolling(2), f)().sum().drop("a"), + ps_func(psdf.groupby(psdf.a).rolling(2)).sum(), + pd_func(pdf.groupby(pdf.a).rolling(2)).sum().drop("a"), ) self.assert_eq( - getattr(psdf.groupby(psdf.a + 1).rolling(2), f)().sort_index(), - getattr(pdf.groupby(pdf.a + 1).rolling(2), f)().drop("a", axis=1).sort_index(), + ps_func(psdf.groupby(psdf.a + 1).rolling(2)).sort_index(), + pd_func(pdf.groupby(pdf.a + 1).rolling(2)).drop("a", axis=1).sort_index(), ) self.assert_eq( - getattr(psdf.b.groupby(psdf.a).rolling(2), f)().sort_index(), - getattr(pdf.b.groupby(pdf.a).rolling(2), f)().sort_index(), + ps_func(psdf.b.groupby(psdf.a).rolling(2)).sort_index(), + pd_func(pdf.b.groupby(pdf.a).rolling(2)).sort_index(), ) self.assert_eq( - getattr(psdf.groupby(psdf.a)["b"].rolling(2), f)().sort_index(), - getattr(pdf.groupby(pdf.a)["b"].rolling(2), f)().sort_index(), + ps_func(psdf.groupby(psdf.a)["b"].rolling(2)).sort_index(), + pd_func(pdf.groupby(pdf.a)["b"].rolling(2)).sort_index(), ) self.assert_eq( - getattr(psdf.groupby(psdf.a)[["b"]].rolling(2), f)().sort_index(), - getattr(pdf.groupby(pdf.a)[["b"]].rolling(2), f)().sort_index(), + ps_func(psdf.groupby(psdf.a)[["b"]].rolling(2)).sort_index(), + pd_func(pdf.groupby(pdf.a)[["b"]].rolling(2)).sort_index(), ) # Multiindex column @@ -167,25 +179,23 @@ def _test_groupby_rolling_func(self, f): # The behavior of GroupBy.rolling is changed from pandas 1.3. if LooseVersion(pd.__version__) >= LooseVersion("1.3"): self.assert_eq( - getattr(psdf.groupby(("a", "x")).rolling(2), f)().sort_index(), - getattr(pdf.groupby(("a", "x")).rolling(2), f)().sort_index(), + ps_func(psdf.groupby(("a", "x")).rolling(2)).sort_index(), + pd_func(pdf.groupby(("a", "x")).rolling(2)).sort_index(), ) self.assert_eq( - getattr(psdf.groupby([("a", "x"), ("a", "y")]).rolling(2), f)().sort_index(), - getattr(pdf.groupby([("a", "x"), ("a", "y")]).rolling(2), f)().sort_index(), + ps_func(psdf.groupby([("a", "x"), ("a", "y")]).rolling(2)).sort_index(), + pd_func(pdf.groupby([("a", "x"), ("a", "y")]).rolling(2)).sort_index(), ) else: self.assert_eq( - getattr(psdf.groupby(("a", "x")).rolling(2), f)().sort_index(), - getattr(pdf.groupby(("a", "x")).rolling(2), f)() - .drop(("a", "x"), axis=1) - .sort_index(), + ps_func(psdf.groupby(("a", "x")).rolling(2)).sort_index(), + pd_func(pdf.groupby(("a", "x")).rolling(2)).drop(("a", "x"), axis=1).sort_index(), ) self.assert_eq( - getattr(psdf.groupby([("a", "x"), ("a", "y")]).rolling(2), f)().sort_index(), - getattr(pdf.groupby([("a", "x"), ("a", "y")]).rolling(2), f)() + ps_func(psdf.groupby([("a", "x"), ("a", "y")]).rolling(2)).sort_index(), + pd_func(pdf.groupby([("a", "x"), ("a", "y")]).rolling(2)) .drop([("a", "x"), ("a", "y")], axis=1) .sort_index(), ) diff --git a/python/pyspark/testing/pandasutils.py b/python/pyspark/testing/pandasutils.py index baa43e5b9d5c2..ad2f74e8af411 100644 --- a/python/pyspark/testing/pandasutils.py +++ b/python/pyspark/testing/pandasutils.py @@ -65,6 +65,12 @@ def setUpClass(cls): super(PandasOnSparkTestCase, cls).setUpClass() cls.spark.conf.set(SPARK_CONF_ARROW_ENABLED, True) + def convert_str_to_lambda(self, func): + """ + This function coverts `func` str to lambda call + """ + return lambda x: getattr(x, func)() + def assertPandasEqual(self, left, right, check_exact=True): if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame): try: