Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ class Analyzer(
* @return the attributes of non selected specified via bitmask (with the bit set to 1)
*/
private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression])
: OpenHashSet[Expression] = {
val set = new OpenHashSet[Expression](2)
: ExpressionSet = {
val set = new ExpressionSet()

var bit = exprs.length - 1
while (bit >= 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,17 @@ trait CheckAnalysis {
s"of type ${f.condition.dataType.simpleString} is not a boolean.")

case Aggregate(groupingExprs, aggregateExprs, child) =>
val normalizedGroupingExprs = ExpressionSet(groupingExprs)
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
case _: AggregateExpression => // OK
case e: Attribute if !groupingExprs.contains(e) =>
case e if normalizedGroupingExprs.contains(e) => // OK
case e if e.children.size > 0 => e.children.foreach(checkValidAggregateExpression)
case e: NamedExpression =>
failAnalysis(
s"expression '${e.prettyString}' is neither present in the group by, " +
s"nor is it an aggregate function. " +
"Add to group by or wrap in first() if you don't care which value you get.")
case e if groupingExprs.contains(e) => // OK
case e if e.references.isEmpty => // OK
case e => e.children.foreach(checkValidAggregateExpression)
s"""expression '${e.prettyString}' is neither present in the group by,
nor is it an aggregate function.
Add to group by or wrap in first() if you don't care which value you get.""")
case _ => // OK e.g Literal
}

val cleaned = aggregateExprs.map(_.transform {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

/**
* Builds a map that is keyed by an normalized expression. Using the expression allows values
* to be looked up even when the attributes used differ cosmetically (i.e., the capitalization
* of the name, or the expected nullability).
*/
sealed class ExpressionMap[A] extends Serializable {
private val baseMap = new collection.mutable.HashMap[Expression, A]()
def get(k: Expression): Option[A] = baseMap.get(ExpressionEquals.normalize(k))

def add(k: Expression, value: A): Unit = {
baseMap.put(ExpressionEquals.normalize(k), value)
}

def values: Iterable[A] = baseMap.values
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

private[expressions] object ExpressionEquals {
def normalize(expr: Expression): Expression = expr.transformUp {
case n: AttributeReference =>
// We don't care about the name of AttributeReference in its semantic equality check
new AttributeReference(null, n.dataType, n.nullable, n.metadata)(n.exprId, n.qualifiers)
}
}

object ExpressionSet {
def apply(exprs: Iterable[Expression]): ExpressionSet = {
val set = new ExpressionSet()
exprs.foreach(e => set.add(e))

set
}
}

/**
* Builds a Expression Set that used to be looked up even when the attributes used
* differ cosmetically (i.e., the capitalization of the name, or the expected nullability).
*/
sealed class ExpressionSet extends Serializable {
private val baseSet: java.util.Set[Expression] = new java.util.HashSet[Expression]()
def contains(expr: Expression): Boolean = {
baseSet.contains(ExpressionEquals.normalize(expr))
}
def add(expr: Expression): Unit = baseSet.add(ExpressionEquals.normalize(expr))
}
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,22 @@ object PartialAggregation {
// We need to pass all grouping expressions though so the grouping can happen a second
// time. However some of them might be unnamed so we alias them allowing them to be
// referenced in the second aggregation.
val namedGroupingExpressions: Map[Expression, NamedExpression] =
groupingExpressions.filter(!_.isInstanceOf[Literal]).map {
case n: NamedExpression => (n, n)
case other => (other, Alias(other, "PartialGroup")())
}.toMap
val namedGroupingMap = new ExpressionMap[NamedExpression]()
// Output (Raw Expression, A named Expression, Its associated Attribute)
val namedGroupingTuples = groupingExpressions.filter(!_.isInstanceOf[Literal]).map {
case n: NamedExpression =>
(n: Expression, n, n.toAttribute)
case other =>
val v = Alias(other, "PartialGroup")()
(other, v, v.toAttribute)
}

val partialGroupingExprs = namedGroupingTuples.map(_._2)
val namedGroupingAttributes = namedGroupingTuples.map(_._3)
// Construct the expression map for substitution in Final Aggregate Expression
namedGroupingTuples.foreach { case (expr, namedExpr, attr) =>
namedGroupingMap.add(expr, namedExpr)
}

// Replace aggregations with a new expression that computes the result from the already
// computed partial evaluations and grouping values.
Expand All @@ -159,18 +170,16 @@ object PartialAggregation {
// Should trim aliases around `GetField`s. These aliases are introduced while
// resolving struct field accesses, because `GetField` is not a `NamedExpression`.
// (Should we just turn `GetField` into a `NamedExpression`?)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

namedGroupingExpressions.values probably come with arbitrary order, which is not right compare to the groupingExpressions.

namedGroupingExpressions
namedGroupingMap
.get(e.transform { case Alias(g: ExtractValue, _) => g })
.map(_.toAttribute)
.getOrElse(e)
}).asInstanceOf[Seq[NamedExpression]]

val partialComputation =
(namedGroupingExpressions.values ++
(partialGroupingExprs ++
partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq

val namedGroupingAttributes = namedGroupingExpressions.values.map(_.toAttribute).toSeq

Some(
(namedGroupingAttributes,
rewrittenAggregateExpressions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,24 @@ class SQLQuerySuite extends QueryTest {
sql("SELECT `key` FROM src").collect().toSeq)
}

test("SPARK-7269 Check analysis failed in case in-sensitive") {
Seq(1,2,3).map { i =>
(i.toString, i.toString)
}.toDF("key", "value").registerTempTable("df_analysis")
sql("SELECT kEy from df_analysis group by key").collect()
sql("SELECT kEy+3 from df_analysis group by key+3").collect()
sql("SELECT kEy+3, a.kEy, A.kEy from df_analysis A group by key").collect()
sql("SELECT cast(kEy+1 as Int) from df_analysis A group by cast(key+1 as int)").collect()
sql("SELECT cast(kEy+1 as Int) from df_analysis A group by key+1").collect()
sql("SELECT 2 from df_analysis A group by key+1").collect()
intercept[AnalysisException] {
sql("SELECT kEy+1 from df_analysis group by key+3")
}
intercept[AnalysisException] {
sql("SELECT cast(key+2 as Int) from df_analysis A group by cast(key+1 as int)")
}
}

test("SPARK-3834 Backticks not correctly handled in subquery aliases") {
checkAnswer(
sql("SELECT a.key FROM (SELECT key FROM src) `a`"),
Expand Down