Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions docs/sql-migration-guide-upgrade.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@ displayTitle: Spark SQL Upgrading Guide

- Since Spark 2.4.5, `TRUNCATE TABLE` command tries to set back original permission and ACLs during re-creating the table/partition paths. To restore the behaviour of earlier versions, set `spark.sql.truncateTable.ignorePermissionAcl.enabled` to `true`.

- Since Spark 2.4.5, `spark.sql.legacy.mssqlserver.numericMapping.enabled` configuration is added in order to support the legacy MsSQLServer dialect mapping behavior using IntegerType and DoubleType for SMALLINT and REAL JDBC types, respectively. To restore the behaviour of 2.4.3 and earlier versions, set `spark.sql.legacy.mssqlserver.numericMapping.enabled` to `true`.
- Since Spark 2.4.5, `spark.sql.legacy.mssqlserver.numericMapping.enabled` configuration is added in order to support the legacy MsSQLServer dialect mapping behavior using IntegerType and DoubleType for SMALLINT and REAL JDBC types, respectively. To restore the behaviour of 2.4.3 and earlier versions, set `spark.sql.legacy.mssqlserver.numericMapping.enabled` to `true`.

- Since Spark 2.4.5, Dataset query fails if it contains ambiguous column reference that is caused by self join. A typical example: `val df1 = ...; val df2 = df1.filter(...);`, then `df1.join(df2, df1("a") > df2("a"))` returns an empty result which is quite confusing. This is because Spark cannot resolve Dataset column references that point to tables being self joined, and `df1("a")` is exactly the same as `df2("a")` in Spark. To restore the behavior before Spark 3.0, you can set `spark.sql.analyzer.failAmbiguousSelfJoin` to `false`.

## Upgrading from Spark SQL 2.4.3 to 2.4.4

- Since Spark 2.4.4, according to [MsSqlServer Guide](https://docs.microsoft.com/en-us/sql/connect/jdbc/using-basic-data-types?view=sql-server-2017), MsSQLServer JDBC Dialect uses ShortType and FloatType for SMALLINT and REAL, respectively. Previously, IntegerType and DoubleType is used.
- Since Spark 2.4.4, according to [MsSqlServer Guide](https://docs.microsoft.com/en-us/sql/connect/jdbc/using-basic-data-types?view=sql-server-2017), MsSQLServer JDBC Dialect uses ShortType and FloatType for SMALLINT and REAL, respectively. Previously, IntegerType and DoubleType is used.

## Upgrading from Spark SQL 2.4 to 2.4.1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,8 @@ class Analyzer(
gid: Expression): Expression = {
expr transform {
case e: GroupingID =>
if (e.groupByExprs.isEmpty || e.groupByExprs == groupByExprs) {
if (e.groupByExprs.isEmpty ||
e.groupByExprs.map(_.canonicalized) == groupByExprs.map(_.canonicalized)) {
Alias(gid, toPrettySQL(e))()
} else {
throw new AnalysisException(
Expand Down Expand Up @@ -936,6 +937,8 @@ class Analyzer(
// To resolve duplicate expression IDs for Join and Intersect
case j @ Join(left, right, _, _) if !j.duplicateResolved =>
j.copy(right = dedupRight(left, right))
// intersect/except will be rewritten to join at the begininng of optimizer. Here we need to
// deduplicate the right side plan, so that we won't produce an invalid self-join later.
case i @ Intersect(left, right, _) if !i.duplicateResolved =>
i.copy(right = dedupRight(left, right))
case e @ Except(left, right, _) if !e.duplicateResolved =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ case class AttributeReference(
}
}

override def withMetadata(newMetadata: Metadata): Attribute = {
override def withMetadata(newMetadata: Metadata): AttributeReference = {
AttributeReference(name, dataType, nullable, newMetadata)(exprId, qualifier)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -996,7 +996,11 @@ case class OneRowRelation() extends LeafNode {
override def computeStats(): Statistics = Statistics(sizeInBytes = 1)

/** [[org.apache.spark.sql.catalyst.trees.TreeNode.makeCopy()]] does not support 0-arg ctor. */
override def makeCopy(newArgs: Array[AnyRef]): OneRowRelation = OneRowRelation()
override def makeCopy(newArgs: Array[AnyRef]): OneRowRelation = {
val newCopy = OneRowRelation()
newCopy.copyTagsFrom(this)
newCopy
}
}

/** A logical plan for `dropDuplicates`. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.trees

import java.util.UUID

import scala.collection.Map
import scala.collection.{mutable, Map}
import scala.reflect.ClassTag

import org.apache.commons.lang3.ClassUtils
Expand Down Expand Up @@ -71,13 +71,34 @@ object CurrentOrigin {
}
}

// A tag of a `TreeNode`, which defines name and type
case class TreeNodeTag[T](name: String)

// scalastyle:off
abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
// scalastyle:on
self: BaseType =>

val origin: Origin = CurrentOrigin.get

/**
* A mutable map for holding auxiliary information of this tree node. It will be carried over
* when this node is copied via `makeCopy`, or transformed via `transformUp`/`transformDown`.
*/
private val tags: mutable.Map[TreeNodeTag[_], Any] = mutable.Map.empty

protected def copyTagsFrom(other: BaseType): Unit = {
tags ++= other.tags
}

def setTagValue[T](tag: TreeNodeTag[T], value: T): Unit = {
tags(tag) = value
}

def getTagValue[T](tag: TreeNodeTag[T]): Option[T] = {
tags.get(tag).map(_.asInstanceOf[T])
}

/**
* Returns a Seq of the children of this node.
* Children should not change. Immutability required for containsChild optimization
Expand Down Expand Up @@ -262,6 +283,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
if (this fastEquals afterRule) {
mapChildren(_.transformDown(rule))
} else {
// If the transform function replaces this node with a new one, carry over the tags.
afterRule.tags ++= this.tags
afterRule.mapChildren(_.transformDown(rule))
}
}
Expand All @@ -275,7 +298,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
*/
def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = {
val afterRuleOnChildren = mapChildren(_.transformUp(rule))
if (this fastEquals afterRuleOnChildren) {
val newNode = if (this fastEquals afterRuleOnChildren) {
CurrentOrigin.withOrigin(origin) {
rule.applyOrElse(this, identity[BaseType])
}
Expand All @@ -284,6 +307,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
rule.applyOrElse(afterRuleOnChildren, identity[BaseType])
}
}
// If the transform function replaces this node with a new one, carry over the tags.
newNode.tags ++= this.tags
newNode
}

/**
Expand Down Expand Up @@ -402,7 +428,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {

try {
CurrentOrigin.withOrigin(origin) {
defaultCtor.newInstance(allArgs.toArray: _*).asInstanceOf[BaseType]
val res = defaultCtor.newInstance(allArgs.toArray: _*).asInstanceOf[BaseType]
res.copyTagsFrom(this)
res
}
} catch {
case e: java.lang.IllegalArgumentException =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,13 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val FAIL_AMBIGUOUS_SELF_JOIN =
buildConf("spark.sql.analyzer.failAmbiguousSelfJoin")
.doc("When true, fail the Dataset query if it contains ambiguous self-join.")
.internal()
.booleanConf
.createWithDefault(true)

// Whether to retain group by columns or not in GroupedData.agg.
val DATAFRAME_RETAIN_GROUP_COLUMNS = buildConf("spark.sql.retainGroupColumns")
.internal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -617,4 +617,57 @@ class TreeNodeSuite extends SparkFunSuite {
val expected = Coalesce(Stream(Literal(1), Literal(3)))
assert(result === expected)
}

test("tags will be carried over after copy & transform") {
val tag = TreeNodeTag[String]("test")

withClue("makeCopy") {
val node = Dummy(None)
node.setTagValue(tag, "a")
val copied = node.makeCopy(Array(Some(Literal(1))))
assert(copied.getTagValue(tag) == Some("a"))
}

def checkTransform(
sameTypeTransform: Expression => Expression,
differentTypeTransform: Expression => Expression): Unit = {
val child = Dummy(None)
child.setTagValue(tag, "child")
val node = Dummy(Some(child))
node.setTagValue(tag, "parent")

val transformed = sameTypeTransform(node)
// Both the child and parent keep the tags
assert(transformed.getTagValue(tag) == Some("parent"))
assert(transformed.children.head.getTagValue(tag) == Some("child"))

val transformed2 = differentTypeTransform(node)
// Both the child and parent keep the tags, even if we transform the node to a new one of
// different type.
assert(transformed2.getTagValue(tag) == Some("parent"))
assert(transformed2.children.head.getTagValue(tag) == Some("child"))
}

withClue("transformDown") {
checkTransform(
sameTypeTransform = _ transformDown {
case Dummy(None) => Dummy(Some(Literal(1)))
},
differentTypeTransform = _ transformDown {
case Dummy(None) => Literal(1)

})
}

withClue("transformUp") {
checkTransform(
sameTypeTransform = _ transformUp {
case Dummy(None) => Dummy(Some(Literal(1)))
},
differentTypeTransform = _ transformUp {
case Dummy(None) => Literal(1)

})
}
}
}
19 changes: 16 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@ private[sql] object Column {
case expr => toPrettySQL(expr)
}
}

private[sql] def stripColumnReferenceMetadata(a: AttributeReference): AttributeReference = {
val metadataWithoutId = new MetadataBuilder()
.withMetadata(a.metadata)
.remove(Dataset.DATASET_ID_KEY)
.remove(Dataset.COL_POS_KEY)
.build()
a.withMetadata(metadataWithoutId)
}
}

/**
Expand Down Expand Up @@ -141,11 +150,15 @@ class Column(val expr: Expression) extends Logging {
override def toString: String = toPrettySQL(expr)

override def equals(that: Any): Boolean = that match {
case that: Column => that.expr.equals(this.expr)
case that: Column => that.normalizedExpr() == this.normalizedExpr()
case _ => false
}

override def hashCode: Int = this.expr.hashCode()
override def hashCode: Int = this.normalizedExpr().hashCode()

private def normalizedExpr(): Expression = expr transform {
case a: AttributeReference => Column.stripColumnReferenceMetadata(a)
}

/** Creates a column based on the given expression. */
private def withExpr(newExpr: Expression): Column = new Column(newExpr)
Expand Down Expand Up @@ -1023,7 +1036,7 @@ class Column(val expr: Expression) extends Logging {
* @since 2.0.0
*/
def name(alias: String): Column = withExpr {
expr match {
normalizedExpr() match {
case ne: NamedExpression => Alias(expr, alias)(explicitMetadata = Some(ne.metadata))
case other => Alias(other, alias)()
}
Expand Down
40 changes: 36 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,14 @@ import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection}
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.arrow.{ArrowBatchStreamWriter, ArrowConverters}
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.python.EvaluatePython
import org.apache.spark.sql.execution.stat.StatFunctions
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.DataStreamWriter
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils
Expand All @@ -60,6 +62,11 @@ import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.util.Utils

private[sql] object Dataset {
val curId = new java.util.concurrent.atomic.AtomicLong()
val DATASET_ID_KEY = "__dataset_id"
val COL_POS_KEY = "__col_position"
val DATASET_ID_TAG = TreeNodeTag[Long]("dataset_id")

def apply[T: Encoder](sparkSession: SparkSession, logicalPlan: LogicalPlan): Dataset[T] = {
val dataset = new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]])
// Eagerly bind the encoder so we verify that the encoder matches the underlying
Expand Down Expand Up @@ -173,6 +180,9 @@ class Dataset[T] private[sql](
encoder: Encoder[T])
extends Serializable {

// A globally unique id of this Dataset.
private val id = Dataset.curId.getAndIncrement()

queryExecution.assertAnalyzed()

// Note for Spark contributors: if adding or updating any action in `Dataset`, please make sure
Expand All @@ -189,14 +199,18 @@ class Dataset[T] private[sql](
@transient private[sql] val logicalPlan: LogicalPlan = {
// For various commands (like DDL) and queries with side effects, we force query execution
// to happen right away to let these side effects take place eagerly.
queryExecution.analyzed match {
val plan = queryExecution.analyzed match {
case c: Command =>
LocalRelation(c.output, withAction("command", queryExecution)(_.executeCollect()))
case u @ Union(children) if children.forall(_.isInstanceOf[Command]) =>
LocalRelation(u.output, withAction("command", queryExecution)(_.executeCollect()))
case _ =>
queryExecution.analyzed
}
if (sparkSession.sessionState.conf.getConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN)) {
plan.setTagValue(Dataset.DATASET_ID_TAG, id)
}
plan
}

/**
Expand Down Expand Up @@ -1271,11 +1285,29 @@ class Dataset[T] private[sql](
if (sqlContext.conf.supportQuotedRegexColumnName) {
colRegex(colName)
} else {
val expr = resolve(colName)
Column(expr)
Column(addDataFrameIdToCol(resolve(colName)))
}
}

// Attach the dataset id and column position to the column reference, so that we can detect
// ambiguous self-join correctly. See the rule `DetectAmbiguousSelfJoin`.
// This must be called before we return a `Column` that contains `AttributeReference`.
// Note that, the metadata added here are only avaiable in the analyzer, as the analyzer rule
// `DetectAmbiguousSelfJoin` will remove it.
private def addDataFrameIdToCol(expr: NamedExpression): NamedExpression = {
val newExpr = expr transform {
case a: AttributeReference
if sparkSession.sessionState.conf.getConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN) =>
val metadata = new MetadataBuilder()
.withMetadata(a.metadata)
.putLong(Dataset.DATASET_ID_KEY, id)
.putLong(Dataset.COL_POS_KEY, logicalPlan.output.indexWhere(a.semanticEquals))
.build()
a.withMetadata(metadata)
}
newExpr.asInstanceOf[NamedExpression]
}

/**
* Selects column based on the column name specified as a regex and returns it as [[Column]].
* @group untypedrel
Expand All @@ -1289,7 +1321,7 @@ class Dataset[T] private[sql](
case ParserUtils.qualifiedEscapedIdentifier(nameParts, columnNameRegex) =>
Column(UnresolvedRegex(columnNameRegex, Some(nameParts), caseSensitive))
case _ =>
Column(resolve(colName))
Column(addDataFrameIdToCol(resolve(colName)))
}
}

Expand Down
Loading