Skip to content

Commit 54d2359

Browse files
committed
[SPARK-18120][SPARK-19557][SQL] Call QueryExecutionListener callback methods for DataFrameWriter methods
## What changes were proposed in this pull request? We only notify `QueryExecutionListener` for several `Dataset` operations, e.g. collect, take, etc. We should also do the notification for `DataFrameWriter` operations. ## How was this patch tested? new regression test close #16664 Author: Wenchen Fan <wenchen@databricks.com> Closes #16962 from cloud-fan/insert.
1 parent 21fde57 commit 54d2359

File tree

3 files changed

+142
-16
lines changed

3 files changed

+142
-16
lines changed

sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ import org.apache.spark.annotation.InterfaceStability
2525
import org.apache.spark.sql.catalyst.TableIdentifier
2626
import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation}
2727
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogRelation, CatalogTable, CatalogTableType}
28-
import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable
28+
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan}
2929
import org.apache.spark.sql.execution.command.DDLUtils
30-
import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation}
30+
import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation, SaveIntoDataSourceCommand}
3131
import org.apache.spark.sql.sources.BaseRelation
3232
import org.apache.spark.sql.types.StructType
3333

@@ -211,13 +211,15 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
211211
}
212212

213213
assertNotBucketed("save")
214-
val dataSource = DataSource(
215-
df.sparkSession,
216-
className = source,
217-
partitionColumns = partitioningColumns.getOrElse(Nil),
218-
options = extraOptions.toMap)
219214

220-
dataSource.write(mode, df)
215+
runCommand(df.sparkSession, "save") {
216+
SaveIntoDataSourceCommand(
217+
query = df.logicalPlan,
218+
provider = source,
219+
partitionColumns = partitioningColumns.getOrElse(Nil),
220+
options = extraOptions.toMap,
221+
mode = mode)
222+
}
221223
}
222224

223225
/**
@@ -260,13 +262,14 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
260262
)
261263
}
262264

263-
df.sparkSession.sessionState.executePlan(
265+
runCommand(df.sparkSession, "insertInto") {
264266
InsertIntoTable(
265267
table = UnresolvedRelation(tableIdent),
266268
partition = Map.empty[String, Option[String]],
267269
query = df.logicalPlan,
268270
overwrite = mode == SaveMode.Overwrite,
269-
ifNotExists = false)).toRdd
271+
ifNotExists = false)
272+
}
270273
}
271274

272275
private def getBucketSpec: Option[BucketSpec] = {
@@ -389,10 +392,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
389392
schema = new StructType,
390393
provider = Some(source),
391394
partitionColumnNames = partitioningColumns.getOrElse(Nil),
392-
bucketSpec = getBucketSpec
393-
)
394-
df.sparkSession.sessionState.executePlan(
395-
CreateTable(tableDesc, mode, Some(df.logicalPlan))).toRdd
395+
bucketSpec = getBucketSpec)
396+
397+
runCommand(df.sparkSession, "saveAsTable")(CreateTable(tableDesc, mode, Some(df.logicalPlan)))
396398
}
397399

398400
/**
@@ -573,6 +575,25 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
573575
format("csv").save(path)
574576
}
575577

578+
/**
579+
* Wrap a DataFrameWriter action to track the QueryExecution and time cost, then report to the
580+
* user-registered callback functions.
581+
*/
582+
private def runCommand(session: SparkSession, name: String)(command: LogicalPlan): Unit = {
583+
val qe = session.sessionState.executePlan(command)
584+
try {
585+
val start = System.nanoTime()
586+
// call `QueryExecution.toRDD` to trigger the execution of commands.
587+
qe.toRdd
588+
val end = System.nanoTime()
589+
session.listenerManager.onSuccess(name, qe, end - start)
590+
} catch {
591+
case e: Exception =>
592+
session.listenerManager.onFailure(name, qe, e)
593+
throw e
594+
}
595+
}
596+
576597
///////////////////////////////////////////////////////////////////////////////////////
577598
// Builder pattern config options
578599
///////////////////////////////////////////////////////////////////////////////////////
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.datasources
19+
20+
import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession}
21+
import org.apache.spark.sql.catalyst.plans.QueryPlan
22+
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
23+
import org.apache.spark.sql.execution.command.RunnableCommand
24+
25+
/**
26+
* Saves the results of `query` in to a data source.
27+
*
28+
* Note that this command is different from [[InsertIntoDataSourceCommand]]. This command will call
29+
* `CreatableRelationProvider.createRelation` to write out the data, while
30+
* [[InsertIntoDataSourceCommand]] calls `InsertableRelation.insert`. Ideally these 2 data source
31+
* interfaces should do the same thing, but as we've already published these 2 interfaces and the
32+
* implementations may have different logic, we have to keep these 2 different commands.
33+
*/
34+
case class SaveIntoDataSourceCommand(
35+
query: LogicalPlan,
36+
provider: String,
37+
partitionColumns: Seq[String],
38+
options: Map[String, String],
39+
mode: SaveMode) extends RunnableCommand {
40+
41+
override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query)
42+
43+
override def run(sparkSession: SparkSession): Seq[Row] = {
44+
DataSource(
45+
sparkSession,
46+
className = provider,
47+
partitionColumns = partitionColumns,
48+
options = options).write(mode, Dataset.ofRows(sparkSession, query))
49+
50+
Seq.empty[Row]
51+
}
52+
}

sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@ package org.apache.spark.sql.util
2020
import scala.collection.mutable.ArrayBuffer
2121

2222
import org.apache.spark._
23-
import org.apache.spark.sql.{functions, QueryTest}
24-
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project}
23+
import org.apache.spark.sql.{functions, AnalysisException, QueryTest}
24+
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
25+
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, InsertIntoTable, LogicalPlan, Project}
2526
import org.apache.spark.sql.execution.{QueryExecution, WholeStageCodegenExec}
27+
import org.apache.spark.sql.execution.datasources.{CreateTable, SaveIntoDataSourceCommand}
2628
import org.apache.spark.sql.test.SharedSQLContext
2729

2830
class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
@@ -159,4 +161,55 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
159161

160162
spark.listenerManager.unregister(listener)
161163
}
164+
165+
test("execute callback functions for DataFrameWriter") {
166+
val commands = ArrayBuffer.empty[(String, LogicalPlan)]
167+
val exceptions = ArrayBuffer.empty[(String, Exception)]
168+
val listener = new QueryExecutionListener {
169+
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
170+
exceptions += funcName -> exception
171+
}
172+
173+
override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
174+
commands += funcName -> qe.logical
175+
}
176+
}
177+
spark.listenerManager.register(listener)
178+
179+
withTempPath { path =>
180+
spark.range(10).write.format("json").save(path.getCanonicalPath)
181+
assert(commands.length == 1)
182+
assert(commands.head._1 == "save")
183+
assert(commands.head._2.isInstanceOf[SaveIntoDataSourceCommand])
184+
assert(commands.head._2.asInstanceOf[SaveIntoDataSourceCommand].provider == "json")
185+
}
186+
187+
withTable("tab") {
188+
sql("CREATE TABLE tab(i long) using parquet")
189+
spark.range(10).write.insertInto("tab")
190+
assert(commands.length == 2)
191+
assert(commands(1)._1 == "insertInto")
192+
assert(commands(1)._2.isInstanceOf[InsertIntoTable])
193+
assert(commands(1)._2.asInstanceOf[InsertIntoTable].table
194+
.asInstanceOf[UnresolvedRelation].tableIdentifier.table == "tab")
195+
}
196+
197+
withTable("tab") {
198+
spark.range(10).select($"id", $"id" % 5 as "p").write.partitionBy("p").saveAsTable("tab")
199+
assert(commands.length == 3)
200+
assert(commands(2)._1 == "saveAsTable")
201+
assert(commands(2)._2.isInstanceOf[CreateTable])
202+
assert(commands(2)._2.asInstanceOf[CreateTable].tableDesc.partitionColumnNames == Seq("p"))
203+
}
204+
205+
withTable("tab") {
206+
sql("CREATE TABLE tab(i long) using parquet")
207+
val e = intercept[AnalysisException] {
208+
spark.range(10).select($"id", $"id").write.insertInto("tab")
209+
}
210+
assert(exceptions.length == 1)
211+
assert(exceptions.head._1 == "insertInto")
212+
assert(exceptions.head._2 == e)
213+
}
214+
}
162215
}

0 commit comments

Comments
 (0)