diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6f422c30cb97..f88d308dcbce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3977,6 +3977,8 @@ object ResolveCreateNamedStruct extends Rule[LogicalPlan] { val children = e.children.grouped(2).flatMap { case Seq(NamePlaceholder, e: NamedExpression) if e.resolved => Seq(Literal(e.name), e) + case Seq(NamePlaceholder, e: ExtractValue) if e.resolved && e.name.isDefined => + Seq(Literal(e.name.get), e) case kv => kv } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 139d9a584ccb..4413a3deaa64 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -90,7 +90,10 @@ object ExtractValue { } } -trait ExtractValue extends Expression +trait ExtractValue extends Expression { + // The name that is used to extract the value. + def name: Option[String] +} /** * Returns the value of fields in the Struct `child`. @@ -156,6 +159,7 @@ case class GetArrayStructFields( override def dataType: DataType = ArrayType(field.dataType, containsNull) override def toString: String = s"$child.${field.name}" override def sql: String = s"${child.sql}.${quoteIdentifier(field.name)}" + override def name: Option[String] = Some(field.name) protected override def nullSafeEval(input: Any): Any = { val array = input.asInstanceOf[ArrayData] @@ -233,6 +237,7 @@ case class GetArrayItem( override def toString: String = s"$child[$ordinal]" override def sql: String = s"${child.sql}[${ordinal.sql}]" + override def name: Option[String] = None override def left: Expression = child override def right: Expression = ordinal @@ -448,6 +453,10 @@ case class GetMapValue( override def toString: String = s"$child[$key]" override def sql: String = s"${child.sql}[${key.sql}]" + override def name: Option[String] = key match { + case NonNullLiteral(s, StringType) => Some(s.toString) + case _ => None + } override def left: Expression = child override def right: Expression = key diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 07e6a40bc07e..3e137d49e64c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -1091,6 +1091,25 @@ class DataFrameAggregateSuite extends QueryTest val df = spark.sql(query) checkAnswer(df, Row(0, "0", 0, 0) :: Row(-1, "1", 1, 1) :: Row(-2, "2", 2, 2) :: Nil) } + + test("SPARK-34713: group by CreateStruct with ExtractValue") { + val structDF = Seq(Tuple1(1 -> 1)).toDF("col") + checkAnswer(structDF.groupBy(struct($"col._1")).count().select("count"), Row(1)) + + val arrayOfStructDF = Seq(Tuple1(Seq(1 -> 1))).toDF("col") + checkAnswer(arrayOfStructDF.groupBy(struct($"col._1")).count().select("count"), Row(1)) + + val mapDF = Seq(Tuple1(Map("a" -> "a"))).toDF("col") + checkAnswer(mapDF.groupBy(struct($"col.a")).count().select("count"), Row(1)) + + val nonStringMapDF = Seq(Tuple1(Map(1 -> 1))).toDF("col") + // Spark implicit casts string literal "a" to int to match the key type. + checkAnswer(nonStringMapDF.groupBy(struct($"col.a")).count().select("count"), Row(1)) + + val arrayDF = Seq(Tuple1(Seq(1))).toDF("col") + val e = intercept[AnalysisException](arrayDF.groupBy(struct($"col.a")).count()) + assert(e.message.contains("requires integral type")) + } } case class B(c: Option[Double])