diff --git a/orc/src/main/java/org/apache/iceberg/orc/ORCSchemaUtil.java b/orc/src/main/java/org/apache/iceberg/orc/ORCSchemaUtil.java index bd6875e183..12d1389d66 100644 --- a/orc/src/main/java/org/apache/iceberg/orc/ORCSchemaUtil.java +++ b/orc/src/main/java/org/apache/iceberg/orc/ORCSchemaUtil.java @@ -312,7 +312,12 @@ private static TypeDescription buildOrcProjectForStructType(Integer fieldId, Typ TypeDescription orcType; OrcField orcField = mapping.getOrDefault(fieldId, null); if (orcField != null && orcField.type.getCategory().equals(TypeDescription.Category.UNION)) { - orcType = orcField.type; + orcType = TypeDescription.createUnion(); + for (Types.NestedField nestedField : type.asStructType().fields()) { + TypeDescription childType = buildOrcProjection(nestedField.fieldId(), nestedField.type(), + isRequired && nestedField.isRequired(), mapping); + orcType.addUnionChild(childType); + } } else { orcType = TypeDescription.createStruct(); for (Types.NestedField nestedField : type.asStructType().fields()) { diff --git a/orc/src/main/java/org/apache/iceberg/orc/OrcSchemaVisitor.java b/orc/src/main/java/org/apache/iceberg/orc/OrcSchemaVisitor.java index 4be48f9fa5..2ff967c9be 100644 --- a/orc/src/main/java/org/apache/iceberg/orc/OrcSchemaVisitor.java +++ b/orc/src/main/java/org/apache/iceberg/orc/OrcSchemaVisitor.java @@ -49,15 +49,14 @@ public static T visit(TypeDescription schema, OrcSchemaVisitor visitor) { case UNION: List types = schema.getChildren(); List options = Lists.newArrayListWithExpectedSize(types.size()); - for (TypeDescription type : types) { - visitor.beforeUnionOption(type); + for (int i = 0; i < types.size(); i++) { + visitor.beforeUnionOption(types.get(i), i); try { - options.add(visit(type, visitor)); + options.add(visit(types.get(i), visitor)); } finally { - visitor.afterUnionOption(type); + visitor.afterUnionOption(types.get(i), i); } } - return visitor.union(schema, options); case LIST: @@ -123,8 +122,8 @@ private static T visitRecord(TypeDescription record, OrcSchemaVisitor vis return visitor.record(record, names, visitFields(fields, names, visitor)); } - public String optionName() { - return "_option"; + public String optionName(int ordinal) { + return "tag_" + ordinal; } public String elementName() { @@ -151,12 +150,12 @@ public void afterField(String name, TypeDescription type) { fieldNames.pop(); } - public void beforeUnionOption(TypeDescription option) { - beforeField(optionName(), option); + public void beforeUnionOption(TypeDescription option, int ordinal) { + beforeField(optionName(ordinal), option); } - public void afterUnionOption(TypeDescription option) { - afterField(optionName(), option); + public void afterUnionOption(TypeDescription option, int ordinal) { + afterField(optionName(ordinal), option); } public void beforeElementField(TypeDescription element) { diff --git a/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcUnions.java b/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcUnions.java index 3be41c41e8..7fcdbad273 100644 --- a/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcUnions.java +++ b/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcUnions.java @@ -31,6 +31,7 @@ import org.apache.iceberg.io.CloseableIterable; import org.apache.iceberg.orc.ORC; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Iterators; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.spark.data.vectorized.VectorizedSparkOrcReaders; import org.apache.iceberg.types.Types; @@ -39,6 +40,7 @@ import org.apache.orc.Writer; import org.apache.orc.storage.ql.exec.vector.BytesColumnVector; import org.apache.orc.storage.ql.exec.vector.LongColumnVector; +import org.apache.orc.storage.ql.exec.vector.StructColumnVector; import org.apache.orc.storage.ql.exec.vector.UnionColumnVector; import org.apache.orc.storage.ql.exec.vector.VectorizedRowBatch; import org.apache.spark.sql.catalyst.InternalRow; @@ -215,4 +217,100 @@ public void testSingleComponentUnion() throws IOException { assertEquals(expectedSchema, expectedSecondRow, rowIterator.next()); } } + + @Test + public void testDeeplyNestedUnion() throws IOException { + TypeDescription orcSchema = + TypeDescription.fromString("struct>>>"); + + Schema expectedSchema = new Schema( + Types.NestedField.optional(0, "c1", Types.StructType.of( + Types.NestedField.optional(1, "tag_0", Types.IntegerType.get()), + Types.NestedField.optional(2, "tag_1", + Types.StructType.of(Types.NestedField.optional(3, "c2", Types.StringType.get()), + Types.NestedField.optional(4, "c3", Types.StructType.of( + Types.NestedField.optional(5, "tag_0", Types.IntegerType.get()), + Types.NestedField.optional(6, "tag_1", Types.StringType.get())))))))); + + final InternalRow expectedFirstRow = new GenericInternalRow(1); + final InternalRow inner1 = new GenericInternalRow(2); + inner1.update(0, null); + final InternalRow inner2 = new GenericInternalRow(2); + inner2.update(0, UTF8String.fromString("foo0")); + final InternalRow inner3 = new GenericInternalRow(2); + inner3.update(0, 0); + inner3.update(1, null); + inner2.update(1, inner3); + inner1.update(1, inner2); + expectedFirstRow.update(0, inner1); + + Configuration conf = new Configuration(); + + File orcFile = temp.newFile(); + Path orcFilePath = new Path(orcFile.getPath()); + + Writer writer = OrcFile.createWriter(orcFilePath, + OrcFile.writerOptions(conf) + .setSchema(orcSchema).overwrite(true)); + + VectorizedRowBatch batch = orcSchema.createRowBatch(); + UnionColumnVector innerUnion1 = (UnionColumnVector) batch.cols[0]; + LongColumnVector innerInt1 = (LongColumnVector) innerUnion1.fields[0]; + innerInt1.fillWithNulls(); + StructColumnVector innerStruct2 = (StructColumnVector) innerUnion1.fields[1]; + BytesColumnVector innerString2 = (BytesColumnVector) innerStruct2.fields[0]; + UnionColumnVector innerUnion3 = (UnionColumnVector) innerStruct2.fields[1]; + LongColumnVector innerInt3 = (LongColumnVector) innerUnion3.fields[0]; + BytesColumnVector innerString3 = (BytesColumnVector) innerUnion3.fields[1]; + innerString3.fillWithNulls(); + + for (int r = 0; r < NUM_OF_ROWS; ++r) { + int row = batch.size++; + innerUnion1.tags[row] = 1; + innerString2.setVal(row, ("foo" + row).getBytes(StandardCharsets.UTF_8)); + innerUnion3.tags[row] = 0; + innerInt3.vector[row] = r; + // If the batch is full, write it out and start over. + if (batch.size == batch.getMaxSize()) { + writer.addRowBatch(batch); + batch.reset(); + innerInt1.fillWithNulls(); + innerString3.fillWithNulls(); + } + } + if (batch.size != 0) { + writer.addRowBatch(batch); + batch.reset(); + } + writer.close(); + + // test non-vectorized reader + List results = Lists.newArrayList(); + try (CloseableIterable reader = ORC.read(Files.localInput(orcFile)) + .project(expectedSchema) + .createReaderFunc(readOrcSchema -> new SparkOrcReader(expectedSchema, readOrcSchema)) + .build()) { + reader.forEach(results::add); + final InternalRow actualFirstRow = results.get(0); + + Assert.assertEquals(results.size(), NUM_OF_ROWS); + assertEquals(expectedSchema, expectedFirstRow, actualFirstRow); + } + + // test vectorized reader + try (CloseableIterable reader = ORC.read(Files.localInput(orcFile)) + .project(expectedSchema) + .createBatchedReaderFunc(readOrcSchema -> + VectorizedSparkOrcReaders.buildReader(expectedSchema, readOrcSchema, ImmutableMap.of())) + .build()) { + final Iterator actualRows = batchesToRows(reader.iterator()); + final InternalRow actualFirstRow = actualRows.next(); + + assertEquals(expectedSchema, expectedFirstRow, actualFirstRow); + } + } + + private Iterator batchesToRows(Iterator batches) { + return Iterators.concat(Iterators.transform(batches, ColumnarBatch::rowIterator)); + } }