diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py
index 15f940738986..d1e016e05693 100644
--- a/python/pyspark/sql/streaming.py
+++ b/python/pyspark/sql/streaming.py
@@ -843,6 +843,168 @@ def trigger(self, processingTime=None, once=None, continuous=None):
self._jwrite = self._jwrite.trigger(jTrigger)
return self
+ @since(2.4)
+ def foreach(self, f):
+ """
+ Sets the output of the streaming query to be processed using the provided writer ``f``.
+ This is often used to write the output of a streaming query to arbitrary storage systems.
+ The processing logic can be specified in two ways.
+
+ #. A **function** that takes a row as input.
+ This is a simple way to express your processing logic. Note that this does
+ not allow you to deduplicate generated data when failures cause reprocessing of
+ some input data. That would require you to specify the processing logic in the next
+ way.
+
+ #. An **object** with a ``process`` method and optional ``open`` and ``close`` methods.
+ The object can have the following methods.
+
+ * ``open(partition_id, epoch_id)``: *Optional* method that initializes the processing
+ (for example, open a connection, start a transaction, etc). Additionally, you can
+ use the `partition_id` and `epoch_id` to deduplicate regenerated data
+ (discussed later).
+
+ * ``process(row)``: *Non-optional* method that processes each :class:`Row`.
+
+ * ``close(error)``: *Optional* method that finalizes and cleans up (for example,
+ close connection, commit transaction, etc.) after all rows have been processed.
+
+ The object will be used by Spark in the following way.
+
+ * A single copy of this object is responsible of all the data generated by a
+ single task in a query. In other words, one instance is responsible for
+ processing one partition of the data generated in a distributed manner.
+
+ * This object must be serializable because each task will get a fresh
+ serialized-deserialized copy of the provided object. Hence, it is strongly
+ recommended that any initialization for writing data (e.g. opening a
+ connection or starting a transaction) is done after the `open(...)`
+ method has been called, which signifies that the task is ready to generate data.
+
+ * The lifecycle of the methods are as follows.
+
+ For each partition with ``partition_id``:
+
+ ... For each batch/epoch of streaming data with ``epoch_id``:
+
+ ....... Method ``open(partitionId, epochId)`` is called.
+
+ ....... If ``open(...)`` returns true, for each row in the partition and
+ batch/epoch, method ``process(row)`` is called.
+
+ ....... Method ``close(errorOrNull)`` is called with error (if any) seen while
+ processing rows.
+
+ Important points to note:
+
+ * The `partitionId` and `epochId` can be used to deduplicate generated data when
+ failures cause reprocessing of some input data. This depends on the execution
+ mode of the query. If the streaming query is being executed in the micro-batch
+ mode, then every partition represented by a unique tuple (partition_id, epoch_id)
+ is guaranteed to have the same data. Hence, (partition_id, epoch_id) can be used
+ to deduplicate and/or transactionally commit data and achieve exactly-once
+ guarantees. However, if the streaming query is being executed in the continuous
+ mode, then this guarantee does not hold and therefore should not be used for
+ deduplication.
+
+ * The ``close()`` method (if exists) will be called if `open()` method exists and
+ returns successfully (irrespective of the return value), except if the Python
+ crashes in the middle.
+
+ .. note:: Evolving.
+
+ >>> # Print every row using a function
+ >>> def print_row(row):
+ ... print(row)
+ ...
+ >>> writer = sdf.writeStream.foreach(print_row)
+ >>> # Print every row using a object with process() method
+ >>> class RowPrinter:
+ ... def open(self, partition_id, epoch_id):
+ ... print("Opened %d, %d" % (partition_id, epoch_id))
+ ... return True
+ ... def process(self, row):
+ ... print(row)
+ ... def close(self, error):
+ ... print("Closed with error: %s" % str(error))
+ ...
+ >>> writer = sdf.writeStream.foreach(RowPrinter())
+ """
+
+ from pyspark.rdd import _wrap_function
+ from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
+ from pyspark.taskcontext import TaskContext
+
+ if callable(f):
+ # The provided object is a callable function that is supposed to be called on each row.
+ # Construct a function that takes an iterator and calls the provided function on each
+ # row.
+ def func_without_process(_, iterator):
+ for x in iterator:
+ f(x)
+ return iter([])
+
+ func = func_without_process
+
+ else:
+ # The provided object is not a callable function. Then it is expected to have a
+ # 'process(row)' method, and optional 'open(partition_id, epoch_id)' and
+ # 'close(error)' methods.
+
+ if not hasattr(f, 'process'):
+ raise Exception("Provided object does not have a 'process' method")
+
+ if not callable(getattr(f, 'process')):
+ raise Exception("Attribute 'process' in provided object is not callable")
+
+ def doesMethodExist(method_name):
+ exists = hasattr(f, method_name)
+ if exists and not callable(getattr(f, method_name)):
+ raise Exception(
+ "Attribute '%s' in provided object is not callable" % method_name)
+ return exists
+
+ open_exists = doesMethodExist('open')
+ close_exists = doesMethodExist('close')
+
+ def func_with_open_process_close(partition_id, iterator):
+ epoch_id = TaskContext.get().getLocalProperty('streaming.sql.batchId')
+ if epoch_id:
+ epoch_id = int(epoch_id)
+ else:
+ raise Exception("Could not get batch id from TaskContext")
+
+ # Check if the data should be processed
+ should_process = True
+ if open_exists:
+ should_process = f.open(partition_id, epoch_id)
+
+ error = None
+
+ try:
+ if should_process:
+ for x in iterator:
+ f.process(x)
+ except Exception as ex:
+ error = ex
+ finally:
+ if close_exists:
+ f.close(error)
+ if error:
+ raise error
+
+ return iter([])
+
+ func = func_with_open_process_close
+
+ serializer = AutoBatchedSerializer(PickleSerializer())
+ wrapped_func = _wrap_function(self._spark._sc, func, serializer, serializer)
+ jForeachWriter = \
+ self._spark._sc._jvm.org.apache.spark.sql.execution.python.PythonForeachWriter(
+ wrapped_func, self._df._jdf.schema())
+ self._jwrite.foreach(jForeachWriter)
+ return self
+
@ignore_unicode_prefix
@since(2.0)
def start(self, path=None, format=None, outputMode=None, partitionBy=None, queryName=None,
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index a2450932e303..e0f9cbee516c 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -1885,6 +1885,263 @@ def test_query_manager_await_termination(self):
q.stop()
shutil.rmtree(tmpPath)
+ class ForeachWriterTester:
+
+ def __init__(self, spark):
+ self.spark = spark
+
+ def write_open_event(self, partitionId, epochId):
+ self._write_event(
+ self.open_events_dir,
+ {'partition': partitionId, 'epoch': epochId})
+
+ def write_process_event(self, row):
+ self._write_event(self.process_events_dir, {'value': 'text'})
+
+ def write_close_event(self, error):
+ self._write_event(self.close_events_dir, {'error': str(error)})
+
+ def write_input_file(self):
+ self._write_event(self.input_dir, "text")
+
+ def open_events(self):
+ return self._read_events(self.open_events_dir, 'partition INT, epoch INT')
+
+ def process_events(self):
+ return self._read_events(self.process_events_dir, 'value STRING')
+
+ def close_events(self):
+ return self._read_events(self.close_events_dir, 'error STRING')
+
+ def run_streaming_query_on_writer(self, writer, num_files):
+ self._reset()
+ try:
+ sdf = self.spark.readStream.format('text').load(self.input_dir)
+ sq = sdf.writeStream.foreach(writer).start()
+ for i in range(num_files):
+ self.write_input_file()
+ sq.processAllAvailable()
+ finally:
+ self.stop_all()
+
+ def assert_invalid_writer(self, writer, msg=None):
+ self._reset()
+ try:
+ sdf = self.spark.readStream.format('text').load(self.input_dir)
+ sq = sdf.writeStream.foreach(writer).start()
+ self.write_input_file()
+ sq.processAllAvailable()
+ self.fail("invalid writer %s did not fail the query" % str(writer)) # not expected
+ except Exception as e:
+ if msg:
+ assert(msg in str(e), "%s not in %s" % (msg, str(e)))
+
+ finally:
+ self.stop_all()
+
+ def stop_all(self):
+ for q in self.spark._wrapped.streams.active:
+ q.stop()
+
+ def _reset(self):
+ self.input_dir = tempfile.mkdtemp()
+ self.open_events_dir = tempfile.mkdtemp()
+ self.process_events_dir = tempfile.mkdtemp()
+ self.close_events_dir = tempfile.mkdtemp()
+
+ def _read_events(self, dir, json):
+ rows = self.spark.read.schema(json).json(dir).collect()
+ dicts = [row.asDict() for row in rows]
+ return dicts
+
+ def _write_event(self, dir, event):
+ import uuid
+ with open(os.path.join(dir, str(uuid.uuid4())), 'w') as f:
+ f.write("%s\n" % str(event))
+
+ def __getstate__(self):
+ return (self.open_events_dir, self.process_events_dir, self.close_events_dir)
+
+ def __setstate__(self, state):
+ self.open_events_dir, self.process_events_dir, self.close_events_dir = state
+
+ def test_streaming_foreach_with_simple_function(self):
+ tester = self.ForeachWriterTester(self.spark)
+
+ def foreach_func(row):
+ tester.write_process_event(row)
+
+ tester.run_streaming_query_on_writer(foreach_func, 2)
+ self.assertEqual(len(tester.process_events()), 2)
+
+ def test_streaming_foreach_with_basic_open_process_close(self):
+ tester = self.ForeachWriterTester(self.spark)
+
+ class ForeachWriter:
+ def open(self, partitionId, epochId):
+ tester.write_open_event(partitionId, epochId)
+ return True
+
+ def process(self, row):
+ tester.write_process_event(row)
+
+ def close(self, error):
+ tester.write_close_event(error)
+
+ tester.run_streaming_query_on_writer(ForeachWriter(), 2)
+
+ open_events = tester.open_events()
+ self.assertEqual(len(open_events), 2)
+ self.assertSetEqual(set([e['epoch'] for e in open_events]), {0, 1})
+
+ self.assertEqual(len(tester.process_events()), 2)
+
+ close_events = tester.close_events()
+ self.assertEqual(len(close_events), 2)
+ self.assertSetEqual(set([e['error'] for e in close_events]), {'None'})
+
+ def test_streaming_foreach_with_open_returning_false(self):
+ tester = self.ForeachWriterTester(self.spark)
+
+ class ForeachWriter:
+ def open(self, partition_id, epoch_id):
+ tester.write_open_event(partition_id, epoch_id)
+ return False
+
+ def process(self, row):
+ tester.write_process_event(row)
+
+ def close(self, error):
+ tester.write_close_event(error)
+
+ tester.run_streaming_query_on_writer(ForeachWriter(), 2)
+
+ self.assertEqual(len(tester.open_events()), 2)
+
+ self.assertEqual(len(tester.process_events()), 0) # no row was processed
+
+ close_events = tester.close_events()
+ self.assertEqual(len(close_events), 2)
+ self.assertSetEqual(set([e['error'] for e in close_events]), {'None'})
+
+ def test_streaming_foreach_without_open_method(self):
+ tester = self.ForeachWriterTester(self.spark)
+
+ class ForeachWriter:
+ def process(self, row):
+ tester.write_process_event(row)
+
+ def close(self, error):
+ tester.write_close_event(error)
+
+ tester.run_streaming_query_on_writer(ForeachWriter(), 2)
+ self.assertEqual(len(tester.open_events()), 0) # no open events
+ self.assertEqual(len(tester.process_events()), 2)
+ self.assertEqual(len(tester.close_events()), 2)
+
+ def test_streaming_foreach_without_close_method(self):
+ tester = self.ForeachWriterTester(self.spark)
+
+ class ForeachWriter:
+ def open(self, partition_id, epoch_id):
+ tester.write_open_event(partition_id, epoch_id)
+ return True
+
+ def process(self, row):
+ tester.write_process_event(row)
+
+ tester.run_streaming_query_on_writer(ForeachWriter(), 2)
+ self.assertEqual(len(tester.open_events()), 2) # no open events
+ self.assertEqual(len(tester.process_events()), 2)
+ self.assertEqual(len(tester.close_events()), 0)
+
+ def test_streaming_foreach_without_open_and_close_methods(self):
+ tester = self.ForeachWriterTester(self.spark)
+
+ class ForeachWriter:
+ def process(self, row):
+ tester.write_process_event(row)
+
+ tester.run_streaming_query_on_writer(ForeachWriter(), 2)
+ self.assertEqual(len(tester.open_events()), 0) # no open events
+ self.assertEqual(len(tester.process_events()), 2)
+ self.assertEqual(len(tester.close_events()), 0)
+
+ def test_streaming_foreach_with_process_throwing_error(self):
+ from pyspark.sql.utils import StreamingQueryException
+
+ tester = self.ForeachWriterTester(self.spark)
+
+ class ForeachWriter:
+ def process(self, row):
+ raise Exception("test error")
+
+ def close(self, error):
+ tester.write_close_event(error)
+
+ try:
+ tester.run_streaming_query_on_writer(ForeachWriter(), 1)
+ self.fail("bad writer did not fail the query") # this is not expected
+ except StreamingQueryException as e:
+ # TODO: Verify whether original error message is inside the exception
+ pass
+
+ self.assertEqual(len(tester.process_events()), 0) # no row was processed
+ close_events = tester.close_events()
+ self.assertEqual(len(close_events), 1)
+ # TODO: Verify whether original error message is inside the exception
+
+ def test_streaming_foreach_with_invalid_writers(self):
+
+ tester = self.ForeachWriterTester(self.spark)
+
+ def func_with_iterator_input(iter):
+ for x in iter:
+ print(x)
+
+ tester.assert_invalid_writer(func_with_iterator_input)
+
+ class WriterWithoutProcess:
+ def open(self, partition):
+ pass
+
+ tester.assert_invalid_writer(WriterWithoutProcess(), "does not have a 'process'")
+
+ class WriterWithNonCallableProcess():
+ process = True
+
+ tester.assert_invalid_writer(WriterWithNonCallableProcess(),
+ "'process' in provided object is not callable")
+
+ class WriterWithNoParamProcess():
+ def process(self):
+ pass
+
+ tester.assert_invalid_writer(WriterWithNoParamProcess())
+
+ # Abstract class for tests below
+ class WithProcess():
+ def process(self, row):
+ pass
+
+ class WriterWithNonCallableOpen(WithProcess):
+ open = True
+
+ tester.assert_invalid_writer(WriterWithNonCallableOpen(),
+ "'open' in provided object is not callable")
+
+ class WriterWithNoParamOpen(WithProcess):
+ def open(self):
+ pass
+
+ tester.assert_invalid_writer(WriterWithNoParamOpen())
+
+ class WriterWithNonCallableClose(WithProcess):
+ close = True
+
+ tester.assert_invalid_writer(WriterWithNonCallableClose(),
+ "'close' in provided object is not callable")
+
def test_help_command(self):
# Regression test for SPARK-5464
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 30723b8e15b3..51649bb40706 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -581,9 +581,9 @@ def test_get_local_property(self):
self.sc.setLocalProperty(key, value)
try:
rdd = self.sc.parallelize(range(1), 1)
- prop1 = rdd.map(lambda x: TaskContext.get().getLocalProperty(key)).collect()[0]
+ prop1 = rdd.map(lambda _: TaskContext.get().getLocalProperty(key)).collect()[0]
self.assertEqual(prop1, value)
- prop2 = rdd.map(lambda x: TaskContext.get().getLocalProperty("otherkey")).collect()[0]
+ prop2 = rdd.map(lambda _: TaskContext.get().getLocalProperty("otherkey")).collect()[0]
self.assertTrue(prop2 is None)
finally:
self.sc.setLocalProperty(key, None)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala
index 86e02e98c01f..b21c50af1843 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala
@@ -20,10 +20,48 @@ package org.apache.spark.sql
import org.apache.spark.annotation.InterfaceStability
/**
- * A class to consume data generated by a `StreamingQuery`. Typically this is used to send the
- * generated data to external systems. Each partition will use a new deserialized instance, so you
- * usually should do all the initialization (e.g. opening a connection or initiating a transaction)
- * in the `open` method.
+ * The abstract class for writing custom logic to process data generated by a query.
+ * This is often used to write the output of a streaming query to arbitrary storage systems.
+ * Any implementation of this base class will be used by Spark in the following way.
+ *
+ *
+ * - A single instance of this class is responsible of all the data generated by a single task
+ * in a query. In other words, one instance is responsible for processing one partition of the
+ * data generated in a distributed manner.
+ *
+ *
- Any implementation of this class must be serializable because each task will get a fresh
+ * serialized-deserialized copy of the provided object. Hence, it is strongly recommended that
+ * any initialization for writing data (e.g. opening a connection or starting a transaction)
+ * is done after the `open(...)` method has been called, which signifies that the task is
+ * ready to generate data.
+ *
+ *
- The lifecycle of the methods are as follows.
+ *
+ *
+ * For each partition with `partitionId`:
+ * For each batch/epoch of streaming data (if its streaming query) with `epochId`:
+ * Method `open(partitionId, epochId)` is called.
+ * If `open` returns true:
+ * For each row in the partition and batch/epoch, method `process(row)` is called.
+ * Method `close(errorOrNull)` is called with error (if any) seen while processing rows.
+ *
+ *
+ *
+ *
+ * Important points to note:
+ *
+ * - The `partitionId` and `epochId` can be used to deduplicate generated data when failures
+ * cause reprocessing of some input data. This depends on the execution mode of the query. If
+ * the streaming query is being executed in the micro-batch mode, then every partition
+ * represented by a unique tuple (partitionId, epochId) is guaranteed to have the same data.
+ * Hence, (partitionId, epochId) can be used to deduplicate and/or transactionally commit data
+ * and achieve exactly-once guarantees. However, if the streaming query is being executed in the
+ * continuous mode, then this guarantee does not hold and therefore should not be used for
+ * deduplication.
+ *
+ *
- The `close()` method will be called if `open()` method returns successfully (irrespective
+ * of the return value), except if the JVM crashes in the middle.
+ *
*
* Scala example:
* {{{
@@ -63,6 +101,7 @@ import org.apache.spark.annotation.InterfaceStability
* }
* });
* }}}
+ *
* @since 2.0.0
*/
@InterfaceStability.Evolving
@@ -71,23 +110,18 @@ abstract class ForeachWriter[T] extends Serializable {
// TODO: Move this to org.apache.spark.sql.util or consolidate this with batch API.
/**
- * Called when starting to process one partition of new data in the executor. The `version` is
- * for data deduplication when there are failures. When recovering from a failure, some data may
- * be generated multiple times but they will always have the same version.
- *
- * If this method finds using the `partitionId` and `version` that this partition has already been
- * processed, it can return `false` to skip the further data processing. However, `close` still
- * will be called for cleaning up resources.
+ * Called when starting to process one partition of new data in the executor. See the class
+ * docs for more information on how to use the `partitionId` and `epochId`.
*
* @param partitionId the partition id.
- * @param version a unique id for data deduplication.
+ * @param epochId a unique id for data deduplication.
* @return `true` if the corresponding partition and version id should be processed. `false`
* indicates the partition should be skipped.
*/
- def open(partitionId: Long, version: Long): Boolean
+ def open(partitionId: Long, epochId: Long): Boolean
/**
- * Called to process the data in the executor side. This method will be called only when `open`
+ * Called to process the data in the executor side. This method will be called only if `open`
* returns `true`.
*/
def process(value: T): Unit
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala
new file mode 100644
index 000000000000..a58773122922
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala
@@ -0,0 +1,161 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io.File
+import java.util.concurrent.TimeUnit
+import java.util.concurrent.locks.ReentrantLock
+
+import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.api.python._
+import org.apache.spark.internal.Logging
+import org.apache.spark.memory.TaskMemoryManager
+import org.apache.spark.sql.ForeachWriter
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.{NextIterator, Utils}
+
+class PythonForeachWriter(func: PythonFunction, schema: StructType)
+ extends ForeachWriter[UnsafeRow] {
+
+ private lazy val context = TaskContext.get()
+ private lazy val buffer = new PythonForeachWriter.UnsafeRowBuffer(
+ context.taskMemoryManager, new File(Utils.getLocalDir(SparkEnv.get.conf)), schema.fields.length)
+ private lazy val inputRowIterator = buffer.iterator
+
+ private lazy val inputByteIterator = {
+ EvaluatePython.registerPicklers()
+ val objIterator = inputRowIterator.map { row => EvaluatePython.toJava(row, schema) }
+ new SerDeUtil.AutoBatchedPickler(objIterator)
+ }
+
+ private lazy val pythonRunner = {
+ val conf = SparkEnv.get.conf
+ val bufferSize = conf.getInt("spark.buffer.size", 65536)
+ val reuseWorker = conf.getBoolean("spark.python.worker.reuse", true)
+ PythonRunner(func, bufferSize, reuseWorker)
+ }
+
+ private lazy val outputIterator =
+ pythonRunner.compute(inputByteIterator, context.partitionId(), context)
+
+ override def open(partitionId: Long, version: Long): Boolean = {
+ outputIterator // initialize everything
+ TaskContext.get.addTaskCompletionListener { _ => buffer.close() }
+ true
+ }
+
+ override def process(value: UnsafeRow): Unit = {
+ buffer.add(value)
+ }
+
+ override def close(errorOrNull: Throwable): Unit = {
+ buffer.allRowsAdded()
+ if (outputIterator.hasNext) outputIterator.next() // to throw python exception if there was one
+ }
+}
+
+object PythonForeachWriter {
+
+ /**
+ * A buffer that is designed for the sole purpose of buffering UnsafeRows in PythonForeachWriter.
+ * It is designed to be used with only 1 writer thread (i.e. JVM task thread) and only 1 reader
+ * thread (i.e. PythonRunner writing thread that reads from the buffer and writes to the Python
+ * worker stdin). Adds to the buffer are non-blocking, and reads through the buffer's iterator
+ * are blocking, that is, it blocks until new data is available or all data has been added.
+ *
+ * Internally, it uses a [[HybridRowQueue]] to buffer the rows in a practically unlimited queue
+ * across memory and local disk. However, HybridRowQueue is designed to be used only with
+ * EvalPythonExec where the reader is always behind the the writer, that is, the reader does not
+ * try to read n+1 rows if the writer has only written n rows at any point of time. This
+ * assumption is not true for PythonForeachWriter where rows may be added at a different rate as
+ * they are consumed by the python worker. Hence, to maintain the invariant of the reader being
+ * behind the writer while using HybridRowQueue, the buffer does the following
+ * - Keeps a count of the rows in the HybridRowQueue
+ * - Blocks the buffer's consuming iterator when the count is 0 so that the reader does not
+ * try to read more rows than what has been written.
+ *
+ * The implementation of the blocking iterator (ReentrantLock, Condition, etc.) has been borrowed
+ * from that of ArrayBlockingQueue.
+ */
+ class UnsafeRowBuffer(taskMemoryManager: TaskMemoryManager, tempDir: File, numFields: Int)
+ extends Logging {
+ private val queue = HybridRowQueue(taskMemoryManager, tempDir, numFields)
+ private val lock = new ReentrantLock()
+ private val unblockRemove = lock.newCondition()
+
+ // All of these are guarded by `lock`
+ private var count = 0L
+ private var allAdded = false
+ private var exception: Throwable = null
+
+ val iterator = new NextIterator[UnsafeRow] {
+ override protected def getNext(): UnsafeRow = {
+ val row = remove()
+ if (row == null) finished = true
+ row
+ }
+ override protected def close(): Unit = { }
+ }
+
+ def add(row: UnsafeRow): Unit = withLock {
+ assert(queue.add(row), s"Failed to add row to HybridRowQueue while sending data to Python" +
+ s"[count = $count, allAdded = $allAdded, exception = $exception]")
+ count += 1
+ unblockRemove.signal()
+ logTrace(s"Added $row, $count left")
+ }
+
+ private def remove(): UnsafeRow = withLock {
+ while (count == 0 && !allAdded && exception == null) {
+ unblockRemove.await(100, TimeUnit.MILLISECONDS)
+ }
+
+ // If there was any error in the adding thread, then rethrow it in the removing thread
+ if (exception != null) throw exception
+
+ if (count > 0) {
+ val row = queue.remove()
+ assert(row != null, "HybridRowQueue.remove() returned null " +
+ s"[count = $count, allAdded = $allAdded, exception = $exception]")
+ count -= 1
+ logTrace(s"Removed $row, $count left")
+ row
+ } else {
+ null
+ }
+ }
+
+ def allRowsAdded(): Unit = withLock {
+ allAdded = true
+ unblockRemove.signal()
+ }
+
+ def close(): Unit = { queue.close() }
+
+ private def withLock[T](f: => T): T = {
+ lock.lockInterruptibly()
+ try { f } catch {
+ case e: Throwable =>
+ if (exception == null) exception = e
+ throw e
+ } finally { lock.unlock() }
+ }
+ }
+}
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala
index df5d69d57e36..f677f25f116a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala
@@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.streaming.sources
import org.apache.spark.sql.{Encoder, ForeachWriter, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.python.PythonForeachWriter
import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport}
import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage}
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
@@ -31,9 +33,14 @@ import org.apache.spark.sql.types.StructType
* [[ForeachWriter]].
*
* @param writer The [[ForeachWriter]] to process all data.
+ * @param converter An object to convert internal rows to target type T. Either it can be
+ * a [[ExpressionEncoder]] or a direct converter function.
* @tparam T The expected type of the sink.
*/
-case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends StreamWriteSupport {
+case class ForeachWriterProvider[T](
+ writer: ForeachWriter[T],
+ converter: Either[ExpressionEncoder[T], InternalRow => T]) extends StreamWriteSupport {
+
override def createStreamWriter(
queryId: String,
schema: StructType,
@@ -44,10 +51,16 @@ case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends S
override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = {
- val encoder = encoderFor[T].resolveAndBind(
- schema.toAttributes,
- SparkSession.getActiveSession.get.sessionState.analyzer)
- ForeachWriterFactory(writer, encoder)
+ val rowConverter: InternalRow => T = converter match {
+ case Left(enc) =>
+ val boundEnc = enc.resolveAndBind(
+ schema.toAttributes,
+ SparkSession.getActiveSession.get.sessionState.analyzer)
+ boundEnc.fromRow
+ case Right(func) =>
+ func
+ }
+ ForeachWriterFactory(writer, rowConverter)
}
override def toString: String = "ForeachSink"
@@ -55,29 +68,44 @@ case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends S
}
}
-case class ForeachWriterFactory[T: Encoder](
+object ForeachWriterProvider {
+ def apply[T](
+ writer: ForeachWriter[T],
+ encoder: ExpressionEncoder[T]): ForeachWriterProvider[_] = {
+ writer match {
+ case pythonWriter: PythonForeachWriter =>
+ new ForeachWriterProvider[UnsafeRow](
+ pythonWriter, Right((x: InternalRow) => x.asInstanceOf[UnsafeRow]))
+ case _ =>
+ new ForeachWriterProvider[T](writer, Left(encoder))
+ }
+ }
+}
+
+case class ForeachWriterFactory[T](
writer: ForeachWriter[T],
- encoder: ExpressionEncoder[T])
+ rowConverter: InternalRow => T)
extends DataWriterFactory[InternalRow] {
override def createDataWriter(
partitionId: Int,
attemptNumber: Int,
epochId: Long): ForeachDataWriter[T] = {
- new ForeachDataWriter(writer, encoder, partitionId, epochId)
+ new ForeachDataWriter(writer, rowConverter, partitionId, epochId)
}
}
/**
* A [[DataWriter]] which writes data in this partition to a [[ForeachWriter]].
+ *
* @param writer The [[ForeachWriter]] to process all data.
- * @param encoder An encoder which can convert [[InternalRow]] to the required type [[T]]
+ * @param rowConverter A function which can convert [[InternalRow]] to the required type [[T]]
* @param partitionId
* @param epochId
* @tparam T The type expected by the writer.
*/
-class ForeachDataWriter[T : Encoder](
+class ForeachDataWriter[T](
writer: ForeachWriter[T],
- encoder: ExpressionEncoder[T],
+ rowConverter: InternalRow => T,
partitionId: Int,
epochId: Long)
extends DataWriter[InternalRow] {
@@ -89,7 +117,7 @@ class ForeachDataWriter[T : Encoder](
if (!opened) return
try {
- writer.process(encoder.fromRow(record))
+ writer.process(rowConverter(record))
} catch {
case t: Throwable =>
writer.close(t)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
index effc1471e8e1..e035c9cdc379 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
@@ -269,7 +269,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
query
} else if (source == "foreach") {
assertNotPartitioned("foreach")
- val sink = new ForeachWriterProvider[T](foreachWriter)(ds.exprEnc)
+ val sink = ForeachWriterProvider[T](foreachWriter, ds.exprEnc)
df.sparkSession.sessionState.streamingQueryManager.startQuery(
extraOptions.get("queryName"),
extraOptions.get("checkpointLocation"),
@@ -307,49 +307,9 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
}
/**
- * Starts the execution of the streaming query, which will continually send results to the given
- * `ForeachWriter` as new data arrives. The `ForeachWriter` can be used to send the data
- * generated by the `DataFrame`/`Dataset` to an external system.
- *
- * Scala example:
- * {{{
- * datasetOfString.writeStream.foreach(new ForeachWriter[String] {
- *
- * def open(partitionId: Long, version: Long): Boolean = {
- * // open connection
- * }
- *
- * def process(record: String) = {
- * // write string to connection
- * }
- *
- * def close(errorOrNull: Throwable): Unit = {
- * // close the connection
- * }
- * }).start()
- * }}}
- *
- * Java example:
- * {{{
- * datasetOfString.writeStream().foreach(new ForeachWriter() {
- *
- * @Override
- * public boolean open(long partitionId, long version) {
- * // open connection
- * }
- *
- * @Override
- * public void process(String value) {
- * // write string to connection
- * }
- *
- * @Override
- * public void close(Throwable errorOrNull) {
- * // close the connection
- * }
- * }).start();
- * }}}
- *
+ * Sets the output of the streaming query to be processed using the provided writer object.
+ * object. See [[org.apache.spark.sql.ForeachWriter]] for more details on the lifecycle and
+ * semantics.
* @since 2.0.0
*/
def foreach(writer: ForeachWriter[T]): DataStreamWriter[T] = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala
new file mode 100644
index 000000000000..07e603477012
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.scalatest.concurrent.Eventually
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark._
+import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager}
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection}
+import org.apache.spark.sql.execution.python.PythonForeachWriter.UnsafeRowBuffer
+import org.apache.spark.sql.types.{DataType, IntegerType}
+import org.apache.spark.util.Utils
+
+class PythonForeachWriterSuite extends SparkFunSuite with Eventually {
+
+ testWithBuffer("UnsafeRowBuffer: iterator blocks when no data is available") { b =>
+ b.assertIteratorBlocked()
+
+ b.add(Seq(1))
+ b.assertOutput(Seq(1))
+ b.assertIteratorBlocked()
+
+ b.add(2 to 100)
+ b.assertOutput(1 to 100)
+ b.assertIteratorBlocked()
+ }
+
+ testWithBuffer("UnsafeRowBuffer: iterator unblocks when all data added") { b =>
+ b.assertIteratorBlocked()
+ b.add(Seq(1))
+ b.assertIteratorBlocked()
+
+ b.allAdded()
+ b.assertThreadTerminated()
+ b.assertOutput(Seq(1))
+ }
+
+ testWithBuffer(
+ "UnsafeRowBuffer: handles more data than memory",
+ memBytes = 5,
+ sleepPerRowReadMs = 1) { b =>
+
+ b.assertIteratorBlocked()
+ b.add(1 to 2000)
+ b.assertOutput(1 to 2000)
+ }
+
+ def testWithBuffer(
+ name: String,
+ memBytes: Long = 4 << 10,
+ sleepPerRowReadMs: Int = 0
+ )(f: BufferTester => Unit): Unit = {
+
+ test(name) {
+ var tester: BufferTester = null
+ try {
+ tester = new BufferTester(memBytes, sleepPerRowReadMs)
+ f(tester)
+ } finally {
+ if (tester == null) tester.close()
+ }
+ }
+ }
+
+
+ class BufferTester(memBytes: Long, sleepPerRowReadMs: Int) {
+ private val buffer = {
+ val mem = new TestMemoryManager(new SparkConf())
+ mem.limit(memBytes)
+ val taskM = new TaskMemoryManager(mem, 0)
+ new UnsafeRowBuffer(taskM, Utils.createTempDir(), 1)
+ }
+ private val iterator = buffer.iterator
+ private val outputBuffer = new ArrayBuffer[Int]
+ private val testTimeout = timeout(20.seconds)
+ private val intProj = UnsafeProjection.create(Array[DataType](IntegerType))
+ private val thread = new Thread() {
+ override def run(): Unit = {
+ while (iterator.hasNext) {
+ outputBuffer.synchronized {
+ outputBuffer += iterator.next().getInt(0)
+ }
+ Thread.sleep(sleepPerRowReadMs)
+ }
+ }
+ }
+ thread.start()
+
+ def add(ints: Seq[Int]): Unit = {
+ ints.foreach { i => buffer.add(intProj.apply(new GenericInternalRow(Array[Any](i)))) }
+ }
+
+ def allAdded(): Unit = { buffer.allRowsAdded() }
+
+ def assertOutput(expectedOutput: Seq[Int]): Unit = {
+ eventually(testTimeout) {
+ val output = outputBuffer.synchronized { outputBuffer.toArray }.toSeq
+ assert(output == expectedOutput)
+ }
+ }
+
+ def assertIteratorBlocked(): Unit = {
+ import Thread.State._
+ eventually(testTimeout) {
+ assert(thread.isAlive)
+ assert(thread.getState == TIMED_WAITING || thread.getState == WAITING)
+ }
+ }
+
+ def assertThreadTerminated(): Unit = {
+ eventually(testTimeout) { assert(!thread.isAlive) }
+ }
+
+ def close(): Unit = {
+ thread.interrupt()
+ thread.join()
+ }
+ }
+}