Skip to content

Commit 43441ae

Browse files
authored
Fix spark avro reader reading union schema data (#83)
* Fix spark avro reader to read correctly structured nested data values * Make sure field-id mapping is correctly maintained given arbitrary nested schema that contains union
1 parent a0fc79b commit 43441ae

File tree

5 files changed

+191
-17
lines changed

5 files changed

+191
-17
lines changed

core/src/main/java/org/apache/iceberg/avro/AvroSchemaVisitor.java

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,21 @@ public static <T> T visit(Schema schema, AvroSchemaVisitor<T> visitor) {
5252
case UNION:
5353
List<Schema> types = schema.getTypes();
5454
List<T> options = Lists.newArrayListWithExpectedSize(types.size());
55-
for (Schema type : types) {
56-
options.add(visit(type, visitor));
55+
if (AvroSchemaUtil.isOptionSchema(schema)) {
56+
for (Schema type : types) {
57+
options.add(visit(type, visitor));
58+
}
59+
} else {
60+
// complex union case
61+
int idx = 0;
62+
for (Schema type : types) {
63+
if (type.getType() != Schema.Type.NULL) {
64+
options.add(visitWithName("tag_" + idx, type, visitor));
65+
idx += 1;
66+
} else {
67+
options.add(visit(type, visitor));
68+
}
69+
}
5770
}
5871
return visitor.union(schema, options);
5972

core/src/main/java/org/apache/iceberg/avro/PruneColumns.java

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,11 @@ public Schema union(Schema union, List<Schema> options) {
125125
return null;
126126
} else {
127127
// Complex union case
128-
return union;
128+
return copyUnion(union, options);
129129
}
130130
}
131131

132+
132133
@Override
133134
@SuppressWarnings("checkstyle:CyclomaticComplexity")
134135
public Schema array(Schema array, Schema element) {
@@ -297,4 +298,19 @@ private static Schema.Field copyField(Schema.Field field, Schema newSchema, Inte
297298
private static boolean isOptionSchemaWithNonNullFirstOption(Schema schema) {
298299
return AvroSchemaUtil.isOptionSchema(schema) && schema.getTypes().get(0).getType() != Schema.Type.NULL;
299300
}
301+
302+
// for primitive types, the visitResult will be null, we want to reuse the primitive types from the original
303+
// schema, while for nested types, we want to use the visitResult because they have content from the previous
304+
// recursive calls.
305+
private static Schema copyUnion(Schema record, List<Schema> visitResults) {
306+
List<Schema> alts = Lists.newArrayListWithExpectedSize(visitResults.size());
307+
for (int i = 0; i < visitResults.size(); i++) {
308+
if (visitResults.get(i) == null) {
309+
alts.add(record.getTypes().get(i));
310+
} else {
311+
alts.add(visitResults.get(i));
312+
}
313+
}
314+
return Schema.createUnion(alts);
315+
}
300316
}

spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroReader.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ public ValueReader<?> union(Type expected, Schema union, List<ValueReader<?>> op
8383
if (AvroSchemaUtil.isOptionSchema(union)) {
8484
return ValueReaders.union(options);
8585
} else {
86-
return SparkValueReaders.union(options);
86+
return SparkValueReaders.union(union, options);
8787
}
8888
}
8989

spark/src/main/java/org/apache/iceberg/spark/data/SparkValueReaders.java

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
import java.nio.charset.StandardCharsets;
2828
import java.util.List;
2929
import java.util.Map;
30+
import java.util.Objects;
31+
import org.apache.avro.Schema;
3032
import org.apache.avro.io.Decoder;
3133
import org.apache.avro.util.Utf8;
3234
import org.apache.iceberg.avro.ValueReader;
@@ -81,8 +83,8 @@ static ValueReader<InternalRow> struct(List<ValueReader<?>> readers, Types.Struc
8183
return new StructReader(readers, struct, idToConstant);
8284
}
8385

84-
static ValueReader<InternalRow> union(List<ValueReader<?>> readers) {
85-
return new UnionReader(readers);
86+
static ValueReader<InternalRow> union(Schema schema, List<ValueReader<?>> readers) {
87+
return new UnionReader(schema, readers);
8688
}
8789

8890
private static class StringReader implements ValueReader<UTF8String> {
@@ -291,9 +293,11 @@ protected void set(InternalRow struct, int pos, Object value) {
291293
}
292294

293295
static class UnionReader implements ValueReader<InternalRow> {
296+
private final Schema schema;
294297
private final ValueReader[] readers;
295298

296-
private UnionReader(List<ValueReader<?>> readers) {
299+
private UnionReader(Schema schema, List<ValueReader<?>> readers) {
300+
this.schema = schema;
297301
this.readers = new ValueReader[readers.size()];
298302
for (int i = 0; i < this.readers.length; i += 1) {
299303
this.readers[i] = readers.get(i);
@@ -302,14 +306,31 @@ private UnionReader(List<ValueReader<?>> readers) {
302306

303307
@Override
304308
public InternalRow read(Decoder decoder, Object reuse) throws IOException {
305-
InternalRow struct = new GenericInternalRow(readers.length);
309+
// first we need to filter out NULL alternative if it exists in the union schema
310+
int nullIndex = -1;
311+
List<Schema> alts = schema.getTypes();
312+
for (int i = 0; i < alts.size(); i++) {
313+
Schema alt = alts.get(i);
314+
if (Objects.equals(alt.getType(), Schema.Type.NULL)) {
315+
nullIndex = i;
316+
break;
317+
}
318+
}
319+
InternalRow struct = new GenericInternalRow(nullIndex >= 0 ? alts.size() - 1 : alts.size());
320+
for (int i = 0; i < struct.numFields(); i += 1) {
321+
struct.setNullAt(i);
322+
}
323+
306324
int index = decoder.readIndex();
307325
Object value = this.readers[index].read(decoder, reuse);
308326

309-
for (int i = 0; i < readers.length; i += 1) {
310-
struct.setNullAt(i);
327+
if (nullIndex < 0) {
328+
struct.update(index, value);
329+
} else if (index < nullIndex) {
330+
struct.update(index, value);
331+
} else if (index > nullIndex) {
332+
struct.update(index - 1, value);
311333
}
312-
struct.update(index, value);
313334

314335
return struct;
315336
}

spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroUnions.java

Lines changed: 130 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import java.io.File;
2323
import java.io.IOException;
24+
import java.util.Arrays;
2425
import java.util.List;
2526
import org.apache.avro.SchemaBuilder;
2627
import org.apache.avro.file.DataFileWriter;
@@ -59,7 +60,7 @@ public void writeAndValidateRequiredComplexUnion() throws IOException {
5960
.endRecord();
6061

6162
GenericData.Record unionRecord1 = new GenericData.Record(avroSchema);
62-
unionRecord1.put("unionCol", "StringType1");
63+
unionRecord1.put("unionCol", "foo");
6364
GenericData.Record unionRecord2 = new GenericData.Record(avroSchema);
6465
unionRecord2.put("unionCol", 1);
6566

@@ -80,6 +81,14 @@ public void writeAndValidateRequiredComplexUnion() throws IOException {
8081
.project(expectedSchema)
8182
.build()) {
8283
rows = Lists.newArrayList(reader);
84+
85+
Assert.assertEquals(2, rows.get(0).getStruct(0, 2).numFields());
86+
Assert.assertTrue(rows.get(0).getStruct(0, 2).isNullAt(0));
87+
Assert.assertEquals("foo", rows.get(0).getStruct(0, 2).getString(1));
88+
89+
Assert.assertEquals(2, rows.get(1).getStruct(0, 2).numFields());
90+
Assert.assertEquals(1, rows.get(1).getStruct(0, 2).getInt(0));
91+
Assert.assertTrue(rows.get(1).getStruct(0, 2).isNullAt(1));
8392
}
8493
}
8594

@@ -96,13 +105,15 @@ public void writeAndValidateOptionalComplexUnion() throws IOException {
96105
.and()
97106
.stringType()
98107
.endUnion()
99-
.noDefault()
108+
.nullDefault()
100109
.endRecord();
101110

102111
GenericData.Record unionRecord1 = new GenericData.Record(avroSchema);
103-
unionRecord1.put("unionCol", "StringType1");
112+
unionRecord1.put("unionCol", "foo");
104113
GenericData.Record unionRecord2 = new GenericData.Record(avroSchema);
105114
unionRecord2.put("unionCol", 1);
115+
GenericData.Record unionRecord3 = new GenericData.Record(avroSchema);
116+
unionRecord3.put("unionCol", null);
106117

107118
File testFile = temp.newFile();
108119
Assert.assertTrue("Delete should succeed", testFile.delete());
@@ -111,6 +122,7 @@ public void writeAndValidateOptionalComplexUnion() throws IOException {
111122
writer.create(avroSchema, testFile);
112123
writer.append(unionRecord1);
113124
writer.append(unionRecord2);
125+
writer.append(unionRecord3);
114126
}
115127

116128
Schema expectedSchema = AvroSchemaUtil.toIceberg(avroSchema);
@@ -121,25 +133,78 @@ public void writeAndValidateOptionalComplexUnion() throws IOException {
121133
.project(expectedSchema)
122134
.build()) {
123135
rows = Lists.newArrayList(reader);
136+
137+
Assert.assertEquals("foo", rows.get(0).getStruct(0, 2).getString(1));
138+
Assert.assertEquals(1, rows.get(1).getStruct(0, 2).getInt(0));
139+
Assert.assertTrue(rows.get(2).getStruct(0, 2).isNullAt(0));
140+
Assert.assertTrue(rows.get(2).getStruct(0, 2).isNullAt(1));
124141
}
125142
}
126143

127144
@Test
128-
public void writeAndValidateSingleComponentUnion() throws IOException {
145+
public void writeAndValidateSingleTypeUnion() throws IOException {
129146
org.apache.avro.Schema avroSchema = SchemaBuilder.record("root")
130147
.fields()
131148
.name("unionCol")
132149
.type()
133150
.unionOf()
151+
.nullType()
152+
.and()
134153
.intType()
135154
.endUnion()
155+
.nullDefault()
156+
.endRecord();
157+
158+
GenericData.Record unionRecord1 = new GenericData.Record(avroSchema);
159+
unionRecord1.put("unionCol", 0);
160+
GenericData.Record unionRecord2 = new GenericData.Record(avroSchema);
161+
unionRecord2.put("unionCol", 1);
162+
163+
File testFile = temp.newFile();
164+
Assert.assertTrue("Delete should succeed", testFile.delete());
165+
166+
try (DataFileWriter<GenericData.Record> writer = new DataFileWriter<>(new GenericDatumWriter<>())) {
167+
writer.create(avroSchema, testFile);
168+
writer.append(unionRecord1);
169+
writer.append(unionRecord2);
170+
}
171+
172+
Schema expectedSchema = AvroSchemaUtil.toIceberg(avroSchema);
173+
174+
List<InternalRow> rows;
175+
try (AvroIterable<InternalRow> reader = Avro.read(Files.localInput(testFile))
176+
.createReaderFunc(SparkAvroReader::new)
177+
.project(expectedSchema)
178+
.build()) {
179+
rows = Lists.newArrayList(reader);
180+
181+
Assert.assertEquals(0, rows.get(0).getInt(0));
182+
Assert.assertEquals(1, rows.get(1).getInt(0));
183+
}
184+
}
185+
186+
@Test
187+
public void testDeeplyNestedUnionSchema1() throws IOException {
188+
org.apache.avro.Schema avroSchema = SchemaBuilder.record("root")
189+
.fields()
190+
.name("col1")
191+
.type()
192+
.array()
193+
.items()
194+
.unionOf()
195+
.nullType()
196+
.and()
197+
.intType()
198+
.and()
199+
.stringType()
200+
.endUnion()
136201
.noDefault()
137202
.endRecord();
138203

139204
GenericData.Record unionRecord1 = new GenericData.Record(avroSchema);
140-
unionRecord1.put("unionCol", 1);
205+
unionRecord1.put("col1", Arrays.asList("foo", 1));
141206
GenericData.Record unionRecord2 = new GenericData.Record(avroSchema);
142-
unionRecord2.put("unionCol", 2);
207+
unionRecord2.put("col1", Arrays.asList(2, "bar"));
143208

144209
File testFile = temp.newFile();
145210
Assert.assertTrue("Delete should succeed", testFile.delete());
@@ -158,6 +223,65 @@ public void writeAndValidateSingleComponentUnion() throws IOException {
158223
.project(expectedSchema)
159224
.build()) {
160225
rows = Lists.newArrayList(reader);
226+
227+
// making sure it reads the correctly nested structured data, based on the transformation from union to struct
228+
Assert.assertEquals("foo", rows.get(0).getArray(0).getStruct(0, 2).getString(1));
229+
}
230+
}
231+
232+
@Test
233+
public void testDeeplyNestedUnionSchema2() throws IOException {
234+
org.apache.avro.Schema avroSchema = SchemaBuilder.record("root")
235+
.fields()
236+
.name("col1")
237+
.type()
238+
.array()
239+
.items()
240+
.unionOf()
241+
.record("r1")
242+
.fields()
243+
.name("id")
244+
.type()
245+
.intType()
246+
.noDefault()
247+
.endRecord()
248+
.and()
249+
.record("r2")
250+
.fields()
251+
.name("id")
252+
.type()
253+
.intType()
254+
.noDefault()
255+
.endRecord()
256+
.endUnion()
257+
.noDefault()
258+
.endRecord();
259+
260+
GenericData.Record outer = new GenericData.Record(avroSchema);
261+
GenericData.Record inner = new GenericData.Record(avroSchema.getFields().get(0).schema()
262+
.getElementType().getTypes().get(0));
263+
264+
inner.put("id", 1);
265+
outer.put("col1", Arrays.asList(inner));
266+
267+
File testFile = temp.newFile();
268+
Assert.assertTrue("Delete should succeed", testFile.delete());
269+
270+
try (DataFileWriter<GenericData.Record> writer = new DataFileWriter<>(new GenericDatumWriter<>())) {
271+
writer.create(avroSchema, testFile);
272+
writer.append(outer);
273+
}
274+
275+
Schema expectedSchema = AvroSchemaUtil.toIceberg(avroSchema);
276+
List<InternalRow> rows;
277+
try (AvroIterable<InternalRow> reader = Avro.read(Files.localInput(testFile))
278+
.createReaderFunc(SparkAvroReader::new)
279+
.project(expectedSchema)
280+
.build()) {
281+
rows = Lists.newArrayList(reader);
282+
283+
// making sure it reads the correctly nested structured data, based on the transformation from union to struct
284+
Assert.assertEquals(1, rows.get(0).getArray(0).getStruct(0, 2).getStruct(0, 1).getInt(0));
161285
}
162286
}
163287
}

0 commit comments

Comments
 (0)