diff --git a/python/pyspark/pandas/tests/connect/groupby/test_parity_grouping.py b/python/pyspark/pandas/tests/connect/groupby/test_parity_grouping.py index 8b3f9927c0f2b..0f07c4ffa8225 100644 --- a/python/pyspark/pandas/tests/connect/groupby/test_parity_grouping.py +++ b/python/pyspark/pandas/tests/connect/groupby/test_parity_grouping.py @@ -19,10 +19,14 @@ from pyspark.pandas.tests.groupby.test_grouping import GroupingTestsMixin from pyspark.testing.connectutils import ReusedConnectTestCase -from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.testing.pandasutils import PandasOnSparkTestUtils -class GroupingParityTests(GroupingTestsMixin, PandasOnSparkTestCase, ReusedConnectTestCase): +class GroupingParityTests( + GroupingTestsMixin, + PandasOnSparkTestUtils, + ReusedConnectTestCase, +): pass diff --git a/python/pyspark/pandas/tests/connect/groupby/test_parity_missing.py b/python/pyspark/pandas/tests/connect/groupby/test_parity_missing.py index f6776d9bac608..d7641ac3ab73b 100644 --- a/python/pyspark/pandas/tests/connect/groupby/test_parity_missing.py +++ b/python/pyspark/pandas/tests/connect/groupby/test_parity_missing.py @@ -19,10 +19,14 @@ from pyspark.pandas.tests.groupby.test_missing import MissingTestsMixin from pyspark.testing.connectutils import ReusedConnectTestCase -from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.testing.pandasutils import PandasOnSparkTestUtils -class MissingParityTests(MissingTestsMixin, PandasOnSparkTestCase, ReusedConnectTestCase): +class MissingParityTests( + MissingTestsMixin, + PandasOnSparkTestUtils, + ReusedConnectTestCase, +): pass diff --git a/python/pyspark/pandas/tests/connect/groupby/test_parity_nlargest_nsmallest.py b/python/pyspark/pandas/tests/connect/groupby/test_parity_nlargest_nsmallest.py index 71c388a1d2981..db8e4f94e118b 100644 --- a/python/pyspark/pandas/tests/connect/groupby/test_parity_nlargest_nsmallest.py +++ b/python/pyspark/pandas/tests/connect/groupby/test_parity_nlargest_nsmallest.py @@ -19,11 +19,13 @@ from pyspark.pandas.tests.groupby.test_nlargest_nsmallest import NlargestNsmallestTestsMixin from pyspark.testing.connectutils import ReusedConnectTestCase -from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.testing.pandasutils import PandasOnSparkTestUtils class NlargestNsmallestParityTests( - NlargestNsmallestTestsMixin, PandasOnSparkTestCase, ReusedConnectTestCase + NlargestNsmallestTestsMixin, + PandasOnSparkTestUtils, + ReusedConnectTestCase, ): pass diff --git a/python/pyspark/pandas/tests/connect/groupby/test_parity_raises.py b/python/pyspark/pandas/tests/connect/groupby/test_parity_raises.py index db122a81ebdd1..1694024cf6182 100644 --- a/python/pyspark/pandas/tests/connect/groupby/test_parity_raises.py +++ b/python/pyspark/pandas/tests/connect/groupby/test_parity_raises.py @@ -19,10 +19,14 @@ from pyspark.pandas.tests.groupby.test_raises import RaisesTestsMixin from pyspark.testing.connectutils import ReusedConnectTestCase -from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.testing.pandasutils import PandasOnSparkTestUtils -class RaisesParityTests(RaisesTestsMixin, PandasOnSparkTestCase, ReusedConnectTestCase): +class RaisesParityTests( + RaisesTestsMixin, + PandasOnSparkTestUtils, + ReusedConnectTestCase, +): pass diff --git a/python/pyspark/pandas/tests/connect/groupby/test_parity_rank.py b/python/pyspark/pandas/tests/connect/groupby/test_parity_rank.py index 2ad5cf07cfcaa..98c6e3fe5c458 100644 --- a/python/pyspark/pandas/tests/connect/groupby/test_parity_rank.py +++ b/python/pyspark/pandas/tests/connect/groupby/test_parity_rank.py @@ -19,10 +19,14 @@ from pyspark.pandas.tests.groupby.test_rank import RankTestsMixin from pyspark.testing.connectutils import ReusedConnectTestCase -from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.testing.pandasutils import PandasOnSparkTestUtils -class RankParityTests(RankTestsMixin, PandasOnSparkTestCase, ReusedConnectTestCase): +class RankParityTests( + RankTestsMixin, + PandasOnSparkTestUtils, + ReusedConnectTestCase, +): pass diff --git a/python/pyspark/pandas/tests/connect/groupby/test_parity_size.py b/python/pyspark/pandas/tests/connect/groupby/test_parity_size.py index 2904f0cded276..603bd9408e1f9 100644 --- a/python/pyspark/pandas/tests/connect/groupby/test_parity_size.py +++ b/python/pyspark/pandas/tests/connect/groupby/test_parity_size.py @@ -19,10 +19,14 @@ from pyspark.pandas.tests.groupby.test_size import SizeTestsMixin from pyspark.testing.connectutils import ReusedConnectTestCase -from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.testing.pandasutils import PandasOnSparkTestUtils -class SizeParityTests(SizeTestsMixin, PandasOnSparkTestCase, ReusedConnectTestCase): +class SizeParityTests( + SizeTestsMixin, + PandasOnSparkTestUtils, + ReusedConnectTestCase, +): pass diff --git a/python/pyspark/pandas/tests/connect/groupby/test_parity_value_counts.py b/python/pyspark/pandas/tests/connect/groupby/test_parity_value_counts.py index a9c84822006df..7aef16ccea57c 100644 --- a/python/pyspark/pandas/tests/connect/groupby/test_parity_value_counts.py +++ b/python/pyspark/pandas/tests/connect/groupby/test_parity_value_counts.py @@ -19,10 +19,14 @@ from pyspark.pandas.tests.groupby.test_value_counts import ValueCountsTestsMixin from pyspark.testing.connectutils import ReusedConnectTestCase -from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.testing.pandasutils import PandasOnSparkTestUtils -class ValueCountsParityTests(ValueCountsTestsMixin, PandasOnSparkTestCase, ReusedConnectTestCase): +class ValueCountsParityTests( + ValueCountsTestsMixin, + PandasOnSparkTestUtils, + ReusedConnectTestCase, +): pass diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py index cba5ad5765836..f529c7e5d0cef 100644 --- a/python/pyspark/testing/connectutils.py +++ b/python/pyspark/testing/connectutils.py @@ -192,3 +192,8 @@ def setUpClass(cls): def tearDownClass(cls): shutil.rmtree(cls.tempdir.name, ignore_errors=True) cls.spark.stop() + + def test_assert_remote_mode(self): + from pyspark.sql import is_remote + + self.assertTrue(is_remote()) diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py index 5d284ffc7956b..5da6f47174382 100644 --- a/python/pyspark/testing/utils.py +++ b/python/pyspark/testing/utils.py @@ -179,6 +179,11 @@ def setUpClass(cls): def tearDownClass(cls): cls.sc.stop() + def test_assert_vanilla_mode(self): + from pyspark.sql import is_remote + + self.assertFalse(is_remote()) + class ByteArrayOutput: def __init__(self):