|
19 | 19 |
|
20 | 20 | package org.apache.iceberg.parquet; |
21 | 21 |
|
| 22 | +import java.io.IOException; |
22 | 23 | import java.util.List; |
23 | | -import java.util.Map; |
24 | | -import java.util.Set; |
25 | 24 | import org.apache.iceberg.FileFormat; |
26 | | -import org.apache.iceberg.Schema; |
| 25 | +import org.apache.iceberg.TestMergingMetrics; |
27 | 26 | import org.apache.iceberg.data.GenericAppenderFactory; |
28 | | -import org.apache.iceberg.data.GenericRecord; |
29 | | -import org.apache.iceberg.data.RandomGenericData; |
30 | 27 | import org.apache.iceberg.data.Record; |
31 | 28 | import org.apache.iceberg.io.FileAppender; |
32 | | -import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; |
33 | | -import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; |
34 | | -import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; |
35 | | -import org.apache.iceberg.types.Types; |
36 | | -import org.junit.Assert; |
37 | | -import org.junit.Rule; |
38 | | -import org.junit.Test; |
39 | | -import org.junit.rules.TemporaryFolder; |
40 | 29 |
|
41 | | -import static org.apache.iceberg.types.Types.NestedField.optional; |
42 | | -import static org.apache.iceberg.types.Types.NestedField.required; |
43 | | - |
44 | | -public class TestParquetMergingMetrics { |
45 | | - |
46 | | - // all supported fields, except for UUID which is on deprecation path: see https://github.com/apache/iceberg/pull/1611 |
47 | | - private static final Types.NestedField ID_FIELD = required(1, "id", Types.IntegerType.get()); |
48 | | - private static final Types.NestedField DATA_FIELD = optional(2, "data", Types.StringType.get()); |
49 | | - private static final Types.NestedField FLOAT_FIELD = required(3, "float", Types.FloatType.get()); |
50 | | - private static final Types.NestedField DOUBLE_FIELD = optional(4, "double", Types.DoubleType.get()); |
51 | | - private static final Types.NestedField DECIMAL_FIELD = optional(5, "decimal", Types.DecimalType.of(5, 3)); |
52 | | - private static final Types.NestedField FIXED_FIELD = optional(7, "fixed", Types.FixedType.ofLength(4)); |
53 | | - private static final Types.NestedField BINARY_FIELD = optional(8, "binary", Types.BinaryType.get()); |
54 | | - private static final Types.NestedField FLOAT_LIST = optional(9, "floatlist", |
55 | | - Types.ListType.ofRequired(10, Types.FloatType.get())); |
56 | | - private static final Types.NestedField LONG_FIELD = optional(11, "long", Types.LongType.get()); |
57 | | - |
58 | | - private static final Types.NestedField MAP_FIELD_1 = optional(17, "map1", |
59 | | - Types.MapType.ofOptional(18, 19, Types.FloatType.get(), Types.StringType.get()) |
60 | | - ); |
61 | | - private static final Types.NestedField MAP_FIELD_2 = optional(20, "map2", |
62 | | - Types.MapType.ofOptional(21, 22, Types.IntegerType.get(), Types.DoubleType.get()) |
63 | | - ); |
64 | | - private static final Types.NestedField STRUCT_FIELD = optional(23, "structField", Types.StructType.of( |
65 | | - required(24, "booleanField", Types.BooleanType.get()), |
66 | | - optional(25, "date", Types.DateType.get()), |
67 | | - optional(26, "time", Types.TimeType.get()), |
68 | | - optional(27, "timestamp", Types.TimestampType.withZone()), |
69 | | - optional(28, "timestampWithoutZone", Types.TimestampType.withoutZone()) |
70 | | - )); |
71 | | - |
72 | | - private static final Set<Integer> IDS_WITH_ZERO_NAN_COUNT = ImmutableSet.of(1, 2, 5, 7, 8, 11, 24, 25, 26, 27, |
73 | | - 28); |
74 | | - private static final Map<Types.NestedField, Integer> FIELDS_WITH_NAN_COUNT_TO_ID = ImmutableMap.of( |
75 | | - FLOAT_FIELD, 3, DOUBLE_FIELD, 4, FLOAT_LIST, 10, MAP_FIELD_1, 18, MAP_FIELD_2, 22 |
76 | | - ); |
77 | | - |
78 | | - // create a schema with all supported fields |
79 | | - private static final Schema SCHEMA = new Schema( |
80 | | - ID_FIELD, |
81 | | - DATA_FIELD, |
82 | | - FLOAT_FIELD, |
83 | | - DOUBLE_FIELD, |
84 | | - DECIMAL_FIELD, |
85 | | - FIXED_FIELD, |
86 | | - BINARY_FIELD, |
87 | | - FLOAT_LIST, |
88 | | - LONG_FIELD, |
89 | | - MAP_FIELD_1, |
90 | | - MAP_FIELD_2, |
91 | | - STRUCT_FIELD |
92 | | - ); |
93 | | - |
94 | | - @Rule |
95 | | - public TemporaryFolder temp = new TemporaryFolder(); |
96 | | - |
97 | | - @Test |
98 | | - public void verifySingleRecordMetric() throws Exception { |
99 | | - Record record = GenericRecord.create(SCHEMA); |
100 | | - record.setField("id", 3); |
101 | | - record.setField("float", Float.NaN); // FLOAT_FIELD - 1 |
102 | | - record.setField("double", Double.NaN); // DOUBLE_FIELD - 1 |
103 | | - record.setField("floatlist", ImmutableList.of(3.3F, 2.8F, Float.NaN, -25.1F, Float.NaN)); // FLOAT_LIST - 2 |
104 | | - record.setField("map1", ImmutableMap.of(Float.NaN, "a", 0F, "b")); // MAP_FIELD_1 - 1 |
105 | | - record.setField("map2", ImmutableMap.of( |
106 | | - 0, 0D, 1, Double.NaN, 2, 2D, 3, Double.NaN, 4, Double.NaN)); // MAP_FIELD_2 - 3 |
| 30 | +public class TestParquetMergingMetrics extends TestMergingMetrics<Record> { |
107 | 31 |
|
| 32 | + @Override |
| 33 | + protected FileAppender<Record> writeAndGetAppender(List<Record> records) throws IOException { |
108 | 34 | FileAppender<Record> appender = new GenericAppenderFactory(SCHEMA).newAppender( |
109 | 35 | org.apache.iceberg.Files.localOutput(temp.newFile()), FileFormat.PARQUET); |
110 | 36 | try (FileAppender<Record> fileAppender = appender) { |
111 | | - fileAppender.add(record); |
| 37 | + records.forEach(fileAppender::add); |
112 | 38 | } |
113 | | - Map<Integer, Long> nanValueCount = appender.metrics().nanValueCounts(); |
114 | | - |
115 | | - assertNaNCountMatch(1L, nanValueCount, FLOAT_FIELD); |
116 | | - assertNaNCountMatch(1L, nanValueCount, DOUBLE_FIELD); |
117 | | - assertNaNCountMatch(2L, nanValueCount, FLOAT_LIST); |
118 | | - assertNaNCountMatch(1L, nanValueCount, MAP_FIELD_1); |
119 | | - assertNaNCountMatch(3L, nanValueCount, MAP_FIELD_2); |
120 | | - } |
121 | | - |
122 | | - private void assertNaNCountMatch(Long expected, Map<Integer, Long> nanValueCount, Types.NestedField field) { |
123 | | - Assert.assertEquals( |
124 | | - String.format("NaN count for field %s does not match expected", field.name()), |
125 | | - expected, nanValueCount.get(FIELDS_WITH_NAN_COUNT_TO_ID.get(field))); |
126 | | - } |
127 | | - |
128 | | - @Test |
129 | | - public void verifyRandomlyGeneratedRecordsMetric() throws Exception { |
130 | | - List<Record> recordList = RandomGenericData.generate(SCHEMA, 50, 250L); |
131 | | - |
132 | | - FileAppender<Record> appender = new GenericAppenderFactory(SCHEMA).newAppender( |
133 | | - org.apache.iceberg.Files.localOutput(temp.newFile()), FileFormat.PARQUET); |
134 | | - try (FileAppender<Record> fileAppender = appender) { |
135 | | - fileAppender.addAll(recordList); |
136 | | - } |
137 | | - Map<Integer, Long> nanValueCount = appender.metrics().nanValueCounts(); |
138 | | - |
139 | | - IDS_WITH_ZERO_NAN_COUNT.forEach(i -> Assert.assertEquals(String.format("Field %s " + |
140 | | - "shouldn't have non-zero nanValueCount", i), Long.valueOf(0), nanValueCount.get(i))); |
141 | | - |
142 | | - FIELDS_WITH_NAN_COUNT_TO_ID.forEach((key, value) -> Assert.assertEquals( |
143 | | - String.format("NaN count for field %s does not match expected", key.name()), |
144 | | - getExpectedNaNCount(recordList, key), |
145 | | - nanValueCount.get(value))); |
146 | | - } |
147 | | - |
148 | | - private Long getExpectedNaNCount(List<Record> expectedRecords, Types.NestedField field) { |
149 | | - return expectedRecords.stream() |
150 | | - .mapToLong(e -> { |
151 | | - Object value = e.getField(field.name()); |
152 | | - if (value == null) { |
153 | | - return 0; |
154 | | - } |
155 | | - if (FLOAT_FIELD.equals(field)) { |
156 | | - return value.equals(Float.NaN) ? 1 : 0; |
157 | | - } else if (DOUBLE_FIELD.equals(field)) { |
158 | | - return value.equals(Double.NaN) ? 1 : 0; |
159 | | - } else if (FLOAT_LIST.equals(field)) { |
160 | | - return ((List<Float>) value).stream() |
161 | | - .filter(val -> val != null && val.equals(Float.NaN)) |
162 | | - .count(); |
163 | | - } else if (MAP_FIELD_1.equals(field)) { |
164 | | - return ((Map<Float, ?>) value).keySet().stream() |
165 | | - .filter(key -> key.equals(Float.NaN)) |
166 | | - .count(); |
167 | | - } else if (MAP_FIELD_2.equals(field)) { |
168 | | - return ((Map<?, Double>) value).values().stream() |
169 | | - .filter(val -> val != null && val.equals(Double.NaN)) |
170 | | - .count(); |
171 | | - } else { |
172 | | - throw new RuntimeException("unknown field name for getting expected NaN count: " + field.name()); |
173 | | - } |
174 | | - }).sum(); |
| 39 | + return appender; |
175 | 40 | } |
176 | 41 | } |
0 commit comments