Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
val writer = ws.createWriter(jobId, df.logicalPlan.schema, mode, options)
if (writer.isPresent) {
runCommand(df.sparkSession, "save") {
WriteToDataSourceV2(writer.get(), df.logicalPlan)
WriteToDataSourceV2(writer.get(), df.planWithBarrier)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is not needed but it is safe to have.

}
}

Expand All @@ -275,7 +275,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
sparkSession = df.sparkSession,
className = source,
partitionColumns = partitioningColumns.getOrElse(Nil),
options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan))
options = extraOptions.toMap).planForWriting(mode, df.planWithBarrier)
}
}

Expand Down Expand Up @@ -323,7 +323,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
InsertIntoTable(
table = UnresolvedRelation(tableIdent),
partition = Map.empty[String, Option[String]],
query = df.logicalPlan,
query = df.planWithBarrier,
overwrite = mode == SaveMode.Overwrite,
ifPartitionNotExists = false)
}
Expand Down Expand Up @@ -459,7 +459,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
partitionColumnNames = partitioningColumns.getOrElse(Nil),
bucketSpec = getBucketSpec)

runCommand(df.sparkSession, "saveAsTable")(CreateTable(tableDesc, mode, Some(df.logicalPlan)))
runCommand(df.sparkSession, "saveAsTable") {
CreateTable(tableDesc, mode, Some(df.planWithBarrier))
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.hadoop.mapred.{FileInputFormat, JobConf}

import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, Resolver}
import org.apache.spark.sql.catalyst.analysis.{EliminateBarriers, NoSuchTableException, Resolver}
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
Expand Down Expand Up @@ -891,8 +891,9 @@ object DDLUtils {
* Throws exception if outputPath tries to overwrite inputpath.
*/
def verifyNotReadPath(query: LogicalPlan, outputPath: Path) : Unit = {
val inputPaths = query.collect {
case LogicalRelation(r: HadoopFsRelation, _, _, _) => r.location.rootPaths
val inputPaths = EliminateBarriers(query).collect {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AnalysisBarrier is a leaf node. That is one of the reasons why it could easily break the other code.

case LogicalRelation(r: HadoopFsRelation, _, _, _) =>
r.location.rootPaths
}.flatten

if (inputPaths.contains(outputPath)) {
Expand Down
42 changes: 41 additions & 1 deletion sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,16 @@ package org.apache.spark.sql

import org.apache.spark.sql.api.java._
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, ExplainCommand}
import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
import org.apache.spark.sql.functions.{lit, udf}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.test.SQLTestData._
import org.apache.spark.sql.types.{DataTypes, DoubleType}
import org.apache.spark.sql.util.QueryExecutionListener


private case class FunctionResult(f1: String, f2: String)

Expand Down Expand Up @@ -325,6 +330,41 @@ class UDFSuite extends QueryTest with SharedSQLContext {
}
}

test("cached Data should be used in the write path") {
withTable("t") {
withTempPath { path =>
var numTotalCachedHit = 0
val listener = new QueryExecutionListener {
override def onFailure(f: String, qe: QueryExecution, e: Exception): Unit = {}

override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
qe.withCachedData match {
case c: CreateDataSourceTableAsSelectCommand
if c.query.isInstanceOf[InMemoryRelation] =>
numTotalCachedHit += 1
case i: InsertIntoHadoopFsRelationCommand
if i.query.isInstanceOf[InMemoryRelation] =>
numTotalCachedHit += 1
case _ =>
}
}
}
spark.listenerManager.register(listener)

val udf1 = udf({ (x: Int, y: Int) => x + y })
val df = spark.range(0, 3).toDF("a")
.withColumn("b", udf1($"a", lit(10)))
df.cache()
df.write.saveAsTable("t")
assert(numTotalCachedHit == 1, "expected to be cached in saveAsTable")
df.write.insertInto("t")
assert(numTotalCachedHit == 2, "expected to be cached in insertInto")
df.write.save(path.getCanonicalPath)
assert(numTotalCachedHit == 3, "expected to be cached in save for native")
}
}
}

test("SPARK-24891 Fix HandleNullInputsForUDF rule") {
val udf1 = udf({(x: Int, y: Int) => x + y})
val df = spark.range(0, 3).toDF("a")
Expand Down