Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ object SchemaConverters {
* This function takes an avro schema and returns a sql schema.
*/
def toSqlType(avroSchema: Schema): SchemaType = {
toSqlTypeHelper(avroSchema, Set.empty)
}

def toSqlTypeHelper(avroSchema: Schema, existingRecordNames: Set[String]): SchemaType = {
avroSchema.getType match {
case INT => avroSchema.getLogicalType match {
case _: Date => SchemaType(DateType, nullable = false)
Expand All @@ -67,21 +71,28 @@ object SchemaConverters {
case ENUM => SchemaType(StringType, nullable = false)

case RECORD =>
if (existingRecordNames.contains(avroSchema.getFullName)) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another approach is to check the whole json string schema(avroSchema.toString) here. But it seems overkill. Avro requires the full name of record to be unique.

throw new IncompatibleSchemaException(s"""
|Found recursive reference in Avro schema, which can not be processed by Spark:
|${avroSchema.toString(true)}
""".stripMargin)
}
val newRecordNames = existingRecordNames + avroSchema.getFullName
val fields = avroSchema.getFields.asScala.map { f =>
val schemaType = toSqlType(f.schema())
val schemaType = toSqlTypeHelper(f.schema(), newRecordNames)
StructField(f.name, schemaType.dataType, schemaType.nullable)
}

SchemaType(StructType(fields), nullable = false)

case ARRAY =>
val schemaType = toSqlType(avroSchema.getElementType)
val schemaType = toSqlTypeHelper(avroSchema.getElementType, existingRecordNames)
SchemaType(
ArrayType(schemaType.dataType, containsNull = schemaType.nullable),
nullable = false)

case MAP =>
val schemaType = toSqlType(avroSchema.getValueType)
val schemaType = toSqlTypeHelper(avroSchema.getValueType, existingRecordNames)
SchemaType(
MapType(StringType, schemaType.dataType, valueContainsNull = schemaType.nullable),
nullable = false)
Expand All @@ -91,13 +102,14 @@ object SchemaConverters {
// In case of a union with null, eliminate it and make a recursive call
val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL)
if (remainingUnionTypes.size == 1) {
toSqlType(remainingUnionTypes.head).copy(nullable = true)
toSqlTypeHelper(remainingUnionTypes.head, existingRecordNames).copy(nullable = true)
} else {
toSqlType(Schema.createUnion(remainingUnionTypes.asJava)).copy(nullable = true)
toSqlTypeHelper(Schema.createUnion(remainingUnionTypes.asJava), existingRecordNames)
.copy(nullable = true)
}
} else avroSchema.getTypes.asScala.map(_.getType) match {
case Seq(t1) =>
toSqlType(avroSchema.getTypes.get(0))
toSqlTypeHelper(avroSchema.getTypes.get(0), existingRecordNames)
case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) =>
SchemaType(LongType, nullable = false)
case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) =>
Expand All @@ -107,7 +119,7 @@ object SchemaConverters {
// This is consistent with the behavior when converting between Avro and Parquet.
val fields = avroSchema.getTypes.asScala.zipWithIndex.map {
case (s, i) =>
val schemaType = toSqlType(s)
val schemaType = toSqlTypeHelper(s, existingRecordNames)
// All fields are nullable because only one of them is set at a time
StructField(s"member$i", schemaType.dataType, nullable = true)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1309,4 +1309,69 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
checkCodec(df, path, "xz")
}
}

private def checkSchemaWithRecursiveLoop(avroSchema: String): Unit = {
val message = intercept[IncompatibleSchemaException] {
SchemaConverters.toSqlType(new Schema.Parser().parse(avroSchema))
}.getMessage

assert(message.contains("Found recursive reference in Avro schema"))
}

test("Detect recursive loop") {
checkSchemaWithRecursiveLoop("""
|{
| "type": "record",
| "name": "LongList",
| "fields" : [
| {"name": "value", "type": "long"}, // each element has a long
| {"name": "next", "type": ["null", "LongList"]} // optional next element
| ]
|}
""".stripMargin)

checkSchemaWithRecursiveLoop("""
|{
| "type": "record",
| "name": "LongList",
| "fields": [
| {
| "name": "value",
| "type": {
| "type": "record",
| "name": "foo",
| "fields": [
| {
| "name": "parent",
| "type": "LongList"
| }
| ]
| }
| }
| ]
|}
""".stripMargin)

checkSchemaWithRecursiveLoop("""
|{
| "type": "record",
| "name": "LongList",
| "fields" : [
| {"name": "value", "type": "long"},
| {"name": "array", "type": {"type": "array", "items": "LongList"}}
| ]
|}
""".stripMargin)

checkSchemaWithRecursiveLoop("""
|{
| "type": "record",
| "name": "LongList",
| "fields" : [
| {"name": "value", "type": "long"},
| {"name": "map", "type": {"type": "map", "values": "LongList"}}
| ]
|}
""".stripMargin)
}
}