Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-43333][SQL] Allow Avro to convert union type to SQL with field name stable with type #41263

Closed
wants to merge 13 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ private[sql] case class AvroDataToCatalyst(
override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType)

override lazy val dataType: DataType = {
val dt = SchemaConverters.toSqlType(expectedSchema).dataType
val dt = SchemaConverters.toSqlType(expectedSchema, options).dataType
parseMode match {
// With PermissiveMode, the output Catalyst row might contain columns of null values for
// corrupt records, even if some of the columns are not nullable in the user-provided schema.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ private[sql] class AvroOptions(
val datetimeRebaseModeInRead: String = parameters
.get(DATETIME_REBASE_MODE)
.getOrElse(SQLConf.get.getConf(SQLConf.AVRO_REBASE_MODE_IN_READ))

val useStableIdForUnionType: Boolean =
parameters.get(STABLE_ID_FOR_UNION_TYPE).map(_.toBoolean).getOrElse(false)
}

private[sql] object AvroOptions extends DataSourceOptions {
Expand All @@ -154,4 +157,11 @@ private[sql] object AvroOptions extends DataSourceOptions {
// datasource similarly to the SQL config `spark.sql.avro.datetimeRebaseModeInRead`,
// and can be set to the same values: `EXCEPTION`, `LEGACY` or `CORRECTED`.
val DATETIME_REBASE_MODE = newOption("datetimeRebaseMode")
// If it is set to true, Avro schema is deserialized into Spark SQL schema, and the Avro Union
// type is transformed into a structure where the field names remain consistent with their
// respective types. The resulting field names are converted to lowercase, e.g. member_int or
// member_string. If two user-defined type names or a user-defined type name and a built-in
// type name are identical regardless of case, an exception will be raised. However, in other
// cases, the field names can be uniquely identified.
val STABLE_ID_FOR_UNION_TYPE = newOption("enableStableIdentifiersForUnionType")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add documentation for this? I think Spark conf version had long doc comment. We can reuse that here.

}
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ private[sql] object AvroUtils extends Logging {
new FileSourceOptions(CaseInsensitiveMap(options)).ignoreCorruptFiles)
}

SchemaConverters.toSqlType(avroSchema).dataType match {
SchemaConverters.toSqlType(avroSchema, options).dataType match {
case t: StructType => Some(t)
case _ => throw new RuntimeException(
s"""Avro schema cannot be converted to a Spark SQL StructType:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

package org.apache.spark.sql.avro

import java.util.Locale

import scala.collection.JavaConverters._
import scala.collection.mutable

import org.apache.avro.{LogicalTypes, Schema, SchemaBuilder}
import org.apache.avro.LogicalTypes.{Date, Decimal, LocalTimestampMicros, LocalTimestampMillis, TimestampMicros, TimestampMillis}
Expand Down Expand Up @@ -49,13 +52,19 @@ object SchemaConverters {
* @since 2.4.0
*/
def toSqlType(avroSchema: Schema): SchemaType = {
toSqlTypeHelper(avroSchema, Set.empty)
toSqlTypeHelper(avroSchema, Set.empty, AvroOptions(Map()))
}
def toSqlType(avroSchema: Schema, options: Map[String, String]): SchemaType = {
toSqlTypeHelper(avroSchema, Set.empty, AvroOptions(options))
}

// The property specifies Catalyst type of the given field
private val CATALYST_TYPE_PROP_NAME = "spark.sql.catalyst.type"

private def toSqlTypeHelper(avroSchema: Schema, existingRecordNames: Set[String]): SchemaType = {
private def toSqlTypeHelper(
avroSchema: Schema,
existingRecordNames: Set[String],
avroOptions: AvroOptions): SchemaType = {
avroSchema.getType match {
case INT => avroSchema.getLogicalType match {
case _: Date => SchemaType(DateType, nullable = false)
Expand Down Expand Up @@ -106,20 +115,23 @@ object SchemaConverters {
}
val newRecordNames = existingRecordNames + avroSchema.getFullName
val fields = avroSchema.getFields.asScala.map { f =>
val schemaType = toSqlTypeHelper(f.schema(), newRecordNames)
val schemaType = toSqlTypeHelper(f.schema(), newRecordNames, avroOptions)
StructField(f.name, schemaType.dataType, schemaType.nullable)
}

SchemaType(StructType(fields.toArray), nullable = false)

case ARRAY =>
val schemaType = toSqlTypeHelper(avroSchema.getElementType, existingRecordNames)
val schemaType = toSqlTypeHelper(
avroSchema.getElementType,
existingRecordNames,
avroOptions)
SchemaType(
ArrayType(schemaType.dataType, containsNull = schemaType.nullable),
nullable = false)

case MAP =>
val schemaType = toSqlTypeHelper(avroSchema.getValueType, existingRecordNames)
val schemaType = toSqlTypeHelper(avroSchema.getValueType, existingRecordNames, avroOptions)
SchemaType(
MapType(StringType, schemaType.dataType, valueContainsNull = schemaType.nullable),
nullable = false)
Expand All @@ -129,26 +141,50 @@ object SchemaConverters {
// In case of a union with null, eliminate it and make a recursive call
val remainingUnionTypes = AvroUtils.nonNullUnionBranches(avroSchema)
if (remainingUnionTypes.size == 1) {
toSqlTypeHelper(remainingUnionTypes.head, existingRecordNames).copy(nullable = true)
} else {
toSqlTypeHelper(Schema.createUnion(remainingUnionTypes.asJava), existingRecordNames)
toSqlTypeHelper(remainingUnionTypes.head, existingRecordNames, avroOptions)
.copy(nullable = true)
} else {
toSqlTypeHelper(
Schema.createUnion(remainingUnionTypes.asJava),
existingRecordNames,
avroOptions).copy(nullable = true)
}
} else avroSchema.getTypes.asScala.map(_.getType).toSeq match {
case Seq(t1) =>
toSqlTypeHelper(avroSchema.getTypes.get(0), existingRecordNames)
toSqlTypeHelper(avroSchema.getTypes.get(0), existingRecordNames, avroOptions)
case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) =>
SchemaType(LongType, nullable = false)
case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) =>
SchemaType(DoubleType, nullable = false)
case _ =>
// Convert complex unions to struct types where field names are member0, member1, etc.
// This is consistent with the behavior when converting between Avro and Parquet.
// When avroOptions.useStableIdForUnionType is false, convert complex unions to struct
// types where field names are member0, member1, etc. This is consistent with the
// behavior when converting between Avro and Parquet.
Comment on lines +161 to +162
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is Parquet connection here? Should this say "consistent with default behavior before adding support for stable names"?.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the existing comment. I just got in different lines after adding "When avroOptions.useStableIdForUnionType is false" in the beginning. I don't know what it is and I have no reason to doubt it is wrong.

// If avroOptions.useStableIdForUnionType is true, include type name in field names
// so that users can drop or add fields and keep field name stable.
val fieldNameSet : mutable.Set[String] = mutable.Set()
val fields = avroSchema.getTypes.asScala.zipWithIndex.map {
case (s, i) =>
val schemaType = toSqlTypeHelper(s, existingRecordNames)
val schemaType = toSqlTypeHelper(s, existingRecordNames, avroOptions)

val fieldName = if (avroOptions.useStableIdForUnionType) {
// Avro's field name may be case sensitive, so field names for two named type
// could be "a" and "A" and we need to distinguish them. In this case, we throw
// an exception.
val temp_name = s"member_${s.getName.toLowerCase(Locale.ROOT)}"
if (fieldNameSet.contains(temp_name)) {
throw new IncompatibleSchemaException(
"Cannot generate stable indentifier for Avro union type due to name " +
s"conflict of type name ${s.getName}")
}
fieldNameSet.add(temp_name)
temp_name
} else {
s"member$i"
}

// All fields are nullable because only one of them is set at a time
StructField(s"member$i", schemaType.dataType, nullable = true)
StructField(fieldName, schemaType.dataType, nullable = true)
}

SchemaType(StructType(fields.toArray), nullable = false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,13 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession {
| ]
|}
""".stripMargin
val avroSchema = AvroOptions(Map("avroSchema" -> avroTypeStruct)).schema.get
val sparkSchema = SchemaConverters.toSqlType(avroSchema).dataType.asInstanceOf[StructType]
val options = Map("avroSchema" -> avroTypeStruct)
val avroOptions = AvroOptions(options)
val avroSchema = avroOptions.schema.get
val sparkSchema = SchemaConverters
.toSqlType(avroSchema, options)
.dataType
.asInstanceOf[StructType]

val df = spark.range(5).select($"id")
val structDf = df.select(struct($"id").as("struct"))
Expand Down
Loading