diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 249e2675b76e..08924a86fd7c 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -495,6 +495,8 @@ def __hash__(self): "pyspark.sql.tests.test_serde", "pyspark.sql.tests.test_session", "pyspark.sql.tests.streaming.test_streaming", + "pyspark.sql.tests.streaming.test_streaming_foreach", + "pyspark.sql.tests.streaming.test_streaming_foreachBatch", "pyspark.sql.tests.streaming.test_streaming_listener", "pyspark.sql.tests.test_types", "pyspark.sql.tests.test_udf", @@ -749,6 +751,8 @@ def __hash__(self): "pyspark.sql.connect.dataframe", "pyspark.sql.connect.functions", "pyspark.sql.connect.avro.functions", + "pyspark.sql.connect.streaming.readwriter", + "pyspark.sql.connect.streaming.query", # sql unittests "pyspark.sql.tests.connect.test_client", "pyspark.sql.tests.connect.test_connect_plan", @@ -773,6 +777,7 @@ def __hash__(self): "pyspark.sql.tests.connect.test_parity_arrow_map", "pyspark.sql.tests.connect.test_parity_pandas_grouped_map", "pyspark.sql.tests.connect.test_parity_pandas_cogrouped_map", + "pyspark.sql.tests.connect.streaming.test_parity_streaming", # ml doctests "pyspark.ml.connect.functions", # ml unittests diff --git a/python/pyspark/sql/connect/streaming/query.py b/python/pyspark/sql/connect/streaming/query.py index 2866945d161f..aebab9fc69fd 100644 --- a/python/pyspark/sql/connect/streaming/query.py +++ b/python/pyspark/sql/connect/streaming/query.py @@ -16,6 +16,7 @@ # import json +import sys from typing import TYPE_CHECKING, Any, cast, Dict, List, Optional from pyspark.errors import StreamingQueryException @@ -65,10 +66,11 @@ def isActive(self) -> bool: isActive.__doc__ = PySparkStreamingQuery.isActive.__doc__ + # TODO (SPARK-42960): Implement and uncomment the doc def awaitTermination(self, timeout: Optional[int] = None) -> Optional[bool]: raise NotImplementedError() - awaitTermination.__doc__ = PySparkStreamingQuery.awaitTermination.__doc__ + # awaitTermination.__doc__ = PySparkStreamingQuery.awaitTermination.__doc__ @property def status(self) -> Dict[str, Any]: @@ -114,7 +116,8 @@ def stop(self) -> None: cmd.stop = True self._execute_streaming_query_cmd(cmd) - stop.__doc__ = PySparkStreamingQuery.stop.__doc__ + # TODO (SPARK-42962): uncomment below + # stop.__doc__ = PySparkStreamingQuery.stop.__doc__ def explain(self, extended: bool = False) -> None: cmd = pb2.StreamingQueryCommand() @@ -124,6 +127,7 @@ def explain(self, extended: bool = False) -> None: explain.__doc__ = PySparkStreamingQuery.explain.__doc__ + # TODO (SPARK-42960): Implement and uncomment the doc def exception(self) -> Optional[StreamingQueryException]: raise NotImplementedError() @@ -149,10 +153,31 @@ def _execute_streaming_query_cmd( def _test() -> None: - # TODO(SPARK-43031): port _test() from legacy query.py. - pass + import doctest + import os + from pyspark.sql import SparkSession as PySparkSession + import pyspark.sql.connect.streaming.query + + os.chdir(os.environ["SPARK_HOME"]) + + globs = pyspark.sql.connect.streaming.query.__dict__.copy() + + globs["spark"] = ( + PySparkSession.builder.appName("sql.connect.streaming.query tests") + .remote("local[4]") + .getOrCreate() + ) + + (failure_count, test_count) = doctest.testmod( + pyspark.sql.connect.streaming.query, + globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF, + ) + globs["spark"].stop() + + if failure_count: + sys.exit(-1) if __name__ == "__main__": - # TODO(SPARK-43031): Add this file dev/sparktestsupport/modules.py to enable testing in CI. _test() diff --git a/python/pyspark/sql/connect/streaming/readwriter.py b/python/pyspark/sql/connect/streaming/readwriter.py index b266f485c96c..e702b3523a4a 100644 --- a/python/pyspark/sql/connect/streaming/readwriter.py +++ b/python/pyspark/sql/connect/streaming/readwriter.py @@ -168,9 +168,75 @@ def json( json.__doc__ = PySparkDataStreamReader.json.__doc__ - # def orc() TODO - # def parquet() TODO - # def text() TODO + def orc( + self, + path: str, + mergeSchema: Optional[bool] = None, + pathGlobFilter: Optional[Union[bool, str]] = None, + recursiveFileLookup: Optional[Union[bool, str]] = None, + ) -> "DataFrame": + self._set_opts( + mergeSchema=mergeSchema, + pathGlobFilter=pathGlobFilter, + recursiveFileLookup=recursiveFileLookup, + ) + if isinstance(path, str): + return self.load(path=path, format="orc") + else: + raise TypeError("path can be only a single string") + + orc.__doc__ = PySparkDataStreamReader.orc.__doc__ + + def parquet( + self, + path: str, + mergeSchema: Optional[bool] = None, + pathGlobFilter: Optional[Union[bool, str]] = None, + recursiveFileLookup: Optional[Union[bool, str]] = None, + datetimeRebaseMode: Optional[Union[bool, str]] = None, + int96RebaseMode: Optional[Union[bool, str]] = None, + ) -> "DataFrame": + self._set_opts( + mergeSchema=mergeSchema, + pathGlobFilter=pathGlobFilter, + recursiveFileLookup=recursiveFileLookup, + datetimeRebaseMode=datetimeRebaseMode, + int96RebaseMode=int96RebaseMode, + ) + self._set_opts( + mergeSchema=mergeSchema, + pathGlobFilter=pathGlobFilter, + recursiveFileLookup=recursiveFileLookup, + datetimeRebaseMode=datetimeRebaseMode, + int96RebaseMode=int96RebaseMode, + ) + if isinstance(path, str): + return self.load(path=path, format="parquet") + else: + raise TypeError("path can be only a single string") + + parquet.__doc__ = PySparkDataStreamReader.parquet.__doc__ + + def text( + self, + path: str, + wholetext: bool = False, + lineSep: Optional[str] = None, + pathGlobFilter: Optional[Union[bool, str]] = None, + recursiveFileLookup: Optional[Union[bool, str]] = None, + ) -> "DataFrame": + self._set_opts( + wholetext=wholetext, + lineSep=lineSep, + pathGlobFilter=pathGlobFilter, + recursiveFileLookup=recursiveFileLookup, + ) + if isinstance(path, str): + return self.load(path=path, format="text") + else: + raise TypeError("path can be only a single string") + + text.__doc__ = PySparkDataStreamReader.text.__doc__ def csv( self, @@ -245,7 +311,7 @@ def csv( csv.__doc__ = PySparkDataStreamReader.csv.__doc__ - # def table() TODO. Use Read(table_name) relation. + # def table() TODO(SPARK-43042). Use Read(table_name) relation. DataStreamReader.__doc__ = PySparkDataStreamReader.__doc__ @@ -366,6 +432,7 @@ def trigger( trigger.__doc__ = PySparkDataStreamWriter.trigger.__doc__ + # TODO (SPARK-43054): Implement and uncomment the doc @overload def foreach(self, f: Callable[[Row], None]) -> "DataStreamWriter": ... @@ -377,7 +444,13 @@ def foreach(self, f: "SupportsProcess") -> "DataStreamWriter": def foreach(self, f: Union[Callable[[Row], None], "SupportsProcess"]) -> "DataStreamWriter": raise NotImplementedError("foreach() is not implemented.") - foreach.__doc__ = PySparkDataStreamWriter.foreach.__doc__ + # foreach.__doc__ = PySparkDataStreamWriter.foreach.__doc__ + + # TODO (SPARK-42944): Implement and uncomment the doc + def foreachBatch(self, func: Callable[["DataFrame", int], None]) -> "DataStreamWriter": + raise NotImplementedError("foreachBatch() is not implemented.") + + # foreachBatch.__doc__ = PySparkDataStreamWriter.foreachBatch.__doc__ def _start_internal( self, @@ -435,7 +508,8 @@ def start( **options, ) - start.__doc__ = PySparkDataStreamWriter.start.__doc__ + # TODO (SPARK-42962): uncomment below + # start.__doc__ = PySparkDataStreamWriter.start.__doc__ def toTable( self, @@ -460,10 +534,32 @@ def toTable( def _test() -> None: - # TODO(SPARK-43031): port _test() from legacy query.py. - pass + import sys + import doctest + from pyspark.sql import SparkSession as PySparkSession + import pyspark.sql.connect.streaming.readwriter + + globs = pyspark.sql.connect.readwriter.__dict__.copy() + + globs["spark"] = ( + PySparkSession.builder.appName("sql.connect.streaming.readwriter tests") + .remote("local[4]") + .getOrCreate() + ) + + (failure_count, test_count) = doctest.testmod( + pyspark.sql.connect.streaming.readwriter, + globs=globs, + optionflags=doctest.ELLIPSIS + | doctest.NORMALIZE_WHITESPACE + | doctest.IGNORE_EXCEPTION_DETAIL, + ) + + globs["spark"].stop() + + if failure_count: + sys.exit(-1) if __name__ == "__main__": - # TODO(SPARK-43031): Add this file dev/sparktestsupport/modules.py to enable testing in CI. _test() diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index e7df25d20fcb..542b898015b4 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -535,7 +535,7 @@ def writeStream(self) -> DataStreamWriter: >>> with tempfile.TemporaryDirectory() as d: ... # Create a table with Rate source. ... df.writeStream.toTable( - ... "my_table", checkpointLocation=d) # doctest: +ELLIPSIS + ... "my_table", checkpointLocation=d) <...streaming.query.StreamingQuery object at 0x...> """ return DataStreamWriter(self) diff --git a/python/pyspark/sql/streaming/query.py b/python/pyspark/sql/streaming/query.py index 3c43628bf378..b902f0514fce 100644 --- a/python/pyspark/sql/streaming/query.py +++ b/python/pyspark/sql/streaming/query.py @@ -16,7 +16,6 @@ # import json -import sys from typing import Any, Dict, List, Optional from py4j.java_gateway import JavaObject, java_import @@ -37,6 +36,9 @@ class StreamingQuery: .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Notes ----- This API is evolving. @@ -68,7 +70,7 @@ def id(self) -> str: Get the unique id of this query that persists across restarts from checkpoint data - >>> sq.id # doctest: +ELLIPSIS + >>> sq.id '...' >>> sq.stop() @@ -95,7 +97,7 @@ def runId(self) -> str: Get the unique id of this query that does not persist across restarts - >>> sq.runId # doctest: +ELLIPSIS + >>> sq.runId '...' >>> sq.stop() @@ -219,7 +221,7 @@ def status(self) -> Dict[str, Any]: Get the current status of the query - >>> sq.status # doctest: +ELLIPSIS + >>> sq.status {'message': '...', 'isDataAvailable': ..., 'isTriggerActive': ...} >>> sq.stop() @@ -248,7 +250,7 @@ def recentProgress(self) -> List[Dict[str, Any]]: Get an array of the most recent query progress updates for this query - >>> sq.recentProgress # doctest: +ELLIPSIS + >>> sq.recentProgress [...] >>> sq.stop() @@ -330,6 +332,7 @@ def stop(self) -> None: Stop streaming query >>> sq.stop() + >>> sq.isActive False """ @@ -632,6 +635,7 @@ def removeListener(self, listener: StreamingQueryListener) -> None: def _test() -> None: import doctest import os + import sys from pyspark.sql import SparkSession import pyspark.sql.streaming.query from py4j.protocol import Py4JError diff --git a/python/pyspark/sql/streaming/readwriter.py b/python/pyspark/sql/streaming/readwriter.py index c58848dc5085..529e3aeb60d9 100644 --- a/python/pyspark/sql/streaming/readwriter.py +++ b/python/pyspark/sql/streaming/readwriter.py @@ -43,6 +43,9 @@ class DataStreamReader(OptionUtils): .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Notes ----- This API is evolving. @@ -50,7 +53,7 @@ class DataStreamReader(OptionUtils): Examples -------- >>> spark.readStream - + <...streaming.readwriter.DataStreamReader object ...> The example below uses Rate source that generates rows continuously. After that, we operate a modulo by 3, and then writes the stream out to the console. @@ -90,7 +93,7 @@ def format(self, source: str) -> "DataStreamReader": Examples -------- >>> spark.readStream.format("text") - + <...streaming.readwriter.DataStreamReader object ...> This API allows to configure other sources to read. The example below writes a small text file, and reads it back via Text source. @@ -133,9 +136,9 @@ def schema(self, schema: Union[StructType, str]) -> "DataStreamReader": -------- >>> from pyspark.sql.types import StructField, StructType, StringType >>> spark.readStream.schema(StructType([StructField("data", StringType(), True)])) - + <...streaming.readwriter.DataStreamReader object ...> >>> spark.readStream.schema("col0 INT, col1 DOUBLE") - + <...streaming.readwriter.DataStreamReader object ...> The example below specifies a different schema to CSV file. @@ -172,7 +175,7 @@ def option(self, key: str, value: "OptionalPrimitiveType") -> "DataStreamReader" Examples -------- >>> spark.readStream.option("x", 1) - + <...streaming.readwriter.DataStreamReader object ...> The example below specifies 'rowsPerSecond' option to Rate source in order to generate 10 rows every second. @@ -198,7 +201,7 @@ def options(self, **options: "OptionalPrimitiveType") -> "DataStreamReader": Examples -------- >>> spark.readStream.options(x="1", y=2) - + <...streaming.readwriter.DataStreamReader object ...> The example below specifies 'rowsPerSecond' and 'numPartitions' options to Rate source in order to generate 10 rows with 10 partitions every second. @@ -764,7 +767,7 @@ def outputMode(self, outputMode: str) -> "DataStreamWriter": -------- >>> df = spark.readStream.format("rate").load() >>> df.writeStream.outputMode('append') - + <...streaming.readwriter.DataStreamWriter object ...> The example below uses Complete mode that the entire aggregated counts are printed out. @@ -798,7 +801,7 @@ def format(self, source: str) -> "DataStreamWriter": -------- >>> df = spark.readStream.format("rate").load() >>> df.writeStream.format("text") - + <...streaming.readwriter.DataStreamWriter object ...> This API allows to configure the source to write. The example below writes a CSV file from Rate source in a streaming manner. @@ -832,7 +835,7 @@ def option(self, key: str, value: "OptionalPrimitiveType") -> "DataStreamWriter" -------- >>> df = spark.readStream.format("rate").load() >>> df.writeStream.option("x", 1) - + <...streaming.readwriter.DataStreamWriter object ...> The example below specifies 'numRows' option to Console source in order to print 3 rows for every batch. @@ -860,7 +863,7 @@ def options(self, **options: "OptionalPrimitiveType") -> "DataStreamWriter": -------- >>> df = spark.readStream.format("rate").load() >>> df.writeStream.option("x", 1) - + <...streaming.readwriter.DataStreamWriter object ...> The example below specifies 'numRows' and 'truncate' options to Console source in order to print 3 rows for every batch without truncating the results. @@ -905,7 +908,7 @@ def partitionBy(self, *cols: str) -> "DataStreamWriter": # type: ignore[misc] -------- >>> df = spark.readStream.format("rate").load() >>> df.writeStream.partitionBy("value") - + <...streaming.readwriter.DataStreamWriter object ...> Partition-by timestamp column from Rate source. @@ -1015,17 +1018,17 @@ def trigger( Trigger the query for execution every 5 seconds >>> df.writeStream.trigger(processingTime='5 seconds') - + <...streaming.readwriter.DataStreamWriter object ...> Trigger the query for execution every 5 seconds >>> df.writeStream.trigger(continuous='5 seconds') - + <...streaming.readwriter.DataStreamWriter object ...> Trigger the query for reading all available data with multiple batches >>> df.writeStream.trigger(availableNow=True) - + <...streaming.readwriter.DataStreamWriter object ...> """ params = [processingTime, once, continuous, availableNow] diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py b/python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py new file mode 100644 index 000000000000..6b4460bab521 --- /dev/null +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py @@ -0,0 +1,68 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.sql.tests.streaming.test_streaming import StreamingTestsMixin +from pyspark.testing.connectutils import ReusedConnectTestCase + + +class StreamingParityTests(StreamingTestsMixin, ReusedConnectTestCase): + @unittest.skip("Will be supported with SPARK-42960.") + def test_stream_await_termination(self): + super().test_stream_await_termination() + + @unittest.skip("Will be supported with SPARK-42960.") + def test_stream_exception(self): + super().test_stream_exception() + + @unittest.skip("Query manager API will be supported later with SPARK-43032.") + def test_stream_status_and_progress(self): + super().test_stream_status_and_progress() + + @unittest.skip("Query manager API will be supported later with SPARK-43032.") + def test_query_manager_await_termination(self): + super().test_query_manager_await_termination() + + @unittest.skip("table API will be supported later with SPARK-43042.") + def test_streaming_read_from_table(self): + super().test_streaming_read_from_table() + + @unittest.skip("table API will be supported later with SPARK-43042.") + def test_streaming_write_to_table(self): + super().test_streaming_write_to_table() + + @unittest.skip("Query manager API will be supported later with SPARK-43032.") + def test_stream_save_options(self): + super().test_stream_save_options() + + @unittest.skip("Query manager API will be supported later with SPARK-43032.") + def test_stream_save_options_overwrite(self): + super().test_stream_save_options_overwrite() + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.connect.streaming.test_parity_streaming import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/streaming/test_streaming.py b/python/pyspark/sql/tests/streaming/test_streaming.py index 9f02ae848bf6..838d413a0cc3 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming.py +++ b/python/pyspark/sql/tests/streaming/test_streaming.py @@ -26,7 +26,39 @@ from pyspark.testing.sqlutils import ReusedSQLTestCase -class StreamingTests(ReusedSQLTestCase): +class StreamingTestsMixin: + def test_streaming_query_functions_basic(self): + df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() + query = ( + df.writeStream.format("memory") + .queryName("test_streaming_query_functions_basic") + .start() + ) + try: + self.assertEquals(query.name, "test_streaming_query_functions_basic") + self.assertTrue(isinstance(query.id, str)) + self.assertTrue(isinstance(query.runId, str)) + self.assertTrue(query.isActive) + # TODO: Will be uncommented with [SPARK-42960] + # self.assertEqual(query.exception(), None) + # self.assertFalse(query.awaitTermination(1)) + query.processAllAvailable() + recentProgress = query.recentProgress + lastProgress = query.lastProgress + self.assertEqual(lastProgress["name"], query.name) + self.assertEqual(lastProgress["id"], query.id) + self.assertTrue(any(p == lastProgress for p in recentProgress)) + query.explain() + + except Exception as e: + self.fail( + "Streaming query functions sanity check shouldn't throw any error. " + "Error message: " + str(e) + ) + + finally: + query.stop() + def test_stream_trigger(self): df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") @@ -77,8 +109,8 @@ def test_stream_read_options_overwrite(self): .schema(bad_schema) .load(path="python/test_support/sql/streaming", schema=schema, format="text") ) - self.assertTrue(df.isStreaming) - self.assertEqual(df.schema.simpleString(), "struct") + self.assertTrue(df.isStreaming) + self.assertEqual(df.schema.simpleString(), "struct") def test_stream_save_options(self): df = ( @@ -295,334 +327,6 @@ def test_query_manager_await_termination(self): q.stop() shutil.rmtree(tmpPath) - class ForeachWriterTester: - def __init__(self, spark): - self.spark = spark - - def write_open_event(self, partitionId, epochId): - self._write_event(self.open_events_dir, {"partition": partitionId, "epoch": epochId}) - - def write_process_event(self, row): - self._write_event(self.process_events_dir, {"value": "text"}) - - def write_close_event(self, error): - self._write_event(self.close_events_dir, {"error": str(error)}) - - def write_input_file(self): - self._write_event(self.input_dir, "text") - - def open_events(self): - return self._read_events(self.open_events_dir, "partition INT, epoch INT") - - def process_events(self): - return self._read_events(self.process_events_dir, "value STRING") - - def close_events(self): - return self._read_events(self.close_events_dir, "error STRING") - - def run_streaming_query_on_writer(self, writer, num_files): - self._reset() - try: - sdf = self.spark.readStream.format("text").load(self.input_dir) - sq = sdf.writeStream.foreach(writer).start() - for i in range(num_files): - self.write_input_file() - sq.processAllAvailable() - finally: - self.stop_all() - - def assert_invalid_writer(self, writer, msg=None): - self._reset() - try: - sdf = self.spark.readStream.format("text").load(self.input_dir) - sq = sdf.writeStream.foreach(writer).start() - self.write_input_file() - sq.processAllAvailable() - self.fail("invalid writer %s did not fail the query" % str(writer)) # not expected - except Exception as e: - if msg: - assert msg in str(e), "%s not in %s" % (msg, str(e)) - - finally: - self.stop_all() - - def stop_all(self): - for q in self.spark.streams.active: - q.stop() - - def _reset(self): - self.input_dir = tempfile.mkdtemp() - self.open_events_dir = tempfile.mkdtemp() - self.process_events_dir = tempfile.mkdtemp() - self.close_events_dir = tempfile.mkdtemp() - - def _read_events(self, dir, json): - rows = self.spark.read.schema(json).json(dir).collect() - dicts = [row.asDict() for row in rows] - return dicts - - def _write_event(self, dir, event): - import uuid - - with open(os.path.join(dir, str(uuid.uuid4())), "w") as f: - f.write("%s\n" % str(event)) - - def __getstate__(self): - return (self.open_events_dir, self.process_events_dir, self.close_events_dir) - - def __setstate__(self, state): - self.open_events_dir, self.process_events_dir, self.close_events_dir = state - - # Those foreach tests are failed in macOS High Sierra by defined rules - # at http://sealiesoftware.com/blog/archive/2017/6/5/Objective-C_and_fork_in_macOS_1013.html - # To work around this, OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES. - def test_streaming_foreach_with_simple_function(self): - tester = self.ForeachWriterTester(self.spark) - - def foreach_func(row): - tester.write_process_event(row) - - tester.run_streaming_query_on_writer(foreach_func, 2) - self.assertEqual(len(tester.process_events()), 2) - - def test_streaming_foreach_with_basic_open_process_close(self): - tester = self.ForeachWriterTester(self.spark) - - class ForeachWriter: - def open(self, partitionId, epochId): - tester.write_open_event(partitionId, epochId) - return True - - def process(self, row): - tester.write_process_event(row) - - def close(self, error): - tester.write_close_event(error) - - tester.run_streaming_query_on_writer(ForeachWriter(), 2) - - open_events = tester.open_events() - self.assertEqual(len(open_events), 2) - self.assertSetEqual(set([e["epoch"] for e in open_events]), {0, 1}) - - self.assertEqual(len(tester.process_events()), 2) - - close_events = tester.close_events() - self.assertEqual(len(close_events), 2) - self.assertSetEqual(set([e["error"] for e in close_events]), {"None"}) - - def test_streaming_foreach_with_open_returning_false(self): - tester = self.ForeachWriterTester(self.spark) - - class ForeachWriter: - def open(self, partition_id, epoch_id): - tester.write_open_event(partition_id, epoch_id) - return False - - def process(self, row): - tester.write_process_event(row) - - def close(self, error): - tester.write_close_event(error) - - tester.run_streaming_query_on_writer(ForeachWriter(), 2) - - self.assertEqual(len(tester.open_events()), 2) - - self.assertEqual(len(tester.process_events()), 0) # no row was processed - - close_events = tester.close_events() - self.assertEqual(len(close_events), 2) - self.assertSetEqual(set([e["error"] for e in close_events]), {"None"}) - - def test_streaming_foreach_without_open_method(self): - tester = self.ForeachWriterTester(self.spark) - - class ForeachWriter: - def process(self, row): - tester.write_process_event(row) - - def close(self, error): - tester.write_close_event(error) - - tester.run_streaming_query_on_writer(ForeachWriter(), 2) - self.assertEqual(len(tester.open_events()), 0) # no open events - self.assertEqual(len(tester.process_events()), 2) - self.assertEqual(len(tester.close_events()), 2) - - def test_streaming_foreach_without_close_method(self): - tester = self.ForeachWriterTester(self.spark) - - class ForeachWriter: - def open(self, partition_id, epoch_id): - tester.write_open_event(partition_id, epoch_id) - return True - - def process(self, row): - tester.write_process_event(row) - - tester.run_streaming_query_on_writer(ForeachWriter(), 2) - self.assertEqual(len(tester.open_events()), 2) # no open events - self.assertEqual(len(tester.process_events()), 2) - self.assertEqual(len(tester.close_events()), 0) - - def test_streaming_foreach_without_open_and_close_methods(self): - tester = self.ForeachWriterTester(self.spark) - - class ForeachWriter: - def process(self, row): - tester.write_process_event(row) - - tester.run_streaming_query_on_writer(ForeachWriter(), 2) - self.assertEqual(len(tester.open_events()), 0) # no open events - self.assertEqual(len(tester.process_events()), 2) - self.assertEqual(len(tester.close_events()), 0) - - def test_streaming_foreach_with_process_throwing_error(self): - from pyspark.errors import StreamingQueryException - - tester = self.ForeachWriterTester(self.spark) - - class ForeachWriter: - def process(self, row): - raise RuntimeError("test error") - - def close(self, error): - tester.write_close_event(error) - - try: - tester.run_streaming_query_on_writer(ForeachWriter(), 1) - self.fail("bad writer did not fail the query") # this is not expected - except StreamingQueryException: - # TODO: Verify whether original error message is inside the exception - pass - - self.assertEqual(len(tester.process_events()), 0) # no row was processed - close_events = tester.close_events() - self.assertEqual(len(close_events), 1) - # TODO: Verify whether original error message is inside the exception - - def test_streaming_foreach_with_invalid_writers(self): - - tester = self.ForeachWriterTester(self.spark) - - def func_with_iterator_input(iter): - for x in iter: - print(x) - - tester.assert_invalid_writer(func_with_iterator_input) - - class WriterWithoutProcess: - def open(self, partition): - pass - - tester.assert_invalid_writer(WriterWithoutProcess(), "does not have a 'process'") - - class WriterWithNonCallableProcess: - process = True - - tester.assert_invalid_writer( - WriterWithNonCallableProcess(), "'process' in provided object is not callable" - ) - - class WriterWithNoParamProcess: - def process(self): - pass - - tester.assert_invalid_writer(WriterWithNoParamProcess()) - - # Abstract class for tests below - class WithProcess: - def process(self, row): - pass - - class WriterWithNonCallableOpen(WithProcess): - open = True - - tester.assert_invalid_writer( - WriterWithNonCallableOpen(), "'open' in provided object is not callable" - ) - - class WriterWithNoParamOpen(WithProcess): - def open(self): - pass - - tester.assert_invalid_writer(WriterWithNoParamOpen()) - - class WriterWithNonCallableClose(WithProcess): - close = True - - tester.assert_invalid_writer( - WriterWithNonCallableClose(), "'close' in provided object is not callable" - ) - - def test_streaming_foreachBatch(self): - q = None - collected = dict() - - def collectBatch(batch_df, batch_id): - collected[batch_id] = batch_df.collect() - - try: - df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") - q = df.writeStream.foreachBatch(collectBatch).start() - q.processAllAvailable() - self.assertTrue(0 in collected) - self.assertTrue(len(collected[0]), 2) - finally: - if q: - q.stop() - - def test_streaming_foreachBatch_tempview(self): - q = None - collected = dict() - - def collectBatch(batch_df, batch_id): - batch_df.createOrReplaceTempView("updates") - # it should use the spark session within given DataFrame, as microbatch execution will - # clone the session which is no longer same with the session used to start the - # streaming query - collected[batch_id] = batch_df.sparkSession.sql("SELECT * FROM updates").collect() - - try: - df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") - q = df.writeStream.foreachBatch(collectBatch).start() - q.processAllAvailable() - self.assertTrue(0 in collected) - self.assertTrue(len(collected[0]), 2) - finally: - if q: - q.stop() - - def test_streaming_foreachBatch_propagates_python_errors(self): - from pyspark.errors import StreamingQueryException - - q = None - - def collectBatch(df, id): - raise RuntimeError("this should fail the query") - - try: - df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") - q = df.writeStream.foreachBatch(collectBatch).start() - q.processAllAvailable() - self.fail("Expected a failure") - except StreamingQueryException as e: - self.assertTrue("this should fail" in str(e)) - finally: - if q: - q.stop() - - def test_streaming_foreachBatch_graceful_stop(self): - # SPARK-39218: Make foreachBatch streaming query stop gracefully - def func(batch_df, _): - batch_df.sparkSession._jvm.java.lang.Thread.sleep(10000) - - q = self.spark.readStream.format("rate").load().writeStream.foreachBatch(func).start() - time.sleep(3) # 'rowsPerSecond' defaults to 1. Waits 3 secs out for the input. - q.stop() - self.assertIsNone(q.exception(), "No exception has to be propagated.") - def test_streaming_read_from_table(self): with self.table("input_table", "this_query"): self.spark.sql("CREATE TABLE input_table (value string) USING parquet") @@ -648,6 +352,10 @@ def test_streaming_write_to_table(self): self.assertTrue(len(result) > 0) +class StreamingTests(StreamingTestsMixin, ReusedSQLTestCase): + pass + + if __name__ == "__main__": import unittest from pyspark.sql.tests.streaming.test_streaming import * # noqa: F401 diff --git a/python/pyspark/sql/tests/streaming/test_streaming_foreach.py b/python/pyspark/sql/tests/streaming/test_streaming_foreach.py new file mode 100644 index 000000000000..8bd36020c9ad --- /dev/null +++ b/python/pyspark/sql/tests/streaming/test_streaming_foreach.py @@ -0,0 +1,297 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import tempfile + +from pyspark.testing.sqlutils import ReusedSQLTestCase + + +class StreamingTestsForeach(ReusedSQLTestCase): + class ForeachWriterTester: + def __init__(self, spark): + self.spark = spark + + def write_open_event(self, partitionId, epochId): + self._write_event(self.open_events_dir, {"partition": partitionId, "epoch": epochId}) + + def write_process_event(self, row): + self._write_event(self.process_events_dir, {"value": "text"}) + + def write_close_event(self, error): + self._write_event(self.close_events_dir, {"error": str(error)}) + + def write_input_file(self): + self._write_event(self.input_dir, "text") + + def open_events(self): + return self._read_events(self.open_events_dir, "partition INT, epoch INT") + + def process_events(self): + return self._read_events(self.process_events_dir, "value STRING") + + def close_events(self): + return self._read_events(self.close_events_dir, "error STRING") + + def run_streaming_query_on_writer(self, writer, num_files): + self._reset() + try: + sdf = self.spark.readStream.format("text").load(self.input_dir) + sq = sdf.writeStream.foreach(writer).start() + for i in range(num_files): + self.write_input_file() + sq.processAllAvailable() + finally: + self.stop_all() + + def assert_invalid_writer(self, writer, msg=None): + self._reset() + try: + sdf = self.spark.readStream.format("text").load(self.input_dir) + sq = sdf.writeStream.foreach(writer).start() + self.write_input_file() + sq.processAllAvailable() + self.fail("invalid writer %s did not fail the query" % str(writer)) # not expected + except Exception as e: + if msg: + assert msg in str(e), "%s not in %s" % (msg, str(e)) + + finally: + self.stop_all() + + def stop_all(self): + for q in self.spark.streams.active: + q.stop() + + def _reset(self): + self.input_dir = tempfile.mkdtemp() + self.open_events_dir = tempfile.mkdtemp() + self.process_events_dir = tempfile.mkdtemp() + self.close_events_dir = tempfile.mkdtemp() + + def _read_events(self, dir, json): + rows = self.spark.read.schema(json).json(dir).collect() + dicts = [row.asDict() for row in rows] + return dicts + + def _write_event(self, dir, event): + import uuid + + with open(os.path.join(dir, str(uuid.uuid4())), "w") as f: + f.write("%s\n" % str(event)) + + def __getstate__(self): + return (self.open_events_dir, self.process_events_dir, self.close_events_dir) + + def __setstate__(self, state): + self.open_events_dir, self.process_events_dir, self.close_events_dir = state + + # Those foreach tests are failed in macOS High Sierra by defined rules + # at http://sealiesoftware.com/blog/archive/2017/6/5/Objective-C_and_fork_in_macOS_1013.html + # To work around this, OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES. + def test_streaming_foreach_with_simple_function(self): + tester = self.ForeachWriterTester(self.spark) + + def foreach_func(row): + tester.write_process_event(row) + + tester.run_streaming_query_on_writer(foreach_func, 2) + self.assertEqual(len(tester.process_events()), 2) + + def test_streaming_foreach_with_basic_open_process_close(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def open(self, partitionId, epochId): + tester.write_open_event(partitionId, epochId) + return True + + def process(self, row): + tester.write_process_event(row) + + def close(self, error): + tester.write_close_event(error) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + + open_events = tester.open_events() + self.assertEqual(len(open_events), 2) + self.assertSetEqual(set([e["epoch"] for e in open_events]), {0, 1}) + + self.assertEqual(len(tester.process_events()), 2) + + close_events = tester.close_events() + self.assertEqual(len(close_events), 2) + self.assertSetEqual(set([e["error"] for e in close_events]), {"None"}) + + def test_streaming_foreach_with_open_returning_false(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def open(self, partition_id, epoch_id): + tester.write_open_event(partition_id, epoch_id) + return False + + def process(self, row): + tester.write_process_event(row) + + def close(self, error): + tester.write_close_event(error) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + + self.assertEqual(len(tester.open_events()), 2) + + self.assertEqual(len(tester.process_events()), 0) # no row was processed + + close_events = tester.close_events() + self.assertEqual(len(close_events), 2) + self.assertSetEqual(set([e["error"] for e in close_events]), {"None"}) + + def test_streaming_foreach_without_open_method(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def process(self, row): + tester.write_process_event(row) + + def close(self, error): + tester.write_close_event(error) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + self.assertEqual(len(tester.open_events()), 0) # no open events + self.assertEqual(len(tester.process_events()), 2) + self.assertEqual(len(tester.close_events()), 2) + + def test_streaming_foreach_without_close_method(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def open(self, partition_id, epoch_id): + tester.write_open_event(partition_id, epoch_id) + return True + + def process(self, row): + tester.write_process_event(row) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + self.assertEqual(len(tester.open_events()), 2) # no open events + self.assertEqual(len(tester.process_events()), 2) + self.assertEqual(len(tester.close_events()), 0) + + def test_streaming_foreach_without_open_and_close_methods(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def process(self, row): + tester.write_process_event(row) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + self.assertEqual(len(tester.open_events()), 0) # no open events + self.assertEqual(len(tester.process_events()), 2) + self.assertEqual(len(tester.close_events()), 0) + + def test_streaming_foreach_with_process_throwing_error(self): + from pyspark.errors import StreamingQueryException + + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def process(self, row): + raise RuntimeError("test error") + + def close(self, error): + tester.write_close_event(error) + + try: + tester.run_streaming_query_on_writer(ForeachWriter(), 1) + self.fail("bad writer did not fail the query") # this is not expected + except StreamingQueryException: + # TODO: Verify whether original error message is inside the exception + pass + + self.assertEqual(len(tester.process_events()), 0) # no row was processed + close_events = tester.close_events() + self.assertEqual(len(close_events), 1) + # TODO: Verify whether original error message is inside the exception + + def test_streaming_foreach_with_invalid_writers(self): + + tester = self.ForeachWriterTester(self.spark) + + def func_with_iterator_input(iter): + for x in iter: + print(x) + + tester.assert_invalid_writer(func_with_iterator_input) + + class WriterWithoutProcess: + def open(self, partition): + pass + + tester.assert_invalid_writer(WriterWithoutProcess(), "does not have a 'process'") + + class WriterWithNonCallableProcess: + process = True + + tester.assert_invalid_writer( + WriterWithNonCallableProcess(), "'process' in provided object is not callable" + ) + + class WriterWithNoParamProcess: + def process(self): + pass + + tester.assert_invalid_writer(WriterWithNoParamProcess()) + + # Abstract class for tests below + class WithProcess: + def process(self, row): + pass + + class WriterWithNonCallableOpen(WithProcess): + open = True + + tester.assert_invalid_writer( + WriterWithNonCallableOpen(), "'open' in provided object is not callable" + ) + + class WriterWithNoParamOpen(WithProcess): + def open(self): + pass + + tester.assert_invalid_writer(WriterWithNoParamOpen()) + + class WriterWithNonCallableClose(WithProcess): + close = True + + tester.assert_invalid_writer( + WriterWithNonCallableClose(), "'close' in provided object is not callable" + ) + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.streaming.test_streaming_foreach import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py b/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py new file mode 100644 index 000000000000..7e5720e42999 --- /dev/null +++ b/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py @@ -0,0 +1,102 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import time + +from pyspark.testing.sqlutils import ReusedSQLTestCase + + +class StreamingTestsForeachBatch(ReusedSQLTestCase): + def test_streaming_foreachBatch(self): + q = None + collected = dict() + + def collectBatch(batch_df, batch_id): + collected[batch_id] = batch_df.collect() + + try: + df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") + q = df.writeStream.foreachBatch(collectBatch).start() + q.processAllAvailable() + self.assertTrue(0 in collected) + self.assertTrue(len(collected[0]), 2) + finally: + if q: + q.stop() + + def test_streaming_foreachBatch_tempview(self): + q = None + collected = dict() + + def collectBatch(batch_df, batch_id): + batch_df.createOrReplaceTempView("updates") + # it should use the spark session within given DataFrame, as microbatch execution will + # clone the session which is no longer same with the session used to start the + # streaming query + collected[batch_id] = batch_df.sparkSession.sql("SELECT * FROM updates").collect() + + try: + df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") + q = df.writeStream.foreachBatch(collectBatch).start() + q.processAllAvailable() + self.assertTrue(0 in collected) + self.assertTrue(len(collected[0]), 2) + finally: + if q: + q.stop() + + def test_streaming_foreachBatch_propagates_python_errors(self): + from pyspark.errors import StreamingQueryException + + q = None + + def collectBatch(df, id): + raise RuntimeError("this should fail the query") + + try: + df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") + q = df.writeStream.foreachBatch(collectBatch).start() + q.processAllAvailable() + self.fail("Expected a failure") + except StreamingQueryException as e: + self.assertTrue("this should fail" in str(e)) + finally: + if q: + q.stop() + + def test_streaming_foreachBatch_graceful_stop(self): + # SPARK-39218: Make foreachBatch streaming query stop gracefully + def func(batch_df, _): + batch_df.sparkSession._jvm.java.lang.Thread.sleep(10000) + + q = self.spark.readStream.format("rate").load().writeStream.foreachBatch(func).start() + time.sleep(3) # 'rowsPerSecond' defaults to 1. Waits 3 secs out for the input. + q.stop() + self.assertIsNone(q.exception(), "No exception has to be propagated.") + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.streaming.test_streaming_foreachBatch import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2)