diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py index 35ca2681cc97..eae2b01c5544 100644 --- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py @@ -14,8 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import os -import unittest import time import pyspark.cloudpickle @@ -25,7 +23,7 @@ from pyspark.testing.connectutils import ReusedConnectTestCase -class TestListener(StreamingQueryListener): +class TestListenerSpark(StreamingQueryListener): def onQueryStarted(self, event): e = pyspark.cloudpickle.dumps(event) df = self.spark.createDataFrame(data=[(e,)]) @@ -45,52 +43,53 @@ def onQueryTerminated(self, event): df.write.mode("append").saveAsTable("listener_terminated_events") -# TODO(SPARK-48089): Reenable this test case -@unittest.skipIf( - "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" -) class StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTestCase): def test_listener_events(self): - test_listener = TestListener() + test_listener = TestListenerSpark() try: - self.spark.streams.addListener(test_listener) - - # This ensures the read socket on the server won't crash (i.e. because of timeout) - # when there hasn't been a new event for a long time - time.sleep(30) - - df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() - df_observe = df.observe("my_event", count(lit(1)).alias("rc")) - df_stateful = df_observe.groupBy().count() # make query stateful - q = ( - df_stateful.writeStream.format("noop") - .queryName("test") - .outputMode("complete") - .start() - ) - - self.assertTrue(q.isActive) - time.sleep(10) - self.assertTrue(q.lastProgress["batchId"] > 0) # ensure at least one batch is ran - q.stop() - self.assertFalse(q.isActive) - - start_event = pyspark.cloudpickle.loads( - self.spark.read.table("listener_start_events").collect()[0][0] - ) - - progress_event = pyspark.cloudpickle.loads( - self.spark.read.table("listener_progress_events").collect()[0][0] - ) - - terminated_event = pyspark.cloudpickle.loads( - self.spark.read.table("listener_terminated_events").collect()[0][0] - ) - - self.check_start_event(start_event) - self.check_progress_event(progress_event) - self.check_terminated_event(terminated_event) + with self.table( + "listener_start_events", + "listener_progress_events", + "listener_terminated_events", + ): + self.spark.streams.addListener(test_listener) + + # This ensures the read socket on the server won't crash (i.e. because of timeout) + # when there hasn't been a new event for a long time + time.sleep(30) + + df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() + df_observe = df.observe("my_event", count(lit(1)).alias("rc")) + df_stateful = df_observe.groupBy().count() # make query stateful + q = ( + df_stateful.writeStream.format("noop") + .queryName("test") + .outputMode("complete") + .start() + ) + + self.assertTrue(q.isActive) + time.sleep(10) + self.assertTrue(q.lastProgress["batchId"] > 0) # ensure at least one batch is ran + q.stop() + self.assertFalse(q.isActive) + + start_event = pyspark.cloudpickle.loads( + self.spark.read.table("listener_start_events").collect()[0][0] + ) + + progress_event = pyspark.cloudpickle.loads( + self.spark.read.table("listener_progress_events").collect()[0][0] + ) + + terminated_event = pyspark.cloudpickle.loads( + self.spark.read.table("listener_terminated_events").collect()[0][0] + ) + + self.check_start_event(start_event) + self.check_progress_event(progress_event) + self.check_terminated_event(terminated_event) finally: self.spark.streams.removeListener(test_listener)