Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3977,6 +3977,8 @@ object ResolveCreateNamedStruct extends Rule[LogicalPlan] {
val children = e.children.grouped(2).flatMap {
case Seq(NamePlaceholder, e: NamedExpression) if e.resolved =>
Seq(Literal(e.name), e)
case Seq(NamePlaceholder, e: ExtractValue) if e.resolved && e.name.isDefined =>
Seq(Literal(e.name.get), e)
case kv =>
kv
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,10 @@ object ExtractValue {
}
}

trait ExtractValue extends Expression
trait ExtractValue extends Expression {
// The name that is used to extract the value.
def name: Option[String]
}

/**
* Returns the value of fields in the Struct `child`.
Expand Down Expand Up @@ -156,6 +159,7 @@ case class GetArrayStructFields(
override def dataType: DataType = ArrayType(field.dataType, containsNull)
override def toString: String = s"$child.${field.name}"
override def sql: String = s"${child.sql}.${quoteIdentifier(field.name)}"
override def name: Option[String] = Some(field.name)

protected override def nullSafeEval(input: Any): Any = {
val array = input.asInstanceOf[ArrayData]
Expand Down Expand Up @@ -233,6 +237,7 @@ case class GetArrayItem(

override def toString: String = s"$child[$ordinal]"
override def sql: String = s"${child.sql}[${ordinal.sql}]"
override def name: Option[String] = None

override def left: Expression = child
override def right: Expression = ordinal
Expand Down Expand Up @@ -448,6 +453,10 @@ case class GetMapValue(

override def toString: String = s"$child[$key]"
override def sql: String = s"${child.sql}[${key.sql}]"
override def name: Option[String] = key match {
case NonNullLiteral(s, StringType) => Some(s.toString)
case _ => None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, do we have a valid case for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e.g. col[1] to get the map value. It's not a single UnresolvedAttribute (multi-part name like a.b.c) and is unrelated to this bug fix

}

override def left: Expression = child
override def right: Expression = key
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1091,6 +1091,25 @@ class DataFrameAggregateSuite extends QueryTest
val df = spark.sql(query)
checkAnswer(df, Row(0, "0", 0, 0) :: Row(-1, "1", 1, 1) :: Row(-2, "2", 2, 2) :: Nil)
}

test("SPARK-34713: group by CreateStruct with ExtractValue") {
Copy link
Member

@dongjoon-hyun dongjoon-hyun Mar 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding this test case. BTW, do you think we can have a narrow-downed test case in catalyst module instead of this test in sql module?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have a UT for ResolveCreateNamedStruct yet. The bug is trivial so I didn't spend time building a new UT.

val structDF = Seq(Tuple1(1 -> 1)).toDF("col")
checkAnswer(structDF.groupBy(struct($"col._1")).count().select("count"), Row(1))

val arrayOfStructDF = Seq(Tuple1(Seq(1 -> 1))).toDF("col")
checkAnswer(arrayOfStructDF.groupBy(struct($"col._1")).count().select("count"), Row(1))

val mapDF = Seq(Tuple1(Map("a" -> "a"))).toDF("col")
checkAnswer(mapDF.groupBy(struct($"col.a")).count().select("count"), Row(1))

val nonStringMapDF = Seq(Tuple1(Map(1 -> 1))).toDF("col")
// Spark implicit casts string literal "a" to int to match the key type.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, since line 1102 and 1105 works and the following works, we don't care the field name at all?

scala> Seq(Tuple1(Map("b" -> "b"))).toDF("col").groupBy(struct($"col.a")).count().select("count").show
+-----+
|count|
+-----+
|    1|
+-----+

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @maropu , too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The col is a map-type column and the syntax a.b can get map values.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I was confused here. Thanks.

checkAnswer(nonStringMapDF.groupBy(struct($"col.a")).count().select("count"), Row(1))

val arrayDF = Seq(Tuple1(Seq(1))).toDF("col")
val e = intercept[AnalysisException](arrayDF.groupBy(struct($"col.a")).count())
assert(e.message.contains("requires integral type"))
}
}

case class B(c: Option[Double])
Expand Down