diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 9c600c9d39cf7..69e5885db405e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -59,6 +59,23 @@ object ExtractValue { GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), ordinal, fields.length, containsNull) + case (ExtractNestedArray(StructType(fields), containsNull, containsNullSeq), + NonNullLiteral(v, StringType)) => + child match { + case ExtractGetArrayStructField(_, num) if num == containsNullSeq.size => + val fieldName = v.toString + val ordinal = findField(fields, fieldName, resolver) + val row = (0 until num).foldRight(child) { (_, e) => + GetArrayItem(e, Literal(0)) + } + val innerArray = GetArrayStructFields(row, fields(ordinal).copy(name = fieldName), + ordinal, fields.length, containsNull) + containsNullSeq.foldRight(innerArray: Expression) { (_, expr) => + new CreateArray(Seq(expr)) + } + case _ => GetArrayItem(child, extraction) + } + case (_: ArrayType, _) => GetArrayItem(child, extraction) case (MapType(kt, _, _), _) => GetMapValue(child, extraction) @@ -95,6 +112,50 @@ object ExtractValue { trait ExtractValue extends Expression +object ExtractNestedArray { + + type ReturnType = Option[(DataType, Boolean, Seq[Boolean])] + + def unapply(dataType: DataType): ReturnType = { + extractArrayType(dataType) + } + + def extractArrayType(dataType: DataType): ReturnType = { + dataType match { + case ArrayType(dt, containsNull) => + extractArrayType(dt) match { + case Some((d, cn, seq)) => Some(d, cn, containsNull +: seq) + case None => Some(dt, containsNull, Seq.empty[Boolean]) + } + case _ => None + } + } +} + +/** + * Extract GetArrayStructField from Expression + */ +object ExtractGetArrayStructField { + + type ReturnType = Option[(Expression, Int)] + + def unapply(expr: Expression): ReturnType = { + extractArrayStruct(expr) + } + + def extractArrayStruct(expr: Expression): ReturnType = { + expr match { + case gas @ GetArrayStructFields(child, _, _, _, _) => + extractArrayStruct(child) match { + case Some((e, deep)) => Some(e, deep + 1) + case None => Some(child, 1) + } + case _ => None + } + } +} + + /** * Returns the value of fields in the Struct `child`. * diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala index 3c826e812b5cc..084259c92a409 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala @@ -378,12 +378,6 @@ class SelectedFieldSuite extends AnalysisTest { StructField("subfield1", IntegerType, nullable = false) :: Nil)) :: Nil))) } - testSelect(arrayWithMultipleFields, "col7.field3.subfield1") { - StructField("col7", ArrayType(StructType( - StructField("field3", ArrayType(StructType( - StructField("subfield1", IntegerType, nullable = false) :: Nil))) :: Nil))) - } - // Array with a nested int array // |-- col1: string (nullable = false) // |-- col8: array (nullable = true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index a219b91627b2b..5c62880f57b61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -22,6 +22,8 @@ import java.net.{MalformedURLException, URL} import java.sql.{Date, Timestamp} import java.util.concurrent.atomic.AtomicBoolean +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.{AccumulatorSuite, SparkException} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.expressions.GenericRow @@ -3521,6 +3523,25 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark |""".stripMargin), Row(1)) } } + + test("SPARK-32002: Support Extract value from nested ArrayStruct") { + withTempView("rows") { + val df = spark.read + .json(Seq( + """{"a": [{"b": [{"c": [1,2]}]}]}""", + """{"a": [{"b": [{"c": [1]}, {"c": [2]}]}]}""", + """{"a":[{}]}""").toDS()) + df.createOrReplaceTempView("nest") + + checkAnswer(sql( + """ + |SELECT a.b.c FROM nest + """.stripMargin), + Row(ArrayBuffer(ArrayBuffer(ArrayBuffer(1, 2)))) :: + Row(ArrayBuffer(ArrayBuffer(ArrayBuffer(1), ArrayBuffer(2)))) :: + Row(ArrayBuffer(null)) :: Nil) + } + } } case class Foo(bar: Option[String])