Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-43333][SQL] Allow Avro to convert union type to SQL with field name stable with type #41263

Closed
wants to merge 13 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@

package org.apache.spark.sql.avro

import java.util.Locale

import scala.collection.JavaConverters._
import scala.collection.mutable

import org.apache.avro.{LogicalTypes, Schema, SchemaBuilder}
import org.apache.avro.LogicalTypes.{Date, Decimal, LocalTimestampMicros, LocalTimestampMillis, TimestampMicros, TimestampMillis}
import org.apache.avro.Schema.Type._

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.Decimal.minBytesForPrecision

Expand Down Expand Up @@ -144,11 +148,31 @@ object SchemaConverters {
case _ =>
// Convert complex unions to struct types where field names are member0, member1, etc.
// This is consistent with the behavior when converting between Avro and Parquet.
val useSchemaId = SQLConf.get.getConf(SQLConf.AVRO_STABLE_ID_FOR_UNION_TYPE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Normally these configs are provided as options for functions (e.g. for from_avro()).
For file source, it should be an option for the source.
Lets not use spark conf.


val fieldNameSet : mutable.Set[String] = mutable.Set()
val fields = avroSchema.getTypes.asScala.zipWithIndex.map {
case (s, i) =>
val schemaType = toSqlTypeHelper(s, existingRecordNames)

val fieldName = if (useSchemaId) {
// Avro's field name may be case sensitive, so field names for two named type
// could be "a" and "A" and we need to distinguish them. In this case, we throw
// an exception.
val temp_name = s"member_${s.getName.toLowerCase(Locale.ROOT)}"
if (fieldNameSet.contains(temp_name)) {
throw new IncompatibleSchemaException(
"Cannot generate stable indentifier for Avro union type due to name " +
s"conflict of type name ${s.getName}")
}
fieldNameSet.add(temp_name)
temp_name
} else {
s"member$i"
}

// All fields are nullable because only one of them is set at a time
StructField(s"member$i", schemaType.dataType, nullable = true)
StructField(fieldName, schemaType.dataType, nullable = true)
}

SchemaType(StructType(fields.toArray), nullable = false)
Expand Down
212 changes: 190 additions & 22 deletions connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,51 @@ abstract class AvroSuite
}, new GenericDatumReader[Any]()).getSchema.toString(false)
}

// Check whether an Avro schema of union type is converted to SQL in an expected way, when the
// stable ID option is on.
//
// @param types Avro types that contain in an Avro union type
// @param expectedSchema expeted SQL schema, provided in DDL string form
// @param fieldsAndRow A list of rows to be appended to the Avro file and the expected
// converted SQL rows
private def checkUnionStableId(
types: List[Schema],
expectedSchema: String,
fieldsAndRow: Seq[(Any, Row)]): Unit = {
withSQLConf(SQLConf.AVRO_STABLE_ID_FOR_UNION_TYPE.key -> "true") {
withTempDir { dir =>
val unionType = Schema.createUnion(
types.asJava
)
val fields =
Seq(new Field("field1", unionType, "doc", null.asInstanceOf[AnyVal])).asJava
val schema = Schema.createRecord("name", "docs", "namespace", false)
schema.setFields(fields)
val datumWriter = new GenericDatumWriter[GenericRecord](schema)
val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter)
dataFileWriter.create(schema, new File(s"$dir.avro"))

fieldsAndRow.map(_._1).foreach { f =>
val avroRec = new GenericData.Record(schema)
f match {
case a : Array[Byte] =>
val fixedSchema = SchemaBuilder.fixed("fixed_name").size(4)
avroRec.put("field1", new Fixed(fixedSchema, a));
case other =>
avroRec.put("field1", other)
}
dataFileWriter.append(avroRec)
}
dataFileWriter.flush()
dataFileWriter.close()

val df = spark.read.format("avro").load(s"$dir.avro")
assert(df.schema === StructType.fromDDL("field1 " + expectedSchema))
assert(df.collect().toSet == fieldsAndRow.map(fr => Row(fr._2)).toSet)
}
}
}

private def getResourceAvroFilePath(name: String): String = {
Thread.currentThread().getContextClassLoader.getResource(name).toString
}
Expand Down Expand Up @@ -271,29 +316,152 @@ abstract class AvroSuite
}
}

test("SPARK-27858 Union type: More than one non-null type") {
withTempDir { dir =>
val complexNullUnionType = Schema.createUnion(
List(Schema.create(Type.INT), Schema.create(Type.NULL), Schema.create(Type.STRING)).asJava)
val fields = Seq(
new Field("field1", complexNullUnionType, "doc", null.asInstanceOf[AnyVal])).asJava
val schema = Schema.createRecord("name", "docs", "namespace", false)
schema.setFields(fields)
val datumWriter = new GenericDatumWriter[GenericRecord](schema)
val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter)
dataFileWriter.create(schema, new File(s"$dir.avro"))
val avroRec = new GenericData.Record(schema)
avroRec.put("field1", 42)
dataFileWriter.append(avroRec)
val avroRec2 = new GenericData.Record(schema)
avroRec2.put("field1", "Alice")
dataFileWriter.append(avroRec2)
dataFileWriter.flush()
dataFileWriter.close()
test("SPARK-43333: union stable id") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove SPARK jira id here.
Can we include a user defined Avro struct also in addition to primitive types? Say 'CustomerInfo'.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests need to include Spark Jira ids unless the test suite is new.

Copy link
Contributor

@rangadi rangadi May 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it mean we need to read the Spark Jira to understand the test? I would be surprised if there is a such policy. Do you have link?
It is a test for a new feature. Ideally it should be understandable by itself and should not need to go to jira ticket. I have added many new tests without adding Jira id.
I am ok if we want to include it here. I don't see any of use of doing so.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, Jira number needs to be included, however, the test name should be descriptive enough to understand what the test does. Jira number is added for the reference, if the test breaks, it is much easier to track down the original change and understand the motivation behind it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can find a note on this in https://spark.apache.org/contributing.html (Pull request section).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the link. Sure.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you update the test name, e.g. Stable field names when converting Union type or Union type: stable field ids/names? So other contributors could understand what is being tested here.

checkUnionStableId(
List(Type.INT, Type.NULL, Type.STRING).map(Schema.create(_)),
"struct<member_int: int, member_string: string>",
Seq(
(42, Row(42, null)),
("Alice", Row(null, "Alice"))))

val df = spark.read.format("avro").load(s"$dir.avro")
assert(df.schema === StructType.fromDDL("field1 struct<member0: int, member1: string>"))
assert(df.collect().toSet == Set(Row(Row(42, null)), Row(Row(null, "Alice"))))
checkUnionStableId(
List( Type.FLOAT, Type.BOOLEAN, Type.BYTES, Type.DOUBLE, Type.LONG).map(Schema.create(_)),
"struct<member_float: float, member_boolean: boolean, " +
"member_bytes: binary, member_double: double, member_long: long>",
Seq(
(true, Row(null, true, null, null, null)),
(42L, Row(null, null, null, null, 42L)),
(42F, Row(42.0, null, null, null, null)),
(42D, Row(null, null, null, 42D, null))))

checkUnionStableId(
List(
Schema.createArray(Schema.create(Type.FLOAT)),
Schema.createMap(Schema.create(Schema.Type.INT))),
"struct<member_array: array<float>, member_map: map<string, int>>",
Seq())

checkUnionStableId(
List(
Schema.createEnum("myenum", "", null, List[String]("e1", "e2").asJava),
Schema.createRecord("myrecord", "", null, false,
List[Schema.Field](new Schema.Field("f", Schema.createFixed("myfield", "", null, 6)))
.asJava),
Schema.createRecord("myrecord2", "", null, false,
List[Schema.Field](new Schema.Field("f", Schema.create(Type.FLOAT)))
.asJava)),
"struct<member_myenum: string, member_myrecord: struct<f: binary>, " +
"member_myrecord2: struct<f: float>>",
Seq())

{
val e = intercept[Exception] {
checkUnionStableId(
List(
Schema.createFixed("MYFIELD2", "", null, 6),
Schema.createFixed("myfield1", "", null, 6),
Schema.createFixed("myfield2", "", null, 9)),
"",
Seq())
}
assert(e.getMessage.contains("Cannot generate stable indentifier"))
}
{
val e = intercept[Exception] {
checkUnionStableId(
List(
Schema.createFixed("ARRAY", "", null, 6),
Schema.createArray(Schema.create(Type.STRING))),
"",
Seq())
}
assert(e.getMessage.contains("Cannot generate stable indentifier"))
}
// Two array types or two map types are not allowed in union.
{
val e = intercept[Exception] {
Schema.createUnion(
List(
Schema.createArray(Schema.create(Type.FLOAT)),
Schema.createArray(Schema.create(Type.STRING))).asJava)
}
assert(e.getMessage.contains("Duplicate in union"))
}
{
val e = intercept[Exception] {
Schema.createUnion(
List(
Schema.createMap(Schema.create(Type.FLOAT)),
Schema.createMap(Schema.create(Type.STRING))).asJava)
}
assert(e.getMessage.contains("Duplicate in union"))
}

// Somehow Avro allows named type "array", but doesn't allow an array type in the same union.
{
val e = intercept[Exception] {
Schema.createUnion(
List(
Schema.createArray(Schema.create(Type.FLOAT)),
Schema.createFixed("array", "", null, 6)
).asJava
)
}
assert(e.getMessage.contains("Duplicate in union"))
}
{
val e = intercept[Exception] {
Schema.createUnion(
List(Schema.createFixed("long", "", null, 6)).asJava
)
}
assert(e.getMessage.contains("Schemas may not be named after primitives"))
}

{
val e = intercept[Exception] {
Schema.createUnion(
List(Schema.createFixed("bytes", "", null, 6)).asJava
)
}
assert(e.getMessage.contains("Schemas may not be named after primitives"))
}
}

test("SPARK-27858 Union type: More than one non-null type") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could add a short description of the test in a comment at the top? This helps in understanding the test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a new test. It's an existing test. I just added the scenario of stable ID.

Seq(true, false).foreach { isStableUnionMember =>
withSQLConf(SQLConf.AVRO_STABLE_ID_FOR_UNION_TYPE.key -> isStableUnionMember.toString) {
withTempDir { dir =>
val complexNullUnionType = Schema.createUnion(
List(Schema.create(Type.INT), Schema.create(Type.NULL), Schema.create(Type.STRING))
.asJava
)
val fields =
Seq(new Field("field1", complexNullUnionType, "doc", null.asInstanceOf[AnyVal])).asJava
val schema = Schema.createRecord("name", "docs", "namespace", false)
schema.setFields(fields)
val datumWriter = new GenericDatumWriter[GenericRecord](schema)
val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter)
dataFileWriter.create(schema, new File(s"$dir.avro"))
val avroRec = new GenericData.Record(schema)
avroRec.put("field1", 42)
dataFileWriter.append(avroRec)
val avroRec2 = new GenericData.Record(schema)
avroRec2.put("field1", "Alice")
dataFileWriter.append(avroRec2)
dataFileWriter.flush()
dataFileWriter.close()

val df = spark.read.format("avro").load(s"$dir.avro")
if (isStableUnionMember) {
assert(df.schema === StructType.fromDDL(
"field1 struct<member_int: int, member_string: string>"))
} else {
assert(df.schema === StructType.fromDDL("field1 struct<member0: int, member1: string>"))
}
assert(df.collect().toSet == Set(Row(Row(42, null)), Row(Row(null, "Alice"))))
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3413,6 +3413,18 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val AVRO_STABLE_ID_FOR_UNION_TYPE = buildConf(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commented above. I think it should be an option for Avro functions and Avro source, not a spark conf.

"spark.sql.avro.enableStableIdentifiersForUnionType")
.doc("If it is set to true, Avro schema is deserialized into Spark SQL schema, and the Avro " +
"Union type is transformed into a structure where the field names remain consistent with " +
"their respective types. The resulting field names are converted to lowercase, " +
"e.g. member_int or member_string. If two user-defined type names or a user-defined type " +
"name and a built-in type name are identical regardless of case, an exception will be " +
"raised. However, in other cases, the field names can be uniquely identified.")
.version("3.5.0")
.booleanConf
.createWithDefault(false)

val LEGACY_PARSE_NULL_PARTITION_SPEC_AS_STRING_LITERAL =
buildConf("spark.sql.legacy.parseNullPartitionSpecAsStringLiteral")
.internal()
Expand Down