Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 44 additions & 45 deletions python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import unittest
import time

import pyspark.cloudpickle
Expand All @@ -25,7 +23,7 @@
from pyspark.testing.connectutils import ReusedConnectTestCase


class TestListener(StreamingQueryListener):
class TestListenerSpark(StreamingQueryListener):
def onQueryStarted(self, event):
e = pyspark.cloudpickle.dumps(event)
df = self.spark.createDataFrame(data=[(e,)])
Expand All @@ -45,52 +43,53 @@ def onQueryTerminated(self, event):
df.write.mode("append").saveAsTable("listener_terminated_events")


# TODO(SPARK-48089): Reenable this test case
@unittest.skipIf(
"SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server"
)
class StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTestCase):
def test_listener_events(self):
test_listener = TestListener()
test_listener = TestListenerSpark()

try:
self.spark.streams.addListener(test_listener)

# This ensures the read socket on the server won't crash (i.e. because of timeout)
# when there hasn't been a new event for a long time
time.sleep(30)

df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load()
df_observe = df.observe("my_event", count(lit(1)).alias("rc"))
df_stateful = df_observe.groupBy().count() # make query stateful
q = (
df_stateful.writeStream.format("noop")
.queryName("test")
.outputMode("complete")
.start()
)

self.assertTrue(q.isActive)
time.sleep(10)
self.assertTrue(q.lastProgress["batchId"] > 0) # ensure at least one batch is ran
q.stop()
self.assertFalse(q.isActive)

start_event = pyspark.cloudpickle.loads(
self.spark.read.table("listener_start_events").collect()[0][0]
)

progress_event = pyspark.cloudpickle.loads(
self.spark.read.table("listener_progress_events").collect()[0][0]
)

terminated_event = pyspark.cloudpickle.loads(
self.spark.read.table("listener_terminated_events").collect()[0][0]
)

self.check_start_event(start_event)
self.check_progress_event(progress_event)
self.check_terminated_event(terminated_event)
with self.table(
"listener_start_events",
"listener_progress_events",
"listener_terminated_events",
):
self.spark.streams.addListener(test_listener)

# This ensures the read socket on the server won't crash (i.e. because of timeout)
# when there hasn't been a new event for a long time
time.sleep(30)

df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load()
df_observe = df.observe("my_event", count(lit(1)).alias("rc"))
df_stateful = df_observe.groupBy().count() # make query stateful
q = (
df_stateful.writeStream.format("noop")
.queryName("test")
.outputMode("complete")
.start()
)

self.assertTrue(q.isActive)
time.sleep(10)
self.assertTrue(q.lastProgress["batchId"] > 0) # ensure at least one batch is ran
q.stop()
self.assertFalse(q.isActive)

start_event = pyspark.cloudpickle.loads(
self.spark.read.table("listener_start_events").collect()[0][0]
)

progress_event = pyspark.cloudpickle.loads(
self.spark.read.table("listener_progress_events").collect()[0][0]
)

terminated_event = pyspark.cloudpickle.loads(
self.spark.read.table("listener_terminated_events").collect()[0][0]
)

self.check_start_event(start_event)
self.check_progress_event(progress_event)
self.check_terminated_event(terminated_event)

finally:
self.spark.streams.removeListener(test_listener)
Expand Down