From d903c5b48655c2a58335e808772864bbf1df8dbc Mon Sep 17 00:00:00 2001 From: Aihua Xu Date: Sun, 24 Aug 2025 16:06:00 -0700 Subject: [PATCH 1/5] Spark shredded variant implementation --- .../iceberg/parquet/ParquetVariantUtil.java | 4 +- .../apache/iceberg/parquet/ParquetWriter.java | 41 +- .../parquet/WriterLazyInitializable.java | 87 +++++ .../iceberg/spark/SparkSQLProperties.java | 10 + .../apache/iceberg/spark/SparkWriteConf.java | 10 + .../iceberg/spark/SparkWriteOptions.java | 3 + .../spark/source/SchemaInferenceVisitor.java | 198 ++++++++++ .../spark/source/SparkFileWriterFactory.java | 12 +- ...parkParquetWriterWithVariantShredding.java | 181 +++++++++ .../iceberg/spark/TestSparkWriteConf.java | 7 + .../spark/variant/TestVariantShredding.java | 363 ++++++++++++++++++ 11 files changed, 912 insertions(+), 4 deletions(-) create mode 100644 parquet/src/main/java/org/apache/iceberg/parquet/WriterLazyInitializable.java create mode 100644 spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java create mode 100644 spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java create mode 100644 spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantUtil.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantUtil.java index ac418a1127bd..d94760773e51 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantUtil.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantUtil.java @@ -57,7 +57,7 @@ import org.apache.parquet.schema.Type; import org.apache.parquet.schema.Types; -class ParquetVariantUtil { +public class ParquetVariantUtil { private ParquetVariantUtil() {} /** @@ -212,7 +212,7 @@ static int scale(PrimitiveType primitive) { * @param value a variant value * @return a Parquet schema that can fully shred the value */ - static Type toParquetSchema(VariantValue value) { + public static Type toParquetSchema(VariantValue value) { return VariantVisitor.visit(value, new ParquetSchemaProducer()); } diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetWriter.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetWriter.java index e31df97c2bad..f359a99d72db 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetWriter.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetWriter.java @@ -51,7 +51,7 @@ class ParquetWriter implements FileAppender, Closeable { private final Map metadata; private final ParquetProperties props; private final CodecFactory.BytesCompressor compressor; - private final MessageType parquetSchema; + private MessageType parquetSchema; private final ParquetValueWriter model; private final MetricsConfig metricsConfig; private final int columnIndexTruncateLength; @@ -134,6 +134,30 @@ private void ensureWriterInitialized() { @Override public void add(T value) { + if (model instanceof WriterLazyInitializable) { + WriterLazyInitializable lazy = (WriterLazyInitializable) model; + if (lazy.needsInitialization()) { + model.write(0, value); + recordCount += 1; + + if (!lazy.needsInitialization()) { + WriterLazyInitializable.InitializationResult result = + lazy.initialize(props, compressor, rowGroupOrdinal); + this.parquetSchema = result.getSchema(); + this.pageStore = result.getPageStore(); + this.writeStore = result.getWriteStore(); + + // Re-initialize the file writer with the new schema + ensureWriterInitialized(); + + // Buffered rows were already written with endRecord() calls + // in the lazy writer's initialization, so we don't call endRecord() here + checkSize(); + } + return; + } + } + recordCount += 1; model.write(0, value); writeStore.endRecord(); @@ -255,6 +279,21 @@ private void startRowGroup() { public void close() throws IOException { if (!closed) { this.closed = true; + + // Force initialization if lazy writer still has buffered data + if (model instanceof WriterLazyInitializable) { + WriterLazyInitializable lazy = (WriterLazyInitializable) model; + if (lazy.needsInitialization()) { + WriterLazyInitializable.InitializationResult result = + lazy.initialize(props, compressor, rowGroupOrdinal); + this.parquetSchema = result.getSchema(); + this.pageStore = result.getPageStore(); + this.writeStore = result.getWriteStore(); + + ensureWriterInitialized(); + } + } + flushRowGroup(true); writeStore.close(); if (writer != null) { diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/WriterLazyInitializable.java b/parquet/src/main/java/org/apache/iceberg/parquet/WriterLazyInitializable.java new file mode 100644 index 000000000000..9c5913d7bd9b --- /dev/null +++ b/parquet/src/main/java/org/apache/iceberg/parquet/WriterLazyInitializable.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.parquet; + +import org.apache.parquet.column.ColumnWriteStore; +import org.apache.parquet.column.ParquetProperties; +import org.apache.parquet.compression.CompressionCodecFactory; +import org.apache.parquet.hadoop.ColumnChunkPageWriteStore; +import org.apache.parquet.schema.MessageType; + +/** + * Interface for ParquetValueWriters that need to defer initialization until they can analyze the + * data. This is useful for scenarios like variant shredding where the schema needs to be inferred + * from the actual data before creating the writer structures. + * + *

Writers implementing this interface can buffer initial rows and perform schema inference + * before committing to a final Parquet schema. + */ +public interface WriterLazyInitializable { + /** + * Result returned by lazy initialization of a ParquetValueWriter required by ParquetWriter. + * Contains the finalized schema and write stores after schema inference or other initialization + * logic. + */ + class InitializationResult { + private final MessageType schema; + private final ColumnChunkPageWriteStore pageStore; + private final ColumnWriteStore writeStore; + + public InitializationResult( + MessageType schema, ColumnChunkPageWriteStore pageStore, ColumnWriteStore writeStore) { + this.schema = schema; + this.pageStore = pageStore; + this.writeStore = writeStore; + } + + public MessageType getSchema() { + return schema; + } + + public ColumnChunkPageWriteStore getPageStore() { + return pageStore; + } + + public ColumnWriteStore getWriteStore() { + return writeStore; + } + } + + /** + * Checks if this writer still needs initialization. This will return true until the writer has + * buffered enough data to perform initialization (e.g., schema inference). + * + * @return true if initialization is still needed, false if already initialized + */ + boolean needsInitialization(); + + /** + * Performs initialization and returns the result containing updated schema and write stores. This + * method should only be called when {@link #needsInitialization()} returns true. + * + * @param props Parquet properties needed for creating write stores + * @param compressor Bytes compressor for compression + * @param rowGroupOrdinal The ordinal number of the current row group + * @return InitializationResult containing the finalized schema and write stores + */ + InitializationResult initialize( + ParquetProperties props, + CompressionCodecFactory.BytesInputCompressor compressor, + int rowGroupOrdinal); +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java index 81139969f746..b12606d23948 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java @@ -109,4 +109,14 @@ private SparkSQLProperties() {} // Prefix for custom snapshot properties public static final String SNAPSHOT_PROPERTY_PREFIX = "spark.sql.iceberg.snapshot-property."; + + // Controls whether to shred variant columns during write operations + public static final String SHRED_VARIANTS = "spark.sql.iceberg.shred-variants"; + public static final boolean SHRED_VARIANTS_DEFAULT = true; + + // Controls the buffer size for variant schema inference during writes + // This determines how many rows are buffered before inferring shredded schema + public static final String VARIANT_INFERENCE_BUFFER_SIZE = + "spark.sql.iceberg.variant.inference.buffer-size"; + public static final int VARIANT_INFERENCE_BUFFER_SIZE_DEFAULT = 10; } diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java index 96131e0e56dd..4baf5585b220 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java @@ -509,6 +509,7 @@ private Map dataWriteProperties() { if (parquetCompressionLevel != null) { writeProperties.put(PARQUET_COMPRESSION_LEVEL, parquetCompressionLevel); } + writeProperties.put(SparkSQLProperties.SHRED_VARIANTS, String.valueOf(shredVariants())); break; case AVRO: @@ -729,4 +730,13 @@ public DeleteGranularity deleteGranularity() { .defaultValue(DeleteGranularity.FILE) .parse(); } + + public boolean shredVariants() { + return confParser + .booleanConf() + .option(SparkWriteOptions.SHRED_VARIANTS) + .sessionConf(SparkSQLProperties.SHRED_VARIANTS) + .defaultValue(SparkSQLProperties.SHRED_VARIANTS_DEFAULT) + .parse(); + } } diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java index 33db70bae587..f8fb41696f76 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java @@ -85,4 +85,7 @@ private SparkWriteOptions() {} // Overrides the delete granularity public static final String DELETE_GRANULARITY = "delete-granularity"; + + // Controls whether to shred variant columns during write operations + public static final String SHRED_VARIANTS = "shred-variants"; } diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java new file mode 100644 index 000000000000..0eed88a8eb66 --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.BINARY; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.List; +import org.apache.iceberg.parquet.ParquetVariantUtil; +import org.apache.iceberg.spark.data.ParquetWithSparkSchemaVisitor; +import org.apache.iceberg.variants.Variant; +import org.apache.iceberg.variants.VariantMetadata; +import org.apache.iceberg.variants.VariantValue; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.LogicalTypeAnnotation; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; +import org.apache.parquet.schema.Types; +import org.apache.parquet.schema.Types.MessageTypeBuilder; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.VariantType; +import org.apache.spark.unsafe.types.VariantVal; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A visitor that infers variant shredding schemas by analyzing buffered rows of data. This visitor + * can be plugged into ParquetWithSparkSchemaVisitor.visit() to create a shredded MessageType based + * on actual variant data content. + * + *

The visitor uses the field names tracked during traversal to look up the correct field index + * in the Spark schema, allowing it to access the corresponding value in the rows for schema + * inference. It searches through all buffered rows to find the first non-null variant value for + * schema inference. + */ +public class SchemaInferenceVisitor extends ParquetWithSparkSchemaVisitor { + private static final Logger LOG = LoggerFactory.getLogger(SchemaInferenceVisitor.class); + + private final List bufferedRows; + private final StructType sparkSchema; + + public SchemaInferenceVisitor(List bufferedRows, StructType sparkSchema) { + this.bufferedRows = bufferedRows; + this.sparkSchema = sparkSchema; + } + + @Override + public Type message(StructType sStruct, MessageType message, List fields) { + MessageTypeBuilder builder = Types.buildMessage(); + + for (Type field : fields) { + if (field != null) { + builder.addField(field); + } + } + + return builder.named(message.getName()); + } + + @Override + public Type struct(StructType sStruct, GroupType struct, List fields) { + Types.GroupBuilder builder = Types.buildGroup(struct.getRepetition()); + + if (struct.getId() != null) { + builder = builder.id(struct.getId().intValue()); + } + + for (Type field : fields) { + if (field != null) { + builder = builder.addField(field); + } + } + + return builder.named(struct.getName()); + } + + @Override + public Type primitive(DataType sPrimitive, PrimitiveType primitive) { + return primitive; + } + + @Override + public Type list(ArrayType sArray, GroupType array, Type element) { + Types.GroupBuilder builder = + Types.buildGroup(array.getRepetition()).as(LogicalTypeAnnotation.listType()); + + if (array.getId() != null) { + builder = builder.id(array.getId().intValue()); + } + + if (element != null) { + builder = builder.addField(element); + } + + return builder.named(array.getName()); + } + + @Override + public Type map(MapType sMap, GroupType map, Type key, Type value) { + Types.GroupBuilder builder = + Types.buildGroup(map.getRepetition()).as(LogicalTypeAnnotation.mapType()); + + if (map.getId() != null) { + builder = builder.id(map.getId().intValue()); + } + + if (key != null) { + builder = builder.addField(key); + } + if (value != null) { + builder = builder.addField(value); + } + + return builder.named(map.getName()); + } + + @Override + public Type variant(VariantType sVariant, GroupType variant) { + int variantFieldIndex = getFieldIndex(currentPath()); + + // Find the first non-null variant value from buffered rows for schema inference + // This ensures we can infer a schema even if the first rows has null variant values + if (!bufferedRows.isEmpty() && variantFieldIndex >= 0) { + for (InternalRow row : bufferedRows) { + if (!row.isNullAt(variantFieldIndex)) { + VariantVal variantVal = row.getVariant(variantFieldIndex); + if (variantVal != null) { + VariantValue variantValue = + VariantValue.from( + VariantMetadata.from( + ByteBuffer.wrap(variantVal.getMetadata()).order(ByteOrder.LITTLE_ENDIAN)), + ByteBuffer.wrap(variantVal.getValue()).order(ByteOrder.LITTLE_ENDIAN)); + + Type shreddedType = ParquetVariantUtil.toParquetSchema(variantValue); + if (shreddedType != null) { + return Types.buildGroup(variant.getRepetition()) + .as(LogicalTypeAnnotation.variantType(Variant.VARIANT_SPEC_VERSION)) + .id(variant.getId().intValue()) + .required(BINARY) + .named("metadata") + .optional(BINARY) + .named("value") + .addField(shreddedType) + .named(variant.getName()); + } + } + } + } + } + + return variant; + } + + private int getFieldIndex(String[] path) { + if (path == null || path.length == 0) { + return -1; + } + + // TODO: For now, we only support top-level variant fields. To support nested variants, we would + // need to navigate the struct hierarchy + if (path.length == 1) { + String fieldName = path[0]; + for (int i = 0; i < sparkSchema.fields().length; i++) { + if (sparkSchema.fields()[i].name().equals(fieldName)) { + return i; + } + } + } else { + LOG.warn( + "Nested variant fields are not yet supported for schema inference. Path: {}", + String.join(".", path)); + } + + return -1; + } +} diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkFileWriterFactory.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkFileWriterFactory.java index a93db17e4a0f..8c74c65fc1b4 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkFileWriterFactory.java +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkFileWriterFactory.java @@ -151,7 +151,17 @@ protected void configurePositionDelete(Avro.DeleteWriteBuilder builder) { @Override protected void configureDataWrite(Parquet.DataWriteBuilder builder) { - builder.createWriterFunc(msgType -> SparkParquetWriters.buildWriter(dataSparkType(), msgType)); + if (SparkParquetWriterWithVariantShredding.shouldUseVariantShredding( + writeProperties, dataSchema())) { + builder.createWriterFunc( + msgType -> + new SparkParquetWriterWithVariantShredding( + dataSparkType(), msgType, writeProperties)); + } else { + builder.createWriterFunc( + msgType -> SparkParquetWriters.buildWriter(dataSparkType(), msgType)); + } + builder.setAll(writeProperties); } diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java new file mode 100644 index 000000000000..8f1a61d60c6f --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; +import org.apache.iceberg.FieldMetrics; +import org.apache.iceberg.Schema; +import org.apache.iceberg.parquet.ParquetValueWriter; +import org.apache.iceberg.parquet.TripleWriter; +import org.apache.iceberg.parquet.WriterLazyInitializable; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.spark.data.ParquetWithSparkSchemaVisitor; +import org.apache.iceberg.spark.data.SparkParquetWriters; +import org.apache.iceberg.types.Types; +import org.apache.parquet.column.ColumnWriteStore; +import org.apache.parquet.column.ParquetProperties; +import org.apache.parquet.compression.CompressionCodecFactory; +import org.apache.parquet.hadoop.ColumnChunkPageWriteStore; +import org.apache.parquet.schema.MessageType; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructType; + +/** + * A Parquet output writer that performs variant shredding with schema inference. This is similar to + * Spark's ParquetOutputWriterWithVariantShredding but adapted for Iceberg. + * + *

The writer works in two phases: 1. Schema inference phase: Buffers initial rows and analyzes + * variant data to infer schemas 2. Writing phase: Creates the actual Parquet writer with inferred + * schemas and writes all data + */ +public class SparkParquetWriterWithVariantShredding + implements ParquetValueWriter, WriterLazyInitializable { + private final StructType sparkSchema; + private final MessageType parquetType; + + private final List bufferedRows; + private ParquetValueWriter actualWriter; + private boolean writerInitialized = false; + private final int bufferSize; + + private static class BufferedRow { + private final int repetitionLevel; + private final InternalRow row; + + BufferedRow(int repetitionLevel, InternalRow row) { + this.repetitionLevel = repetitionLevel; + this.row = row; + } + } + + public SparkParquetWriterWithVariantShredding( + StructType sparkSchema, MessageType parquetType, Map properties) { + this.sparkSchema = sparkSchema; + this.parquetType = parquetType; + + this.bufferSize = + Integer.parseInt( + properties.getOrDefault( + SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, + String.valueOf(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE_DEFAULT))); + this.bufferedRows = Lists.newArrayList(); + } + + @Override + public void write(int repetitionLevel, InternalRow row) { + if (!writerInitialized) { + bufferedRows.add( + new BufferedRow( + repetitionLevel, row.copy())); /* Make a copy of the object since row gets reused */ + + if (bufferedRows.size() >= bufferSize) { + writerInitialized = true; + } + } else { + actualWriter.write(repetitionLevel, row); + } + } + + @Override + public List> columns() { + if (actualWriter != null) { + return actualWriter.columns(); + } + return Collections.emptyList(); + } + + @Override + public void setColumnStore(ColumnWriteStore columnStore) { + // Ignored for lazy initialization - will be set on actualWriter after initialization + } + + @Override + public Stream> metrics() { + if (actualWriter != null) { + return actualWriter.metrics(); + } + return Stream.empty(); + } + + @Override + public boolean needsInitialization() { + return !writerInitialized; + } + + @Override + public InitializationResult initialize( + ParquetProperties props, + CompressionCodecFactory.BytesInputCompressor compressor, + int rowGroupOrdinal) { + if (bufferedRows.isEmpty()) { + throw new IllegalStateException("No buffered rows available for schema inference"); + } + + List rows = Lists.newLinkedList(); + for (BufferedRow bufferedRow : bufferedRows) { + rows.add(bufferedRow.row); + } + + MessageType shreddedSchema = + (MessageType) + ParquetWithSparkSchemaVisitor.visit( + sparkSchema, parquetType, new SchemaInferenceVisitor(rows, sparkSchema)); + + actualWriter = SparkParquetWriters.buildWriter(sparkSchema, shreddedSchema); + + ColumnChunkPageWriteStore pageStore = + new ColumnChunkPageWriteStore( + compressor, + shreddedSchema, + props.getAllocator(), + 64, + ParquetProperties.DEFAULT_PAGE_WRITE_CHECKSUM_ENABLED, + null, + rowGroupOrdinal); + + ColumnWriteStore columnStore = props.newColumnWriteStore(shreddedSchema, pageStore, pageStore); + + actualWriter.setColumnStore(columnStore); + + for (BufferedRow bufferedRow : bufferedRows) { + actualWriter.write(bufferedRow.repetitionLevel, bufferedRow.row); + columnStore.endRecord(); + } + + bufferedRows.clear(); + writerInitialized = true; + + return new InitializationResult(shreddedSchema, pageStore, columnStore); + } + + public static boolean shouldUseVariantShredding(Map properties, Schema schema) { + boolean shreddingEnabled = + properties.containsKey(SparkSQLProperties.SHRED_VARIANTS) + && Boolean.parseBoolean(properties.get(SparkSQLProperties.SHRED_VARIANTS)); + + boolean hasVariantFields = + schema.columns().stream().anyMatch(field -> field.type() instanceof Types.VariantType); + + return shreddingEnabled && hasVariantFields; + } +} diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java index 61aacfa4589d..d97579f29e86 100644 --- a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java @@ -41,6 +41,7 @@ import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_RANGE; import static org.apache.iceberg.spark.SparkSQLProperties.COMPRESSION_CODEC; import static org.apache.iceberg.spark.SparkSQLProperties.COMPRESSION_LEVEL; +import static org.apache.iceberg.spark.SparkSQLProperties.SHRED_VARIANTS; import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.DELETE; import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.MERGE; import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.UPDATE; @@ -339,6 +340,8 @@ public void testSparkConfOverride() { TableProperties.DELETE_PARQUET_COMPRESSION, "snappy"), ImmutableMap.of( + SHRED_VARIANTS, + "true", DELETE_PARQUET_COMPRESSION, "zstd", PARQUET_COMPRESSION, @@ -460,6 +463,8 @@ public void testDataPropsDefaultsAsDeleteProps() { PARQUET_COMPRESSION_LEVEL, "5"), ImmutableMap.of( + SHRED_VARIANTS, + "true", DELETE_PARQUET_COMPRESSION, "zstd", PARQUET_COMPRESSION, @@ -531,6 +536,8 @@ public void testDeleteFileWriteConf() { DELETE_PARQUET_COMPRESSION_LEVEL, "6"), ImmutableMap.of( + SHRED_VARIANTS, + "true", DELETE_PARQUET_COMPRESSION, "zstd", PARQUET_COMPRESSION, diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java new file mode 100644 index 000000000000..d82a241ba148 --- /dev/null +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java @@ -0,0 +1,363 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.variant; + +import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.METASTOREURIS; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.CatalogTestBase; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.variants.Variant; +import org.apache.parquet.hadoop.ParquetFileReader; +import org.apache.parquet.hadoop.util.HadoopInputFile; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.LogicalTypeAnnotation; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.internal.SQLConf; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestVariantShredding extends CatalogTestBase { + + private static final Schema SCHEMA = + new Schema( + Types.NestedField.required(1, "id", Types.IntegerType.get()), + Types.NestedField.optional(2, "address", Types.VariantType.get())); + + private static final Schema SCHEMA2 = + new Schema( + Types.NestedField.required(1, "id", Types.IntegerType.get()), + Types.NestedField.optional(2, "address", Types.VariantType.get()), + Types.NestedField.optional(3, "metadata", Types.VariantType.get())); + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + protected static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.HADOOP.catalogName(), + SparkCatalogConfig.HADOOP.implementation(), + SparkCatalogConfig.HADOOP.properties() + }, + }; + } + + @BeforeAll + public static void startMetastoreAndSpark() { + // First call parent to initialize metastore and spark with local[2] + CatalogTestBase.startMetastoreAndSpark(); + + // Now stop and recreate spark with local[1] to write all rows to a single file + if (spark != null) { + spark.stop(); + } + + spark = + SparkSession.builder() + .master("local[1]") // Use one thread to write the rows to a single parquet file + .config("spark.driver.host", InetAddress.getLoopbackAddress().getHostAddress()) + .config(SQLConf.PARTITION_OVERWRITE_MODE().key(), "dynamic") + .config("spark.hadoop." + METASTOREURIS.varname, hiveConf.get(METASTOREURIS.varname)) + .config("spark.sql.legacy.respectNullabilityInTextDatasetConversion", "true") + .enableHiveSupport() + .getOrCreate(); + + sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); + } + + @BeforeEach + public void before() { + super.before(); + validationCatalog.createTable( + tableIdent, SCHEMA, null, Map.of(TableProperties.FORMAT_VERSION, "3")); + } + + @AfterEach + public void after() { + validationCatalog.dropTable(tableIdent, true); + } + + @TestTemplate + public void testVariantShreddingWrite() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + String values = + "(1, parse_json('{\"name\": \"Joe\", \"streets\": [\"Apt #3\", \"1234 Ave\"], \"zip\": 10001}')), (2, null)"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType streets = + field( + "streets", + list( + element( + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, + LogicalTypeAnnotation.stringType())))); + GroupType zip = + field( + "zip", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(16))); + GroupType address = variant("address", 2, objectFields(name, streets, zip)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testVariantShreddingWithNullFirstRow() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = "(1, null), (2, parse_json('{\"city\": \"Seattle\", \"state\": \"WA\"}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType city = + field( + "city", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType state = + field( + "state", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType address = variant("address", 2, objectFields(city, state)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testVariantShreddingWithTwoVariantColumns() throws IOException { + validationCatalog.dropTable(tableIdent, true); + validationCatalog.createTable( + tableIdent, SCHEMA2, null, Map.of(TableProperties.FORMAT_VERSION, "3")); + + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = + "(1, parse_json('{\"city\": \"NYC\", \"zip\": 10001}'), parse_json('{\"type\": \"home\", \"verified\": true}')), " + + "(2, null, null)"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType city = + field( + "city", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType zip = + field( + "zip", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(16, true))); + GroupType address = variant("address", 2, objectFields(city, zip)); + + GroupType type = + field( + "type", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType verified = + field("verified", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.BOOLEAN)); + GroupType metadata = variant("metadata", 3, objectFields(type, verified)); + + MessageType expectedSchema = parquetSchema(address, metadata); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testVariantShreddingWithTwoVariantColumnsOneNull() throws IOException { + validationCatalog.dropTable(tableIdent, true); + validationCatalog.createTable( + tableIdent, SCHEMA2, null, Map.of(TableProperties.FORMAT_VERSION, "3")); + + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + // First row: address is null, metadata has value + // Second row: address has value, metadata is null + String values = + "(1, null, parse_json('{\"label\": \"primary\"}'))," + + " (2, parse_json('{\"street\": \"Main St\"}'), null)"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType street = + field( + "street", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType address = variant("address", 2, objectFields(street)); + + GroupType label = + field( + "label", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType metadata = variant("metadata", 3, objectFields(label)); + + MessageType expectedSchema = parquetSchema(address, metadata); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testVariantShreddingDisabled() throws IOException { + // Test with shredding explicitly disabled + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "false"); + + String values = "(1, parse_json('{\"city\": \"NYC\", \"zip\": 10001}')), (2, null)"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType address = variant("address", 2); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + private void verifyParquetSchema(Table table, MessageType expectedSchema) throws IOException { + try (CloseableIterable tasks = table.newScan().planFiles()) { + assertThat(tasks).isNotEmpty(); + + FileScanTask task = tasks.iterator().next(); + String path = task.file().location(); + + HadoopInputFile inputFile = + HadoopInputFile.fromPath(new org.apache.hadoop.fs.Path(path), new Configuration()); + + try (ParquetFileReader reader = ParquetFileReader.open(inputFile)) { + MessageType actualSchema = reader.getFileMetaData().getSchema(); + assertThat(actualSchema).isEqualTo(expectedSchema); + } + } + } + + private static MessageType parquetSchema(Type... variantTypes) { + return org.apache.parquet.schema.Types.buildMessage() + .required(PrimitiveType.PrimitiveTypeName.INT32) + .id(1) + .named("id") + .addFields(variantTypes) + .named("table"); + } + + private static GroupType variant(String name, int fieldId) { + return org.apache.parquet.schema.Types.buildGroup(Type.Repetition.OPTIONAL) + .id(fieldId) + .as(LogicalTypeAnnotation.variantType(Variant.VARIANT_SPEC_VERSION)) + .required(PrimitiveType.PrimitiveTypeName.BINARY) + .named("metadata") + .required(PrimitiveType.PrimitiveTypeName.BINARY) + .named("value") + .named(name); + } + + private static GroupType variant(String name, int fieldId, Type shreddedType) { + checkShreddedType(shreddedType); + return org.apache.parquet.schema.Types.buildGroup(Type.Repetition.OPTIONAL) + .id(fieldId) + .as(LogicalTypeAnnotation.variantType(Variant.VARIANT_SPEC_VERSION)) + .required(PrimitiveType.PrimitiveTypeName.BINARY) + .named("metadata") + .optional(PrimitiveType.PrimitiveTypeName.BINARY) + .named("value") + .addField(shreddedType) + .named(name); + } + + private static Type shreddedPrimitive(PrimitiveType.PrimitiveTypeName primitive) { + return org.apache.parquet.schema.Types.optional(primitive).named("typed_value"); + } + + private static Type shreddedPrimitive( + PrimitiveType.PrimitiveTypeName primitive, LogicalTypeAnnotation annotation) { + return org.apache.parquet.schema.Types.optional(primitive).as(annotation).named("typed_value"); + } + + private static GroupType objectFields(GroupType... fields) { + for (GroupType fieldType : fields) { + checkField(fieldType); + } + + return org.apache.parquet.schema.Types.buildGroup(Type.Repetition.OPTIONAL) + .addFields(fields) + .named("typed_value"); + } + + private static GroupType field(String name, Type shreddedType) { + checkShreddedType(shreddedType); + return org.apache.parquet.schema.Types.buildGroup(Type.Repetition.REQUIRED) + .optional(PrimitiveType.PrimitiveTypeName.BINARY) + .named("value") + .addField(shreddedType) + .named(name); + } + + private static GroupType element(Type shreddedType) { + return field("element", shreddedType); + } + + private static GroupType list(GroupType elementType) { + return org.apache.parquet.schema.Types.optionalList().element(elementType).named("typed_value"); + } + + private static void checkShreddedType(Type shreddedType) { + Preconditions.checkArgument( + shreddedType.getName().equals("typed_value"), + "Invalid shredded type name: %s should be typed_value", + shreddedType.getName()); + Preconditions.checkArgument( + shreddedType.isRepetition(Type.Repetition.OPTIONAL), + "Invalid shredded type repetition: %s should be OPTIONAL", + shreddedType.getRepetition()); + } + + private static void checkField(GroupType fieldType) { + Preconditions.checkArgument( + fieldType.isRepetition(Type.Repetition.REQUIRED), + "Invalid field type repetition: %s should be REQUIRED", + fieldType.getRepetition()); + } +} From c570ed8582ea49c815e4430bf6a2edb6431695d1 Mon Sep 17 00:00:00 2001 From: Aihua Xu Date: Sat, 1 Nov 2025 21:26:35 -0700 Subject: [PATCH 2/5] Add heuristics to determine the shredding schema --- .../parquet/ParquetVariantWriters.java | 58 +- .../iceberg/parquet/VariantWriterBuilder.java | 16 +- .../iceberg/spark/SparkSQLProperties.java | 12 + .../apache/iceberg/spark/SparkWriteConf.java | 18 + .../spark/source/SchemaInferenceVisitor.java | 83 +-- ...parkParquetWriterWithVariantShredding.java | 9 +- .../source/VariantShreddingAnalyzer.java | 545 ++++++++++++++++++ .../spark/variant/TestVariantShredding.java | 396 ++++++++++++- 8 files changed, 1077 insertions(+), 60 deletions(-) create mode 100644 spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantWriters.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantWriters.java index 9e94b1bbd6cd..42cdee7a1a5c 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantWriters.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantWriters.java @@ -65,6 +65,11 @@ static ParquetValueWriter primitive( return new PrimitiveWriter<>(writer, Sets.immutableEnumSet(Arrays.asList(types))); } + static ParquetValueWriter decimal( + ParquetValueWriter writer, int expectedScale, PhysicalType... types) { + return new DecimalWriter(writer, expectedScale, Sets.immutableEnumSet(Arrays.asList(types))); + } + @SuppressWarnings("unchecked") static ParquetValueWriter shredded( int valueDefinitionLevel, @@ -253,6 +258,49 @@ public void setColumnStore(ColumnWriteStore columnStore) { } } + /** + * A TypedWriter for decimals that validates scale before writing. + * If the scale doesn't match, it returns false from canWrite() to trigger fallback to value field. + */ + private static class DecimalWriter implements TypedWriter { + private final Set types; + private final ParquetValueWriter writer; + private final int expectedScale; + + private DecimalWriter( + ParquetValueWriter writer, int expectedScale, Set types) { + this.types = types; + this.writer = (ParquetValueWriter) writer; + this.expectedScale = expectedScale; + } + + @Override + public Set types() { + return types; + } + + @Override + public void write(int repetitionLevel, VariantValue value) { + java.math.BigDecimal decimal = (java.math.BigDecimal) value.asPrimitive().get(); + // Validate scale matches before writing + if (decimal.scale() != expectedScale) { + throw new IllegalArgumentException( + "Cannot write decimal with scale " + decimal.scale() + " to schema expecting scale " + expectedScale); + } + writer.write(repetitionLevel, decimal); + } + + @Override + public List> columns() { + return writer.columns(); + } + + @Override + public void setColumnStore(ColumnWriteStore columnStore) { + writer.setColumnStore(columnStore); + } + } + private static class ShreddedVariantWriter implements ParquetValueWriter { private final int valueDefinitionLevel; private final ParquetValueWriter valueWriter; @@ -275,8 +323,14 @@ private ShreddedVariantWriter( @Override public void write(int repetitionLevel, VariantValue value) { if (typedWriter.types().contains(value.type())) { - typedWriter.write(repetitionLevel, value); - writeNull(valueWriter, repetitionLevel, valueDefinitionLevel); + try { + typedWriter.write(repetitionLevel, value); + writeNull(valueWriter, repetitionLevel, valueDefinitionLevel); + } catch (IllegalArgumentException e) { + // Fall back to value field if typed write fails (e.g., decimal scale mismatch) + valueWriter.write(repetitionLevel, value); + writeNull(typedWriter, repetitionLevel, typedDefinitionLevel); + } } else { valueWriter.write(repetitionLevel, value); writeNull(typedWriter, repetitionLevel, typedDefinitionLevel); diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/VariantWriterBuilder.java b/parquet/src/main/java/org/apache/iceberg/parquet/VariantWriterBuilder.java index a447a102690a..53cf5d9933d6 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/VariantWriterBuilder.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/VariantWriterBuilder.java @@ -198,27 +198,31 @@ public Optional> visit(StringLogicalTypeAnnotation ignored @Override public Optional> visit(DecimalLogicalTypeAnnotation decimal) { ParquetValueWriter writer; + int scale = decimal.getScale(); switch (desc.getPrimitiveType().getPrimitiveTypeName()) { case FIXED_LEN_BYTE_ARRAY: case BINARY: writer = - ParquetVariantWriters.primitive( + ParquetVariantWriters.decimal( ParquetValueWriters.decimalAsFixed( - desc, decimal.getPrecision(), decimal.getScale()), + desc, decimal.getPrecision(), scale), + scale, PhysicalType.DECIMAL16); return Optional.of(writer); case INT64: writer = - ParquetVariantWriters.primitive( + ParquetVariantWriters.decimal( ParquetValueWriters.decimalAsLong( - desc, decimal.getPrecision(), decimal.getScale()), + desc, decimal.getPrecision(), scale), + scale, PhysicalType.DECIMAL8); return Optional.of(writer); case INT32: writer = - ParquetVariantWriters.primitive( + ParquetVariantWriters.decimal( ParquetValueWriters.decimalAsInteger( - desc, decimal.getPrecision(), decimal.getScale()), + desc, decimal.getPrecision(), scale), + scale, PhysicalType.DECIMAL4); return Optional.of(writer); } diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java index b12606d23948..e111becad89e 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java @@ -119,4 +119,16 @@ private SparkSQLProperties() {} public static final String VARIANT_INFERENCE_BUFFER_SIZE = "spark.sql.iceberg.variant.inference.buffer-size"; public static final int VARIANT_INFERENCE_BUFFER_SIZE_DEFAULT = 10; + + // Controls the minimum occurrence threshold for variant fields during shredding + // Fields that appear in fewer than this percentage of rows will be dropped + public static final String VARIANT_MIN_OCCURRENCE_THRESHOLD = + "spark.sql.iceberg.variant.min-occurrence-threshold"; + public static final double VARIANT_MIN_OCCURRENCE_THRESHOLD_DEFAULT = 0.1; // 10% + + // Controls the maximum number of fields to shred in a variant column + // This prevents creating overly wide Parquet schemas + public static final String VARIANT_MAX_SHREDDED_FIELDS = + "spark.sql.iceberg.variant.max-shredded-fields"; + public static final int VARIANT_MAX_SHREDDED_FIELDS_DEFAULT = 300; } diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java index 4baf5585b220..34fcd2f1e467 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java @@ -510,6 +510,24 @@ private Map dataWriteProperties() { writeProperties.put(PARQUET_COMPRESSION_LEVEL, parquetCompressionLevel); } writeProperties.put(SparkSQLProperties.SHRED_VARIANTS, String.valueOf(shredVariants())); + + // Add variant shredding configuration properties + if (shredVariants()) { + String variantMaxFields = sessionConf.get(SparkSQLProperties.VARIANT_MAX_SHREDDED_FIELDS, null); + if (variantMaxFields != null) { + writeProperties.put(SparkSQLProperties.VARIANT_MAX_SHREDDED_FIELDS, variantMaxFields); + } + + String variantMinOccurrence = sessionConf.get(SparkSQLProperties.VARIANT_MIN_OCCURRENCE_THRESHOLD, null); + if (variantMinOccurrence != null) { + writeProperties.put(SparkSQLProperties.VARIANT_MIN_OCCURRENCE_THRESHOLD, variantMinOccurrence); + } + + String variantBufferSize = sessionConf.get(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, null); + if (variantBufferSize != null) { + writeProperties.put(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, variantBufferSize); + } + } break; case AVRO: diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java index 0eed88a8eb66..c03fc74f00e3 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java @@ -23,7 +23,9 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.List; +import java.util.Map; import org.apache.iceberg.parquet.ParquetVariantUtil; +import org.apache.iceberg.spark.SparkSQLProperties; import org.apache.iceberg.spark.data.ParquetWithSparkSchemaVisitor; import org.apache.iceberg.variants.Variant; import org.apache.iceberg.variants.VariantMetadata; @@ -46,24 +48,33 @@ import org.slf4j.LoggerFactory; /** - * A visitor that infers variant shredding schemas by analyzing buffered rows of data. This visitor - * can be plugged into ParquetWithSparkSchemaVisitor.visit() to create a shredded MessageType based - * on actual variant data content. - * - *

The visitor uses the field names tracked during traversal to look up the correct field index - * in the Spark schema, allowing it to access the corresponding value in the rows for schema - * inference. It searches through all buffered rows to find the first non-null variant value for - * schema inference. + * A visitor that infers variant shredding schemas by analyzing buffered rows of data. */ public class SchemaInferenceVisitor extends ParquetWithSparkSchemaVisitor { private static final Logger LOG = LoggerFactory.getLogger(SchemaInferenceVisitor.class); private final List bufferedRows; private final StructType sparkSchema; + private final VariantShreddingAnalyzer analyzer; - public SchemaInferenceVisitor(List bufferedRows, StructType sparkSchema) { + public SchemaInferenceVisitor( + List bufferedRows, StructType sparkSchema, Map properties) { this.bufferedRows = bufferedRows; this.sparkSchema = sparkSchema; + + double minOccurrenceThreshold = + Double.parseDouble( + properties.getOrDefault( + SparkSQLProperties.VARIANT_MIN_OCCURRENCE_THRESHOLD, + String.valueOf(SparkSQLProperties.VARIANT_MIN_OCCURRENCE_THRESHOLD_DEFAULT))); + + int maxFields = + Integer.parseInt( + properties.getOrDefault( + SparkSQLProperties.VARIANT_MAX_SHREDDED_FIELDS, + String.valueOf(SparkSQLProperties.VARIANT_MAX_SHREDDED_FIELDS_DEFAULT))); + + this.analyzer = new VariantShreddingAnalyzer(minOccurrenceThreshold, maxFields); } @Override @@ -140,33 +151,22 @@ public Type map(MapType sMap, GroupType map, Type key, Type value) { public Type variant(VariantType sVariant, GroupType variant) { int variantFieldIndex = getFieldIndex(currentPath()); - // Find the first non-null variant value from buffered rows for schema inference - // This ensures we can infer a schema even if the first rows has null variant values + // Apply heuristics to determine the shredding schema: + // - Fields must appear in at least the configured percentage of rows + // - Type consistency determines if typed_value is created + // - Maximum field count to avoid overly wide schemas if (!bufferedRows.isEmpty() && variantFieldIndex >= 0) { - for (InternalRow row : bufferedRows) { - if (!row.isNullAt(variantFieldIndex)) { - VariantVal variantVal = row.getVariant(variantFieldIndex); - if (variantVal != null) { - VariantValue variantValue = - VariantValue.from( - VariantMetadata.from( - ByteBuffer.wrap(variantVal.getMetadata()).order(ByteOrder.LITTLE_ENDIAN)), - ByteBuffer.wrap(variantVal.getValue()).order(ByteOrder.LITTLE_ENDIAN)); - - Type shreddedType = ParquetVariantUtil.toParquetSchema(variantValue); - if (shreddedType != null) { - return Types.buildGroup(variant.getRepetition()) - .as(LogicalTypeAnnotation.variantType(Variant.VARIANT_SPEC_VERSION)) - .id(variant.getId().intValue()) - .required(BINARY) - .named("metadata") - .optional(BINARY) - .named("value") - .addField(shreddedType) - .named(variant.getName()); - } - } - } + Type shreddedType = analyzer.analyzeAndCreateSchema(bufferedRows, variantFieldIndex); + if (shreddedType != null) { + return Types.buildGroup(variant.getRepetition()) + .as(LogicalTypeAnnotation.variantType(Variant.VARIANT_SPEC_VERSION)) + .id(variant.getId().intValue()) + .required(BINARY) + .named("metadata") + .optional(BINARY) + .named("value") + .addField(shreddedType) + .named(variant.getName()); } } @@ -178,9 +178,9 @@ private int getFieldIndex(String[] path) { return -1; } - // TODO: For now, we only support top-level variant fields. To support nested variants, we would - // need to navigate the struct hierarchy + // Support nested variant fields by navigating the struct hierarchy if (path.length == 1) { + // Top-level field - direct lookup String fieldName = path[0]; for (int i = 0; i < sparkSchema.fields().length; i++) { if (sparkSchema.fields()[i].name().equals(fieldName)) { @@ -188,8 +188,15 @@ private int getFieldIndex(String[] path) { } } } else { + // Nested field - navigate through struct hierarchy + // For now, we only support direct struct nesting (not arrays/maps) + LOG.debug( + "Attempting to resolve nested variant field path: {}", String.join(".", path)); + // TODO: Implement full nested field resolution when needed + // This would require tracking the current struct context during traversal + // and maintaining a stack of field indices LOG.warn( - "Nested variant fields are not yet supported for schema inference. Path: {}", + "Multi-level nested variant fields require struct context tracking. Path: {}", String.join(".", path)); } diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java index 8f1a61d60c6f..6a2ed1e85324 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java @@ -41,8 +41,7 @@ import org.apache.spark.sql.types.StructType; /** - * A Parquet output writer that performs variant shredding with schema inference. This is similar to - * Spark's ParquetOutputWriterWithVariantShredding but adapted for Iceberg. + * A Parquet output writer that performs variant shredding with schema inference. * *

The writer works in two phases: 1. Schema inference phase: Buffers initial rows and analyzes * variant data to infer schemas 2. Writing phase: Creates the actual Parquet writer with inferred @@ -52,6 +51,7 @@ public class SparkParquetWriterWithVariantShredding implements ParquetValueWriter, WriterLazyInitializable { private final StructType sparkSchema; private final MessageType parquetType; + private final Map properties; private final List bufferedRows; private ParquetValueWriter actualWriter; @@ -72,6 +72,7 @@ public SparkParquetWriterWithVariantShredding( StructType sparkSchema, MessageType parquetType, Map properties) { this.sparkSchema = sparkSchema; this.parquetType = parquetType; + this.properties = properties; this.bufferSize = Integer.parseInt( @@ -139,7 +140,9 @@ public InitializationResult initialize( MessageType shreddedSchema = (MessageType) ParquetWithSparkSchemaVisitor.visit( - sparkSchema, parquetType, new SchemaInferenceVisitor(rows, sparkSchema)); + sparkSchema, + parquetType, + new SchemaInferenceVisitor(rows, sparkSchema, properties)); actualWriter = SparkParquetWriters.buildWriter(sparkSchema, shreddedSchema); diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java new file mode 100644 index 000000000000..581043cd802e --- /dev/null +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java @@ -0,0 +1,545 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.source; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.apache.iceberg.parquet.ParquetVariantUtil; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.variants.PhysicalType; +import org.apache.iceberg.variants.VariantArray; +import org.apache.iceberg.variants.VariantMetadata; +import org.apache.iceberg.variants.VariantObject; +import org.apache.iceberg.variants.VariantPrimitive; +import org.apache.iceberg.variants.VariantValue; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.LogicalTypeAnnotation; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; +import org.apache.parquet.schema.Types; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.unsafe.types.VariantVal; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Analyzes variant data across buffered rows to determine an optimal shredding schema. + ** + *

    + *
  • If a field appears consistently with a consistent type → create both {@code value} and + * {@code typed_value} + *
  • If a field appears with inconsistent types → only create {@code value} + *
  • Drop fields that occur in less than the configured threshold of sampled rows + *
  • Cap the maximum fields to shred + *
+ */ +public class VariantShreddingAnalyzer { + private static final Logger LOG = LoggerFactory.getLogger(VariantShreddingAnalyzer.class); + + private final double minOccurrenceThreshold; + private final int maxFields; + + /** + * Creates a new analyzer with the specified configuration. + * + * @param minOccurrenceThreshold minimum occurrence threshold (e.g., 0.1 for 10%) + * @param maxFields maximum number of fields to shred + */ + public VariantShreddingAnalyzer(double minOccurrenceThreshold, int maxFields) { + this.minOccurrenceThreshold = minOccurrenceThreshold; + this.maxFields = maxFields; + } + + /** + * Analyzes buffered variant values to determine the optimal shredding schema. + * + * @param bufferedRows the buffered rows to analyze + * @param variantFieldIndex the index of the variant field in the rows + * @return the shredded schema type, or null if no shredding should be performed + */ + public Type analyzeAndCreateSchema(List bufferedRows, int variantFieldIndex) { + if (bufferedRows.isEmpty()) { + return null; + } + + List variantValues = extractVariantValues(bufferedRows, variantFieldIndex); + if (variantValues.isEmpty()) { + return null; + } + + FieldStats stats = analyzeFields(variantValues); + return buildShreddedSchema(stats, variantValues.size()); + } + + private static List extractVariantValues( + List bufferedRows, int variantFieldIndex) { + List values = new java.util.ArrayList<>(); + + for (InternalRow row : bufferedRows) { + if (!row.isNullAt(variantFieldIndex)) { + VariantVal variantVal = row.getVariant(variantFieldIndex); + if (variantVal != null) { + VariantValue variantValue = + VariantValue.from( + VariantMetadata.from( + ByteBuffer.wrap(variantVal.getMetadata()).order(ByteOrder.LITTLE_ENDIAN)), + ByteBuffer.wrap(variantVal.getValue()).order(ByteOrder.LITTLE_ENDIAN)); + values.add(variantValue); + } + } + } + + return values; + } + + private static FieldStats analyzeFields(List variantValues) { + FieldStats stats = new FieldStats(); + + for (VariantValue value : variantValues) { + if (value.type() == PhysicalType.OBJECT) { + VariantObject obj = value.asObject(); + for (String fieldName : obj.fieldNames()) { + VariantValue fieldValue = obj.get(fieldName); + if (fieldValue != null) { + stats.recordField(fieldName, fieldValue); + } + } + } + } + + return stats; + } + + private Type buildShreddedSchema(FieldStats stats, int totalRows) { + int minOccurrences = (int) Math.ceil(totalRows * minOccurrenceThreshold); + + // Get fields that meet the occurrence threshold + Set candidateFields = Sets.newTreeSet(); + for (Map.Entry entry : stats.fieldInfoMap.entrySet()) { + String fieldName = entry.getKey(); + FieldInfo info = entry.getValue(); + + if (info.occurrenceCount >= minOccurrences) { + candidateFields.add(fieldName); + } else { + LOG.debug( + "Field '{}' appears only {} times out of {} (< {}%), dropping", + fieldName, + info.occurrenceCount, + totalRows, + (int) (minOccurrenceThreshold * 100)); + } + } + + if (candidateFields.isEmpty()) { + return null; + } + + // Build the typed_value struct with field count limit + Types.GroupBuilder objectBuilder = Types.buildGroup(Type.Repetition.OPTIONAL); + int fieldCount = 0; + + for (String fieldName : candidateFields) { + FieldInfo info = stats.fieldInfoMap.get(fieldName); + + if (info.hasConsistentType()) { + Type shreddedFieldType = createShreddedFieldType(fieldName, info); + if (shreddedFieldType != null) { + if (fieldCount + 2 > maxFields) { + LOG.debug( + "Reached maximum field limit ({}) while processing field '{}', stopping", + maxFields, + fieldName); + break; + } + objectBuilder.addField(shreddedFieldType); + fieldCount += 2; + } + } else { + Type valueOnlyField = createValueOnlyField(fieldName); + if (fieldCount + 1 > maxFields) { + LOG.debug( + "Reached maximum field limit ({}) while processing field '{}', stopping", + maxFields, + fieldName); + break; + } + objectBuilder.addField(valueOnlyField); + fieldCount += 1; + LOG.debug( + "Field '{}' has inconsistent types ({}), creating value-only field", + fieldName, + info.observedTypes); + } + } + + if (fieldCount == 0) { + return null; + } + + LOG.info("Created shredded schema with {} fields for {} candidate fields", fieldCount, candidateFields.size()); + return objectBuilder.named("typed_value"); + } + + private static Type createShreddedFieldType(String fieldName, FieldInfo info) { + PhysicalType physicalType = info.getConsistentType(); + if (physicalType == null) { + return null; + } + + // For array types, analyze the first value to determine element type + Type typedValue; + if (physicalType == PhysicalType.ARRAY) { + typedValue = createArrayTypedValue(info); + } else if (physicalType == PhysicalType.DECIMAL4 + || physicalType == PhysicalType.DECIMAL8 + || physicalType == PhysicalType.DECIMAL16) { + // For decimals, infer precision and scale from actual values + typedValue = createDecimalTypedValue(info, physicalType); + } else if (physicalType == PhysicalType.OBJECT) { + // For nested objects, attempt recursive shredding + typedValue = createNestedObjectTypedValue(info); + } else { + // Convert the physical type to a Parquet type for typed_value + typedValue = convertPhysicalTypeToParquet(physicalType); + } + + if (typedValue == null) { + // If we can't create a typed_value (e.g., inconsistent decimal scales), + // create a value-only field instead of skipping the field entirely + return Types.buildGroup(Type.Repetition.REQUIRED) + .optional(PrimitiveType.PrimitiveTypeName.BINARY) + .named("value") + .named(fieldName); + } + + return Types.buildGroup(Type.Repetition.REQUIRED) + .optional(PrimitiveType.PrimitiveTypeName.BINARY) + .named("value") + .addField(typedValue) + .named(fieldName); + } + + private static Type createDecimalTypedValue(FieldInfo info, PhysicalType decimalType) { + // Analyze decimal values to determine precision and scale + // All values must have the same scale to be considered consistent + Integer consistentScale = null; + int maxPrecision = 0; + + for (VariantValue value : info.observedValues) { + if (value.type() == decimalType) { + try { + VariantPrimitive primitive = value.asPrimitive(); + Object decimalValue = primitive.get(); + if (decimalValue instanceof BigDecimal) { + BigDecimal bd = (BigDecimal) decimalValue; + int precision = bd.precision(); + int scale = bd.scale(); + + // Check scale consistency + if (consistentScale == null) { + consistentScale = scale; + } else if (consistentScale != scale) { + // Different scales mean inconsistent types - no typed_value + LOG.debug( + "Decimal values have inconsistent scales ({} vs {}), skipping typed_value", + consistentScale, + scale); + return null; + } + + maxPrecision = Math.max(maxPrecision, precision); + } + } catch (Exception e) { + LOG.debug("Failed to analyze decimal value", e); + } + } + } + + if (maxPrecision == 0 || consistentScale == null) { + LOG.debug("Could not determine decimal precision/scale, skipping typed_value"); + return null; + } + + // Determine the appropriate Parquet type based on precision + PrimitiveType.PrimitiveTypeName primitiveType; + if (maxPrecision <= 9) { + primitiveType = PrimitiveType.PrimitiveTypeName.INT32; + } else if (maxPrecision <= 18) { + primitiveType = PrimitiveType.PrimitiveTypeName.INT64; + } else { + primitiveType = PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY; + } + + return Types.optional(primitiveType) + .as(LogicalTypeAnnotation.decimalType(consistentScale, maxPrecision)) + .named("typed_value"); + } + + private static Type createNestedObjectTypedValue(FieldInfo info) { + // For nested objects, we can recursively analyze their fields + // For now, we'll create a simpler representation + // A full implementation would recursively build the object structure + + // Get a sample object to analyze its fields + for (VariantValue value : info.observedValues) { + if (value.type() == PhysicalType.OBJECT) { + try { + VariantObject obj = value.asObject(); + int numFields = obj.numFields(); + + // Only shred simple nested objects (not too many fields) + if (numFields > 0 && numFields <= 20) { + // Analyze fields in the nested object + Map> nestedFieldTypes = Maps.newHashMap(); + + for (String fieldName : obj.fieldNames()) { + VariantValue fieldValue = obj.get(fieldName); + if (fieldValue != null) { + nestedFieldTypes + .computeIfAbsent(fieldName, k -> Sets.newHashSet()) + .add(fieldValue.type()); + } + } + + // Build nested struct with fields that have consistent types + Types.GroupBuilder nestedBuilder = + Types.buildGroup(Type.Repetition.OPTIONAL); + int fieldCount = 0; + + for (Map.Entry> entry : nestedFieldTypes.entrySet()) { + String fieldName = entry.getKey(); + Set types = entry.getValue(); + + // Only include fields with consistent types + if (types.size() == 1) { + PhysicalType fieldType = types.iterator().next(); + Type fieldParquetType = convertPhysicalTypeToParquet(fieldType); + if (fieldParquetType != null) { + GroupType nestedField = + Types.buildGroup(Type.Repetition.REQUIRED) + .optional(PrimitiveType.PrimitiveTypeName.BINARY) + .named("value") + .addField(fieldParquetType) + .named(fieldName); + nestedBuilder.addField(nestedField); + fieldCount++; + } + } + } + + if (fieldCount > 0) { + return nestedBuilder.named("typed_value"); + } + } + } catch (Exception e) { + LOG.debug("Failed to analyze nested object", e); + } + break; + } + } + + LOG.debug("Skipping nested object - complex structure or analysis failed"); + return null; + } + + private static Type createArrayTypedValue(FieldInfo info) { + // Get a sample array value to analyze element types + for (VariantValue value : info.observedValues) { + if (value.type() == PhysicalType.ARRAY) { + try { + VariantArray array = value.asArray(); + int numElements = array.numElements(); + if (numElements > 0) { + // Analyze elements to determine if they have consistent type + Set elementTypes = Sets.newHashSet(); + for (int i = 0; i < numElements; i++) { + elementTypes.add(array.get(i).type()); + } + + // If all elements have consistent type, create typed array + if (elementTypes.size() == 1 + || (elementTypes.size() == 2 + && elementTypes.contains(PhysicalType.BOOLEAN_TRUE) + && elementTypes.contains(PhysicalType.BOOLEAN_FALSE))) { + PhysicalType elementType = elementTypes.iterator().next(); + if (elementType == PhysicalType.BOOLEAN_FALSE + || elementType == PhysicalType.BOOLEAN_TRUE) { + elementType = PhysicalType.BOOLEAN_TRUE; + } + Type elementParquetType = convertPhysicalTypeToParquet(elementType); + if (elementParquetType != null) { + // Create list with typed element + GroupType element = + Types.buildGroup(Type.Repetition.REQUIRED) + .optional(PrimitiveType.PrimitiveTypeName.BINARY) + .named("value") + .addField(elementParquetType) + .named("element"); + return Types.optionalList().element(element).named("typed_value"); + } + } + } + } catch (Exception e) { + LOG.debug("Failed to analyze array elements", e); + } + break; + } + } + return null; + } + + private static Type createValueOnlyField(String fieldName) { + // Create a field with only the value field (no typed_value) + return Types.buildGroup(Type.Repetition.REQUIRED) + .optional(PrimitiveType.PrimitiveTypeName.BINARY) + .named("value") + .named(fieldName); + } + + private static Type convertPhysicalTypeToParquet(PhysicalType physicalType) { + switch (physicalType) { + case BOOLEAN_TRUE: + case BOOLEAN_FALSE: + return Types.optional(PrimitiveType.PrimitiveTypeName.BOOLEAN).named("typed_value"); + + case INT8: + return Types.optional(PrimitiveType.PrimitiveTypeName.INT32) + .as(LogicalTypeAnnotation.intType(8, true)) + .named("typed_value"); + + case INT16: + return Types.optional(PrimitiveType.PrimitiveTypeName.INT32) + .as(LogicalTypeAnnotation.intType(16, true)) + .named("typed_value"); + + case INT32: + return Types.optional(PrimitiveType.PrimitiveTypeName.INT32) + .as(LogicalTypeAnnotation.intType(32, true)) + .named("typed_value"); + + case INT64: + return Types.optional(PrimitiveType.PrimitiveTypeName.INT64).named("typed_value"); + + case FLOAT: + return Types.optional(PrimitiveType.PrimitiveTypeName.FLOAT).named("typed_value"); + + case DOUBLE: + return Types.optional(PrimitiveType.PrimitiveTypeName.DOUBLE).named("typed_value"); + + case STRING: + return Types.optional(PrimitiveType.PrimitiveTypeName.BINARY) + .as(LogicalTypeAnnotation.stringType()) + .named("typed_value"); + + case BINARY: + return Types.optional(PrimitiveType.PrimitiveTypeName.BINARY).named("typed_value"); + + case DATE: + return Types.optional(PrimitiveType.PrimitiveTypeName.INT32) + .as(LogicalTypeAnnotation.dateType()) + .named("typed_value"); + + case TIMESTAMPTZ: + return Types.optional(PrimitiveType.PrimitiveTypeName.INT64) + .as(LogicalTypeAnnotation.timestampType(true, LogicalTypeAnnotation.TimeUnit.MICROS)) + .named("typed_value"); + + case TIMESTAMPNTZ: + return Types.optional(PrimitiveType.PrimitiveTypeName.INT64) + .as(LogicalTypeAnnotation.timestampType(false, LogicalTypeAnnotation.TimeUnit.MICROS)) + .named("typed_value"); + + case DECIMAL4: + case DECIMAL8: + case DECIMAL16: + // Decimals are now handled in createDecimalTypedValue() + // This case should not be reached for consistent decimal types + LOG.debug("Decimal type {} should be handled by createDecimalTypedValue()", physicalType); + return null; + + case ARRAY: + // Arrays are now handled in createArrayTypedValue() + LOG.debug("Array type should be handled by createArrayTypedValue()"); + return null; + + case OBJECT: + // Nested objects are now handled in createNestedObjectTypedValue() + LOG.debug("Object type should be handled by createNestedObjectTypedValue()"); + return null; + + default: + LOG.debug("Unknown physical type: {}", physicalType); + return null; + } + } + + /** Tracks statistics about fields across multiple variant values. */ + private static class FieldStats { + private final Map fieldInfoMap = Maps.newHashMap(); + + void recordField(String fieldName, VariantValue value) { + FieldInfo info = fieldInfoMap.computeIfAbsent(fieldName, k -> new FieldInfo()); + info.observe(value); + } + } + + /** Tracks occurrence count and type consistency for a single field. */ + private static class FieldInfo { + private int occurrenceCount = 0; + private final Set observedTypes = Sets.newHashSet(); + private final List observedValues = new java.util.ArrayList<>(); + + void observe(VariantValue value) { + occurrenceCount++; + observedTypes.add(value.type()); + observedValues.add(value); + } + + boolean hasConsistentType() { + // Handle boolean types specially - both TRUE and FALSE map to BOOLEAN + if (observedTypes.size() == 2 + && observedTypes.contains(PhysicalType.BOOLEAN_TRUE) + && observedTypes.contains(PhysicalType.BOOLEAN_FALSE)) { + return true; + } + return observedTypes.size() == 1; + } + + PhysicalType getConsistentType() { + if (!hasConsistentType()) { + return null; + } + + // Handle boolean types + if (observedTypes.contains(PhysicalType.BOOLEAN_TRUE) + || observedTypes.contains(PhysicalType.BOOLEAN_FALSE)) { + return PhysicalType.BOOLEAN_TRUE; // Use TRUE as canonical boolean type + } + + return observedTypes.iterator().next(); + } + } +} + diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java index d82a241ba148..083242c6b743 100644 --- a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java @@ -23,6 +23,7 @@ import java.io.IOException; import java.net.InetAddress; +import java.util.List; import java.util.Map; import org.apache.hadoop.conf.Configuration; import org.apache.iceberg.FileScanTask; @@ -136,7 +137,7 @@ public void testVariantShreddingWrite() throws IOException { "zip", shreddedPrimitive( PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(16))); - GroupType address = variant("address", 2, objectFields(name, streets, zip)); + GroupType address = variant("address", 2, Type.Repetition.OPTIONAL, objectFields(name, streets, zip)); MessageType expectedSchema = parquetSchema(address); Table table = validationCatalog.loadTable(tableIdent); @@ -160,7 +161,7 @@ public void testVariantShreddingWithNullFirstRow() throws IOException { "state", shreddedPrimitive( PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); - GroupType address = variant("address", 2, objectFields(city, state)); + GroupType address = variant("address", 2, Type.Repetition.OPTIONAL, objectFields(city, state)); MessageType expectedSchema = parquetSchema(address); Table table = validationCatalog.loadTable(tableIdent); @@ -190,7 +191,7 @@ public void testVariantShreddingWithTwoVariantColumns() throws IOException { "zip", shreddedPrimitive( PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(16, true))); - GroupType address = variant("address", 2, objectFields(city, zip)); + GroupType address = variant("address", 2, Type.Repetition.OPTIONAL, objectFields(city, zip)); GroupType type = field( @@ -199,7 +200,7 @@ public void testVariantShreddingWithTwoVariantColumns() throws IOException { PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); GroupType verified = field("verified", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.BOOLEAN)); - GroupType metadata = variant("metadata", 3, objectFields(type, verified)); + GroupType metadata = variant("metadata", 3, Type.Repetition.OPTIONAL, objectFields(type, verified)); MessageType expectedSchema = parquetSchema(address, metadata); @@ -227,14 +228,14 @@ public void testVariantShreddingWithTwoVariantColumnsOneNull() throws IOExceptio "street", shreddedPrimitive( PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); - GroupType address = variant("address", 2, objectFields(street)); + GroupType address = variant("address", 2, Type.Repetition.OPTIONAL, objectFields(street)); GroupType label = field( "label", shreddedPrimitive( PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); - GroupType metadata = variant("metadata", 3, objectFields(label)); + GroupType metadata = variant("metadata", 3, Type.Repetition.OPTIONAL, objectFields(label)); MessageType expectedSchema = parquetSchema(address, metadata); @@ -250,13 +251,385 @@ public void testVariantShreddingDisabled() throws IOException { String values = "(1, parse_json('{\"city\": \"NYC\", \"zip\": 10001}')), (2, null)"; sql("INSERT INTO %s VALUES %s", tableName, values); - GroupType address = variant("address", 2); + GroupType address = variant("address", 2, Type.Repetition.OPTIONAL); MessageType expectedSchema = parquetSchema(address); Table table = validationCatalog.loadTable(tableIdent); verifyParquetSchema(table, expectedSchema); } + @TestTemplate + public void testConsistentTypeCreatesTypedValue() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = + "(1, parse_json('{\"name\": \"Alice\", \"age\": 30}'))," + + " (2, parse_json('{\"name\": \"Bob\", \"age\": 25}'))," + + " (3, parse_json('{\"name\": \"Charlie\", \"age\": 35}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType age = + field("age", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(age, name)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + /** + * Test Heuristic 2: Inconsistent Type → Value Only + * + *

When a field appears with different types across rows, only the "value" field should be + * created (no "typed_value"). + */ + @TestTemplate + public void testInconsistentTypeCreatesValueOnly() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + // "age" appears as both string and int - inconsistent type + String values = + "(1, parse_json('{\"age\": \"25\"}'))," + + " (2, parse_json('{\"age\": 30}'))," + + " (3, parse_json('{\"age\": \"35\"}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + // "age" should have only "value" field, no "typed_value" + GroupType age = valueOnlyField("age"); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(age)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + /** + * Test Heuristic 3: Rare Fields Are Dropped + * + *

Fields that appear in less than the configured threshold percentage of rows should be + * dropped from the shredded schema. + */ + @TestTemplate + public void testRareFieldIsDropped() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + // Set threshold to 20% (0.2) + spark.conf().set(SparkSQLProperties.VARIANT_MIN_OCCURRENCE_THRESHOLD, "0.2"); + + // "common" appears in all 10 rows (100%), "rare" appears in 1 row (10%) + String values = + "(1, parse_json('{\"common\": 1, \"rare\": 100}'))," + + " (2, parse_json('{\"common\": 2}'))," + + " (3, parse_json('{\"common\": 3}'))," + + " (4, parse_json('{\"common\": 4}'))," + + " (5, parse_json('{\"common\": 5}'))," + + " (6, parse_json('{\"common\": 6}'))," + + " (7, parse_json('{\"common\": 7}'))," + + " (8, parse_json('{\"common\": 8}'))," + + " (9, parse_json('{\"common\": 9}'))," + + " (10, parse_json('{\"common\": 10}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + // Only "common" should be present (appears in 100% of rows) + // "rare" should be dropped (appears in only 10% of rows, below 20% threshold) + GroupType common = + field("common", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(common)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + + // Reset threshold to default to avoid interference with other tests + spark.conf().unset(SparkSQLProperties.VARIANT_MIN_OCCURRENCE_THRESHOLD); + } + + /** + * Test Heuristic 4: Boolean Type Handling + * + *

Both "true" and "false" values should be treated as the same consistent boolean type, and a + * typed_value field should be created. + */ + @TestTemplate + public void testBooleanTypeHandling() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + // "active" field has both true and false values - should be treated as consistent boolean + String values = + "(1, parse_json('{\"active\": true}'))," + + " (2, parse_json('{\"active\": false}'))," + + " (3, parse_json('{\"active\": true}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + // "active" should have typed_value with boolean type + GroupType active = field("active", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.BOOLEAN)); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(active)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + + // Reset field limit to default to avoid interference from previous tests + spark.conf().unset(SparkSQLProperties.VARIANT_MAX_SHREDDED_FIELDS); + } + + /** + * Test Heuristic 5: Mixed Fields (Consistent and Inconsistent) + * + *

Tests a realistic scenario with multiple fields where some have consistent types and others + * don't. + */ + @TestTemplate + public void testMixedFieldsConsistentAndInconsistent() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + // "name": always string (consistent) + // "age": mixed int/string (inconsistent) + // "active": boolean (consistent) + String values = + "(1, parse_json('{\"name\": \"Alice\", \"age\": 30, \"active\": true}'))," + + " (2, parse_json('{\"name\": \"Bob\", \"age\": \"25\", \"active\": false}'))," + + " (3, parse_json('{\"name\": \"Charlie\", \"age\": \"35\", \"active\": true}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + // "name" should have typed_value (consistent string) + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + + // "age" should NOT have typed_value (inconsistent types) + GroupType age = valueOnlyField("age"); + + // "active" should have typed_value (consistent boolean) + GroupType active = field("active", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.BOOLEAN)); + + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(active, age, name)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + /** + * Test Heuristic 6: Field Limit Enforcement + * + *

Verify that the analyzer respects the maximum field limit and stops adding fields once the + * limit is reached. + */ + @TestTemplate + public void testMaxFieldLimitEnforcement() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + // Set very low field limit + spark.conf().set(SparkSQLProperties.VARIANT_MAX_SHREDDED_FIELDS, "4"); + + // Create rows with many fields (a, b, c, d, e, f) + String values = + "(1, parse_json('{\"a\": 1, \"b\": 2, \"c\": 3, \"d\": 4, \"e\": 5, \"f\": 6}'))," + + " (2, parse_json('{\"a\": 1, \"b\": 2, \"c\": 3, \"d\": 4, \"e\": 5, \"f\": 6}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + // With limit 4: field "a" (2 fields: value + typed_value) + field "b" (2 fields) = 4 total + // Fields are added alphabetically, so only "a" and "b" should be present + GroupType a = + field("a", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + GroupType b = + field("b", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(a, b)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + + // Reset field limit to default to avoid interference from previous tests + spark.conf().unset(SparkSQLProperties.VARIANT_MAX_SHREDDED_FIELDS); + } + + /** + * Test Heuristic 7: Decimal Type Handling - Inconsistent Scales + * + *

Verify that decimal fields with different scales are treated as inconsistent types + * and only get a value field (no typed_value). + */ + @TestTemplate + public void testDecimalTypeHandlingInconsistentScales() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + // Decimal values with different scales: scale 6, 2, 2 + // 123.456789 → precision 9, scale 6 + // 678.90 → precision 5, scale 2 + // 999.99 → precision 5, scale 2 + // These are treated as inconsistent types due to different scales + String values = + "(1, parse_json('{\"price\": 123.456789}'))," + + " (2, parse_json('{\"price\": 678.90}'))," + + " (3, parse_json('{\"price\": 999.99}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + // "price" has inconsistent scales, so only "value" field (no typed_value) + GroupType price = valueOnlyField("price"); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(price)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + /** + * Test Heuristic 7b: Decimal Type Handling - Consistent Scales + * + *

Verify that decimal fields with the same scale get proper typed_value with inferred + * precision/scale. + */ + @TestTemplate + public void testDecimalTypeHandlingConsistentScales() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + // Decimal values with consistent scale (all 2 decimal places) + String values = + "(1, parse_json('{\"price\": 123.45}'))," + + " (2, parse_json('{\"price\": 678.90}'))," + + " (3, parse_json('{\"price\": 999.99}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + // "price" should have typed_value with inferred DECIMAL(5,2) type + GroupType price = + field( + "price", + org.apache.parquet.schema.Types.optional(PrimitiveType.PrimitiveTypeName.INT32) + .as(LogicalTypeAnnotation.decimalType(2, 5)) + .named("typed_value")); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(price)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + /** + * Test Heuristic 7c: Decimal Type Handling - Inconsistent After Buffering + * + *

Verify that when buffered rows have consistent decimal scales but subsequent unbuffered rows + * have inconsistent scales, the inconsistent values are written to the value field only. + * The schema is inferred from buffered rows and should include typed_value for the consistent type. + */ + @TestTemplate + public void testDecimalTypeHandlingInconsistentAfterBuffering() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + // Set a small buffer size to test the scenario + spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "3"); + + // First 3 rows (buffered): consistent scale (2 decimal places) + // 4th row onwards (unbuffered): different scale (6 decimal places) + // Schema should be inferred from buffered rows with DECIMAL(5,2) + // The unbuffered row with different scale should still write successfully to value field + String values = + "(1, parse_json('{\"price\": 123.45}'))," + + " (2, parse_json('{\"price\": 678.90}'))," + + " (3, parse_json('{\"price\": 999.99}'))," + + " (4, parse_json('{\"price\": 111.111111}'))"; // Different scale - should write to value only + sql("INSERT INTO %s VALUES %s", tableName, values); + + // Schema should have typed_value with DECIMAL(5,2) based on buffered rows + GroupType price = + field( + "price", + org.apache.parquet.schema.Types.optional(PrimitiveType.PrimitiveTypeName.INT32) + .as(LogicalTypeAnnotation.decimalType(2, 5)) + .named("typed_value")); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(price)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + + // Verify all rows were written successfully + List result = sql("SELECT id, address FROM %s ORDER BY id", tableName); + assertThat(result).hasSize(4); + + // Reset buffer size to default + spark.conf().unset(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE); + } + + /** + * Test Heuristic 8: Array Type Handling + * + *

Verify that array fields with consistent element types get proper typed_value. + */ + @TestTemplate + public void testArrayTypeHandling() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + // Arrays with consistent element types (all strings) + String values = + "(1, parse_json('{\"tags\": [\"java\", \"scala\", \"python\"]}'))," + + " (2, parse_json('{\"tags\": [\"rust\", \"go\"]}'))," + + " (3, parse_json('{\"tags\": [\"javascript\"]}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + // "tags" should have typed_value with list of strings + GroupType tags = + field( + "tags", + list( + element( + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, + LogicalTypeAnnotation.stringType())))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(tags)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + /** + * Test Heuristic 9: Nested Object Handling + * + *

Verify that simple nested objects are recursively shredded. + */ + @TestTemplate + public void testNestedObjectHandling() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + // Nested objects with consistent structure + String values = + "(1, parse_json('{\"location\": {\"city\": \"Seattle\", \"zip\": 98101}}'))," + + " (2, parse_json('{\"location\": {\"city\": \"Portland\", \"zip\": 97201}}'))," + + " (3, parse_json('{\"location\": {\"city\": \"NYC\", \"zip\": 10001}}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + // Nested "location" object should be shredded with its fields + GroupType city = + field( + "city", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType zip = + field("zip", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(32, true))); + GroupType location = field("location", objectFields(zip, city)); + + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(location)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + /** Helper method to create a value-only field (no typed_value) for inconsistent types. */ + private static GroupType valueOnlyField(String name) { + return org.apache.parquet.schema.Types.buildGroup(Type.Repetition.REQUIRED) + .optional(PrimitiveType.PrimitiveTypeName.BINARY) + .named("value") + .named(name); + } + private void verifyParquetSchema(Table table, MessageType expectedSchema) throws IOException { try (CloseableIterable tasks = table.newScan().planFiles()) { assertThat(tasks).isNotEmpty(); @@ -283,8 +656,8 @@ private static MessageType parquetSchema(Type... variantTypes) { .named("table"); } - private static GroupType variant(String name, int fieldId) { - return org.apache.parquet.schema.Types.buildGroup(Type.Repetition.OPTIONAL) + private static GroupType variant(String name, int fieldId, Type.Repetition repetition) { + return org.apache.parquet.schema.Types.buildGroup(repetition) .id(fieldId) .as(LogicalTypeAnnotation.variantType(Variant.VARIANT_SPEC_VERSION)) .required(PrimitiveType.PrimitiveTypeName.BINARY) @@ -294,9 +667,10 @@ private static GroupType variant(String name, int fieldId) { .named(name); } - private static GroupType variant(String name, int fieldId, Type shreddedType) { + private static GroupType variant( + String name, int fieldId, Type.Repetition repetition, Type shreddedType) { checkShreddedType(shreddedType); - return org.apache.parquet.schema.Types.buildGroup(Type.Repetition.OPTIONAL) + return org.apache.parquet.schema.Types.buildGroup(repetition) .id(fieldId) .as(LogicalTypeAnnotation.variantType(Variant.VARIANT_SPEC_VERSION)) .required(PrimitiveType.PrimitiveTypeName.BINARY) From eeeb35c6dee4d948660d6f9c395562376ec84821 Mon Sep 17 00:00:00 2001 From: Aihua Xu Date: Fri, 9 Jan 2026 12:20:07 -0800 Subject: [PATCH 3/5] Simplify heuristics to most common type --- .../iceberg/parquet/ParquetVariantUtil.java | 4 +- .../parquet/ParquetVariantWriters.java | 48 -- .../iceberg/parquet/VariantWriterBuilder.java | 16 +- .../iceberg/spark/SparkSQLProperties.java | 12 - .../apache/iceberg/spark/SparkWriteConf.java | 18 +- .../spark/source/SchemaInferenceVisitor.java | 33 +- .../source/VariantShreddingAnalyzer.java | 513 +++++------------- .../spark/variant/TestVariantShredding.java | 418 +++----------- 8 files changed, 238 insertions(+), 824 deletions(-) diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantUtil.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantUtil.java index d94760773e51..ac418a1127bd 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantUtil.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantUtil.java @@ -57,7 +57,7 @@ import org.apache.parquet.schema.Type; import org.apache.parquet.schema.Types; -public class ParquetVariantUtil { +class ParquetVariantUtil { private ParquetVariantUtil() {} /** @@ -212,7 +212,7 @@ static int scale(PrimitiveType primitive) { * @param value a variant value * @return a Parquet schema that can fully shred the value */ - public static Type toParquetSchema(VariantValue value) { + static Type toParquetSchema(VariantValue value) { return VariantVisitor.visit(value, new ParquetSchemaProducer()); } diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantWriters.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantWriters.java index 42cdee7a1a5c..08016667bdab 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantWriters.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantWriters.java @@ -65,11 +65,6 @@ static ParquetValueWriter primitive( return new PrimitiveWriter<>(writer, Sets.immutableEnumSet(Arrays.asList(types))); } - static ParquetValueWriter decimal( - ParquetValueWriter writer, int expectedScale, PhysicalType... types) { - return new DecimalWriter(writer, expectedScale, Sets.immutableEnumSet(Arrays.asList(types))); - } - @SuppressWarnings("unchecked") static ParquetValueWriter shredded( int valueDefinitionLevel, @@ -258,49 +253,6 @@ public void setColumnStore(ColumnWriteStore columnStore) { } } - /** - * A TypedWriter for decimals that validates scale before writing. - * If the scale doesn't match, it returns false from canWrite() to trigger fallback to value field. - */ - private static class DecimalWriter implements TypedWriter { - private final Set types; - private final ParquetValueWriter writer; - private final int expectedScale; - - private DecimalWriter( - ParquetValueWriter writer, int expectedScale, Set types) { - this.types = types; - this.writer = (ParquetValueWriter) writer; - this.expectedScale = expectedScale; - } - - @Override - public Set types() { - return types; - } - - @Override - public void write(int repetitionLevel, VariantValue value) { - java.math.BigDecimal decimal = (java.math.BigDecimal) value.asPrimitive().get(); - // Validate scale matches before writing - if (decimal.scale() != expectedScale) { - throw new IllegalArgumentException( - "Cannot write decimal with scale " + decimal.scale() + " to schema expecting scale " + expectedScale); - } - writer.write(repetitionLevel, decimal); - } - - @Override - public List> columns() { - return writer.columns(); - } - - @Override - public void setColumnStore(ColumnWriteStore columnStore) { - writer.setColumnStore(columnStore); - } - } - private static class ShreddedVariantWriter implements ParquetValueWriter { private final int valueDefinitionLevel; private final ParquetValueWriter valueWriter; diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/VariantWriterBuilder.java b/parquet/src/main/java/org/apache/iceberg/parquet/VariantWriterBuilder.java index 53cf5d9933d6..a447a102690a 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/VariantWriterBuilder.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/VariantWriterBuilder.java @@ -198,31 +198,27 @@ public Optional> visit(StringLogicalTypeAnnotation ignored @Override public Optional> visit(DecimalLogicalTypeAnnotation decimal) { ParquetValueWriter writer; - int scale = decimal.getScale(); switch (desc.getPrimitiveType().getPrimitiveTypeName()) { case FIXED_LEN_BYTE_ARRAY: case BINARY: writer = - ParquetVariantWriters.decimal( + ParquetVariantWriters.primitive( ParquetValueWriters.decimalAsFixed( - desc, decimal.getPrecision(), scale), - scale, + desc, decimal.getPrecision(), decimal.getScale()), PhysicalType.DECIMAL16); return Optional.of(writer); case INT64: writer = - ParquetVariantWriters.decimal( + ParquetVariantWriters.primitive( ParquetValueWriters.decimalAsLong( - desc, decimal.getPrecision(), scale), - scale, + desc, decimal.getPrecision(), decimal.getScale()), PhysicalType.DECIMAL8); return Optional.of(writer); case INT32: writer = - ParquetVariantWriters.decimal( + ParquetVariantWriters.primitive( ParquetValueWriters.decimalAsInteger( - desc, decimal.getPrecision(), scale), - scale, + desc, decimal.getPrecision(), decimal.getScale()), PhysicalType.DECIMAL4); return Optional.of(writer); } diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java index e111becad89e..b12606d23948 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java @@ -119,16 +119,4 @@ private SparkSQLProperties() {} public static final String VARIANT_INFERENCE_BUFFER_SIZE = "spark.sql.iceberg.variant.inference.buffer-size"; public static final int VARIANT_INFERENCE_BUFFER_SIZE_DEFAULT = 10; - - // Controls the minimum occurrence threshold for variant fields during shredding - // Fields that appear in fewer than this percentage of rows will be dropped - public static final String VARIANT_MIN_OCCURRENCE_THRESHOLD = - "spark.sql.iceberg.variant.min-occurrence-threshold"; - public static final double VARIANT_MIN_OCCURRENCE_THRESHOLD_DEFAULT = 0.1; // 10% - - // Controls the maximum number of fields to shred in a variant column - // This prevents creating overly wide Parquet schemas - public static final String VARIANT_MAX_SHREDDED_FIELDS = - "spark.sql.iceberg.variant.max-shredded-fields"; - public static final int VARIANT_MAX_SHREDDED_FIELDS_DEFAULT = 300; } diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java index 34fcd2f1e467..80d245712e6b 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java @@ -510,22 +510,14 @@ private Map dataWriteProperties() { writeProperties.put(PARQUET_COMPRESSION_LEVEL, parquetCompressionLevel); } writeProperties.put(SparkSQLProperties.SHRED_VARIANTS, String.valueOf(shredVariants())); - + // Add variant shredding configuration properties if (shredVariants()) { - String variantMaxFields = sessionConf.get(SparkSQLProperties.VARIANT_MAX_SHREDDED_FIELDS, null); - if (variantMaxFields != null) { - writeProperties.put(SparkSQLProperties.VARIANT_MAX_SHREDDED_FIELDS, variantMaxFields); - } - - String variantMinOccurrence = sessionConf.get(SparkSQLProperties.VARIANT_MIN_OCCURRENCE_THRESHOLD, null); - if (variantMinOccurrence != null) { - writeProperties.put(SparkSQLProperties.VARIANT_MIN_OCCURRENCE_THRESHOLD, variantMinOccurrence); - } - - String variantBufferSize = sessionConf.get(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, null); + String variantBufferSize = + sessionConf.get(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, null); if (variantBufferSize != null) { - writeProperties.put(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, variantBufferSize); + writeProperties.put( + SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, variantBufferSize); } } break; diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java index c03fc74f00e3..6903f1f03353 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java @@ -20,16 +20,10 @@ import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.BINARY; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; import java.util.List; import java.util.Map; -import org.apache.iceberg.parquet.ParquetVariantUtil; -import org.apache.iceberg.spark.SparkSQLProperties; import org.apache.iceberg.spark.data.ParquetWithSparkSchemaVisitor; import org.apache.iceberg.variants.Variant; -import org.apache.iceberg.variants.VariantMetadata; -import org.apache.iceberg.variants.VariantValue; import org.apache.parquet.schema.GroupType; import org.apache.parquet.schema.LogicalTypeAnnotation; import org.apache.parquet.schema.MessageType; @@ -43,13 +37,10 @@ import org.apache.spark.sql.types.MapType; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.VariantType; -import org.apache.spark.unsafe.types.VariantVal; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * A visitor that infers variant shredding schemas by analyzing buffered rows of data. - */ +/** A visitor that infers variant shredding schemas by analyzing buffered rows of data. */ public class SchemaInferenceVisitor extends ParquetWithSparkSchemaVisitor { private static final Logger LOG = LoggerFactory.getLogger(SchemaInferenceVisitor.class); @@ -61,20 +52,7 @@ public SchemaInferenceVisitor( List bufferedRows, StructType sparkSchema, Map properties) { this.bufferedRows = bufferedRows; this.sparkSchema = sparkSchema; - - double minOccurrenceThreshold = - Double.parseDouble( - properties.getOrDefault( - SparkSQLProperties.VARIANT_MIN_OCCURRENCE_THRESHOLD, - String.valueOf(SparkSQLProperties.VARIANT_MIN_OCCURRENCE_THRESHOLD_DEFAULT))); - - int maxFields = - Integer.parseInt( - properties.getOrDefault( - SparkSQLProperties.VARIANT_MAX_SHREDDED_FIELDS, - String.valueOf(SparkSQLProperties.VARIANT_MAX_SHREDDED_FIELDS_DEFAULT))); - - this.analyzer = new VariantShreddingAnalyzer(minOccurrenceThreshold, maxFields); + this.analyzer = new VariantShreddingAnalyzer(); } @Override @@ -151,10 +129,6 @@ public Type map(MapType sMap, GroupType map, Type key, Type value) { public Type variant(VariantType sVariant, GroupType variant) { int variantFieldIndex = getFieldIndex(currentPath()); - // Apply heuristics to determine the shredding schema: - // - Fields must appear in at least the configured percentage of rows - // - Type consistency determines if typed_value is created - // - Maximum field count to avoid overly wide schemas if (!bufferedRows.isEmpty() && variantFieldIndex >= 0) { Type shreddedType = analyzer.analyzeAndCreateSchema(bufferedRows, variantFieldIndex); if (shreddedType != null) { @@ -190,8 +164,7 @@ private int getFieldIndex(String[] path) { } else { // Nested field - navigate through struct hierarchy // For now, we only support direct struct nesting (not arrays/maps) - LOG.debug( - "Attempting to resolve nested variant field path: {}", String.join(".", path)); + LOG.debug("Attempting to resolve nested variant field path: {}", String.join(".", path)); // TODO: Implement full nested field resolution when needed // This would require tracking the current struct context during traversal // and maintaining a stack of field indices diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java index 581043cd802e..27b526134737 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java @@ -24,7 +24,6 @@ import java.util.List; import java.util.Map; import java.util.Set; -import org.apache.iceberg.parquet.ParquetVariantUtil; import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.iceberg.relocated.com.google.common.collect.Sets; import org.apache.iceberg.variants.PhysicalType; @@ -40,36 +39,20 @@ import org.apache.parquet.schema.Types; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.unsafe.types.VariantVal; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** * Analyzes variant data across buffered rows to determine an optimal shredding schema. - ** + * *

    - *
  • If a field appears consistently with a consistent type → create both {@code value} and - * {@code typed_value} - *
  • If a field appears with inconsistent types → only create {@code value} - *
  • Drop fields that occur in less than the configured threshold of sampled rows - *
  • Cap the maximum fields to shred + *
  • shred to the most common type *
*/ public class VariantShreddingAnalyzer { - private static final Logger LOG = LoggerFactory.getLogger(VariantShreddingAnalyzer.class); + private static final String TYPED_VALUE = "typed_value"; + private static final String VALUE = "value"; + private static final String ELEMENT = "element"; - private final double minOccurrenceThreshold; - private final int maxFields; - - /** - * Creates a new analyzer with the specified configuration. - * - * @param minOccurrenceThreshold minimum occurrence threshold (e.g., 0.1 for 10%) - * @param maxFields maximum number of fields to shred - */ - public VariantShreddingAnalyzer(double minOccurrenceThreshold, int maxFields) { - this.minOccurrenceThreshold = minOccurrenceThreshold; - this.maxFields = maxFields; - } + public VariantShreddingAnalyzer() {} /** * Analyzes buffered variant values to determine the optimal shredding schema. @@ -79,17 +62,13 @@ public VariantShreddingAnalyzer(double minOccurrenceThreshold, int maxFields) { * @return the shredded schema type, or null if no shredding should be performed */ public Type analyzeAndCreateSchema(List bufferedRows, int variantFieldIndex) { - if (bufferedRows.isEmpty()) { - return null; - } - List variantValues = extractVariantValues(bufferedRows, variantFieldIndex); if (variantValues.isEmpty()) { return null; } - FieldStats stats = analyzeFields(variantValues); - return buildShreddedSchema(stats, variantValues.size()); + PathNode root = buildPathTree(variantValues); + return buildTypedValue(root, root.info.getMostCommonType()); } private static List extractVariantValues( @@ -100,12 +79,12 @@ private static List extractVariantValues( if (!row.isNullAt(variantFieldIndex)) { VariantVal variantVal = row.getVariant(variantFieldIndex); if (variantVal != null) { - VariantValue variantValue = - VariantValue.from( - VariantMetadata.from( - ByteBuffer.wrap(variantVal.getMetadata()).order(ByteOrder.LITTLE_ENDIAN)), - ByteBuffer.wrap(variantVal.getValue()).order(ByteOrder.LITTLE_ENDIAN)); - values.add(variantValue); + VariantValue variantValue = + VariantValue.from( + VariantMetadata.from( + ByteBuffer.wrap(variantVal.getMetadata()).order(ByteOrder.LITTLE_ENDIAN)), + ByteBuffer.wrap(variantVal.getValue()).order(ByteOrder.LITTLE_ENDIAN)); + values.add(variantValue); } } } @@ -113,433 +92,231 @@ private static List extractVariantValues( return values; } - private static FieldStats analyzeFields(List variantValues) { - FieldStats stats = new FieldStats(); + private static PathNode buildPathTree(List variantValues) { + PathNode root = new PathNode(null); + root.info = new FieldInfo(); for (VariantValue value : variantValues) { - if (value.type() == PhysicalType.OBJECT) { - VariantObject obj = value.asObject(); - for (String fieldName : obj.fieldNames()) { - VariantValue fieldValue = obj.get(fieldName); - if (fieldValue != null) { - stats.recordField(fieldName, fieldValue); - } - } - } + traverse(root, value); } - return stats; + return root; } - private Type buildShreddedSchema(FieldStats stats, int totalRows) { - int minOccurrences = (int) Math.ceil(totalRows * minOccurrenceThreshold); - - // Get fields that meet the occurrence threshold - Set candidateFields = Sets.newTreeSet(); - for (Map.Entry entry : stats.fieldInfoMap.entrySet()) { - String fieldName = entry.getKey(); - FieldInfo info = entry.getValue(); - - if (info.occurrenceCount >= minOccurrences) { - candidateFields.add(fieldName); - } else { - LOG.debug( - "Field '{}' appears only {} times out of {} (< {}%), dropping", - fieldName, - info.occurrenceCount, - totalRows, - (int) (minOccurrenceThreshold * 100)); - } + private static void traverse(PathNode node, VariantValue value) { + if (value == null) { + return; } - if (candidateFields.isEmpty()) { - return null; - } + node.info.observe(value); - // Build the typed_value struct with field count limit - Types.GroupBuilder objectBuilder = Types.buildGroup(Type.Repetition.OPTIONAL); - int fieldCount = 0; - - for (String fieldName : candidateFields) { - FieldInfo info = stats.fieldInfoMap.get(fieldName); - - if (info.hasConsistentType()) { - Type shreddedFieldType = createShreddedFieldType(fieldName, info); - if (shreddedFieldType != null) { - if (fieldCount + 2 > maxFields) { - LOG.debug( - "Reached maximum field limit ({}) while processing field '{}', stopping", - maxFields, - fieldName); - break; + if (value.type() == PhysicalType.OBJECT) { + VariantObject obj = value.asObject(); + for (String fieldName : obj.fieldNames()) { + VariantValue fieldValue = obj.get(fieldName); + if (fieldValue != null) { + PathNode childNode = node.objectChildren.computeIfAbsent(fieldName, PathNode::new); + if (childNode.info == null) { + childNode.info = new FieldInfo(); } - objectBuilder.addField(shreddedFieldType); - fieldCount += 2; + traverse(childNode, fieldValue); } - } else { - Type valueOnlyField = createValueOnlyField(fieldName); - if (fieldCount + 1 > maxFields) { - LOG.debug( - "Reached maximum field limit ({}) while processing field '{}', stopping", - maxFields, - fieldName); - break; + } + } else if (value.type() == PhysicalType.ARRAY) { + VariantArray array = value.asArray(); + int numElements = array.numElements(); + if (node.arrayElement == null) { + node.arrayElement = new PathNode(null); + node.arrayElement.info = new FieldInfo(); + } + for (int i = 0; i < numElements; i++) { + VariantValue element = array.get(i); + if (element != null) { + traverse(node.arrayElement, element); } - objectBuilder.addField(valueOnlyField); - fieldCount += 1; - LOG.debug( - "Field '{}' has inconsistent types ({}), creating value-only field", - fieldName, - info.observedTypes); } } - - if (fieldCount == 0) { - return null; - } - - LOG.info("Created shredded schema with {} fields for {} candidate fields", fieldCount, candidateFields.size()); - return objectBuilder.named("typed_value"); } - private static Type createShreddedFieldType(String fieldName, FieldInfo info) { - PhysicalType physicalType = info.getConsistentType(); - if (physicalType == null) { - return null; - } + private static Type buildFieldGroup(PathNode node) { + Type typedValue = buildTypedValue(node, node.info.getMostCommonType()); + return Types.buildGroup(Type.Repetition.REQUIRED) + .optional(PrimitiveType.PrimitiveTypeName.BINARY) + .named(VALUE) + .addField(typedValue) + .named(node.fieldName); + } - // For array types, analyze the first value to determine element type + private static Type buildTypedValue(PathNode node, PhysicalType physicalType) { Type typedValue; if (physicalType == PhysicalType.ARRAY) { - typedValue = createArrayTypedValue(info); - } else if (physicalType == PhysicalType.DECIMAL4 - || physicalType == PhysicalType.DECIMAL8 - || physicalType == PhysicalType.DECIMAL16) { - // For decimals, infer precision and scale from actual values - typedValue = createDecimalTypedValue(info, physicalType); + typedValue = createArrayTypedValue(node); } else if (physicalType == PhysicalType.OBJECT) { - // For nested objects, attempt recursive shredding - typedValue = createNestedObjectTypedValue(info); + typedValue = createObjectTypedValue(node); } else { - // Convert the physical type to a Parquet type for typed_value - typedValue = convertPhysicalTypeToParquet(physicalType); + typedValue = createPrimitiveTypedValue(node.info, physicalType); } - if (typedValue == null) { - // If we can't create a typed_value (e.g., inconsistent decimal scales), - // create a value-only field instead of skipping the field entirely - return Types.buildGroup(Type.Repetition.REQUIRED) - .optional(PrimitiveType.PrimitiveTypeName.BINARY) - .named("value") - .named(fieldName); - } - - return Types.buildGroup(Type.Repetition.REQUIRED) - .optional(PrimitiveType.PrimitiveTypeName.BINARY) - .named("value") - .addField(typedValue) - .named(fieldName); + return typedValue; } - private static Type createDecimalTypedValue(FieldInfo info, PhysicalType decimalType) { - // Analyze decimal values to determine precision and scale - // All values must have the same scale to be considered consistent - Integer consistentScale = null; - int maxPrecision = 0; - - for (VariantValue value : info.observedValues) { - if (value.type() == decimalType) { - try { - VariantPrimitive primitive = value.asPrimitive(); - Object decimalValue = primitive.get(); - if (decimalValue instanceof BigDecimal) { - BigDecimal bd = (BigDecimal) decimalValue; - int precision = bd.precision(); - int scale = bd.scale(); - - // Check scale consistency - if (consistentScale == null) { - consistentScale = scale; - } else if (consistentScale != scale) { - // Different scales mean inconsistent types - no typed_value - LOG.debug( - "Decimal values have inconsistent scales ({} vs {}), skipping typed_value", - consistentScale, - scale); - return null; - } - - maxPrecision = Math.max(maxPrecision, precision); - } - } catch (Exception e) { - LOG.debug("Failed to analyze decimal value", e); - } - } - } - - if (maxPrecision == 0 || consistentScale == null) { - LOG.debug("Could not determine decimal precision/scale, skipping typed_value"); + private static Type createObjectTypedValue(PathNode node) { + if (node.objectChildren.isEmpty()) { return null; } - // Determine the appropriate Parquet type based on precision - PrimitiveType.PrimitiveTypeName primitiveType; - if (maxPrecision <= 9) { - primitiveType = PrimitiveType.PrimitiveTypeName.INT32; - } else if (maxPrecision <= 18) { - primitiveType = PrimitiveType.PrimitiveTypeName.INT64; - } else { - primitiveType = PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY; + Types.GroupBuilder builder = Types.buildGroup(Type.Repetition.OPTIONAL); + for (PathNode child : node.objectChildren.values()) { + builder.addField(buildFieldGroup(child)); } - return Types.optional(primitiveType) - .as(LogicalTypeAnnotation.decimalType(consistentScale, maxPrecision)) - .named("typed_value"); + return builder.named(TYPED_VALUE); } - private static Type createNestedObjectTypedValue(FieldInfo info) { - // For nested objects, we can recursively analyze their fields - // For now, we'll create a simpler representation - // A full implementation would recursively build the object structure - - // Get a sample object to analyze its fields - for (VariantValue value : info.observedValues) { - if (value.type() == PhysicalType.OBJECT) { - try { - VariantObject obj = value.asObject(); - int numFields = obj.numFields(); - - // Only shred simple nested objects (not too many fields) - if (numFields > 0 && numFields <= 20) { - // Analyze fields in the nested object - Map> nestedFieldTypes = Maps.newHashMap(); - - for (String fieldName : obj.fieldNames()) { - VariantValue fieldValue = obj.get(fieldName); - if (fieldValue != null) { - nestedFieldTypes - .computeIfAbsent(fieldName, k -> Sets.newHashSet()) - .add(fieldValue.type()); - } - } - - // Build nested struct with fields that have consistent types - Types.GroupBuilder nestedBuilder = - Types.buildGroup(Type.Repetition.OPTIONAL); - int fieldCount = 0; - - for (Map.Entry> entry : nestedFieldTypes.entrySet()) { - String fieldName = entry.getKey(); - Set types = entry.getValue(); - - // Only include fields with consistent types - if (types.size() == 1) { - PhysicalType fieldType = types.iterator().next(); - Type fieldParquetType = convertPhysicalTypeToParquet(fieldType); - if (fieldParquetType != null) { - GroupType nestedField = - Types.buildGroup(Type.Repetition.REQUIRED) - .optional(PrimitiveType.PrimitiveTypeName.BINARY) - .named("value") - .addField(fieldParquetType) - .named(fieldName); - nestedBuilder.addField(nestedField); - fieldCount++; - } - } - } - - if (fieldCount > 0) { - return nestedBuilder.named("typed_value"); - } - } - } catch (Exception e) { - LOG.debug("Failed to analyze nested object", e); - } - break; - } - } + private static Type createArrayTypedValue(PathNode node) { + PathNode elementNode = node.arrayElement; + PhysicalType elementType = elementNode.info.getMostCommonType(); + Type elementTypedValue = buildTypedValue(elementNode, elementType); - LOG.debug("Skipping nested object - complex structure or analysis failed"); - return null; + GroupType elementGroup = + Types.buildGroup(Type.Repetition.REQUIRED) + .optional(PrimitiveType.PrimitiveTypeName.BINARY) + .named(VALUE) + .addField(elementTypedValue) + .named(ELEMENT); + + return Types.optionalList().element(elementGroup).named(TYPED_VALUE); } - private static Type createArrayTypedValue(FieldInfo info) { - // Get a sample array value to analyze element types - for (VariantValue value : info.observedValues) { - if (value.type() == PhysicalType.ARRAY) { - try { - VariantArray array = value.asArray(); - int numElements = array.numElements(); - if (numElements > 0) { - // Analyze elements to determine if they have consistent type - Set elementTypes = Sets.newHashSet(); - for (int i = 0; i < numElements; i++) { - elementTypes.add(array.get(i).type()); - } - - // If all elements have consistent type, create typed array - if (elementTypes.size() == 1 - || (elementTypes.size() == 2 - && elementTypes.contains(PhysicalType.BOOLEAN_TRUE) - && elementTypes.contains(PhysicalType.BOOLEAN_FALSE))) { - PhysicalType elementType = elementTypes.iterator().next(); - if (elementType == PhysicalType.BOOLEAN_FALSE - || elementType == PhysicalType.BOOLEAN_TRUE) { - elementType = PhysicalType.BOOLEAN_TRUE; - } - Type elementParquetType = convertPhysicalTypeToParquet(elementType); - if (elementParquetType != null) { - // Create list with typed element - GroupType element = - Types.buildGroup(Type.Repetition.REQUIRED) - .optional(PrimitiveType.PrimitiveTypeName.BINARY) - .named("value") - .addField(elementParquetType) - .named("element"); - return Types.optionalList().element(element).named("typed_value"); - } - } - } - } catch (Exception e) { - LOG.debug("Failed to analyze array elements", e); - } - break; - } + private static class PathNode { + private final String fieldName; + private final Map objectChildren = Maps.newTreeMap(); + private PathNode arrayElement = null; + private FieldInfo info = null; + + private PathNode(String fieldName) { + this.fieldName = fieldName; } - return null; } - private static Type createValueOnlyField(String fieldName) { - // Create a field with only the value field (no typed_value) - return Types.buildGroup(Type.Repetition.REQUIRED) - .optional(PrimitiveType.PrimitiveTypeName.BINARY) - .named("value") - .named(fieldName); + /** Use DECIMAL with maximum precision and scale as the shredding type */ + private static Type createDecimalTypedValue(FieldInfo info) { + int maxPrecision = info.maxDecimalPrecision; + int maxScale = info.maxDecimalScale; + + if (maxPrecision <= 9) { + return Types.optional(PrimitiveType.PrimitiveTypeName.INT32) + .as(LogicalTypeAnnotation.decimalType(maxScale, maxPrecision)) + .named(TYPED_VALUE); + } else if (maxPrecision <= 18) { + return Types.optional(PrimitiveType.PrimitiveTypeName.INT64) + .as(LogicalTypeAnnotation.decimalType(maxScale, maxPrecision)) + .named(TYPED_VALUE); + } else { + return Types.optional(PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY) + .length(16) + .as(LogicalTypeAnnotation.decimalType(maxScale, maxPrecision)) + .named(TYPED_VALUE); + } } - private static Type convertPhysicalTypeToParquet(PhysicalType physicalType) { - switch (physicalType) { + private static Type createPrimitiveTypedValue(FieldInfo info, PhysicalType primitiveType) { + switch (primitiveType) { case BOOLEAN_TRUE: case BOOLEAN_FALSE: - return Types.optional(PrimitiveType.PrimitiveTypeName.BOOLEAN).named("typed_value"); + return Types.optional(PrimitiveType.PrimitiveTypeName.BOOLEAN).named(TYPED_VALUE); case INT8: return Types.optional(PrimitiveType.PrimitiveTypeName.INT32) .as(LogicalTypeAnnotation.intType(8, true)) - .named("typed_value"); + .named(TYPED_VALUE); case INT16: return Types.optional(PrimitiveType.PrimitiveTypeName.INT32) .as(LogicalTypeAnnotation.intType(16, true)) - .named("typed_value"); + .named(TYPED_VALUE); case INT32: return Types.optional(PrimitiveType.PrimitiveTypeName.INT32) .as(LogicalTypeAnnotation.intType(32, true)) - .named("typed_value"); + .named(TYPED_VALUE); case INT64: - return Types.optional(PrimitiveType.PrimitiveTypeName.INT64).named("typed_value"); + return Types.optional(PrimitiveType.PrimitiveTypeName.INT64).named(TYPED_VALUE); case FLOAT: - return Types.optional(PrimitiveType.PrimitiveTypeName.FLOAT).named("typed_value"); + return Types.optional(PrimitiveType.PrimitiveTypeName.FLOAT).named(TYPED_VALUE); case DOUBLE: - return Types.optional(PrimitiveType.PrimitiveTypeName.DOUBLE).named("typed_value"); + return Types.optional(PrimitiveType.PrimitiveTypeName.DOUBLE).named(TYPED_VALUE); case STRING: return Types.optional(PrimitiveType.PrimitiveTypeName.BINARY) .as(LogicalTypeAnnotation.stringType()) - .named("typed_value"); + .named(TYPED_VALUE); case BINARY: - return Types.optional(PrimitiveType.PrimitiveTypeName.BINARY).named("typed_value"); + return Types.optional(PrimitiveType.PrimitiveTypeName.BINARY).named(TYPED_VALUE); case DATE: return Types.optional(PrimitiveType.PrimitiveTypeName.INT32) .as(LogicalTypeAnnotation.dateType()) - .named("typed_value"); + .named(TYPED_VALUE); case TIMESTAMPTZ: return Types.optional(PrimitiveType.PrimitiveTypeName.INT64) .as(LogicalTypeAnnotation.timestampType(true, LogicalTypeAnnotation.TimeUnit.MICROS)) - .named("typed_value"); + .named(TYPED_VALUE); case TIMESTAMPNTZ: return Types.optional(PrimitiveType.PrimitiveTypeName.INT64) .as(LogicalTypeAnnotation.timestampType(false, LogicalTypeAnnotation.TimeUnit.MICROS)) - .named("typed_value"); + .named(TYPED_VALUE); case DECIMAL4: case DECIMAL8: case DECIMAL16: - // Decimals are now handled in createDecimalTypedValue() - // This case should not be reached for consistent decimal types - LOG.debug("Decimal type {} should be handled by createDecimalTypedValue()", physicalType); - return null; - - case ARRAY: - // Arrays are now handled in createArrayTypedValue() - LOG.debug("Array type should be handled by createArrayTypedValue()"); - return null; - - case OBJECT: - // Nested objects are now handled in createNestedObjectTypedValue() - LOG.debug("Object type should be handled by createNestedObjectTypedValue()"); - return null; + return createDecimalTypedValue(info); default: - LOG.debug("Unknown physical type: {}", physicalType); - return null; - } - } - - /** Tracks statistics about fields across multiple variant values. */ - private static class FieldStats { - private final Map fieldInfoMap = Maps.newHashMap(); - - void recordField(String fieldName, VariantValue value) { - FieldInfo info = fieldInfoMap.computeIfAbsent(fieldName, k -> new FieldInfo()); - info.observe(value); + throw new UnsupportedOperationException( + "Unknown primitive physical type: " + primitiveType); } } - /** Tracks occurrence count and type consistency for a single field. */ + /** Tracks occurrence count and types for a single field. */ private static class FieldInfo { - private int occurrenceCount = 0; private final Set observedTypes = Sets.newHashSet(); - private final List observedValues = new java.util.ArrayList<>(); + private final Map typeCounts = Maps.newHashMap(); + private int maxDecimalPrecision = 0; + private int maxDecimalScale = 0; void observe(VariantValue value) { - occurrenceCount++; - observedTypes.add(value.type()); - observedValues.add(value); - } - - boolean hasConsistentType() { - // Handle boolean types specially - both TRUE and FALSE map to BOOLEAN - if (observedTypes.size() == 2 - && observedTypes.contains(PhysicalType.BOOLEAN_TRUE) - && observedTypes.contains(PhysicalType.BOOLEAN_FALSE)) { - return true; + // Use BOOLEAN_TRUE for both TRUE/FALSE values + PhysicalType type = + value.type() == PhysicalType.BOOLEAN_FALSE ? PhysicalType.BOOLEAN_TRUE : value.type(); + observedTypes.add(type); + typeCounts.compute(type, (k, v) -> (v == null) ? 1 : v + 1); + + // Track max precision and scale for decimal types + if (type == PhysicalType.DECIMAL4 + || type == PhysicalType.DECIMAL8 + || type == PhysicalType.DECIMAL16) { + VariantPrimitive primitive = value.asPrimitive(); + Object decimalValue = primitive.get(); + if (decimalValue instanceof BigDecimal) { + BigDecimal bd = (BigDecimal) decimalValue; + maxDecimalPrecision = Math.max(maxDecimalPrecision, bd.precision()); + maxDecimalScale = Math.max(maxDecimalScale, bd.scale()); + } } - return observedTypes.size() == 1; } - PhysicalType getConsistentType() { - if (!hasConsistentType()) { - return null; - } - - // Handle boolean types - if (observedTypes.contains(PhysicalType.BOOLEAN_TRUE) - || observedTypes.contains(PhysicalType.BOOLEAN_FALSE)) { - return PhysicalType.BOOLEAN_TRUE; // Use TRUE as canonical boolean type - } - - return observedTypes.iterator().next(); + PhysicalType getMostCommonType() { + return typeCounts.entrySet().stream() + .max(Map.Entry.comparingByValue()) + .map(Map.Entry::getKey) + .orElse(null); } } } - diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java index 083242c6b743..5f4eb2a2732f 100644 --- a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java @@ -19,11 +19,11 @@ package org.apache.iceberg.spark.variant; import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.METASTOREURIS; +import static org.apache.parquet.schema.Types.optional; import static org.assertj.core.api.Assertions.assertThat; import java.io.IOException; import java.net.InetAddress; -import java.util.List; import java.util.Map; import org.apache.hadoop.conf.Configuration; import org.apache.iceberg.FileScanTask; @@ -112,140 +112,8 @@ public void after() { validationCatalog.dropTable(tableIdent, true); } - @TestTemplate - public void testVariantShreddingWrite() throws IOException { - spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); - String values = - "(1, parse_json('{\"name\": \"Joe\", \"streets\": [\"Apt #3\", \"1234 Ave\"], \"zip\": 10001}')), (2, null)"; - sql("INSERT INTO %s VALUES %s", tableName, values); - - GroupType name = - field( - "name", - shreddedPrimitive( - PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); - GroupType streets = - field( - "streets", - list( - element( - shreddedPrimitive( - PrimitiveType.PrimitiveTypeName.BINARY, - LogicalTypeAnnotation.stringType())))); - GroupType zip = - field( - "zip", - shreddedPrimitive( - PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(16))); - GroupType address = variant("address", 2, Type.Repetition.OPTIONAL, objectFields(name, streets, zip)); - MessageType expectedSchema = parquetSchema(address); - - Table table = validationCatalog.loadTable(tableIdent); - verifyParquetSchema(table, expectedSchema); - } - - @TestTemplate - public void testVariantShreddingWithNullFirstRow() throws IOException { - spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); - - String values = "(1, null), (2, parse_json('{\"city\": \"Seattle\", \"state\": \"WA\"}'))"; - sql("INSERT INTO %s VALUES %s", tableName, values); - - GroupType city = - field( - "city", - shreddedPrimitive( - PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); - GroupType state = - field( - "state", - shreddedPrimitive( - PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); - GroupType address = variant("address", 2, Type.Repetition.OPTIONAL, objectFields(city, state)); - MessageType expectedSchema = parquetSchema(address); - - Table table = validationCatalog.loadTable(tableIdent); - verifyParquetSchema(table, expectedSchema); - } - - @TestTemplate - public void testVariantShreddingWithTwoVariantColumns() throws IOException { - validationCatalog.dropTable(tableIdent, true); - validationCatalog.createTable( - tableIdent, SCHEMA2, null, Map.of(TableProperties.FORMAT_VERSION, "3")); - - spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); - - String values = - "(1, parse_json('{\"city\": \"NYC\", \"zip\": 10001}'), parse_json('{\"type\": \"home\", \"verified\": true}')), " - + "(2, null, null)"; - sql("INSERT INTO %s VALUES %s", tableName, values); - - GroupType city = - field( - "city", - shreddedPrimitive( - PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); - GroupType zip = - field( - "zip", - shreddedPrimitive( - PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(16, true))); - GroupType address = variant("address", 2, Type.Repetition.OPTIONAL, objectFields(city, zip)); - - GroupType type = - field( - "type", - shreddedPrimitive( - PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); - GroupType verified = - field("verified", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.BOOLEAN)); - GroupType metadata = variant("metadata", 3, Type.Repetition.OPTIONAL, objectFields(type, verified)); - - MessageType expectedSchema = parquetSchema(address, metadata); - - Table table = validationCatalog.loadTable(tableIdent); - verifyParquetSchema(table, expectedSchema); - } - - @TestTemplate - public void testVariantShreddingWithTwoVariantColumnsOneNull() throws IOException { - validationCatalog.dropTable(tableIdent, true); - validationCatalog.createTable( - tableIdent, SCHEMA2, null, Map.of(TableProperties.FORMAT_VERSION, "3")); - - spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); - - // First row: address is null, metadata has value - // Second row: address has value, metadata is null - String values = - "(1, null, parse_json('{\"label\": \"primary\"}'))," - + " (2, parse_json('{\"street\": \"Main St\"}'), null)"; - sql("INSERT INTO %s VALUES %s", tableName, values); - - GroupType street = - field( - "street", - shreddedPrimitive( - PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); - GroupType address = variant("address", 2, Type.Repetition.OPTIONAL, objectFields(street)); - - GroupType label = - field( - "label", - shreddedPrimitive( - PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); - GroupType metadata = variant("metadata", 3, Type.Repetition.OPTIONAL, objectFields(label)); - - MessageType expectedSchema = parquetSchema(address, metadata); - - Table table = validationCatalog.loadTable(tableIdent); - verifyParquetSchema(table, expectedSchema); - } - @TestTemplate public void testVariantShreddingDisabled() throws IOException { - // Test with shredding explicitly disabled spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "false"); String values = "(1, parse_json('{\"city\": \"NYC\", \"zip\": 10001}')), (2, null)"; @@ -259,7 +127,7 @@ public void testVariantShreddingDisabled() throws IOException { } @TestTemplate - public void testConsistentTypeCreatesTypedValue() throws IOException { + public void testConsistentType() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); String values = @@ -274,7 +142,10 @@ public void testConsistentTypeCreatesTypedValue() throws IOException { shreddedPrimitive( PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); GroupType age = - field("age", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + field( + "age", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(age, name)); MessageType expectedSchema = parquetSchema(address); @@ -282,25 +153,21 @@ public void testConsistentTypeCreatesTypedValue() throws IOException { verifyParquetSchema(table, expectedSchema); } - /** - * Test Heuristic 2: Inconsistent Type → Value Only - * - *

When a field appears with different types across rows, only the "value" field should be - * created (no "typed_value"). - */ @TestTemplate - public void testInconsistentTypeCreatesValueOnly() throws IOException { + public void testInconsistentType() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); - // "age" appears as both string and int - inconsistent type String values = "(1, parse_json('{\"age\": \"25\"}'))," + " (2, parse_json('{\"age\": 30}'))," + " (3, parse_json('{\"age\": \"35\"}'))"; sql("INSERT INTO %s VALUES %s", tableName, values); - // "age" should have only "value" field, no "typed_value" - GroupType age = valueOnlyField("age"); + GroupType age = + field( + "age", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(age)); MessageType expectedSchema = parquetSchema(address); @@ -308,172 +175,80 @@ public void testInconsistentTypeCreatesValueOnly() throws IOException { verifyParquetSchema(table, expectedSchema); } - /** - * Test Heuristic 3: Rare Fields Are Dropped - * - *

Fields that appear in less than the configured threshold percentage of rows should be - * dropped from the shredded schema. - */ - @TestTemplate - public void testRareFieldIsDropped() throws IOException { - spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); - // Set threshold to 20% (0.2) - spark.conf().set(SparkSQLProperties.VARIANT_MIN_OCCURRENCE_THRESHOLD, "0.2"); - - // "common" appears in all 10 rows (100%), "rare" appears in 1 row (10%) - String values = - "(1, parse_json('{\"common\": 1, \"rare\": 100}'))," - + " (2, parse_json('{\"common\": 2}'))," - + " (3, parse_json('{\"common\": 3}'))," - + " (4, parse_json('{\"common\": 4}'))," - + " (5, parse_json('{\"common\": 5}'))," - + " (6, parse_json('{\"common\": 6}'))," - + " (7, parse_json('{\"common\": 7}'))," - + " (8, parse_json('{\"common\": 8}'))," - + " (9, parse_json('{\"common\": 9}'))," - + " (10, parse_json('{\"common\": 10}'))"; - sql("INSERT INTO %s VALUES %s", tableName, values); - - // Only "common" should be present (appears in 100% of rows) - // "rare" should be dropped (appears in only 10% of rows, below 20% threshold) - GroupType common = - field("common", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); - GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(common)); - MessageType expectedSchema = parquetSchema(address); - - Table table = validationCatalog.loadTable(tableIdent); - verifyParquetSchema(table, expectedSchema); - - // Reset threshold to default to avoid interference with other tests - spark.conf().unset(SparkSQLProperties.VARIANT_MIN_OCCURRENCE_THRESHOLD); - } - - /** - * Test Heuristic 4: Boolean Type Handling - * - *

Both "true" and "false" values should be treated as the same consistent boolean type, and a - * typed_value field should be created. - */ @TestTemplate - public void testBooleanTypeHandling() throws IOException { + public void testPrimitiveType() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); - // "active" field has both true and false values - should be treated as consistent boolean - String values = - "(1, parse_json('{\"active\": true}'))," - + " (2, parse_json('{\"active\": false}'))," - + " (3, parse_json('{\"active\": true}'))"; + String values = "(1, parse_json('123')), (2, parse_json('\"abc\"')), (3, parse_json('12'))"; sql("INSERT INTO %s VALUES %s", tableName, values); - // "active" should have typed_value with boolean type - GroupType active = field("active", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.BOOLEAN)); - GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(active)); + GroupType address = + variant( + "address", + 2, + Type.Repetition.REQUIRED, + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); MessageType expectedSchema = parquetSchema(address); Table table = validationCatalog.loadTable(tableIdent); verifyParquetSchema(table, expectedSchema); - - // Reset field limit to default to avoid interference from previous tests - spark.conf().unset(SparkSQLProperties.VARIANT_MAX_SHREDDED_FIELDS); } - /** - * Test Heuristic 5: Mixed Fields (Consistent and Inconsistent) - * - *

Tests a realistic scenario with multiple fields where some have consistent types and others - * don't. - */ @TestTemplate - public void testMixedFieldsConsistentAndInconsistent() throws IOException { + public void testPrimitiveDecimalType() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); - // "name": always string (consistent) - // "age": mixed int/string (inconsistent) - // "active": boolean (consistent) String values = - "(1, parse_json('{\"name\": \"Alice\", \"age\": 30, \"active\": true}'))," - + " (2, parse_json('{\"name\": \"Bob\", \"age\": \"25\", \"active\": false}'))," - + " (3, parse_json('{\"name\": \"Charlie\", \"age\": \"35\", \"active\": true}'))"; + "(1, parse_json('123.56')), (2, parse_json('\"abc\"')), (3, parse_json('12.56'))"; sql("INSERT INTO %s VALUES %s", tableName, values); - // "name" should have typed_value (consistent string) - GroupType name = - field( - "name", + GroupType address = + variant( + "address", + 2, + Type.Repetition.REQUIRED, shreddedPrimitive( - PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); - - // "age" should NOT have typed_value (inconsistent types) - GroupType age = valueOnlyField("age"); - - // "active" should have typed_value (consistent boolean) - GroupType active = field("active", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.BOOLEAN)); - - GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(active, age, name)); + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.decimalType(2, 5))); MessageType expectedSchema = parquetSchema(address); Table table = validationCatalog.loadTable(tableIdent); verifyParquetSchema(table, expectedSchema); } - /** - * Test Heuristic 6: Field Limit Enforcement - * - *

Verify that the analyzer respects the maximum field limit and stops adding fields once the - * limit is reached. - */ @TestTemplate - public void testMaxFieldLimitEnforcement() throws IOException { + public void testBooleanType() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); - // Set very low field limit - spark.conf().set(SparkSQLProperties.VARIANT_MAX_SHREDDED_FIELDS, "4"); - // Create rows with many fields (a, b, c, d, e, f) String values = - "(1, parse_json('{\"a\": 1, \"b\": 2, \"c\": 3, \"d\": 4, \"e\": 5, \"f\": 6}'))," - + " (2, parse_json('{\"a\": 1, \"b\": 2, \"c\": 3, \"d\": 4, \"e\": 5, \"f\": 6}'))"; + "(1, parse_json('{\"active\": true}'))," + + " (2, parse_json('{\"active\": false}'))," + + " (3, parse_json('{\"active\": true}'))"; sql("INSERT INTO %s VALUES %s", tableName, values); - // With limit 4: field "a" (2 fields: value + typed_value) + field "b" (2 fields) = 4 total - // Fields are added alphabetically, so only "a" and "b" should be present - GroupType a = - field("a", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); - GroupType b = - field("b", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); - - GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(a, b)); + GroupType active = field("active", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.BOOLEAN)); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(active)); MessageType expectedSchema = parquetSchema(address); Table table = validationCatalog.loadTable(tableIdent); verifyParquetSchema(table, expectedSchema); - - // Reset field limit to default to avoid interference from previous tests - spark.conf().unset(SparkSQLProperties.VARIANT_MAX_SHREDDED_FIELDS); } - /** - * Test Heuristic 7: Decimal Type Handling - Inconsistent Scales - * - *

Verify that decimal fields with different scales are treated as inconsistent types - * and only get a value field (no typed_value). - */ @TestTemplate - public void testDecimalTypeHandlingInconsistentScales() throws IOException { + public void testDecimalTypeWithInconsistentScales() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); - // Decimal values with different scales: scale 6, 2, 2 - // 123.456789 → precision 9, scale 6 - // 678.90 → precision 5, scale 2 - // 999.99 → precision 5, scale 2 - // These are treated as inconsistent types due to different scales String values = "(1, parse_json('{\"price\": 123.456789}'))," + " (2, parse_json('{\"price\": 678.90}'))," + " (3, parse_json('{\"price\": 999.99}'))"; sql("INSERT INTO %s VALUES %s", tableName, values); - // "price" has inconsistent scales, so only "value" field (no typed_value) - GroupType price = valueOnlyField("price"); + GroupType price = + field( + "price", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.decimalType(6, 9))); GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(price)); MessageType expectedSchema = parquetSchema(address); @@ -481,30 +256,21 @@ public void testDecimalTypeHandlingInconsistentScales() throws IOException { verifyParquetSchema(table, expectedSchema); } - /** - * Test Heuristic 7b: Decimal Type Handling - Consistent Scales - * - *

Verify that decimal fields with the same scale get proper typed_value with inferred - * precision/scale. - */ @TestTemplate - public void testDecimalTypeHandlingConsistentScales() throws IOException { + public void testDecimalTypeWithConsistentScales() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); - // Decimal values with consistent scale (all 2 decimal places) String values = "(1, parse_json('{\"price\": 123.45}'))," + " (2, parse_json('{\"price\": 678.90}'))," + " (3, parse_json('{\"price\": 999.99}'))"; sql("INSERT INTO %s VALUES %s", tableName, values); - // "price" should have typed_value with inferred DECIMAL(5,2) type GroupType price = field( "price", - org.apache.parquet.schema.Types.optional(PrimitiveType.PrimitiveTypeName.INT32) - .as(LogicalTypeAnnotation.decimalType(2, 5)) - .named("typed_value")); + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.decimalType(2, 5))); GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(price)); MessageType expectedSchema = parquetSchema(address); @@ -512,68 +278,38 @@ public void testDecimalTypeHandlingConsistentScales() throws IOException { verifyParquetSchema(table, expectedSchema); } - /** - * Test Heuristic 7c: Decimal Type Handling - Inconsistent After Buffering - * - *

Verify that when buffered rows have consistent decimal scales but subsequent unbuffered rows - * have inconsistent scales, the inconsistent values are written to the value field only. - * The schema is inferred from buffered rows and should include typed_value for the consistent type. - */ @TestTemplate - public void testDecimalTypeHandlingInconsistentAfterBuffering() throws IOException { + public void testArrayType() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); - // Set a small buffer size to test the scenario - spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "3"); - // First 3 rows (buffered): consistent scale (2 decimal places) - // 4th row onwards (unbuffered): different scale (6 decimal places) - // Schema should be inferred from buffered rows with DECIMAL(5,2) - // The unbuffered row with different scale should still write successfully to value field String values = - "(1, parse_json('{\"price\": 123.45}'))," - + " (2, parse_json('{\"price\": 678.90}'))," - + " (3, parse_json('{\"price\": 999.99}'))," - + " (4, parse_json('{\"price\": 111.111111}'))"; // Different scale - should write to value only + "(1, parse_json('[\"java\", \"scala\", \"python\"]'))," + + " (2, parse_json('[\"rust\", \"go\"]'))," + + " (3, parse_json('[\"javascript\"]'))"; sql("INSERT INTO %s VALUES %s", tableName, values); - // Schema should have typed_value with DECIMAL(5,2) based on buffered rows - GroupType price = - field( - "price", - org.apache.parquet.schema.Types.optional(PrimitiveType.PrimitiveTypeName.INT32) - .as(LogicalTypeAnnotation.decimalType(2, 5)) - .named("typed_value")); - GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(price)); + GroupType arr = + list( + element( + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType()))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, arr); MessageType expectedSchema = parquetSchema(address); Table table = validationCatalog.loadTable(tableIdent); verifyParquetSchema(table, expectedSchema); - - // Verify all rows were written successfully - List result = sql("SELECT id, address FROM %s ORDER BY id", tableName); - assertThat(result).hasSize(4); - - // Reset buffer size to default - spark.conf().unset(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE); } - /** - * Test Heuristic 8: Array Type Handling - * - *

Verify that array fields with consistent element types get proper typed_value. - */ @TestTemplate - public void testArrayTypeHandling() throws IOException { + public void testNestedArrayType() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); - // Arrays with consistent element types (all strings) String values = "(1, parse_json('{\"tags\": [\"java\", \"scala\", \"python\"]}'))," + " (2, parse_json('{\"tags\": [\"rust\", \"go\"]}'))," + " (3, parse_json('{\"tags\": [\"javascript\"]}'))"; sql("INSERT INTO %s VALUES %s", tableName, values); - // "tags" should have typed_value with list of strings GroupType tags = field( "tags", @@ -589,47 +325,44 @@ public void testArrayTypeHandling() throws IOException { verifyParquetSchema(table, expectedSchema); } - /** - * Test Heuristic 9: Nested Object Handling - * - *

Verify that simple nested objects are recursively shredded. - */ @TestTemplate - public void testNestedObjectHandling() throws IOException { + public void testNestedObjectType() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); - // Nested objects with consistent structure String values = - "(1, parse_json('{\"location\": {\"city\": \"Seattle\", \"zip\": 98101}}'))," + "(1, parse_json('{\"location\": {\"city\": \"Seattle\", \"zip\": 98101}, \"tags\": [\"java\", \"scala\", \"python\"]}'))," + " (2, parse_json('{\"location\": {\"city\": \"Portland\", \"zip\": 97201}}'))," + " (3, parse_json('{\"location\": {\"city\": \"NYC\", \"zip\": 10001}}'))"; sql("INSERT INTO %s VALUES %s", tableName, values); - // Nested "location" object should be shredded with its fields GroupType city = field( "city", shreddedPrimitive( PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); GroupType zip = - field("zip", shreddedPrimitive(PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(32, true))); - GroupType location = field("location", objectFields(zip, city)); + field( + "zip", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(32, true))); + GroupType location = field("location", objectFields(city, zip)); + GroupType tags = + field( + "tags", + list( + element( + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, + LogicalTypeAnnotation.stringType())))); - GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(location)); + GroupType address = + variant("address", 2, Type.Repetition.REQUIRED, objectFields(location, tags)); MessageType expectedSchema = parquetSchema(address); Table table = validationCatalog.loadTable(tableIdent); verifyParquetSchema(table, expectedSchema); } - /** Helper method to create a value-only field (no typed_value) for inconsistent types. */ - private static GroupType valueOnlyField(String name) { - return org.apache.parquet.schema.Types.buildGroup(Type.Repetition.REQUIRED) - .optional(PrimitiveType.PrimitiveTypeName.BINARY) - .named("value") - .named(name); - } - private void verifyParquetSchema(Table table, MessageType expectedSchema) throws IOException { try (CloseableIterable tasks = table.newScan().planFiles()) { assertThat(tasks).isNotEmpty(); @@ -644,6 +377,9 @@ private void verifyParquetSchema(Table table, MessageType expectedSchema) throws MessageType actualSchema = reader.getFileMetaData().getSchema(); assertThat(actualSchema).isEqualTo(expectedSchema); } + + // Print the result + spark.read().format("iceberg").load(tableName).orderBy("id").show(false); } } @@ -682,12 +418,12 @@ private static GroupType variant( } private static Type shreddedPrimitive(PrimitiveType.PrimitiveTypeName primitive) { - return org.apache.parquet.schema.Types.optional(primitive).named("typed_value"); + return optional(primitive).named("typed_value"); } private static Type shreddedPrimitive( PrimitiveType.PrimitiveTypeName primitive, LogicalTypeAnnotation annotation) { - return org.apache.parquet.schema.Types.optional(primitive).as(annotation).named("typed_value"); + return optional(primitive).as(annotation).named("typed_value"); } private static GroupType objectFields(GroupType... fields) { From 5c0533e90cc7663b5468b73382671fa15f359a93 Mon Sep 17 00:00:00 2001 From: Aihua Xu Date: Thu, 15 Jan 2026 11:23:33 -0800 Subject: [PATCH 4/5] Add to 4.1 --- .../apache/iceberg/parquet/ParquetWriter.java | 10 +- .../iceberg/spark/SparkSQLProperties.java | 10 -- .../apache/iceberg/spark/SparkWriteConf.java | 20 --- .../iceberg/spark/SparkWriteOptions.java | 3 - .../spark/source/SparkFileWriterFactory.java | 12 +- .../iceberg/spark/TestSparkWriteConf.java | 7 - .../iceberg/spark/SparkSQLProperties.java | 10 ++ .../apache/iceberg/spark/SparkWriteConf.java | 20 +++ .../iceberg/spark/SparkWriteOptions.java | 3 + .../spark/source/SchemaInferenceVisitor.java | 12 +- .../spark/source/SparkFileWriterFactory.java | 12 +- ...parkParquetWriterWithVariantShredding.java | 0 .../source/VariantShreddingAnalyzer.java | 129 +++++++++--------- .../iceberg/spark/TestSparkWriteConf.java | 7 + .../spark/variant/TestVariantShredding.java | 62 +++++++++ 15 files changed, 187 insertions(+), 130 deletions(-) rename spark/{v4.0 => v4.1}/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java (89%) rename spark/{v4.0 => v4.1}/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java (100%) rename spark/{v4.0 => v4.1}/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java (74%) rename spark/{v4.0 => v4.1}/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java (87%) diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetWriter.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetWriter.java index f359a99d72db..7144b474a0dc 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetWriter.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetWriter.java @@ -134,8 +134,7 @@ private void ensureWriterInitialized() { @Override public void add(T value) { - if (model instanceof WriterLazyInitializable) { - WriterLazyInitializable lazy = (WriterLazyInitializable) model; + if (model instanceof WriterLazyInitializable lazy) { if (lazy.needsInitialization()) { model.write(0, value); recordCount += 1; @@ -144,7 +143,9 @@ public void add(T value) { WriterLazyInitializable.InitializationResult result = lazy.initialize(props, compressor, rowGroupOrdinal); this.parquetSchema = result.getSchema(); + this.pageStore.close(); this.pageStore = result.getPageStore(); + this.writeStore.close(); this.writeStore = result.getWriteStore(); // Re-initialize the file writer with the new schema @@ -281,13 +282,14 @@ public void close() throws IOException { this.closed = true; // Force initialization if lazy writer still has buffered data - if (model instanceof WriterLazyInitializable) { - WriterLazyInitializable lazy = (WriterLazyInitializable) model; + if (model instanceof WriterLazyInitializable lazy) { if (lazy.needsInitialization()) { WriterLazyInitializable.InitializationResult result = lazy.initialize(props, compressor, rowGroupOrdinal); this.parquetSchema = result.getSchema(); + this.pageStore.close(); this.pageStore = result.getPageStore(); + this.writeStore.close(); this.writeStore = result.getWriteStore(); ensureWriterInitialized(); diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java index b12606d23948..81139969f746 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java @@ -109,14 +109,4 @@ private SparkSQLProperties() {} // Prefix for custom snapshot properties public static final String SNAPSHOT_PROPERTY_PREFIX = "spark.sql.iceberg.snapshot-property."; - - // Controls whether to shred variant columns during write operations - public static final String SHRED_VARIANTS = "spark.sql.iceberg.shred-variants"; - public static final boolean SHRED_VARIANTS_DEFAULT = true; - - // Controls the buffer size for variant schema inference during writes - // This determines how many rows are buffered before inferring shredded schema - public static final String VARIANT_INFERENCE_BUFFER_SIZE = - "spark.sql.iceberg.variant.inference.buffer-size"; - public static final int VARIANT_INFERENCE_BUFFER_SIZE_DEFAULT = 10; } diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java index 80d245712e6b..96131e0e56dd 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java @@ -509,17 +509,6 @@ private Map dataWriteProperties() { if (parquetCompressionLevel != null) { writeProperties.put(PARQUET_COMPRESSION_LEVEL, parquetCompressionLevel); } - writeProperties.put(SparkSQLProperties.SHRED_VARIANTS, String.valueOf(shredVariants())); - - // Add variant shredding configuration properties - if (shredVariants()) { - String variantBufferSize = - sessionConf.get(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, null); - if (variantBufferSize != null) { - writeProperties.put( - SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, variantBufferSize); - } - } break; case AVRO: @@ -740,13 +729,4 @@ public DeleteGranularity deleteGranularity() { .defaultValue(DeleteGranularity.FILE) .parse(); } - - public boolean shredVariants() { - return confParser - .booleanConf() - .option(SparkWriteOptions.SHRED_VARIANTS) - .sessionConf(SparkSQLProperties.SHRED_VARIANTS) - .defaultValue(SparkSQLProperties.SHRED_VARIANTS_DEFAULT) - .parse(); - } } diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java index f8fb41696f76..33db70bae587 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java @@ -85,7 +85,4 @@ private SparkWriteOptions() {} // Overrides the delete granularity public static final String DELETE_GRANULARITY = "delete-granularity"; - - // Controls whether to shred variant columns during write operations - public static final String SHRED_VARIANTS = "shred-variants"; } diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkFileWriterFactory.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkFileWriterFactory.java index 8c74c65fc1b4..a93db17e4a0f 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkFileWriterFactory.java +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkFileWriterFactory.java @@ -151,17 +151,7 @@ protected void configurePositionDelete(Avro.DeleteWriteBuilder builder) { @Override protected void configureDataWrite(Parquet.DataWriteBuilder builder) { - if (SparkParquetWriterWithVariantShredding.shouldUseVariantShredding( - writeProperties, dataSchema())) { - builder.createWriterFunc( - msgType -> - new SparkParquetWriterWithVariantShredding( - dataSparkType(), msgType, writeProperties)); - } else { - builder.createWriterFunc( - msgType -> SparkParquetWriters.buildWriter(dataSparkType(), msgType)); - } - + builder.createWriterFunc(msgType -> SparkParquetWriters.buildWriter(dataSparkType(), msgType)); builder.setAll(writeProperties); } diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java index d97579f29e86..61aacfa4589d 100644 --- a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java @@ -41,7 +41,6 @@ import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_RANGE; import static org.apache.iceberg.spark.SparkSQLProperties.COMPRESSION_CODEC; import static org.apache.iceberg.spark.SparkSQLProperties.COMPRESSION_LEVEL; -import static org.apache.iceberg.spark.SparkSQLProperties.SHRED_VARIANTS; import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.DELETE; import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.MERGE; import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.UPDATE; @@ -340,8 +339,6 @@ public void testSparkConfOverride() { TableProperties.DELETE_PARQUET_COMPRESSION, "snappy"), ImmutableMap.of( - SHRED_VARIANTS, - "true", DELETE_PARQUET_COMPRESSION, "zstd", PARQUET_COMPRESSION, @@ -463,8 +460,6 @@ public void testDataPropsDefaultsAsDeleteProps() { PARQUET_COMPRESSION_LEVEL, "5"), ImmutableMap.of( - SHRED_VARIANTS, - "true", DELETE_PARQUET_COMPRESSION, "zstd", PARQUET_COMPRESSION, @@ -536,8 +531,6 @@ public void testDeleteFileWriteConf() { DELETE_PARQUET_COMPRESSION_LEVEL, "6"), ImmutableMap.of( - SHRED_VARIANTS, - "true", DELETE_PARQUET_COMPRESSION, "zstd", PARQUET_COMPRESSION, diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java index 81139969f746..b12606d23948 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java @@ -109,4 +109,14 @@ private SparkSQLProperties() {} // Prefix for custom snapshot properties public static final String SNAPSHOT_PROPERTY_PREFIX = "spark.sql.iceberg.snapshot-property."; + + // Controls whether to shred variant columns during write operations + public static final String SHRED_VARIANTS = "spark.sql.iceberg.shred-variants"; + public static final boolean SHRED_VARIANTS_DEFAULT = true; + + // Controls the buffer size for variant schema inference during writes + // This determines how many rows are buffered before inferring shredded schema + public static final String VARIANT_INFERENCE_BUFFER_SIZE = + "spark.sql.iceberg.variant.inference.buffer-size"; + public static final int VARIANT_INFERENCE_BUFFER_SIZE_DEFAULT = 10; } diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java index 96131e0e56dd..80d245712e6b 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java @@ -509,6 +509,17 @@ private Map dataWriteProperties() { if (parquetCompressionLevel != null) { writeProperties.put(PARQUET_COMPRESSION_LEVEL, parquetCompressionLevel); } + writeProperties.put(SparkSQLProperties.SHRED_VARIANTS, String.valueOf(shredVariants())); + + // Add variant shredding configuration properties + if (shredVariants()) { + String variantBufferSize = + sessionConf.get(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, null); + if (variantBufferSize != null) { + writeProperties.put( + SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, variantBufferSize); + } + } break; case AVRO: @@ -729,4 +740,13 @@ public DeleteGranularity deleteGranularity() { .defaultValue(DeleteGranularity.FILE) .parse(); } + + public boolean shredVariants() { + return confParser + .booleanConf() + .option(SparkWriteOptions.SHRED_VARIANTS) + .sessionConf(SparkSQLProperties.SHRED_VARIANTS) + .defaultValue(SparkSQLProperties.SHRED_VARIANTS_DEFAULT) + .parse(); + } } diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java index 33db70bae587..f8fb41696f76 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java @@ -85,4 +85,7 @@ private SparkWriteOptions() {} // Overrides the delete granularity public static final String DELETE_GRANULARITY = "delete-granularity"; + + // Controls whether to shred variant columns during write operations + public static final String SHRED_VARIANTS = "shred-variants"; } diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java similarity index 89% rename from spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java rename to spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java index 6903f1f03353..06a79b8dcef0 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SchemaInferenceVisitor.java @@ -152,7 +152,6 @@ private int getFieldIndex(String[] path) { return -1; } - // Support nested variant fields by navigating the struct hierarchy if (path.length == 1) { // Top-level field - direct lookup String fieldName = path[0]; @@ -162,15 +161,8 @@ private int getFieldIndex(String[] path) { } } } else { - // Nested field - navigate through struct hierarchy - // For now, we only support direct struct nesting (not arrays/maps) - LOG.debug("Attempting to resolve nested variant field path: {}", String.join(".", path)); - // TODO: Implement full nested field resolution when needed - // This would require tracking the current struct context during traversal - // and maintaining a stack of field indices - LOG.warn( - "Multi-level nested variant fields require struct context tracking. Path: {}", - String.join(".", path)); + // TODO: Implement full nested field resolution + LOG.warn("Nested variant shredding is not supported. Path: {}", String.join(".", path)); } return -1; diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkFileWriterFactory.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkFileWriterFactory.java index a93db17e4a0f..8c74c65fc1b4 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkFileWriterFactory.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkFileWriterFactory.java @@ -151,7 +151,17 @@ protected void configurePositionDelete(Avro.DeleteWriteBuilder builder) { @Override protected void configureDataWrite(Parquet.DataWriteBuilder builder) { - builder.createWriterFunc(msgType -> SparkParquetWriters.buildWriter(dataSparkType(), msgType)); + if (SparkParquetWriterWithVariantShredding.shouldUseVariantShredding( + writeProperties, dataSchema())) { + builder.createWriterFunc( + msgType -> + new SparkParquetWriterWithVariantShredding( + dataSparkType(), msgType, writeProperties)); + } else { + builder.createWriterFunc( + msgType -> SparkParquetWriters.buildWriter(dataSparkType(), msgType)); + } + builder.setAll(writeProperties); } diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java similarity index 100% rename from spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java rename to spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java similarity index 74% rename from spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java rename to spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java index 27b526134737..9487c2dc0141 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java @@ -104,7 +104,7 @@ private static PathNode buildPathTree(List variantValues) { } private static void traverse(PathNode node, VariantValue value) { - if (value == null) { + if (value == null || value.type() == PhysicalType.NULL) { return; } @@ -139,7 +139,12 @@ private static void traverse(PathNode node, VariantValue value) { } private static Type buildFieldGroup(PathNode node) { - Type typedValue = buildTypedValue(node, node.info.getMostCommonType()); + PhysicalType commonType = node.info.getMostCommonType(); + if (commonType == null) { + return null; + } + + Type typedValue = buildTypedValue(node, commonType); return Types.buildGroup(Type.Repetition.REQUIRED) .optional(PrimitiveType.PrimitiveTypeName.BINARY) .named(VALUE) @@ -167,7 +172,12 @@ private static Type createObjectTypedValue(PathNode node) { Types.GroupBuilder builder = Types.buildGroup(Type.Repetition.OPTIONAL); for (PathNode child : node.objectChildren.values()) { - builder.addField(buildFieldGroup(child)); + Type fieldType = buildFieldGroup(child); + if (fieldType == null) { + continue; + } + + builder.addField(fieldType); } return builder.named(TYPED_VALUE); @@ -221,67 +231,58 @@ private static Type createDecimalTypedValue(FieldInfo info) { } private static Type createPrimitiveTypedValue(FieldInfo info, PhysicalType primitiveType) { - switch (primitiveType) { - case BOOLEAN_TRUE: - case BOOLEAN_FALSE: - return Types.optional(PrimitiveType.PrimitiveTypeName.BOOLEAN).named(TYPED_VALUE); - - case INT8: - return Types.optional(PrimitiveType.PrimitiveTypeName.INT32) - .as(LogicalTypeAnnotation.intType(8, true)) - .named(TYPED_VALUE); - - case INT16: - return Types.optional(PrimitiveType.PrimitiveTypeName.INT32) - .as(LogicalTypeAnnotation.intType(16, true)) - .named(TYPED_VALUE); - - case INT32: - return Types.optional(PrimitiveType.PrimitiveTypeName.INT32) - .as(LogicalTypeAnnotation.intType(32, true)) - .named(TYPED_VALUE); - - case INT64: - return Types.optional(PrimitiveType.PrimitiveTypeName.INT64).named(TYPED_VALUE); - - case FLOAT: - return Types.optional(PrimitiveType.PrimitiveTypeName.FLOAT).named(TYPED_VALUE); - - case DOUBLE: - return Types.optional(PrimitiveType.PrimitiveTypeName.DOUBLE).named(TYPED_VALUE); - - case STRING: - return Types.optional(PrimitiveType.PrimitiveTypeName.BINARY) - .as(LogicalTypeAnnotation.stringType()) - .named(TYPED_VALUE); - - case BINARY: - return Types.optional(PrimitiveType.PrimitiveTypeName.BINARY).named(TYPED_VALUE); - - case DATE: - return Types.optional(PrimitiveType.PrimitiveTypeName.INT32) - .as(LogicalTypeAnnotation.dateType()) - .named(TYPED_VALUE); - - case TIMESTAMPTZ: - return Types.optional(PrimitiveType.PrimitiveTypeName.INT64) - .as(LogicalTypeAnnotation.timestampType(true, LogicalTypeAnnotation.TimeUnit.MICROS)) - .named(TYPED_VALUE); - - case TIMESTAMPNTZ: - return Types.optional(PrimitiveType.PrimitiveTypeName.INT64) - .as(LogicalTypeAnnotation.timestampType(false, LogicalTypeAnnotation.TimeUnit.MICROS)) - .named(TYPED_VALUE); - - case DECIMAL4: - case DECIMAL8: - case DECIMAL16: - return createDecimalTypedValue(info); - - default: - throw new UnsupportedOperationException( - "Unknown primitive physical type: " + primitiveType); - } + return switch (primitiveType) { + case BOOLEAN_TRUE, BOOLEAN_FALSE -> + Types.optional(PrimitiveType.PrimitiveTypeName.BOOLEAN).named(TYPED_VALUE); + case INT8 -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT32) + .as(LogicalTypeAnnotation.intType(8, true)) + .named(TYPED_VALUE); + case INT16 -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT32) + .as(LogicalTypeAnnotation.intType(16, true)) + .named(TYPED_VALUE); + case INT32 -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT32) + .as(LogicalTypeAnnotation.intType(32, true)) + .named(TYPED_VALUE); + case INT64 -> Types.optional(PrimitiveType.PrimitiveTypeName.INT64).named(TYPED_VALUE); + case FLOAT -> Types.optional(PrimitiveType.PrimitiveTypeName.FLOAT).named(TYPED_VALUE); + case DOUBLE -> Types.optional(PrimitiveType.PrimitiveTypeName.DOUBLE).named(TYPED_VALUE); + case STRING -> + Types.optional(PrimitiveType.PrimitiveTypeName.BINARY) + .as(LogicalTypeAnnotation.stringType()) + .named(TYPED_VALUE); + case BINARY -> Types.optional(PrimitiveType.PrimitiveTypeName.BINARY).named(TYPED_VALUE); + case TIME -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT64) + .as(LogicalTypeAnnotation.timeType(false, LogicalTypeAnnotation.TimeUnit.MICROS)) + .named(TYPED_VALUE); + case DATE -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT32) + .as(LogicalTypeAnnotation.dateType()) + .named(TYPED_VALUE); + case TIMESTAMPTZ -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT64) + .as(LogicalTypeAnnotation.timestampType(true, LogicalTypeAnnotation.TimeUnit.MICROS)) + .named(TYPED_VALUE); + case TIMESTAMPNTZ -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT64) + .as(LogicalTypeAnnotation.timestampType(false, LogicalTypeAnnotation.TimeUnit.MICROS)) + .named(TYPED_VALUE); + case TIMESTAMPTZ_NANOS -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT64) + .as(LogicalTypeAnnotation.timestampType(true, LogicalTypeAnnotation.TimeUnit.NANOS)) + .named(TYPED_VALUE); + case TIMESTAMPNTZ_NANOS -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT64) + .as(LogicalTypeAnnotation.timestampType(false, LogicalTypeAnnotation.TimeUnit.NANOS)) + .named(TYPED_VALUE); + case DECIMAL4, DECIMAL8, DECIMAL16 -> createDecimalTypedValue(info); + default -> + throw new UnsupportedOperationException( + "Unknown primitive physical type: " + primitiveType); + }; } /** Tracks occurrence count and types for a single field. */ diff --git a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java index 61aacfa4589d..d97579f29e86 100644 --- a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java +++ b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/TestSparkWriteConf.java @@ -41,6 +41,7 @@ import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE_RANGE; import static org.apache.iceberg.spark.SparkSQLProperties.COMPRESSION_CODEC; import static org.apache.iceberg.spark.SparkSQLProperties.COMPRESSION_LEVEL; +import static org.apache.iceberg.spark.SparkSQLProperties.SHRED_VARIANTS; import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.DELETE; import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.MERGE; import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.UPDATE; @@ -339,6 +340,8 @@ public void testSparkConfOverride() { TableProperties.DELETE_PARQUET_COMPRESSION, "snappy"), ImmutableMap.of( + SHRED_VARIANTS, + "true", DELETE_PARQUET_COMPRESSION, "zstd", PARQUET_COMPRESSION, @@ -460,6 +463,8 @@ public void testDataPropsDefaultsAsDeleteProps() { PARQUET_COMPRESSION_LEVEL, "5"), ImmutableMap.of( + SHRED_VARIANTS, + "true", DELETE_PARQUET_COMPRESSION, "zstd", PARQUET_COMPRESSION, @@ -531,6 +536,8 @@ public void testDeleteFileWriteConf() { DELETE_PARQUET_COMPRESSION_LEVEL, "6"), ImmutableMap.of( + SHRED_VARIANTS, + "true", DELETE_PARQUET_COMPRESSION, "zstd", PARQUET_COMPRESSION, diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java similarity index 87% rename from spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java rename to spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java index 5f4eb2a2732f..df239a674ba7 100644 --- a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java +++ b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java @@ -153,6 +153,33 @@ public void testConsistentType() throws IOException { verifyParquetSchema(table, expectedSchema); } + @TestTemplate + public void testExcludingNullValue() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = + "(1, parse_json('{\"name\": \"Alice\", \"age\": 30, \"dummy\": null}'))," + + " (2, parse_json('{\"name\": \"Bob\", \"age\": 25}'))," + + " (3, parse_json('{\"name\": \"Charlie\", \"age\": 35}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType age = + field( + "age", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(age, name)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + @TestTemplate public void testInconsistentType() throws IOException { spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); @@ -363,6 +390,41 @@ public void testNestedObjectType() throws IOException { verifyParquetSchema(table, expectedSchema); } + @TestTemplate + public void testLazyInitializationWithBufferedRows() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "5"); + + String values = + "(1, parse_json('{\"name\": \"Alice\", \"age\": 30}'))," + + " (2, parse_json('{\"name\": \"Bob\", \"age\": 25}'))," + + " (3, parse_json('{\"name\": \"Charlie\", \"age\": 35}'))," + + " (4, parse_json('{\"name\": \"David\", \"age\": 28}'))," + + " (5, parse_json('{\"name\": \"Eve\", \"age\": 32}'))," + + " (6, parse_json('{\"name\": \"Frank\", \"age\": 40}'))," + + " (7, parse_json('{\"name\": \"Grace\", \"age\": 27}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType age = + field( + "age", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(age, name)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + + long rowCount = spark.read().format("iceberg").load(tableName).count(); + assertThat(rowCount).isEqualTo(7); + } + private void verifyParquetSchema(Table table, MessageType expectedSchema) throws IOException { try (CloseableIterable tasks = table.newScan().planFiles()) { assertThat(tasks).isNotEmpty(); From d7e15a718ef0b6510ff2a60eec6a3720b7f9650f Mon Sep 17 00:00:00 2001 From: Aihua Xu Date: Mon, 26 Jan 2026 22:40:09 -0800 Subject: [PATCH 5/5] Add tie break and INT/DECIMAL promotion --- .../apache/iceberg/parquet/ParquetWriter.java | 11 +- .../parquet/WriterLazyInitializable.java | 8 +- .../apache/iceberg/spark/SparkWriteConf.java | 5 +- ...parkParquetWriterWithVariantShredding.java | 11 +- .../source/VariantShreddingAnalyzer.java | 61 ++++++- .../spark/variant/TestVariantShredding.java | 159 +++++++++++++++++- 6 files changed, 239 insertions(+), 16 deletions(-) diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetWriter.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetWriter.java index 7144b474a0dc..d19ab9f16125 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetWriter.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetWriter.java @@ -51,7 +51,6 @@ class ParquetWriter implements FileAppender, Closeable { private final Map metadata; private final ParquetProperties props; private final CodecFactory.BytesCompressor compressor; - private MessageType parquetSchema; private final ParquetValueWriter model; private final MetricsConfig metricsConfig; private final int columnIndexTruncateLength; @@ -60,6 +59,7 @@ class ParquetWriter implements FileAppender, Closeable { private final Configuration conf; private final InternalFileEncryptor fileEncryptor; + private MessageType parquetSchema; private ColumnChunkPageWriteStore pageStore = null; private ColumnWriteStore writeStore; private long recordCount = 0; @@ -141,7 +141,8 @@ public void add(T value) { if (!lazy.needsInitialization()) { WriterLazyInitializable.InitializationResult result = - lazy.initialize(props, compressor, rowGroupOrdinal); + lazy.initialize( + props, compressor, rowGroupOrdinal, columnIndexTruncateLength, fileEncryptor); this.parquetSchema = result.getSchema(); this.pageStore.close(); this.pageStore = result.getPageStore(); @@ -281,11 +282,13 @@ public void close() throws IOException { if (!closed) { this.closed = true; - // Force initialization if lazy writer still has buffered data if (model instanceof WriterLazyInitializable lazy) { + // If initialization is not triggered with few data, lazy writer needs to initialize and + // process remaining buffered data if (lazy.needsInitialization()) { WriterLazyInitializable.InitializationResult result = - lazy.initialize(props, compressor, rowGroupOrdinal); + lazy.initialize( + props, compressor, rowGroupOrdinal, columnIndexTruncateLength, fileEncryptor); this.parquetSchema = result.getSchema(); this.pageStore.close(); this.pageStore = result.getPageStore(); diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/WriterLazyInitializable.java b/parquet/src/main/java/org/apache/iceberg/parquet/WriterLazyInitializable.java index 9c5913d7bd9b..f7b6c591fa49 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/WriterLazyInitializable.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/WriterLazyInitializable.java @@ -21,6 +21,7 @@ import org.apache.parquet.column.ColumnWriteStore; import org.apache.parquet.column.ParquetProperties; import org.apache.parquet.compression.CompressionCodecFactory; +import org.apache.parquet.crypto.InternalFileEncryptor; import org.apache.parquet.hadoop.ColumnChunkPageWriteStore; import org.apache.parquet.schema.MessageType; @@ -78,10 +79,15 @@ public ColumnWriteStore getWriteStore() { * @param props Parquet properties needed for creating write stores * @param compressor Bytes compressor for compression * @param rowGroupOrdinal The ordinal number of the current row group + * @param columnIndexTruncateLength The column index truncate length from ParquetWriter config + * @param fileEncryptor The file encryptor from ParquetWriter, may be null if encryption is + * disabled * @return InitializationResult containing the finalized schema and write stores */ InitializationResult initialize( ParquetProperties props, CompressionCodecFactory.BytesInputCompressor compressor, - int rowGroupOrdinal); + int rowGroupOrdinal, + int columnIndexTruncateLength, + InternalFileEncryptor fileEncryptor); } diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java index 80d245712e6b..e72aa706dfa6 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java @@ -509,10 +509,11 @@ private Map dataWriteProperties() { if (parquetCompressionLevel != null) { writeProperties.put(PARQUET_COMPRESSION_LEVEL, parquetCompressionLevel); } - writeProperties.put(SparkSQLProperties.SHRED_VARIANTS, String.valueOf(shredVariants())); + boolean shouldShredVariants = shredVariants(); + writeProperties.put(SparkSQLProperties.SHRED_VARIANTS, String.valueOf(shouldShredVariants)); // Add variant shredding configuration properties - if (shredVariants()) { + if (shouldShredVariants) { String variantBufferSize = sessionConf.get(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, null); if (variantBufferSize != null) { diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java index 6a2ed1e85324..5b9c10ff548f 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkParquetWriterWithVariantShredding.java @@ -108,6 +108,9 @@ public List> columns() { @Override public void setColumnStore(ColumnWriteStore columnStore) { // Ignored for lazy initialization - will be set on actualWriter after initialization + if (actualWriter != null) { + actualWriter.setColumnStore(columnStore); + } } @Override @@ -127,7 +130,9 @@ public boolean needsInitialization() { public InitializationResult initialize( ParquetProperties props, CompressionCodecFactory.BytesInputCompressor compressor, - int rowGroupOrdinal) { + int rowGroupOrdinal, + int columnIndexTruncateLength, + org.apache.parquet.crypto.InternalFileEncryptor fileEncryptor) { if (bufferedRows.isEmpty()) { throw new IllegalStateException("No buffered rows available for schema inference"); } @@ -151,9 +156,9 @@ public InitializationResult initialize( compressor, shreddedSchema, props.getAllocator(), - 64, + columnIndexTruncateLength, ParquetProperties.DEFAULT_PAGE_WRITE_CHECKSUM_ENABLED, - null, + fileEncryptor, rowGroupOrdinal); ColumnWriteStore columnStore = props.newColumnWriteStore(shreddedSchema, pageStore, pageStore); diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java index 9487c2dc0141..fba3d258995b 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/VariantShreddingAnalyzer.java @@ -246,7 +246,10 @@ private static Type createPrimitiveTypedValue(FieldInfo info, PhysicalType primi Types.optional(PrimitiveType.PrimitiveTypeName.INT32) .as(LogicalTypeAnnotation.intType(32, true)) .named(TYPED_VALUE); - case INT64 -> Types.optional(PrimitiveType.PrimitiveTypeName.INT64).named(TYPED_VALUE); + case INT64 -> + Types.optional(PrimitiveType.PrimitiveTypeName.INT64) + .as(LogicalTypeAnnotation.intType(64, true)) + .named(TYPED_VALUE); case FLOAT -> Types.optional(PrimitiveType.PrimitiveTypeName.FLOAT).named(TYPED_VALUE); case DOUBLE -> Types.optional(PrimitiveType.PrimitiveTypeName.DOUBLE).named(TYPED_VALUE); case STRING -> @@ -296,6 +299,7 @@ void observe(VariantValue value) { // Use BOOLEAN_TRUE for both TRUE/FALSE values PhysicalType type = value.type() == PhysicalType.BOOLEAN_FALSE ? PhysicalType.BOOLEAN_TRUE : value.type(); + observedTypes.add(type); typeCounts.compute(type, (k, v) -> (v == null) ? 1 : v + 1); @@ -314,10 +318,61 @@ void observe(VariantValue value) { } PhysicalType getMostCommonType() { - return typeCounts.entrySet().stream() - .max(Map.Entry.comparingByValue()) + Map combinedCounts = Maps.newHashMap(); + + int integerTotalCount = 0; + PhysicalType mostCapableInteger = null; + + int decimalTotalCount = 0; + PhysicalType mostCapableDecimal = null; + + for (Map.Entry entry : typeCounts.entrySet()) { + PhysicalType type = entry.getKey(); + int count = entry.getValue(); + + if (isIntegerType(type)) { + integerTotalCount += count; + if (mostCapableInteger == null || type.ordinal() > mostCapableInteger.ordinal()) { + mostCapableInteger = type; + } + } else if (isDecimalType(type)) { + decimalTotalCount += count; + if (mostCapableDecimal == null || type.ordinal() > mostCapableDecimal.ordinal()) { + mostCapableDecimal = type; + } + } else { + combinedCounts.put(type, count); + } + } + + if (mostCapableInteger != null) { + combinedCounts.put(mostCapableInteger, integerTotalCount); + } + + if (mostCapableDecimal != null) { + combinedCounts.put(mostCapableDecimal, decimalTotalCount); + } + + // Pick the most common type with tie-breaking + return combinedCounts.entrySet().stream() + .max( + Map.Entry.comparingByValue() + .thenComparingInt(entry -> entry.getKey().ordinal())) .map(Map.Entry::getKey) .orElse(null); } + + private boolean isIntegerType(PhysicalType type) { + return type == PhysicalType.INT8 + || type == PhysicalType.INT16 + || type == PhysicalType.INT32 + || type == PhysicalType.INT64; + } + + private boolean isDecimalType(PhysicalType type) { + return type == PhysicalType.DECIMAL4 + || type == PhysicalType.DECIMAL8 + || type == PhysicalType.DECIMAL16; + } } } diff --git a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java index df239a674ba7..ec668c2043f8 100644 --- a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java +++ b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/variant/TestVariantShredding.java @@ -19,6 +19,7 @@ package org.apache.iceberg.spark.variant; import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.METASTOREURIS; +import static org.apache.iceberg.TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES; import static org.apache.parquet.schema.Types.optional; import static org.assertj.core.api.Assertions.assertThat; @@ -109,6 +110,8 @@ public void before() { @AfterEach public void after() { + spark.conf().unset(SparkSQLProperties.SHRED_VARIANTS); + spark.conf().unset(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE); validationCatalog.dropTable(tableIdent, true); } @@ -425,6 +428,159 @@ public void testLazyInitializationWithBufferedRows() throws IOException { assertThat(rowCount).isEqualTo(7); } + @TestTemplate + public void testTieBreakingWithEqualCounts() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + String values = + "(1, parse_json('{\"value\": 10}'))," + + " (2, parse_json('{\"value\": 20}'))," + + " (3, parse_json('{\"value\": \"hello\"}'))," + + " (4, parse_json('{\"value\": \"world\"}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + // When counts are tied, sort the types in order and choose the last one + GroupType value = + field( + "value", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(value)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testMultipleRowGroups() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "3"); + + int numRows = 1000; + StringBuilder valuesBuilder = new StringBuilder(); + for (int i = 1; i <= numRows; i++) { + if (i > 1) { + valuesBuilder.append(", "); + } + valuesBuilder.append( + String.format("(%d, parse_json('{\"name\": \"User%d\", \"age\": %d}'))", i, i, 20 + i)); + } + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')", + tableName, PARQUET_ROW_GROUP_SIZE_BYTES, 1024); + sql("INSERT INTO %s VALUES %s", tableName, valuesBuilder.toString()); + + GroupType name = + field( + "name", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType age = + field( + "age", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(age, name)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + + long rowCount = spark.read().format("iceberg").load(tableName).count(); + assertThat(rowCount).isEqualTo(numRows); + } + + @TestTemplate + public void testColumnIndexTruncateLength() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + spark.conf().set(SparkSQLProperties.VARIANT_INFERENCE_BUFFER_SIZE, "3"); + + int customTruncateLength = 10; + sql( + "ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')", + tableName, "parquet.columnindex.truncate.length", customTruncateLength); + + StringBuilder valuesBuilder = new StringBuilder(); + for (int i = 1; i <= 10; i++) { + if (i > 1) { + valuesBuilder.append(", "); + } + String longValue = "A".repeat(20); + valuesBuilder.append( + String.format( + "(%d, parse_json('{\"description\": \"%s\", \"id\": %d}'))", i, longValue, i)); + } + sql("INSERT INTO %s VALUES %s", tableName, valuesBuilder.toString()); + + GroupType description = + field( + "description", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.BINARY, LogicalTypeAnnotation.stringType())); + GroupType id = + field( + "id", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8, true))); + GroupType address = + variant("address", 2, Type.Repetition.REQUIRED, objectFields(description, id)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + + long rowCount = spark.read().format("iceberg").load(tableName).count(); + assertThat(rowCount).isEqualTo(10); + } + + @TestTemplate + public void testIntegerFamilyPromotion() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + // Mix of INT8, INT16, INT32, INT64 - should promote to INT64 + String values = + "(1, parse_json('{\"value\": 10}'))," + + " (2, parse_json('{\"value\": 1000}'))," + + " (3, parse_json('{\"value\": 100000}'))," + + " (4, parse_json('{\"value\": 10000000000}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType value = + field( + "value", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT64, LogicalTypeAnnotation.intType(64, true))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(value)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + + @TestTemplate + public void testDecimalFamilyPromotion() throws IOException { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + + // Test that they get promoted to the most capable decimal type observed + String values = + "(1, parse_json('{\"value\": 1.5}'))," + + " (2, parse_json('{\"value\": 123.456789}'))," + + " (3, parse_json('{\"value\": 123456789123456.789}'))"; + sql("INSERT INTO %s VALUES %s", tableName, values); + + GroupType value = + field( + "value", + shreddedPrimitive( + PrimitiveType.PrimitiveTypeName.INT64, LogicalTypeAnnotation.decimalType(6, 18))); + GroupType address = variant("address", 2, Type.Repetition.REQUIRED, objectFields(value)); + MessageType expectedSchema = parquetSchema(address); + + Table table = validationCatalog.loadTable(tableIdent); + verifyParquetSchema(table, expectedSchema); + } + private void verifyParquetSchema(Table table, MessageType expectedSchema) throws IOException { try (CloseableIterable tasks = table.newScan().planFiles()) { assertThat(tasks).isNotEmpty(); @@ -439,9 +595,6 @@ private void verifyParquetSchema(Table table, MessageType expectedSchema) throws MessageType actualSchema = reader.getFileMetaData().getSchema(); assertThat(actualSchema).isEqualTo(expectedSchema); } - - // Print the result - spark.read().format("iceberg").load(tableName).orderBy("id").show(false); } }