diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala index 12c7905f62d1a..6d783be112777 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -223,7 +223,7 @@ object DataType { ("elementType", t: JValue), ("type", JString("array"))) => assertValidTypeForCollations(fieldPath, "array", collationsMap) - val elementType = parseDataType(t, fieldPath + ".element", collationsMap) + val elementType = parseDataType(t, appendFieldToPath(fieldPath, "element"), collationsMap) ArrayType(elementType, n) case JSortedObject( @@ -232,8 +232,8 @@ object DataType { ("valueContainsNull", JBool(n)), ("valueType", v: JValue)) => assertValidTypeForCollations(fieldPath, "map", collationsMap) - val keyType = parseDataType(k, fieldPath + ".key", collationsMap) - val valueType = parseDataType(v, fieldPath + ".value", collationsMap) + val keyType = parseDataType(k, appendFieldToPath(fieldPath, "key"), collationsMap) + val valueType = parseDataType(v, appendFieldToPath(fieldPath, "value"), collationsMap) MapType(keyType, valueType, n) case JSortedObject( @@ -304,6 +304,13 @@ object DataType { } } + /** + * Appends a field name to a given path, using a dot separator if the path is not empty. + */ + private def appendFieldToPath(basePath: String, fieldName: String): String = { + if (basePath.isEmpty) fieldName else s"$basePath.$fieldName" + } + /** * Returns a map of field path to collation name. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 8fd9b7c43a659..4343e464b2c80 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.types import com.fasterxml.jackson.core.JsonParseException +import org.json4s.jackson.JsonMethods import org.apache.spark.{SparkException, SparkFunSuite, SparkIllegalArgumentException} import org.apache.spark.sql.catalyst.analysis.{caseInsensitiveResolution, caseSensitiveResolution} @@ -1001,6 +1002,50 @@ class DataTypeSuite extends SparkFunSuite { ) } + test("parse array type with collation metadata") { + val unicodeCollationId = CollationFactory.collationNameToId("UNICODE") + val arrayJson = + s""" + |{ + | "type": "array", + | "elementType": "string", + | "containsNull": true + |} + |""".stripMargin + + val collationsMap = Map("element" -> "UNICODE") + + // Parse without collations map + assert(DataType.parseDataType(JsonMethods.parse(arrayJson)) === ArrayType(StringType)) + + val parsedWithCollations = DataType.parseDataType( + JsonMethods.parse(arrayJson), collationsMap = collationsMap) + assert(parsedWithCollations === ArrayType(StringType(unicodeCollationId))) + } + + test("parse map type with collation metadata") { + val unicodeCollationId = CollationFactory.collationNameToId("UNICODE") + val mapJson = + s""" + |{ + | "type": "map", + | "keyType": "string", + | "valueType": "string", + | "valueContainsNull": true + |} + |""".stripMargin + + val collationsMap = Map("key" -> "UNICODE", "value" -> "UNICODE") + + // Parse without collations map + assert(DataType.parseDataType(JsonMethods.parse(mapJson)) === MapType(StringType, StringType)) + + val parsedWithCollations = DataType.parseDataType( + JsonMethods.parse(mapJson), collationsMap = collationsMap) + assert(parsedWithCollations === + MapType(StringType(unicodeCollationId), StringType(unicodeCollationId))) + } + test("SPARK-48680: Add CharType and VarcharType to DataTypes JAVA API") { assert(DataTypes.createCharType(1) === CharType(1)) assert(DataTypes.createVarcharType(100) === VarcharType(100))