Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,7 +1263,6 @@ def test_access_column(self):
self.assertTrue(isinstance(df['key'], Column))
self.assertTrue(isinstance(df[0], Column))
self.assertRaises(IndexError, lambda: df[2])
self.assertRaises(AnalysisException, lambda: df["bad_key"])
self.assertRaises(TypeError, lambda: df[{}])

def test_column_name_with_non_ascii(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -556,31 +556,31 @@ class Analyzer(
right.collect {
// Handle base relations that might appear more than once.
case oldVersion: MultiInstanceRelation
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
val newVersion = oldVersion.newInstance()
(oldVersion, newVersion)

case oldVersion: SerializeFromObject
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
(oldVersion, oldVersion.copy(serializer = oldVersion.serializer.map(_.newInstance())))

// Handle projects that create conflicting aliases.
case oldVersion @ Project(projectList, _)
if findAliases(projectList).intersect(conflictingAttributes).nonEmpty =>
if findAliases(projectList).intersect(conflictingAttributes).nonEmpty =>
(oldVersion, oldVersion.copy(projectList = newAliases(projectList)))

case oldVersion @ Aggregate(_, aggregateExpressions, _)
if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
(oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions)))

case oldVersion: Generate
if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty =>
if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty =>
val newOutput = oldVersion.generatorOutput.map(_.newInstance())
(oldVersion, oldVersion.copy(generatorOutput = newOutput))

case oldVersion @ Window(windowExpressions, _, _, child)
if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes)
.nonEmpty =>
if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes)
.nonEmpty =>
(oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions)))
}
// Only handle first case, others will be fixed on the next pass.
Expand All @@ -597,11 +597,16 @@ class Analyzer(
val newRight = right transformUp {
case r if r == oldRelation => newRelation
} transformUp {
case other => other transformExpressions {
case a: Attribute =>
attributeRewrites.get(a).getOrElse(a).withQualifier(a.qualifier)
}
case other =>
val transformed = other transformExpressions {
case a: Attribute =>
attributeRewrites.get(a).getOrElse(a).withQualifier(a.qualifier)
}

transformed.setPlanId(other.planId)
transformed
}
newRight.setPlanId(right.planId)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand this. If it's a self-join, and left and right is same plan(same plan id), then after dedupRight, left and right is not same plan but still have same plan id right? How do we resolve UnresolvedAttribute with plan id? I think it's still ambiguous.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah I understand now. Different dataframes must have different logical plans, the problem we are trying to fix is indirect self-join. For indirect self-join, the left and right here must be different plans and have different plan ids.

newRight
}
}
Expand Down Expand Up @@ -664,11 +669,18 @@ class Analyzer(

case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString}")
q transformExpressionsUp {
case u @ UnresolvedAttribute(nameParts) =>
q transformExpressionsUp {
case u @ UnresolvedAttribute(nameParts, targetPlanIdOpt) =>
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
val result =
withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) }
withPosition(u) {
targetPlanIdOpt match {
case Some(targetPlanId) =>
resolveExpressionFromSpecificLogicalPlan(nameParts, q, targetPlanId)
case None =>
q.resolveChildren(nameParts, resolver).getOrElse(u)
}
}
logDebug(s"Resolving $u to $result")
result
case UnresolvedExtractValue(child, fieldExpr) if child.resolved =>
Expand Down Expand Up @@ -746,6 +758,19 @@ class Analyzer(
exprs.exists(_.find(_.isInstanceOf[UnresolvedDeserializer]).isDefined)
}

private[sql] def resolveExpressionFromSpecificLogicalPlan(
nameParts: Seq[String],
planToSearchFrom: LogicalPlan,
targetPlanId: Long): Expression = {
lazy val name = UnresolvedAttribute(nameParts).name
planToSearchFrom.findByBreadthFirst(_.planId == targetPlanId) match {
case Some(foundPlan) =>
foundPlan.resolve(nameParts, resolver).get
case None =>
failAnalysis(s"Could not find $name in any logical plan.")
}
}

protected[sql] def resolveExpression(
expr: Expression,
plan: LogicalPlan,
Expand All @@ -757,8 +782,14 @@ class Analyzer(
try {
expr transformUp {
case GetColumnByOrdinal(ordinal, _) => plan.output(ordinal)
case u @ UnresolvedAttribute(nameParts) =>
withPosition(u) { plan.resolve(nameParts, resolver).getOrElse(u) }
case u @ UnresolvedAttribute(nameParts, targetPlanIdOpt) =>
withPosition(u) {
targetPlanIdOpt match {
case Some(targetPlanId) =>
resolveExpressionFromSpecificLogicalPlan(nameParts, plan, targetPlanId)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The resolved attribute must be among the output of children, is it guaranteed here? resolveExpressionFromSpecificLogicalPlan will pick a sub-tree and its output may not be propagated all the way up.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for taking a look.
If resolved attributes are not in output of a child-logical-plan even though those are in output of a sub-tree, CheckAnalysys verifies and raise AnalysysException as well as before applying this patch

case None => plan.resolve(nameParts, resolver).getOrElse(u)
}
}
case UnresolvedExtractValue(child, fieldName) if child.resolved =>
ExtractValue(child, fieldName, resolver)
}
Expand Down Expand Up @@ -986,12 +1017,17 @@ class Analyzer(
plan transformDown {
case q: LogicalPlan if q.childrenResolved && !q.resolved =>
q transformExpressions {
case u @ UnresolvedAttribute(nameParts) =>
case u @ UnresolvedAttribute(nameParts, targetPlanIdOpt) =>
withPosition(u) {
try {
outer.resolve(nameParts, resolver) match {
case Some(outerAttr) => OuterReference(outerAttr)
case None => u
targetPlanIdOpt match {
case Some(targetPlanId) =>
resolveExpressionFromSpecificLogicalPlan(nameParts, outer, targetPlanId)
case None =>
outer.resolve(nameParts, resolver) match {
case Some(outerAttr) => OuterReference(outerAttr)
case None => u
}
}
} catch {
case _: AnalysisException => u
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ case class UnresolvedTableValuedFunction(functionName: String, functionArgs: Seq
/**
* Holds the name of an attribute that has yet to be resolved.
*/
case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Unevaluable {
case class UnresolvedAttribute(
nameParts: Seq[String],
targetPlanIdOpt: Option[Long] = None) extends Attribute with Unevaluable {

def name: String =
nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".")
Expand Down Expand Up @@ -162,7 +164,7 @@ object UnresolvedAttribute {
}
if (inBacktick) throw e
nameParts += tmp.mkString
nameParts.toSeq
nameParts
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ object ExpressionEncoder {
} else {
val input = GetColumnByOrdinal(index, enc.schema)
val deserialized = enc.deserializer.transformUp {
case UnresolvedAttribute(nameParts) =>
case UnresolvedAttribute(nameParts, _) =>
assert(nameParts.length == 1)
UnresolvedExtractValue(input, Literal(nameParts.head))
case GetColumnByOrdinal(ordinal, _) => GetStructField(input, ordinal)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ case class AttributeReference(
extends Attribute with Unevaluable {

/**
* Returns true iff the expression id is the same for both attributes.
* Returns true if the expression id is the same for both attributes.
*/
def sameRef(other: AttributeReference): Boolean = this.exprId == other.exprId

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1156,8 +1156,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
override def visitDereference(ctx: DereferenceContext): Expression = withOrigin(ctx) {
val attr = ctx.fieldName.getText
expression(ctx.base) match {
case UnresolvedAttribute(nameParts) =>
UnresolvedAttribute(nameParts :+ attr)
case UnresolvedAttribute(nameParts, targetPlanId) =>
UnresolvedAttribute(nameParts :+ attr, targetPlanId)
case e =>
UnresolvedExtractValue(e, Literal(attr))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,20 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.types.StructType

object LogicalPlan {
private val curId = new java.util.concurrent.atomic.AtomicLong()
def newPlanId: Long = curId.getAndIncrement()
}

abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {

private var _analyzed: Boolean = false

// Logical plans are identified by planId
// even though a logical plan is replaced by the analyzer
// to deduplicate expressions which have same exprId.
private var _planId: Long = LogicalPlan.newPlanId

/**
* Marks this plan as already analyzed. This should only be called by CheckAnalysis.
*/
Expand All @@ -43,6 +52,10 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
*/
def analyzed: Boolean = _analyzed

private[catalyst] def setPlanId(planId: Long): Unit = { _planId = planId }

def planId: Long = _planId

/** Returns true if this subtree contains any streaming data sources. */
def isStreaming: Boolean = children.exists(_.isStreaming == true)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.trees
import java.util.UUID

import scala.collection.Map
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

import org.apache.commons.lang3.ClassUtils
Expand Down Expand Up @@ -110,6 +111,24 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
children.foldLeft(Option.empty[BaseType]) { (l, r) => l.orElse(r.find(f)) }
}

def findByBreadthFirst(f: BaseType => Boolean): Option[BaseType] = {
val queue = new ArrayBuffer[BaseType]
var foundOpt: Option[BaseType] = None
queue.append(this)

// Do breadth first search to find most exact logical plan
while (queue.nonEmpty && foundOpt.isEmpty) {
val currentNode = queue.remove(0)
f(currentNode) match {
case true => foundOpt = Option(currentNode)
case false =>
val childPlans = currentNode.children.reverse
childPlans.foreach(queue.append(_))
}
}
foundOpt
}

/**
* Runs the given function on this node and then recursively on [[children]].
* @param f the function to be applied to each node in the tree.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*/
private def replaceCol(col: StructField, replacementMap: Map[_, _]): Column = {
val keyExpr = df.col(col.name).expr
def buildExpr(v: Any) = Cast(Literal(v), keyExpr.dataType)
def buildExpr(v: Any) = Cast(Literal(v), col.dataType)
val branches = replacementMap.flatMap { case (source, target) =>
Seq(buildExpr(source), buildExpr(target))
}.toSeq
Expand Down
38 changes: 29 additions & 9 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1067,12 +1067,11 @@ class Dataset[T] private[sql](
* @group untypedrel
* @since 2.0.0
*/
def col(colName: String): Column = colName match {
case "*" =>
Column(ResolvedStar(queryExecution.analyzed.output))
case _ =>
val expr = resolve(colName)
Column(expr)
def col(colName: String): Column = withStarResolved(colName) {
val expr = UnresolvedAttribute(
UnresolvedAttribute.parseAttributeName(colName),
Some(queryExecution.analyzed.planId))
Column(expr)
}

/**
Expand Down Expand Up @@ -1949,9 +1948,17 @@ class Dataset[T] private[sql](
*/
def drop(col: Column): DataFrame = {
val expression = col match {
case Column(u: UnresolvedAttribute) =>
queryExecution.analyzed.resolveQuoted(
u.name, sparkSession.sessionState.analyzer.resolver).getOrElse(u)
case Column(u @ UnresolvedAttribute(nameParts, targetPlanIdOpt)) =>
val plan = queryExecution.analyzed
val analyzer = sparkSession.sessionState.analyzer
val resolver = analyzer.resolver

targetPlanIdOpt match {
case Some(targetPlanId) =>
analyzer.resolveExpressionFromSpecificLogicalPlan(nameParts, plan, targetPlanId)
case None =>
plan.resolveQuoted(u.name, resolver).getOrElse(u)
}
case Column(expr: Expression) => expr
}
val attrs = this.logicalPlan.output
Expand Down Expand Up @@ -2786,6 +2793,19 @@ class Dataset[T] private[sql](
}
}

/** Another version of `col` which resolve an expression immediately.
* Mainly intended to use for test for example in case of passing columns to a SparkPlan.
*/
private[sql] def colInternal(colName: String): Column = withStarResolved(colName) {
val expr = resolve(colName)
Column(expr)
}

private def withStarResolved(colName: String)(f: => Column): Column = colName match {
case "*" => Column(ResolvedStar(queryExecution.analyzed.output))
case _ => f
}

/** A convenient function to wrap a logical plan and produce a DataFrame. */
@inline private def withPlan(logicalPlan: => LogicalPlan): DataFrame = {
Dataset.ofRows(sparkSession, logicalPlan)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,9 @@ class RelationalGroupedDataset protected[sql](
packageNames: Array[Byte],
broadcastVars: Array[Broadcast[Object]],
outputSchema: StructType): DataFrame = {
val groupingNamedExpressions = groupingExprs.map(alias)
val groupingNamedExpressions = groupingExprs
.map(df.sparkSession.sessionState.analyzer.resolveExpression(_, df.logicalPlan))
.map(alias)
val groupingCols = groupingNamedExpressions.map(Column(_))
val groupingDataFrame = df.select(groupingCols : _*)
val groupingAttributes = groupingNamedExpressions.map(_.toAttribute)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
Row(id, name, age, salary)
}.toSeq)
assert(df.schema.map(_.name) === Seq("id", "name", "age", "salary"))
assert(df("id") == person("id"))
val dfAnalyzer = df.sparkSession.sessionState.analyzer
val personAnalyzer = person.sparkSession.sessionState.analyzer
assert(dfAnalyzer.resolveExpression(df("id").expr, df.queryExecution.analyzed) ==
personAnalyzer.resolveExpression(person("id").expr, person.queryExecution.analyzed))
}

test("drop top level columns that contains dot") {
Expand Down Expand Up @@ -1601,6 +1604,28 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
assert(df.persist.take(1).apply(0).toSeq(100).asInstanceOf[Long] == 100)
}

test("""SPARK-17154: df("column_name") should return correct result when we do self-join""") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens when:

val joined = df.join(df, "inner")  // columns: col1, col2, col3, col1, col2, col3
val selected  = joined.select(df("col1"))

As there are two plans with the same plan id, the breadth-first search will get one plan among them. So df("col") will be resolved. However, I think in this case, we should have an ambiguous error message.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a good question!

I'm also thinking about this. If a plan id matches more than one sub-tree in the logical plan, should we just fail the query instead of using BFS to pick the first one?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, direct-self-join (means both child Datasets are same) is still ambiguous.
In this case, df("colmn-name") will refers to a Dataset of the right side in the proposed implementation.

I'm wondering a direct-self-join like df.join(df, , ) is similar to a query like as follows.

SELECT ... FROM my_table df join my_table df on ;

Those queries should not be valid so I also think we shouldn't allow users to join two same Datasets and warn to duplicate the Dataset if they intend to do direct-self-join.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also thinking about this. If a plan id matches more than one sub-tree in the logical plan, should we just fail the query instead of using BFS to pick the first one?

If logical-plan on the right side is copied by dedupRight, there should be multiple logical-plans which have same planId so it maybe better to fail the query in case of direct-self-join.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although I can't immediately think out the actual use case the self-join of two same Datasets, I am still wondering do we want to disallow it? Conceptually, it should work, even you can't select columns from it due to ambiguousness. But I think you can still save it or do other operators on it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should support a self-join of the same Dataset/DateFrame of the same name. That is,

df.join(df)

should be blocked. We can ask the user to express it as

df.join(df.as("df2"))

, which is clearer. We certainly must not support

df.join(df, df("col1") === df("col2")

, which blindly put "col1" and "col2" to the first df. @sarutak 's solution does change the behaviour to an error.

val df = Seq(
(1, "a", "A"),
(2, "b", "B"),
(3, "c", "C"),
(4, "d", "D"),
(5, "e", "E")).toDF("col1", "col2", "col3")
val filtered = df.filter("col1 != 3").select("col1", "col2")
val joined = filtered.join(df, filtered("col1") === df("col1"), "inner")
val selected1 = joined.select(df("col3"))

checkAnswer(selected1, Row("A") :: Row("B") :: Row("D") :: Row("E") :: Nil)

val rightOuterJoined = filtered.join(df, filtered("col1") === df("col1"), "right")
val selected2 = rightOuterJoined.select(df("col1"))

checkAnswer(selected2, Row(1) :: Row(2) :: Row(3) :: Row(4) :: Row(5) :: Nil)

val selected3 = rightOuterJoined.select(filtered("col1"))
checkAnswer(selected3, Row(1) :: Row(2) :: Row(null) :: Row(4) :: Row(5) :: Nil)
}

test("SPARK-17409: Do Not Optimize Query in CTAS (Data source tables) More Than Once") {
withTable("bar") {
withTempView("foo") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ object SparkPlanTest {
case plan: SparkPlan =>
val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
plan transformExpressions {
case UnresolvedAttribute(Seq(u)) =>
case UnresolvedAttribute(Seq(u), _) =>
inputMap.getOrElse(u,
sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
}
Expand Down
Loading