From 701a45506d75169455384ec8eebd30e509591c30 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 31 May 2018 15:45:10 -0700 Subject: [PATCH 1/8] First draft --- dev/sparktestsupport/modules.py | 14 +- python/pyspark/sql/streaming.py | 82 +++++++++ python/pyspark/sql/tests.py | 160 ++++++++++++++++- .../python/PythonForeachWriter.scala | 161 ++++++++++++++++++ .../sources/ForeachWriterProvider.scala | 52 ++++-- .../sql/streaming/DataStreamWriter.scala | 5 +- .../python/PythonForeachWriterSuite.scala | 137 +++++++++++++++ 7 files changed, 584 insertions(+), 27 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index dfea762db98c..aacde5a957fe 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -389,19 +389,7 @@ def __hash__(self): "python/pyspark/sql" ], python_test_goals=[ - "pyspark.sql.types", - "pyspark.sql.context", - "pyspark.sql.session", - "pyspark.sql.conf", - "pyspark.sql.catalog", - "pyspark.sql.column", - "pyspark.sql.dataframe", - "pyspark.sql.group", - "pyspark.sql.functions", - "pyspark.sql.readwriter", - "pyspark.sql.streaming", - "pyspark.sql.udf", - "pyspark.sql.window", + "pyspark.sql.tests", ] ) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 15f940738986..148632400721 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -30,6 +30,7 @@ from pyspark.sql.readwriter import OptionUtils, to_str from pyspark.sql.types import * from pyspark.sql.utils import StreamingQueryException +from abc import ABCMeta, abstractmethod __all__ = ["StreamingQuery", "StreamingQueryManager", "DataStreamReader", "DataStreamWriter"] @@ -843,6 +844,87 @@ def trigger(self, processingTime=None, once=None, continuous=None): self._jwrite = self._jwrite.trigger(jTrigger) return self + def foreach(self, f): + + from pyspark.rdd import _wrap_function + from pyspark.serializers import PickleSerializer, AutoBatchedSerializer + from pyspark.taskcontext import TaskContext + + if callable(f): + """ + The provided object is a callable function that is supposed to be called on each row. + Construct a function that takes an iterator and calls the provided function on each row. + """ + def func_without_process(_, iterator): + for x in iterator: + f(x) + return iter([]) + + func = func_without_process + + else: + """ + The provided object is not a callable function. Then it is expected to have a + 'process(row)' method, and optional 'open(partitionId, epochOrBatchId)' and + 'close(error)' methods. + """ + + if not hasattr(f, 'process'): + raise Exception( + "Provided object is neither callable nor does it have a 'process' method") + + if not callable(getattr(f, 'process')): + raise Exception("Attribute 'process' in provided object is not callable") + + open_exists = False + if hasattr(f, 'open'): + if not callable(getattr(f, 'open')): + raise Exception("Attribute 'open' in provided object is not callable") + else: + open_exists = True + + close_exists = False + if hasattr(f, "close"): + if not callable(getattr(f, 'close')): + raise Exception("Attribute 'close' in provided object is not callable") + else: + close_exists = True + + def func_with_open_process_close(partitionId, iterator): + version = TaskContext.get().getLocalProperty('streaming.sql.batchId') + if version: + version = int(version) + else: + raise Exception("Could not get batch id from TaskContext") + + should_process = True + if open_exists: + should_process = f.open(partitionId, version) + + def call_close_if_needed(error): + if open_exists and close_exists: + f.close(error) + try: + if should_process: + for x in iterator: + f.process(x) + except Exception as ex: + call_close_if_needed(ex) + raise ex + + call_close_if_needed(None) + return iter([]) + + func = func_with_open_process_close + + serializer = AutoBatchedSerializer(PickleSerializer()) + wrapped_func = _wrap_function(self._spark._sc, func, serializer, serializer) + jForeachWriter = \ + self._spark._sc._jvm.org.apache.spark.sql.execution.python.PythonForeachWriter( + wrapped_func, self._df._jdf.schema()) + self._jwrite.foreach(jForeachWriter) + return self + @ignore_unicode_prefix @since(2.0) def start(self, path=None, format=None, outputMode=None, partitionBy=None, queryName=None, diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a2450932e303..7884f929cdd9 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -296,6 +296,7 @@ def tearDown(self): # tear down test_bucketed_write state self.spark.sql("DROP TABLE IF EXISTS pyspark_bucket") + ''' def test_row_should_be_read_only(self): row = Row(a=1, b=2) self.assertEqual(1, row.a) @@ -1884,7 +1885,164 @@ def test_query_manager_await_termination(self): finally: q.stop() shutil.rmtree(tmpPath) + ''' + class ForeachWriterTester: + + def __init__(self, spark): + self.spark = spark + self.input_dir = tempfile.mkdtemp() + self.open_events_dir = tempfile.mkdtemp() + self.process_events_dir = tempfile.mkdtemp() + self.close_events_dir = tempfile.mkdtemp() + + def write_open_event(self, partitionId, epochId): + self._write_event( + self.open_events_dir, + {'partition': partitionId, 'epoch': epochId}) + + def write_process_event(self, row): + self._write_event(self.process_events_dir, {'value': 'text'}) + + def write_close_event(self, error): + self._write_event(self.close_events_dir, {'error': str(error)}) + + def write_input_file(self): + self._write_event(self.input_dir, "text") + + def open_events(self): + return self._read_events(self.open_events_dir, 'partition INT, epoch INT') + + def process_events(self): + return self._read_events(self.process_events_dir, 'value STRING') + + def close_events(self): + return self._read_events(self.close_events_dir, 'error STRING') + + def run_streaming_query_on_writer(self, writer, num_files): + try: + sdf = self.spark.readStream.format('text').load(self.input_dir) + sq = sdf.writeStream.foreach(writer).start() + for i in range(num_files): + self.write_input_file() + sq.processAllAvailable() + sq.stop() + finally: + self.stop_all() + + def _read_events(self, dir, json): + rows = self.spark.read.schema(json).json(dir).collect() + dicts = [row.asDict() for row in rows] + return dicts + + def _write_event(self, dir, event): + import random + file = open(os.path.join(dir, str(random.randint(0, 100000))), 'w') + file.write("%s\n" % str(event)) + file.close() + + def stop_all(self): + for q in self.spark._wrapped.streams.active: + q.stop() + + def __getstate__(self): + return (self.open_events_dir, self.process_events_dir, self.close_events_dir) + + def __setstate__(self, state): + self.open_events_dir, self.process_events_dir, self.close_events_dir = state + + def test_streaming_foreach_with_simple_function(self): + tester = self.ForeachWriterTester(self.spark) + + def foreach_func(row): + tester.write_process_event(row) + + tester.run_streaming_query_on_writer(foreach_func, 2) + self.assertEqual(len(tester.process_events()), 2) + + def test_streaming_foreach_with_basic_open_process_close(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def open(self, partitionId, epochId): + tester.write_open_event(partitionId, epochId) + return True + + def process(self, row): + tester.write_process_event(row) + + def close(self, error): + tester.write_close_event(error) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + + open_events = tester.open_events() + self.assertEqual(len(open_events), 2) + self.assertSetEqual(set([e['epoch'] for e in open_events]), {0, 1}) + + self.assertEqual(len(tester.process_events()), 2) + + close_events = tester.close_events() + self.assertEqual(len(close_events), 2) + self.assertSetEqual(set([e['error'] for e in close_events]), {'None'}) + + def test_streaming_foreach_with_open_returning_false(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def open(self, partitionId, epochId): + tester.write_open_event(partitionId, epochId) + return False + + def process(self, row): + tester.write_process_event(row) + + def close(self, error): + tester.write_close_event(error) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + + self.assertEqual(len(tester.open_events()), 2) + self.assertEqual(len(tester.process_events()), 0) # no row was processed + close_events = tester.close_events() + self.assertEqual(len(close_events), 2) + self.assertSetEqual(set([e['error'] for e in close_events]), {'None'}) + + def test_streaming_foreach_with_process_throwing_error(self): + from pyspark.sql.utils import StreamingQueryException + + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def open(self, partitionId, epochId): + tester.write_open_event(partitionId, epochId) + return True + + def process(self, row): + raise Exception("test error") + + def close(self, error): + tester.write_close_event(error) + + try: + sdf = self.spark.readStream.format('text').load(tester.input_dir) + sq = sdf.writeStream.foreach(ForeachWriter()).start() + tester.write_input_file() + sq.processAllAvailable() + self.fail("bad writer should fail the query") # this is not expected + except StreamingQueryException as e: + # self.assertTrue("test error" in e.desc) # this is expected + pass + finally: + tester.stop_all() + + self.assertEqual(len(tester.open_events()), 1) + self.assertEqual(len(tester.process_events()), 0) # no row was processed + close_events = tester.close_events() + self.assertEqual(len(close_events), 1) + # self.assertTrue("test error" in e[0]['error']) + + ''' def test_help_command(self): # Regression test for SPARK-5464 rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) @@ -5391,7 +5549,7 @@ def test_invalid_args(self): AnalysisException, 'mixture.*aggregate function.*group aggregate pandas UDF'): df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect() - + ''' if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala new file mode 100644 index 000000000000..47e290a44b2b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.File +import java.util.concurrent.TimeUnit +import java.util.concurrent.locks.ReentrantLock + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.api.python._ +import org.apache.spark.internal.Logging +import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.sql.ForeachWriter +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.{NextIterator, Utils} + +class PythonForeachWriter(func: PythonFunction, schema: StructType) + extends ForeachWriter[UnsafeRow] { + + private lazy val context = TaskContext.get() + private lazy val buffer = new PythonForeachWriter.UnsafeRowBuffer( + context.taskMemoryManager, new File(Utils.getLocalDir(SparkEnv.get.conf)), schema.fields.length) + private lazy val inputRowIterator = buffer.iterator + + private lazy val inputByteIterator = { + EvaluatePython.registerPicklers() + val objIterator = inputRowIterator.map { row => EvaluatePython.toJava(row, schema) } + new SerDeUtil.AutoBatchedPickler(objIterator) + } + + private lazy val pythonRunner = { + val conf = SparkEnv.get.conf + val bufferSize = conf.getInt("spark.buffer.size", 65536) + val reuseWorker = conf.getBoolean("spark.python.worker.reuse", true) + PythonRunner(func, bufferSize, reuseWorker) + } + + private lazy val outputIterator = + pythonRunner.compute(inputByteIterator, context.partitionId(), context) + + override def open(partitionId: Long, version: Long): Boolean = { + outputIterator // initialize everything + TaskContext.get.addTaskCompletionListener { _ => buffer.close() } + true + } + + override def process(value: UnsafeRow): Unit = { + buffer.add(value) + } + + override def close(errorOrNull: Throwable): Unit = { + buffer.allRowsAdded() + if (outputIterator.hasNext) outputIterator.next() // to throw python exception if there was one + } +} + +object PythonForeachWriter { + + /** + * A buffer that is designed for the sole purpose of buffering UnsafeRows in PythonForeahWriter. + * It is designed to be used with only 1 writer thread (i.e. JVM task thread) and only 1 reader + * thread (i.e. PythonRunner writing thread that reads from the buffer and writes to the Python + * worker stdin). Adds to the buffer are non-blocking, and reads through the buffer's iterator + * are blocking, that is, it blocks until new data is available or all data has been added. + * + * Internally, it uses a [[HybridRowQueue]] to buffer the rows in a practically unlimited queue + * across memory and local disk. However, HybridRowQueue is designed to be used only with + * EvalPythonExec where the reader is always behind the the writer, that is, the reader does not + * try to read n+1 rows if the writer has only written n rows at any point of time. This + * assumption is not true for PythonForeachWriter where rows may be added at a different rate as + * they are consumed by the python worker. Hence, to maintain the invariant of the reader being + * behind the writer while using HybridRowQueue, the buffer does the following + * - Keeps a count of the rows in the HybridRowQueue + * - Blocks the buffer's consuming iterator when the count is 0 so that the reader does not + * try to read more rows than what has been written. + * + * The implementation of the blocking iterator (ReentrantLock, Condition, etc.) has been borrowed + * from that of ArrayBlockingQueue. + */ + class UnsafeRowBuffer(taskMemoryManager: TaskMemoryManager, tempDir: File, numFields: Int) + extends Logging { + private val queue = HybridRowQueue(taskMemoryManager, tempDir, numFields) + private val lock = new ReentrantLock() + private val unblockRemove = lock.newCondition() + + // All of these are guarded by `lock` + private var count = 0L + private var allAdded = false + private var exception: Throwable = null + + val iterator = new NextIterator[UnsafeRow] { + override protected def getNext(): UnsafeRow = { + val row = remove() + if (row == null) finished = true + row + } + override protected def close(): Unit = { } + } + + def add(row: UnsafeRow): Unit = withLock { + assert(queue.add(row), s"Failed to add row to HybridRowQueue while sending data to Python" + + s"[count = $count, allAdded = $allAdded, exception = $exception]") + count += 1 + unblockRemove.signal() + logTrace(s"Added $row, $count left") + } + + private def remove(): UnsafeRow = withLock { + while (count == 0 && !allAdded && exception == null) { + unblockRemove.await(100, TimeUnit.MILLISECONDS) + } + + // If there was any error in the adding thread, then rethrow it in the removing thread + if (exception != null) throw exception + + if (count > 0) { + val row = queue.remove() + assert(row != null, s"HybridRowQueue.remove() returned null " + + s"[count = $count, allAdded = $allAdded, exception = $exception]") + count -= 1 + logTrace(s"Removed $row, $count left") + row + } else { + null + } + } + + def allRowsAdded(): Unit = withLock { + allAdded = true + unblockRemove.signal() + } + + def close(): Unit = { queue.close() } + + private def withLock[T](f: => T): T = { + lock.lockInterruptibly() + try { f } catch { + case e: Throwable => + if (exception == null) exception = e + throw e + } finally { lock.unlock() } + } + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala index df5d69d57e36..f677f25f116a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.sql.{Encoder, ForeachWriter, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.python.PythonForeachWriter import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter @@ -31,9 +33,14 @@ import org.apache.spark.sql.types.StructType * [[ForeachWriter]]. * * @param writer The [[ForeachWriter]] to process all data. + * @param converter An object to convert internal rows to target type T. Either it can be + * a [[ExpressionEncoder]] or a direct converter function. * @tparam T The expected type of the sink. */ -case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends StreamWriteSupport { +case class ForeachWriterProvider[T]( + writer: ForeachWriter[T], + converter: Either[ExpressionEncoder[T], InternalRow => T]) extends StreamWriteSupport { + override def createStreamWriter( queryId: String, schema: StructType, @@ -44,10 +51,16 @@ case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends S override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = { - val encoder = encoderFor[T].resolveAndBind( - schema.toAttributes, - SparkSession.getActiveSession.get.sessionState.analyzer) - ForeachWriterFactory(writer, encoder) + val rowConverter: InternalRow => T = converter match { + case Left(enc) => + val boundEnc = enc.resolveAndBind( + schema.toAttributes, + SparkSession.getActiveSession.get.sessionState.analyzer) + boundEnc.fromRow + case Right(func) => + func + } + ForeachWriterFactory(writer, rowConverter) } override def toString: String = "ForeachSink" @@ -55,29 +68,44 @@ case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends S } } -case class ForeachWriterFactory[T: Encoder]( +object ForeachWriterProvider { + def apply[T]( + writer: ForeachWriter[T], + encoder: ExpressionEncoder[T]): ForeachWriterProvider[_] = { + writer match { + case pythonWriter: PythonForeachWriter => + new ForeachWriterProvider[UnsafeRow]( + pythonWriter, Right((x: InternalRow) => x.asInstanceOf[UnsafeRow])) + case _ => + new ForeachWriterProvider[T](writer, Left(encoder)) + } + } +} + +case class ForeachWriterFactory[T]( writer: ForeachWriter[T], - encoder: ExpressionEncoder[T]) + rowConverter: InternalRow => T) extends DataWriterFactory[InternalRow] { override def createDataWriter( partitionId: Int, attemptNumber: Int, epochId: Long): ForeachDataWriter[T] = { - new ForeachDataWriter(writer, encoder, partitionId, epochId) + new ForeachDataWriter(writer, rowConverter, partitionId, epochId) } } /** * A [[DataWriter]] which writes data in this partition to a [[ForeachWriter]]. + * * @param writer The [[ForeachWriter]] to process all data. - * @param encoder An encoder which can convert [[InternalRow]] to the required type [[T]] + * @param rowConverter A function which can convert [[InternalRow]] to the required type [[T]] * @param partitionId * @param epochId * @tparam T The type expected by the writer. */ -class ForeachDataWriter[T : Encoder]( +class ForeachDataWriter[T]( writer: ForeachWriter[T], - encoder: ExpressionEncoder[T], + rowConverter: InternalRow => T, partitionId: Int, epochId: Long) extends DataWriter[InternalRow] { @@ -89,7 +117,7 @@ class ForeachDataWriter[T : Encoder]( if (!opened) return try { - writer.process(encoder.fromRow(record)) + writer.process(rowConverter(record)) } catch { case t: Throwable => writer.close(t) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index effc1471e8e1..05ed8d93b188 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -23,9 +23,12 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.{AnalysisException, Dataset, ForeachWriter} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.python.PythonForeachWriter import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources.{ForeachWriterProvider, MemoryPlanV2, MemorySinkV2} @@ -269,7 +272,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { query } else if (source == "foreach") { assertNotPartitioned("foreach") - val sink = new ForeachWriterProvider[T](foreachWriter)(ds.exprEnc) + val sink = ForeachWriterProvider[T](foreachWriter, ds.exprEnc) df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala new file mode 100644 index 000000000000..07e603477012 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.concurrent.Eventually +import org.scalatest.time.SpanSugar._ + +import org.apache.spark._ +import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} +import org.apache.spark.sql.execution.python.PythonForeachWriter.UnsafeRowBuffer +import org.apache.spark.sql.types.{DataType, IntegerType} +import org.apache.spark.util.Utils + +class PythonForeachWriterSuite extends SparkFunSuite with Eventually { + + testWithBuffer("UnsafeRowBuffer: iterator blocks when no data is available") { b => + b.assertIteratorBlocked() + + b.add(Seq(1)) + b.assertOutput(Seq(1)) + b.assertIteratorBlocked() + + b.add(2 to 100) + b.assertOutput(1 to 100) + b.assertIteratorBlocked() + } + + testWithBuffer("UnsafeRowBuffer: iterator unblocks when all data added") { b => + b.assertIteratorBlocked() + b.add(Seq(1)) + b.assertIteratorBlocked() + + b.allAdded() + b.assertThreadTerminated() + b.assertOutput(Seq(1)) + } + + testWithBuffer( + "UnsafeRowBuffer: handles more data than memory", + memBytes = 5, + sleepPerRowReadMs = 1) { b => + + b.assertIteratorBlocked() + b.add(1 to 2000) + b.assertOutput(1 to 2000) + } + + def testWithBuffer( + name: String, + memBytes: Long = 4 << 10, + sleepPerRowReadMs: Int = 0 + )(f: BufferTester => Unit): Unit = { + + test(name) { + var tester: BufferTester = null + try { + tester = new BufferTester(memBytes, sleepPerRowReadMs) + f(tester) + } finally { + if (tester == null) tester.close() + } + } + } + + + class BufferTester(memBytes: Long, sleepPerRowReadMs: Int) { + private val buffer = { + val mem = new TestMemoryManager(new SparkConf()) + mem.limit(memBytes) + val taskM = new TaskMemoryManager(mem, 0) + new UnsafeRowBuffer(taskM, Utils.createTempDir(), 1) + } + private val iterator = buffer.iterator + private val outputBuffer = new ArrayBuffer[Int] + private val testTimeout = timeout(20.seconds) + private val intProj = UnsafeProjection.create(Array[DataType](IntegerType)) + private val thread = new Thread() { + override def run(): Unit = { + while (iterator.hasNext) { + outputBuffer.synchronized { + outputBuffer += iterator.next().getInt(0) + } + Thread.sleep(sleepPerRowReadMs) + } + } + } + thread.start() + + def add(ints: Seq[Int]): Unit = { + ints.foreach { i => buffer.add(intProj.apply(new GenericInternalRow(Array[Any](i)))) } + } + + def allAdded(): Unit = { buffer.allRowsAdded() } + + def assertOutput(expectedOutput: Seq[Int]): Unit = { + eventually(testTimeout) { + val output = outputBuffer.synchronized { outputBuffer.toArray }.toSeq + assert(output == expectedOutput) + } + } + + def assertIteratorBlocked(): Unit = { + import Thread.State._ + eventually(testTimeout) { + assert(thread.isAlive) + assert(thread.getState == TIMED_WAITING || thread.getState == WAITING) + } + } + + def assertThreadTerminated(): Unit = { + eventually(testTimeout) { assert(!thread.isAlive) } + } + + def close(): Unit = { + thread.interrupt() + thread.join() + } + } +} From 0920260158ee85c6b1437378b505f74797a61ec9 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 1 Jun 2018 02:01:36 -0700 Subject: [PATCH 2/8] Added docs --- python/pyspark/sql/streaming.py | 81 +++++++++++++++++++ python/pyspark/tests.py | 4 +- .../org/apache/spark/sql/ForeachWriter.scala | 61 ++++++++++---- .../sql/streaming/DataStreamWriter.scala | 48 +---------- 4 files changed, 132 insertions(+), 62 deletions(-) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 148632400721..fcf8f585ed53 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -845,6 +845,87 @@ def trigger(self, processingTime=None, once=None, continuous=None): return self def foreach(self, f): + """ + Sets the output of the streaming query to be processed using the provided writer ``f``. + This is often used to write the output of a streaming query to arbitrary storage systems. + The processing logic can be specified in two ways. + + #. A **function** that takes a row as input. + This is a simple way to express your processing logic. Note that this does + not allow you to deduplicate generated data when failures cause reprocessing of + some input data. That would require you to specify the processing logic in the next + way. + + #. An **object** with a ``process`` method and optional ``open`` and ``close`` methods. + The object can have the following methods. + + * ``open(partition_id, epoch_id)``: *Optional* method that initializes the processing + (for example, open a connection, start a transaction, etc). Additionally, you can + use the `partition_id` and `epoch_id` to deduplicate regenerated data + (discussed later). + + * ``process(row)``: *Non-optional* method that processes each :class:`Row`. + + * ``close(error)``: *Optional* method that finalizes and cleans up (for example, + close connection, commit transaction, etc.) after all rows have been processed. + + The object will be used by Spark in the following way. + + * A single copy of this object is responsible of all the data generated by a + single task in a query. In other words, one instance is responsible for + processing one partition of the data generated in a distributed manner. + + * This object must be serializable because each task will get a fresh + serialized-deserializedcopy of the provided object. Hence, it is strongly + recommended that any initialization for writing data (e.g. opening a + connection or starting a transaction) be done open after the `open(...)` + method has been called, which signifies that the task is ready to generate data. + + * The lifecycle of the methods are as follows. + + For each partition with ``partition_id``: + + ... For each batch/epoch of streaming data with ``epoch_id``: + + ....... Method ``open(partitionId, epochId)`` is called. + + ....... If ``open(...)`` returns true, for each row in the partition and + batch/epoch, method ``process(row)`` is called. + + ....... Method ``close(errorOrNull)`` is called with error (if any) seen while + processing rows. + + Important points to note: + + * The `partitionId` and `epochId` can be used to deduplicate generated data when + failures cause reprocessing of some input data. This depends on the execution + mode of the query. If the streaming query is being executed in the micro-batch + mode, then every partition represented by a unique tuple (partition_id, epoch_id) + is guaranteed to have the same data. Hence, (partition_id, epoch_id) can be used + to deduplicate and/or transactionally commit data and achieve exactly-once + guarantees. However, if the streaming query is being executed in the continuous + mode, then this guarantee does not hold and therefore should not be used for + deduplication. + + * The ``close()`` method is will be called if `open()` method returns successfully + (irrespective of the return value), except if the JVM crashes in the middle. + + .. note:: Evolving. + + >>> # Print every row using a function + >>> writer = sdf.writeStream.foreach(lambda x: print(x)) + >>> # Print every row using a object with process() method + >>> class RowPrinter: + ... def open(self, partition_id, epoch_id): + ... print("Opened %d, %d" % (partition_id, epoch_id)) + ... return True + ... def process(self, row): + ... print(row) + ... def close(self, error): + ... print("Closed with error: %s" % str(error)) + ... + >>> writer = sdf.writeStream.foreach(RowPrinter()) + """ from pyspark.rdd import _wrap_function from pyspark.serializers import PickleSerializer, AutoBatchedSerializer diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 30723b8e15b3..51649bb40706 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -581,9 +581,9 @@ def test_get_local_property(self): self.sc.setLocalProperty(key, value) try: rdd = self.sc.parallelize(range(1), 1) - prop1 = rdd.map(lambda x: TaskContext.get().getLocalProperty(key)).collect()[0] + prop1 = rdd.map(lambda _: TaskContext.get().getLocalProperty(key)).collect()[0] self.assertEqual(prop1, value) - prop2 = rdd.map(lambda x: TaskContext.get().getLocalProperty("otherkey")).collect()[0] + prop2 = rdd.map(lambda _: TaskContext.get().getLocalProperty("otherkey")).collect()[0] self.assertTrue(prop2 is None) finally: self.sc.setLocalProperty(key, None) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala index 86e02e98c01f..6243884921d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala @@ -20,10 +20,48 @@ package org.apache.spark.sql import org.apache.spark.annotation.InterfaceStability /** - * A class to consume data generated by a `StreamingQuery`. Typically this is used to send the - * generated data to external systems. Each partition will use a new deserialized instance, so you - * usually should do all the initialization (e.g. opening a connection or initiating a transaction) - * in the `open` method. + * The abstract class for writing custom logic to process data generated by a query. + * This is often used to write the output of a streaming query to arbitrary storage systems. + * Any implementation of this base class will be used by Spark in the following way. + * + * + * + * Important points to note: + * * * Scala example: * {{{ @@ -63,6 +101,7 @@ import org.apache.spark.annotation.InterfaceStability * } * }); * }}} + * * @since 2.0.0 */ @InterfaceStability.Evolving @@ -71,23 +110,17 @@ abstract class ForeachWriter[T] extends Serializable { // TODO: Move this to org.apache.spark.sql.util or consolidate this with batch API. /** - * Called when starting to process one partition of new data in the executor. The `version` is - * for data deduplication when there are failures. When recovering from a failure, some data may - * be generated multiple times but they will always have the same version. - * - * If this method finds using the `partitionId` and `version` that this partition has already been - * processed, it can return `false` to skip the further data processing. However, `close` still - * will be called for cleaning up resources. + * Called when starting to process one partition of new data in the executor. * * @param partitionId the partition id. - * @param version a unique id for data deduplication. + * @param epochId a unique id for data deduplication. * @return `true` if the corresponding partition and version id should be processed. `false` * indicates the partition should be skipped. */ - def open(partitionId: Long, version: Long): Boolean + def open(partitionId: Long, epochId: Long): Boolean /** - * Called to process the data in the executor side. This method will be called only when `open` + * Called to process the data in the executor side. This method will be called only if `open` * returns `true`. */ def process(value: T): Unit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 05ed8d93b188..096b57802ab0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -23,12 +23,9 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.{AnalysisException, Dataset, ForeachWriter} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.python.PythonForeachWriter import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources.{ForeachWriterProvider, MemoryPlanV2, MemorySinkV2} @@ -310,49 +307,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { } /** - * Starts the execution of the streaming query, which will continually send results to the given - * `ForeachWriter` as new data arrives. The `ForeachWriter` can be used to send the data - * generated by the `DataFrame`/`Dataset` to an external system. - * - * Scala example: - * {{{ - * datasetOfString.writeStream.foreach(new ForeachWriter[String] { - * - * def open(partitionId: Long, version: Long): Boolean = { - * // open connection - * } - * - * def process(record: String) = { - * // write string to connection - * } - * - * def close(errorOrNull: Throwable): Unit = { - * // close the connection - * } - * }).start() - * }}} - * - * Java example: - * {{{ - * datasetOfString.writeStream().foreach(new ForeachWriter() { - * - * @Override - * public boolean open(long partitionId, long version) { - * // open connection - * } - * - * @Override - * public void process(String value) { - * // write string to connection - * } - * - * @Override - * public void close(Throwable errorOrNull) { - * // close the connection - * } - * }).start(); - * }}} - * + * Sets the output of the streaming query to be processed using the provided writer object. + * object. See [[ForeachWriter]] for more details on the lifecycle and semantics. * @since 2.0.0 */ def foreach(writer: ForeachWriter[T]): DataStreamWriter[T] = { From f40dff64d968a8102d4e061602d007b9aaa63abd Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 1 Jun 2018 02:07:22 -0700 Subject: [PATCH 3/8] Small changes --- python/pyspark/sql/streaming.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index fcf8f585ed53..d55f28ac4a5b 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -907,8 +907,9 @@ def foreach(self, f): mode, then this guarantee does not hold and therefore should not be used for deduplication. - * The ``close()`` method is will be called if `open()` method returns successfully - (irrespective of the return value), except if the JVM crashes in the middle. + * The ``close()`` method (if exists) is will be called if `open()` method exists and + returns successfully (irrespective of the return value), except if the Python + crashes in the middle. .. note:: Evolving. @@ -946,7 +947,7 @@ def func_without_process(_, iterator): else: """ The provided object is not a callable function. Then it is expected to have a - 'process(row)' method, and optional 'open(partitionId, epochOrBatchId)' and + 'process(row)' method, and optional 'open(partition_id, epoch_id)' and 'close(error)' methods. """ @@ -971,16 +972,16 @@ def func_without_process(_, iterator): else: close_exists = True - def func_with_open_process_close(partitionId, iterator): - version = TaskContext.get().getLocalProperty('streaming.sql.batchId') - if version: - version = int(version) + def func_with_open_process_close(partition_id, iterator): + epoch_id = TaskContext.get().getLocalProperty('streaming.sql.batchId') + if epoch_id: + epoch_id = int(epoch_id) else: raise Exception("Could not get batch id from TaskContext") should_process = True if open_exists: - should_process = f.open(partitionId, version) + should_process = f.open(partition_id, epoch_id) def call_close_if_needed(error): if open_exists and close_exists: From d1cd9338b64c70bb4372517c8cbbfab918e9d604 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 1 Jun 2018 02:43:52 -0700 Subject: [PATCH 4/8] maybe fix docs --- .../org/apache/spark/sql/streaming/DataStreamWriter.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 096b57802ab0..e035c9cdc379 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -308,7 +308,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { /** * Sets the output of the streaming query to be processed using the provided writer object. - * object. See [[ForeachWriter]] for more details on the lifecycle and semantics. + * object. See [[org.apache.spark.sql.ForeachWriter]] for more details on the lifecycle and + * semantics. * @since 2.0.0 */ def foreach(writer: ForeachWriter[T]): DataStreamWriter[T] = { From 8e30e8d434459f2c480aeec4109965d7c95d4069 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 7 Jun 2018 16:57:53 -0700 Subject: [PATCH 5/8] Addressed comments --- python/pyspark/sql/streaming.py | 63 ++++--- python/pyspark/sql/tests.py | 163 ++++++++++++++---- .../org/apache/spark/sql/ForeachWriter.scala | 7 +- .../python/PythonForeachWriter.scala | 2 +- 4 files changed, 168 insertions(+), 67 deletions(-) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index d55f28ac4a5b..bef30f2ee257 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -30,7 +30,6 @@ from pyspark.sql.readwriter import OptionUtils, to_str from pyspark.sql.types import * from pyspark.sql.utils import StreamingQueryException -from abc import ABCMeta, abstractmethod __all__ = ["StreamingQuery", "StreamingQueryManager", "DataStreamReader", "DataStreamWriter"] @@ -844,6 +843,7 @@ def trigger(self, processingTime=None, once=None, continuous=None): self._jwrite = self._jwrite.trigger(jTrigger) return self + @since(2.4) def foreach(self, f): """ Sets the output of the streaming query to be processed using the provided writer ``f``. @@ -876,9 +876,9 @@ def foreach(self, f): processing one partition of the data generated in a distributed manner. * This object must be serializable because each task will get a fresh - serialized-deserializedcopy of the provided object. Hence, it is strongly + serialized-deserialized copy of the provided object. Hence, it is strongly recommended that any initialization for writing data (e.g. opening a - connection or starting a transaction) be done open after the `open(...)` + connection or starting a transaction) is done after the `open(...)` method has been called, which signifies that the task is ready to generate data. * The lifecycle of the methods are as follows. @@ -907,7 +907,7 @@ def foreach(self, f): mode, then this guarantee does not hold and therefore should not be used for deduplication. - * The ``close()`` method (if exists) is will be called if `open()` method exists and + * The ``close()`` method (if exists) will be called if `open()` method exists and returns successfully (irrespective of the return value), except if the Python crashes in the middle. @@ -933,10 +933,9 @@ def foreach(self, f): from pyspark.taskcontext import TaskContext if callable(f): - """ - The provided object is a callable function that is supposed to be called on each row. - Construct a function that takes an iterator and calls the provided function on each row. - """ + # The provided object is a callable function that is supposed to be called on each row. + # Construct a function that takes an iterator and calls the provided function on each + # row. def func_without_process(_, iterator): for x in iterator: f(x) @@ -945,32 +944,25 @@ def func_without_process(_, iterator): func = func_without_process else: - """ - The provided object is not a callable function. Then it is expected to have a - 'process(row)' method, and optional 'open(partition_id, epoch_id)' and - 'close(error)' methods. - """ + # The provided object is not a callable function. Then it is expected to have a + # 'process(row)' method, and optional 'open(partition_id, epoch_id)' and + # 'close(error)' methods. if not hasattr(f, 'process'): - raise Exception( - "Provided object is neither callable nor does it have a 'process' method") + raise Exception("Provided object does not have a 'process' method") if not callable(getattr(f, 'process')): raise Exception("Attribute 'process' in provided object is not callable") - open_exists = False - if hasattr(f, 'open'): - if not callable(getattr(f, 'open')): - raise Exception("Attribute 'open' in provided object is not callable") - else: - open_exists = True + def doesMethodExist(method_name): + exists = hasattr(f, method_name) + if exists and not callable(getattr(f, method_name)): + raise Exception( + "Attribute '%s' in provided object is not callable" % method_name) + return exists - close_exists = False - if hasattr(f, "close"): - if not callable(getattr(f, 'close')): - raise Exception("Attribute 'close' in provided object is not callable") - else: - close_exists = True + open_exists = doesMethodExist('open') + close_exists = doesMethodExist('close') def func_with_open_process_close(partition_id, iterator): epoch_id = TaskContext.get().getLocalProperty('streaming.sql.batchId') @@ -979,22 +971,27 @@ def func_with_open_process_close(partition_id, iterator): else: raise Exception("Could not get batch id from TaskContext") + # Check if the data should be processed should_process = True if open_exists: should_process = f.open(partition_id, epoch_id) - def call_close_if_needed(error): - if open_exists and close_exists: - f.close(error) + error = None + try: if should_process: for x in iterator: f.process(x) + except Exception as ex: - call_close_if_needed(ex) - raise ex + error = ex + + finally: + if close_exists: + f.close(error) + if error: + raise error - call_close_if_needed(None) return iter([]) func = func_with_open_process_close diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 7884f929cdd9..1cac892ca34b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1891,10 +1891,6 @@ class ForeachWriterTester: def __init__(self, spark): self.spark = spark - self.input_dir = tempfile.mkdtemp() - self.open_events_dir = tempfile.mkdtemp() - self.process_events_dir = tempfile.mkdtemp() - self.close_events_dir = tempfile.mkdtemp() def write_open_event(self, partitionId, epochId): self._write_event( @@ -1920,30 +1916,50 @@ def close_events(self): return self._read_events(self.close_events_dir, 'error STRING') def run_streaming_query_on_writer(self, writer, num_files): + self._reset() try: sdf = self.spark.readStream.format('text').load(self.input_dir) sq = sdf.writeStream.foreach(writer).start() for i in range(num_files): self.write_input_file() sq.processAllAvailable() - sq.stop() finally: self.stop_all() + def assert_invalid_writer(self, writer, msg=None): + self._reset() + try: + sdf = self.spark.readStream.format('text').load(self.input_dir) + sq = sdf.writeStream.foreach(writer).start() + self.write_input_file() + sq.processAllAvailable() + self.fail("invalid writer %s did not fail the query" % str(writer)) # not expected + except Exception as e: + if msg: + assert(msg in str(e), "%s not in %s" % (msg, str(e))) + + finally: + self.stop_all() + + def stop_all(self): + for q in self.spark._wrapped.streams.active: + q.stop() + + def _reset(self): + self.input_dir = tempfile.mkdtemp() + self.open_events_dir = tempfile.mkdtemp() + self.process_events_dir = tempfile.mkdtemp() + self.close_events_dir = tempfile.mkdtemp() + def _read_events(self, dir, json): rows = self.spark.read.schema(json).json(dir).collect() dicts = [row.asDict() for row in rows] return dicts def _write_event(self, dir, event): - import random - file = open(os.path.join(dir, str(random.randint(0, 100000))), 'w') - file.write("%s\n" % str(event)) - file.close() - - def stop_all(self): - for q in self.spark._wrapped.streams.active: - q.stop() + import uuid + with open(os.path.join(dir, str(uuid.uuid4())), 'w') as f: + f.write("%s\n" % str(event)) def __getstate__(self): return (self.open_events_dir, self.process_events_dir, self.close_events_dir) @@ -1990,8 +2006,8 @@ def test_streaming_foreach_with_open_returning_false(self): tester = self.ForeachWriterTester(self.spark) class ForeachWriter: - def open(self, partitionId, epochId): - tester.write_open_event(partitionId, epochId) + def open(self, partition_id, epoch_id): + tester.write_open_event(partition_id, epoch_id) return False def process(self, row): @@ -2003,21 +2019,62 @@ def close(self, error): tester.run_streaming_query_on_writer(ForeachWriter(), 2) self.assertEqual(len(tester.open_events()), 2) + self.assertEqual(len(tester.process_events()), 0) # no row was processed + close_events = tester.close_events() self.assertEqual(len(close_events), 2) self.assertSetEqual(set([e['error'] for e in close_events]), {'None'}) - def test_streaming_foreach_with_process_throwing_error(self): - from pyspark.sql.utils import StreamingQueryException + def test_streaming_foreach_without_open_method(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def process(self, row): + tester.write_process_event(row) + def close(self, error): + tester.write_close_event(error) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + self.assertEqual(len(tester.open_events()), 0) # no open events + self.assertEqual(len(tester.process_events()), 2) + self.assertEqual(len(tester.close_events()), 2) + + def test_streaming_foreach_without_close_method(self): tester = self.ForeachWriterTester(self.spark) class ForeachWriter: - def open(self, partitionId, epochId): - tester.write_open_event(partitionId, epochId) + def open(self, partition_id, epoch_id): + tester.write_open_event(partition_id, epoch_id) return True + def process(self, row): + tester.write_process_event(row) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + self.assertEqual(len(tester.open_events()), 2) # no open events + self.assertEqual(len(tester.process_events()), 2) + self.assertEqual(len(tester.close_events()), 0) + + def test_streaming_foreach_without_open_and_close_methods(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def process(self, row): + tester.write_process_event(row) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + self.assertEqual(len(tester.open_events()), 0) # no open events + self.assertEqual(len(tester.process_events()), 2) + self.assertEqual(len(tester.close_events()), 0) + + def test_streaming_foreach_with_process_throwing_error(self): + from pyspark.sql.utils import StreamingQueryException + + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: def process(self, row): raise Exception("test error") @@ -2025,24 +2082,70 @@ def close(self, error): tester.write_close_event(error) try: - sdf = self.spark.readStream.format('text').load(tester.input_dir) - sq = sdf.writeStream.foreach(ForeachWriter()).start() - tester.write_input_file() - sq.processAllAvailable() - self.fail("bad writer should fail the query") # this is not expected + tester.run_streaming_query_on_writer(ForeachWriter(), 1) + self.fail("bad writer did not fail the query") # this is not expected except StreamingQueryException as e: - # self.assertTrue("test error" in e.desc) # this is expected + # TODO: Verify whether original error message is inside the exception pass - finally: - tester.stop_all() - self.assertEqual(len(tester.open_events()), 1) self.assertEqual(len(tester.process_events()), 0) # no row was processed close_events = tester.close_events() self.assertEqual(len(close_events), 1) - # self.assertTrue("test error" in e[0]['error']) + # TODO: Verify whether original error message is inside the exception - ''' + def test_streaming_foreach_with_invalid_writers(self): + + tester = self.ForeachWriterTester(self.spark) + + def func_with_iterator_input(iter): + for x in iter: + print(x) + + tester.assert_invalid_writer(func_with_iterator_input) + + class WriterWithoutProcess: + def open(self, partition): + pass + + tester.assert_invalid_writer(WriterWithoutProcess(), "does not have a 'process'") + + class WriterWithNonCallableProcess(): + process = True + + tester.assert_invalid_writer(WriterWithNonCallableProcess(), + "'process' in provided object is not callable") + + class WriterWithNoParamProcess(): + def process(self): + pass + + tester.assert_invalid_writer(WriterWithNoParamProcess()) + + # Abstract class for tests below + class WithProcess(): + def process(self, row): + pass + + class WriterWithNonCallableOpen(WithProcess): + open = True + + tester.assert_invalid_writer(WriterWithNonCallableOpen(), + "'open' in provided object is not callable") + + class WriterWithNoParamOpen(WithProcess): + def open(self): + pass + + tester.assert_invalid_writer(WriterWithNoParamOpen()) + + class WriterWithNonCallableClose(WithProcess): + close = True + + tester.assert_invalid_writer(WriterWithNonCallableClose(), + "'close' in provided object is not callable") + + +''' def test_help_command(self): # Regression test for SPARK-5464 rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala index 6243884921d0..b21c50af1843 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala @@ -32,7 +32,7 @@ import org.apache.spark.annotation.InterfaceStability *
  • Any implementation of this class must be serializable because each task will get a fresh * serialized-deserialized copy of the provided object. Hence, it is strongly recommended that * any initialization for writing data (e.g. opening a connection or starting a transaction) - * be done open after the `open(...)` method has been called, which signifies that the task is + * is done after the `open(...)` method has been called, which signifies that the task is * ready to generate data. * *
  • The lifecycle of the methods are as follows. @@ -59,7 +59,7 @@ import org.apache.spark.annotation.InterfaceStability * continuous mode, then this guarantee does not hold and therefore should not be used for * deduplication. * - *
  • The `close()` method is will be called if `open()` method returns successfully (irrespective + *
  • The `close()` method will be called if `open()` method returns successfully (irrespective * of the return value), except if the JVM crashes in the middle. * * @@ -110,7 +110,8 @@ abstract class ForeachWriter[T] extends Serializable { // TODO: Move this to org.apache.spark.sql.util or consolidate this with batch API. /** - * Called when starting to process one partition of new data in the executor. + * Called when starting to process one partition of new data in the executor. See the class + * docs for more information on how to use the `partitionId` and `epochId`. * * @param partitionId the partition id. * @param epochId a unique id for data deduplication. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala index 47e290a44b2b..34ef0052196f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala @@ -73,7 +73,7 @@ class PythonForeachWriter(func: PythonFunction, schema: StructType) object PythonForeachWriter { /** - * A buffer that is designed for the sole purpose of buffering UnsafeRows in PythonForeahWriter. + * A buffer that is designed for the sole purpose of buffering UnsafeRows in PythonForeachWriter. * It is designed to be used with only 1 writer thread (i.e. JVM task thread) and only 1 reader * thread (i.e. PythonRunner writing thread that reads from the buffer and writes to the Python * worker stdin). Adds to the buffer are non-blocking, and reads through the buffer's iterator From ecf3d88954afdab16c492fd479c1051aff8a3b95 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 7 Jun 2018 17:00:27 -0700 Subject: [PATCH 6/8] Undo temp changes --- dev/sparktestsupport/modules.py | 14 +++++++++++++- python/pyspark/sql/tests.py | 6 +----- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index aacde5a957fe..dfea762db98c 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -389,7 +389,19 @@ def __hash__(self): "python/pyspark/sql" ], python_test_goals=[ - + "pyspark.sql.types", + "pyspark.sql.context", + "pyspark.sql.session", + "pyspark.sql.conf", + "pyspark.sql.catalog", + "pyspark.sql.column", + "pyspark.sql.dataframe", + "pyspark.sql.group", + "pyspark.sql.functions", + "pyspark.sql.readwriter", + "pyspark.sql.streaming", + "pyspark.sql.udf", + "pyspark.sql.window", "pyspark.sql.tests", ] ) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1cac892ca34b..cecc159f7d36 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -296,7 +296,6 @@ def tearDown(self): # tear down test_bucketed_write state self.spark.sql("DROP TABLE IF EXISTS pyspark_bucket") - ''' def test_row_should_be_read_only(self): row = Row(a=1, b=2) self.assertEqual(1, row.a) @@ -1885,7 +1884,6 @@ def test_query_manager_await_termination(self): finally: q.stop() shutil.rmtree(tmpPath) - ''' class ForeachWriterTester: @@ -2144,8 +2142,6 @@ class WriterWithNonCallableClose(WithProcess): tester.assert_invalid_writer(WriterWithNonCallableClose(), "'close' in provided object is not callable") - -''' def test_help_command(self): # Regression test for SPARK-5464 rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) @@ -5652,7 +5648,7 @@ def test_invalid_args(self): AnalysisException, 'mixture.*aggregate function.*group aggregate pandas UDF'): df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect() - ''' + if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: From d081110210c4c883a82169c1c6687369469189af Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 14 Jun 2018 15:47:46 -0700 Subject: [PATCH 7/8] Fixed pypy syntax error --- python/pyspark/sql/streaming.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index bef30f2ee257..8785c422b840 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -914,7 +914,10 @@ def foreach(self, f): .. note:: Evolving. >>> # Print every row using a function - >>> writer = sdf.writeStream.foreach(lambda x: print(x)) + >>> def print_row(row): + ... print(row) + ... + >>> writer = sdf.writeStream.foreach(print_row) >>> # Print every row using a object with process() method >>> class RowPrinter: ... def open(self, partition_id, epoch_id): From 1ab612f3d965cc6cad2b24c309bb7f11f931ea1e Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 15 Jun 2018 00:37:04 -0700 Subject: [PATCH 8/8] Addressed nits --- python/pyspark/sql/streaming.py | 2 -- python/pyspark/sql/tests.py | 8 ++++---- .../spark/sql/execution/python/PythonForeachWriter.scala | 2 +- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 8785c422b840..d1e016e05693 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -985,10 +985,8 @@ def func_with_open_process_close(partition_id, iterator): if should_process: for x in iterator: f.process(x) - except Exception as ex: error = ex - finally: if close_exists: f.close(error) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index cecc159f7d36..e0f9cbee516c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2018,7 +2018,7 @@ def close(self, error): self.assertEqual(len(tester.open_events()), 2) - self.assertEqual(len(tester.process_events()), 0) # no row was processed + self.assertEqual(len(tester.process_events()), 0) # no row was processed close_events = tester.close_events() self.assertEqual(len(close_events), 2) @@ -2035,7 +2035,7 @@ def close(self, error): tester.write_close_event(error) tester.run_streaming_query_on_writer(ForeachWriter(), 2) - self.assertEqual(len(tester.open_events()), 0) # no open events + self.assertEqual(len(tester.open_events()), 0) # no open events self.assertEqual(len(tester.process_events()), 2) self.assertEqual(len(tester.close_events()), 2) @@ -2063,7 +2063,7 @@ def process(self, row): tester.write_process_event(row) tester.run_streaming_query_on_writer(ForeachWriter(), 2) - self.assertEqual(len(tester.open_events()), 0) # no open events + self.assertEqual(len(tester.open_events()), 0) # no open events self.assertEqual(len(tester.process_events()), 2) self.assertEqual(len(tester.close_events()), 0) @@ -2081,7 +2081,7 @@ def close(self, error): try: tester.run_streaming_query_on_writer(ForeachWriter(), 1) - self.fail("bad writer did not fail the query") # this is not expected + self.fail("bad writer did not fail the query") # this is not expected except StreamingQueryException as e: # TODO: Verify whether original error message is inside the exception pass diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala index 34ef0052196f..a58773122922 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala @@ -131,7 +131,7 @@ object PythonForeachWriter { if (count > 0) { val row = queue.remove() - assert(row != null, s"HybridRowQueue.remove() returned null " + + assert(row != null, "HybridRowQueue.remove() returned null " + s"[count = $count, allAdded = $allAdded, exception = $exception]") count -= 1 logTrace(s"Removed $row, $count left")