From db1987f63370c6c2f9434aea76da7d326565be5a Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 9 Apr 2018 17:54:44 +0800 Subject: [PATCH 1/6] Makes collect in PySpark as action for a query executor listener --- .../scala/org/apache/spark/sql/Dataset.scala | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 0aee1d7be578..82aa92dd8048 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3189,10 +3189,10 @@ class Dataset[T] private[sql]( private[sql] def collectToPython(): Int = { EvaluatePython.registerPicklers() - withNewExecutionId { + withAction("collect", queryExecution) { plan => val toJava: (Any) => Any = EvaluatePython.toJava(_, schema) - val iter = new SerDeUtil.AutoBatchedPickler( - queryExecution.executedPlan.executeCollect().iterator.map(toJava)) + val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler( + plan.executeCollect().iterator.map(toJava)) PythonRDD.serveIterator(iter, "serve-DataFrame") } } @@ -3201,8 +3201,9 @@ class Dataset[T] private[sql]( * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. */ private[sql] def collectAsArrowToPython(): Int = { - withNewExecutionId { - val iter = toArrowPayload.collect().iterator.map(_.asPythonSerializable) + withAction("collect", queryExecution) { plan => + val iter: Iterator[Array[Byte]] = + toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) PythonRDD.serveIterator(iter, "serve-Arrow") } } @@ -3312,10 +3313,15 @@ class Dataset[T] private[sql]( /** Convert to an RDD of ArrowPayload byte arrays */ private[sql] def toArrowPayload: RDD[ArrowPayload] = { + // This is only used in tests, for now. + toArrowPayload(queryExecution.executedPlan) + } + + private[sql] def toArrowPayload(plan: SparkPlan): RDD[ArrowPayload] = { val schemaCaptured = this.schema val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone - queryExecution.toRdd.mapPartitionsInternal { iter => + plan.execute().mapPartitionsInternal { iter => val context = TaskContext.get() ArrowConverters.toPayloadIterator( iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, context) From edb5eea8501c8348d037b3328229f0cdc078441a Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 10 Apr 2018 19:14:22 +0800 Subject: [PATCH 2/6] Add a test and address comments --- python/pyspark/sql/tests.py | 66 +++++++++++++++++++ .../scala/org/apache/spark/sql/Dataset.scala | 14 ++-- .../sql/TestQueryExecutionListener.scala | 45 +++++++++++++ 3 files changed, 118 insertions(+), 7 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index dd04ffb4ed39..896506f89374 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3062,6 +3062,72 @@ def test_sparksession_with_stopped_sparkcontext(self): sc.stop() +class SQLTests3(unittest.TestCase): + @classmethod + def setUpClass(cls): + import glob + from pyspark.find_spark_home import _find_spark_home + + SPARK_HOME = _find_spark_home() + filename_pattern = ( + "sql/core/target/scala-*/test-classes/org/apache/spark/sql/" + "TestQueryExecutionListener.class") + if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)): + raise unittest.SkipTest( + "'org.apache.spark.sql.TestQueryExecutionListener' is not " + "available. Skipping the related tests.") + + # Note that 'spark.sql.queryExecutionListeners' is a static immutable configuration. + cls.spark = SparkSession.builder \ + .master("local[4]") \ + .appName(cls.__name__) \ + .config( + "spark.sql.queryExecutionListeners", + "org.apache.spark.sql.TestQueryExecutionListener") \ + .getOrCreate() + + @classmethod + def tearDownClass(cls): + cls.spark.stop() + + def tearDown(self): + self.spark._jvm.OnSuccessCall.clear() + + # This test is separate because it uses 'spark.sql.queryExecutionListeners' which is + # static and immutable. This can't be set or unset when we already have the session. + def test_query_execution_listener_on_collect(self): + self.assertFalse( + self.spark._jvm.OnSuccessCall.isCalled(), + "The callback from the query execution listener should not be called before 'collect'") + self.spark.sql("SELECT * FROM range(1)").collect() + self.assertTrue( + self.spark._jvm.OnSuccessCall.isCalled(), + "The callback from the query execution listener should be called after 'collect'") + + @unittest.skipIf( + not _have_pandas or not _have_pyarrow, + _pandas_requirement_message or _pyarrow_requirement_message) + def test_query_execution_listener_on_collect_with_arrow(self): + # Here, it deplicates codes in ReusedSQLTestCase.sql_conf context manager. + # Refactor and deduplicate it if there is another case like this. + old_value = self.spark.conf.get("spark.sql.execution.arrow.enabled", None) + self.spark.conf.set("spark.sql.execution.arrow.enabled", "true") + try: + self.assertFalse( + self.spark._jvm.OnSuccessCall.isCalled(), + "The callback from the query execution listener should not be " + "called before 'collect'") + self.spark.sql("SELECT * FROM range(1)").toPandas() + self.assertTrue( + self.spark._jvm.OnSuccessCall.isCalled(), + "The callback from the query execution listener should be called after 'collect'") + finally: + if old_value is None: + self.spark.conf.unset("spark.sql.execution.arrow.enabled") + else: + self.spark.conf.set("spark.sql.execution.arrow.enabled", old_value) + + class SparkSessionTests(PySparkTestCase): # This test is separate because it's closely related with session's start and stop. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 82aa92dd8048..917168162b23 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3189,7 +3189,7 @@ class Dataset[T] private[sql]( private[sql] def collectToPython(): Int = { EvaluatePython.registerPicklers() - withAction("collect", queryExecution) { plan => + withAction("collectToPython", queryExecution) { plan => val toJava: (Any) => Any = EvaluatePython.toJava(_, schema) val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler( plan.executeCollect().iterator.map(toJava)) @@ -3201,7 +3201,7 @@ class Dataset[T] private[sql]( * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. */ private[sql] def collectAsArrowToPython(): Int = { - withAction("collect", queryExecution) { plan => + withAction("collectAsArrowToPython", queryExecution) { plan => val iter: Iterator[Array[Byte]] = toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) PythonRDD.serveIterator(iter, "serve-Arrow") @@ -3312,11 +3312,6 @@ class Dataset[T] private[sql]( } /** Convert to an RDD of ArrowPayload byte arrays */ - private[sql] def toArrowPayload: RDD[ArrowPayload] = { - // This is only used in tests, for now. - toArrowPayload(queryExecution.executedPlan) - } - private[sql] def toArrowPayload(plan: SparkPlan): RDD[ArrowPayload] = { val schemaCaptured = this.schema val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch @@ -3327,4 +3322,9 @@ class Dataset[T] private[sql]( iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, context) } } + + // This is only used in tests, for now. + private[sql] def toArrowPayload: RDD[ArrowPayload] = { + toArrowPayload(queryExecution.executedPlan) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala new file mode 100644 index 000000000000..39e5e014a8d0 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.concurrent.atomic.AtomicBoolean + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.util.QueryExecutionListener + + +class TestQueryExecutionListener extends QueryExecutionListener with Logging { + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + OnSuccessCall.isOnSuccessCalled.set(true) + } + + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { } +} + +/** + * This has variables to check if `onSuccess` is actually called or not. Currently, this is for + * the test case in PySpark. See SPARK-23942. + */ +object OnSuccessCall { + val isOnSuccessCalled = new AtomicBoolean(false) + + def isCalled(): Boolean = isOnSuccessCalled.get() + + def clear(): Unit = isOnSuccessCalled.set(false) +} From e865c883abd1f1e340ef50d149e2defc5636610e Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 10 Apr 2018 19:20:46 +0800 Subject: [PATCH 3/6] Fix nits --- python/pyspark/sql/tests.py | 11 ++++++----- .../apache/spark/sql/TestQueryExecutionListener.scala | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 896506f89374..d33aca6a24e7 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3063,6 +3063,9 @@ def test_sparksession_with_stopped_sparkcontext(self): class SQLTests3(unittest.TestCase): + # These tests are separate because it uses 'spark.sql.queryExecutionListeners' which is + # static and immutable. This can't be set or unset. + @classmethod def setUpClass(cls): import glob @@ -3075,7 +3078,7 @@ def setUpClass(cls): if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)): raise unittest.SkipTest( "'org.apache.spark.sql.TestQueryExecutionListener' is not " - "available. Skipping the related tests.") + "available. Will skip the related tests.") # Note that 'spark.sql.queryExecutionListeners' is a static immutable configuration. cls.spark = SparkSession.builder \ @@ -3093,8 +3096,6 @@ def tearDownClass(cls): def tearDown(self): self.spark._jvm.OnSuccessCall.clear() - # This test is separate because it uses 'spark.sql.queryExecutionListeners' which is - # static and immutable. This can't be set or unset when we already have the session. def test_query_execution_listener_on_collect(self): self.assertFalse( self.spark._jvm.OnSuccessCall.isCalled(), @@ -3116,11 +3117,11 @@ def test_query_execution_listener_on_collect_with_arrow(self): self.assertFalse( self.spark._jvm.OnSuccessCall.isCalled(), "The callback from the query execution listener should not be " - "called before 'collect'") + "called before 'toPandas'") self.spark.sql("SELECT * FROM range(1)").toPandas() self.assertTrue( self.spark._jvm.OnSuccessCall.isCalled(), - "The callback from the query execution listener should be called after 'collect'") + "The callback from the query execution listener should be called after 'toPandas'") finally: if old_value is None: self.spark.conf.unset("spark.sql.execution.arrow.enabled") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala index 39e5e014a8d0..46c2459f1451 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala @@ -33,7 +33,7 @@ class TestQueryExecutionListener extends QueryExecutionListener with Logging { } /** - * This has variables to check if `onSuccess` is actually called or not. Currently, this is for + * This has a variable to check if `onSuccess` is actually called or not. Currently, this is for * the test case in PySpark. See SPARK-23942. */ object OnSuccessCall { From deacb17816c21811efd630e91cee4c30d421eb36 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 11 Apr 2018 09:44:24 +0800 Subject: [PATCH 4/6] Address comments --- python/pyspark/sql/tests.py | 42 ++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index d33aca6a24e7..59f9b14ccba1 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -186,16 +186,12 @@ def __init__(self, key, value): self.value = value -class ReusedSQLTestCase(ReusedPySparkTestCase): - @classmethod - def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() - cls.spark = SparkSession(cls.sc) - - @classmethod - def tearDownClass(cls): - ReusedPySparkTestCase.tearDownClass() - cls.spark.stop() +class SQLTestUtils(object): + """ + This util assumes the instance of this to have 'spark' attribute, having a spark session. + It is usually used with 'ReusedSQLTestCase' class but can be used if you feel sure the + the implementation of this class has 'spark' attribute. + """ @contextmanager def sql_conf(self, pairs): @@ -204,6 +200,7 @@ def sql_conf(self, pairs): `value` to the configuration `key` and then restores it back when it exits. """ assert isinstance(pairs, dict), "pairs should be a dictionary." + assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." keys = pairs.keys() new_values = pairs.values() @@ -219,6 +216,18 @@ def sql_conf(self, pairs): else: self.spark.conf.set(key, old_value) + +class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils): + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.spark = SparkSession(cls.sc) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + cls.spark.stop() + def assertPandasEqual(self, expected, result): msg = ("DataFrames are not equal: " + "\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) + @@ -3062,7 +3071,7 @@ def test_sparksession_with_stopped_sparkcontext(self): sc.stop() -class SQLTests3(unittest.TestCase): +class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils): # These tests are separate because it uses 'spark.sql.queryExecutionListeners' which is # static and immutable. This can't be set or unset. @@ -3109,11 +3118,7 @@ def test_query_execution_listener_on_collect(self): not _have_pandas or not _have_pyarrow, _pandas_requirement_message or _pyarrow_requirement_message) def test_query_execution_listener_on_collect_with_arrow(self): - # Here, it deplicates codes in ReusedSQLTestCase.sql_conf context manager. - # Refactor and deduplicate it if there is another case like this. - old_value = self.spark.conf.get("spark.sql.execution.arrow.enabled", None) - self.spark.conf.set("spark.sql.execution.arrow.enabled", "true") - try: + with self.sql_conf({"spark.sql.execution.arrow.enabled": True}): self.assertFalse( self.spark._jvm.OnSuccessCall.isCalled(), "The callback from the query execution listener should not be " @@ -3122,11 +3127,6 @@ def test_query_execution_listener_on_collect_with_arrow(self): self.assertTrue( self.spark._jvm.OnSuccessCall.isCalled(), "The callback from the query execution listener should be called after 'toPandas'") - finally: - if old_value is None: - self.spark.conf.unset("spark.sql.execution.arrow.enabled") - else: - self.spark.conf.set("spark.sql.execution.arrow.enabled", old_value) class SparkSessionTests(PySparkTestCase): From 1a52dbe6275b0f7316e1b1928e1ec19e5f405028 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 11 Apr 2018 19:50:35 +0800 Subject: [PATCH 5/6] Address a comment and add few more words in comments --- python/pyspark/sql/tests.py | 2 +- .../scala/org/apache/spark/sql/TestQueryExecutionListener.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 59f9b14ccba1..873b6c187f9f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3073,7 +3073,7 @@ def test_sparksession_with_stopped_sparkcontext(self): class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils): # These tests are separate because it uses 'spark.sql.queryExecutionListeners' which is - # static and immutable. This can't be set or unset. + # static and immutable. This can't be set or unset, for example, via `spark.conf`. @classmethod def setUpClass(cls): diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala index 46c2459f1451..5a0fbc6efb9a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.util.QueryExecutionListener -class TestQueryExecutionListener extends QueryExecutionListener with Logging { +class TestQueryExecutionListener extends QueryExecutionListener { override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { OnSuccessCall.isOnSuccessCalled.set(true) } From 7c1b3c606d9b90dbffbaa7d7442bbd609940b98f Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 12 Apr 2018 21:28:04 +0800 Subject: [PATCH 6/6] D'oh --- .../scala/org/apache/spark/sql/TestQueryExecutionListener.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala index 5a0fbc6efb9a..d2a6358ee822 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql import java.util.concurrent.atomic.AtomicBoolean -import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.util.QueryExecutionListener