diff --git a/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/ArrowToAvroUtils.java b/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/ArrowToAvroUtils.java index 87b594af9e..e09b99f670 100644 --- a/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/ArrowToAvroUtils.java +++ b/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/ArrowToAvroUtils.java @@ -17,10 +17,14 @@ package org.apache.arrow.adapter.avro; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; +import java.util.Set; +import java.util.regex.Pattern; import org.apache.arrow.adapter.avro.producers.AvroBigIntProducer; import org.apache.arrow.adapter.avro.producers.AvroBooleanProducer; import org.apache.arrow.adapter.avro.producers.AvroBytesProducer; +import org.apache.arrow.adapter.avro.producers.AvroEnumProducer; import org.apache.arrow.adapter.avro.producers.AvroFixedSizeBinaryProducer; import org.apache.arrow.adapter.avro.producers.AvroFixedSizeListProducer; import org.apache.arrow.adapter.avro.producers.AvroFloat2Producer; @@ -41,6 +45,7 @@ import org.apache.arrow.adapter.avro.producers.AvroUint8Producer; import org.apache.arrow.adapter.avro.producers.BaseAvroProducer; import org.apache.arrow.adapter.avro.producers.CompositeAvroProducer; +import org.apache.arrow.adapter.avro.producers.DictionaryDecodingProducer; import org.apache.arrow.adapter.avro.producers.Producer; import org.apache.arrow.adapter.avro.producers.logical.AvroDateDayProducer; import org.apache.arrow.adapter.avro.producers.logical.AvroDateMilliProducer; @@ -59,6 +64,7 @@ import org.apache.arrow.adapter.avro.producers.logical.AvroTimestampSecProducer; import org.apache.arrow.adapter.avro.producers.logical.AvroTimestampSecTzProducer; import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.BaseIntVector; import org.apache.arrow.vector.BigIntVector; import org.apache.arrow.vector.BitVector; import org.apache.arrow.vector.DateDayVector; @@ -96,11 +102,14 @@ import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.MapVector; import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.TimeUnit; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.util.Text; import org.apache.avro.LogicalType; import org.apache.avro.LogicalTypes; import org.apache.avro.Schema; @@ -162,17 +171,29 @@ public class ArrowToAvroUtils { * may be nullable. Record types must contain at least one child field and cannot contain multiple * fields with the same name * + *

String fields that are dictionary-encoded will be represented as an Avro enum, so long as + * all the values meet the restrictions on Avro enums (non-null, valid identifiers). Other data + * types that are dictionary encoded, or string fields that do not meet the avro requirements, + * will be output as their decoded type. + * * @param arrowFields The arrow fields used to generate the Avro schema * @param typeName Name of the top level Avro record type * @param namespace Namespace of the top level Avro record type + * @param dictionaries A dictionary provider is required if any fields use dictionary encoding * @return An Avro record schema for the given list of fields, with the specified name and * namespace */ public static Schema createAvroSchema( - List arrowFields, String typeName, String namespace) { + List arrowFields, String typeName, String namespace, DictionaryProvider dictionaries) { SchemaBuilder.RecordBuilder assembler = SchemaBuilder.record(typeName).namespace(namespace); - return buildRecordSchema(assembler, arrowFields, namespace); + return buildRecordSchema(assembler, arrowFields, namespace, dictionaries); + } + + /** Overload provided for convenience, sets dictionaries = null. */ + public static Schema createAvroSchema( + List arrowFields, String typeName, String namespace) { + return createAvroSchema(arrowFields, typeName, namespace, null); } /** Overload provided for convenience, sets namespace = null. */ @@ -185,61 +206,83 @@ public static Schema createAvroSchema(List arrowFields) { return createAvroSchema(arrowFields, GENERIC_RECORD_TYPE_NAME); } + /** + * Overload provided for convenience, sets name = GENERIC_RECORD_TYPE_NAME and namespace = null. + */ + public static Schema createAvroSchema(List arrowFields, DictionaryProvider dictionaries) { + return createAvroSchema(arrowFields, GENERIC_RECORD_TYPE_NAME, null, dictionaries); + } + private static T buildRecordSchema( - SchemaBuilder.RecordBuilder builder, List fields, String namespace) { + SchemaBuilder.RecordBuilder builder, + List fields, + String namespace, + DictionaryProvider dictionaries) { if (fields.isEmpty()) { throw new IllegalArgumentException("Record field must have at least one child field"); } SchemaBuilder.FieldAssembler assembler = builder.namespace(namespace).fields(); for (Field field : fields) { - assembler = buildFieldSchema(assembler, field, namespace); + assembler = buildFieldSchema(assembler, field, namespace, dictionaries); } return assembler.endRecord(); } private static SchemaBuilder.FieldAssembler buildFieldSchema( - SchemaBuilder.FieldAssembler assembler, Field field, String namespace) { + SchemaBuilder.FieldAssembler assembler, + Field field, + String namespace, + DictionaryProvider dictionaries) { return assembler .name(field.getName()) - .type(buildTypeSchema(SchemaBuilder.builder(), field, namespace)) + .type(buildTypeSchema(SchemaBuilder.builder(), field, namespace, dictionaries)) .noDefault(); } private static T buildTypeSchema( - SchemaBuilder.TypeBuilder builder, Field field, String namespace) { + SchemaBuilder.TypeBuilder builder, + Field field, + String namespace, + DictionaryProvider dictionaries) { // Nullable unions need special handling, since union types cannot be directly nested if (field.getType().getTypeID() == ArrowType.ArrowTypeID.Union) { boolean unionNullable = field.getChildren().stream().anyMatch(Field::isNullable); if (unionNullable) { SchemaBuilder.UnionAccumulator union = builder.unionOf().nullType(); - return addTypesToUnion(union, field.getChildren(), namespace); + return addTypesToUnion(union, field.getChildren(), namespace, dictionaries); } else { Field headType = field.getChildren().get(0); List tailTypes = field.getChildren().subList(1, field.getChildren().size()); SchemaBuilder.UnionAccumulator union = - buildBaseTypeSchema(builder.unionOf(), headType, namespace); - return addTypesToUnion(union, tailTypes, namespace); + buildBaseTypeSchema(builder.unionOf(), headType, namespace, dictionaries); + return addTypesToUnion(union, tailTypes, namespace, dictionaries); } } else if (field.isNullable()) { - return buildBaseTypeSchema(builder.nullable(), field, namespace); + return buildBaseTypeSchema(builder.nullable(), field, namespace, dictionaries); } else { - return buildBaseTypeSchema(builder, field, namespace); + return buildBaseTypeSchema(builder, field, namespace, dictionaries); } } private static T buildArraySchema( - SchemaBuilder.ArrayBuilder builder, Field listField, String namespace) { + SchemaBuilder.ArrayBuilder builder, + Field listField, + String namespace, + DictionaryProvider dictionaries) { if (listField.getChildren().size() != 1) { throw new IllegalArgumentException("List field must have exactly one child field"); } Field itemField = listField.getChildren().get(0); - return buildTypeSchema(builder.items(), itemField, namespace); + return buildTypeSchema(builder.items(), itemField, namespace, dictionaries); } private static T buildMapSchema( - SchemaBuilder.MapBuilder builder, Field mapField, String namespace) { + SchemaBuilder.MapBuilder builder, + Field mapField, + String namespace, + DictionaryProvider dictionaries) { if (mapField.getChildren().size() != 1) { throw new IllegalArgumentException("Map field must have exactly one child field"); } @@ -253,11 +296,14 @@ private static T buildMapSchema( throw new IllegalArgumentException( "Map keys must be of type string and cannot be nullable for conversion to Avro"); } - return buildTypeSchema(builder.values(), valueField, namespace); + return buildTypeSchema(builder.values(), valueField, namespace, dictionaries); } private static T buildBaseTypeSchema( - SchemaBuilder.BaseTypeBuilder builder, Field field, String namespace) { + SchemaBuilder.BaseTypeBuilder builder, + Field field, + String namespace, + DictionaryProvider dictionaries) { ArrowType.ArrowTypeID typeID = field.getType().getTypeID(); @@ -269,6 +315,33 @@ private static T buildBaseTypeSchema( return builder.booleanType(); case Int: + if (field.getDictionary() != null) { + if (dictionaries == null) { + throw new IllegalArgumentException( + "Field references a dictionary but no dictionaries were provided: " + + field.getName()); + } + Dictionary dictionary = dictionaries.lookup(field.getDictionary().getId()); + if (dictionary == null) { + throw new IllegalArgumentException( + "Field references a dictionary that does not exist: " + + field.getName() + + ", dictionary ID = " + + field.getDictionary().getId()); + } + if (dictionaryIsValidEnum(dictionary)) { + String[] symbols = dictionarySymbols(dictionary); + return builder.enumeration(field.getName()).symbols(symbols); + } else { + Field decodedField = + new Field( + field.getName(), + dictionary.getVector().getField().getFieldType(), + dictionary.getVector().getField().getChildren()); + return buildBaseTypeSchema(builder, decodedField, namespace, dictionaries); + } + } + ArrowType.Int intType = (ArrowType.Int) field.getType(); if (intType.getBitWidth() > 32 || (intType.getBitWidth() == 32 && !intType.getIsSigned())) { return builder.longType(); @@ -328,7 +401,7 @@ private static T buildBaseTypeSchema( String childNamespace = namespace == null ? field.getName() : namespace + "." + field.getName(); return buildRecordSchema( - builder.record(field.getName()), field.getChildren(), childNamespace); + builder.record(field.getName()), field.getChildren(), childNamespace, dictionaries); case List: case FixedSizeList: @@ -339,13 +412,13 @@ private static T buildBaseTypeSchema( new Field("item", itemField.getFieldType(), itemField.getChildren()); Field safeListField = new Field(field.getName(), field.getFieldType(), List.of(safeItemField)); - return buildArraySchema(builder.array(), safeListField, namespace); + return buildArraySchema(builder.array(), safeListField, namespace, dictionaries); } else { - return buildArraySchema(builder.array(), field, namespace); + return buildArraySchema(builder.array(), field, namespace, dictionaries); } case Map: - return buildMapSchema(builder.map(), field, namespace); + return buildMapSchema(builder.map(), field, namespace, dictionaries); default: throw new IllegalArgumentException( @@ -354,9 +427,12 @@ private static T buildBaseTypeSchema( } private static T addTypesToUnion( - SchemaBuilder.UnionAccumulator accumulator, List unionFields, String namespace) { + SchemaBuilder.UnionAccumulator accumulator, + List unionFields, + String namespace, + DictionaryProvider dictionaries) { for (var field : unionFields) { - accumulator = buildBaseTypeSchema(accumulator.and(), field, namespace); + accumulator = buildBaseTypeSchema(accumulator.and(), field, namespace, dictionaries); } return accumulator.endUnion(); } @@ -373,30 +449,88 @@ private static LogicalType timestampLogicalType(ArrowType.Timestamp timestampTyp } } + private static boolean dictionaryIsValidEnum(Dictionary dictionary) { + + if (dictionary.getVectorType().getTypeID() != ArrowType.ArrowTypeID.Utf8) { + return false; + } + + VarCharVector vector = (VarCharVector) dictionary.getVector(); + Set symbols = new HashSet<>(); + + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + return false; + } + Text text = vector.getObject(i); + if (text == null) { + return false; + } + String symbol = text.toString(); + if (!ENUM_REGEX.matcher(symbol).matches()) { + return false; + } + if (symbols.contains(symbol)) { + return false; + } + symbols.add(symbol); + } + + return true; + } + + private static String[] dictionarySymbols(Dictionary dictionary) { + + VarCharVector vector = (VarCharVector) dictionary.getVector(); + String[] symbols = new String[vector.getValueCount()]; + + for (int i = 0; i < vector.getValueCount(); i++) { + Text text = vector.getObject(i); + // This should never happen if dictionaryIsValidEnum() succeeded + if (text == null) { + throw new IllegalArgumentException("Illegal null value in enum"); + } + symbols[i] = text.toString(); + } + + return symbols; + } + + private static final Pattern ENUM_REGEX = Pattern.compile("^[A-Za-z_][A-Za-z0-9_]*$"); + /** * Create a composite Avro producer for a set of field vectors (typically the root set of a VSR). * * @param vectors The vectors that will be used to produce Avro data * @return The resulting composite Avro producer */ - public static CompositeAvroProducer createCompositeProducer(List vectors) { + public static CompositeAvroProducer createCompositeProducer( + List vectors, DictionaryProvider dictionaries) { List> producers = new ArrayList<>(vectors.size()); for (FieldVector vector : vectors) { - BaseAvroProducer producer = createProducer(vector); + BaseAvroProducer producer = createProducer(vector, dictionaries); producers.add(producer); } return new CompositeAvroProducer(producers); } - private static BaseAvroProducer createProducer(FieldVector vector) { + /** Overload provided for convenience, sets dictionaries = null. */ + public static CompositeAvroProducer createCompositeProducer(List vectors) { + + return createCompositeProducer(vectors, null); + } + + private static BaseAvroProducer createProducer( + FieldVector vector, DictionaryProvider dictionaries) { boolean nullable = vector.getField().isNullable(); - return createProducer(vector, nullable); + return createProducer(vector, nullable, dictionaries); } - private static BaseAvroProducer createProducer(FieldVector vector, boolean nullable) { + private static BaseAvroProducer createProducer( + FieldVector vector, boolean nullable, DictionaryProvider dictionaries) { Preconditions.checkNotNull(vector, "Arrow vector object can't be null"); @@ -405,10 +539,34 @@ private static BaseAvroProducer createProducer(FieldVector vector, boolean nu // Avro understands nullable types as a union of type | null // Most nullable fields in a VSR will not be unions, so provide a special wrapper if (nullable && minorType != Types.MinorType.UNION) { - final BaseAvroProducer innerProducer = createProducer(vector, false); + final BaseAvroProducer innerProducer = createProducer(vector, false, dictionaries); return new AvroNullableProducer<>(innerProducer); } + if (vector.getField().getDictionary() != null) { + if (dictionaries == null) { + throw new IllegalArgumentException( + "Field references a dictionary but no dictionaries were provided: " + + vector.getField().getName()); + } + Dictionary dictionary = dictionaries.lookup(vector.getField().getDictionary().getId()); + if (dictionary == null) { + throw new IllegalArgumentException( + "Field references a dictionary that does not exist: " + + vector.getField().getName() + + ", dictionary ID = " + + vector.getField().getDictionary().getId()); + } + // If a field is dictionary-encoded but cannot be represented as an Avro enum, + // then decode it before writing + if (dictionaryIsValidEnum(dictionary)) { + return new AvroEnumProducer((BaseIntVector) vector); + } else { + BaseAvroProducer dictProducer = createProducer(dictionary.getVector(), false, null); + return new DictionaryDecodingProducer<>((BaseIntVector) vector, dictProducer); + } + } + switch (minorType) { case NULL: return new AvroNullProducer((NullVector) vector); @@ -486,21 +644,23 @@ private static BaseAvroProducer createProducer(FieldVector vector, boolean nu Producer[] childProducers = new Producer[childVectors.size()]; for (int i = 0; i < childVectors.size(); i++) { FieldVector childVector = childVectors.get(i); - childProducers[i] = createProducer(childVector, childVector.getField().isNullable()); + childProducers[i] = + createProducer(childVector, childVector.getField().isNullable(), dictionaries); } return new AvroStructProducer(structVector, childProducers); case LIST: ListVector listVector = (ListVector) vector; FieldVector itemVector = listVector.getDataVector(); - Producer itemProducer = createProducer(itemVector, itemVector.getField().isNullable()); + Producer itemProducer = + createProducer(itemVector, itemVector.getField().isNullable(), dictionaries); return new AvroListProducer(listVector, itemProducer); case FIXED_SIZE_LIST: FixedSizeListVector fixedListVector = (FixedSizeListVector) vector; FieldVector fixedItemVector = fixedListVector.getDataVector(); Producer fixedItemProducer = - createProducer(fixedItemVector, fixedItemVector.getField().isNullable()); + createProducer(fixedItemVector, fixedItemVector.getField().isNullable(), dictionaries); return new AvroFixedSizeListProducer(fixedListVector, fixedItemProducer); case MAP: @@ -514,7 +674,7 @@ private static BaseAvroProducer createProducer(FieldVector vector, boolean nu FieldVector valueVector = entryVector.getChildrenFromFields().get(1); Producer keyProducer = new AvroStringProducer(keyVector); Producer valueProducer = - createProducer(valueVector, valueVector.getField().isNullable()); + createProducer(valueVector, valueVector.getField().isNullable(), dictionaries); Producer entryProducer = new AvroStructProducer(entryVector, new Producer[] {keyProducer, valueProducer}); return new AvroMapProducer(mapVector, entryProducer); diff --git a/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/producers/AvroEnumProducer.java b/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/producers/AvroEnumProducer.java index 068566493e..eebfb7d241 100644 --- a/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/producers/AvroEnumProducer.java +++ b/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/producers/AvroEnumProducer.java @@ -17,22 +17,22 @@ package org.apache.arrow.adapter.avro.producers; import java.io.IOException; -import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.BaseIntVector; import org.apache.avro.io.Encoder; /** - * Producer that produces enum values from a dictionary-encoded {@link IntVector}, writes data to an - * Avro encoder. + * Producer that produces enum values from a dictionary-encoded {@link BaseIntVector}, writes data + * to an Avro encoder. */ -public class AvroEnumProducer extends BaseAvroProducer { +public class AvroEnumProducer extends BaseAvroProducer { /** Instantiate an AvroEnumProducer. */ - public AvroEnumProducer(IntVector vector) { + public AvroEnumProducer(BaseIntVector vector) { super(vector); } @Override public void produce(Encoder encoder) throws IOException { - encoder.writeEnum(vector.get(currentIndex++)); + encoder.writeEnum((int) vector.getValueAsLong(currentIndex++)); } } diff --git a/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/producers/DictionaryDecodingProducer.java b/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/producers/DictionaryDecodingProducer.java new file mode 100644 index 0000000000..afeba08511 --- /dev/null +++ b/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/producers/DictionaryDecodingProducer.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.adapter.avro.producers; + +import java.io.IOException; +import org.apache.arrow.vector.BaseIntVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.avro.io.Encoder; + +/** + * Producer that decodes values from a dictionary-encoded {@link FieldVector}, writes the resulting + * values to an Avro encoder. + * + * @param Type of the underlying dictionary vector + */ +public class DictionaryDecodingProducer + extends BaseAvroProducer { + + private final Producer dictProducer; + + /** Instantiate a DictionaryDecodingProducer. */ + public DictionaryDecodingProducer(BaseIntVector indexVector, Producer dictProducer) { + super(indexVector); + this.dictProducer = dictProducer; + } + + @Override + public void produce(Encoder encoder) throws IOException { + int dicIndex = (int) vector.getValueAsLong(currentIndex++); + dictProducer.setPosition(dicIndex); + dictProducer.produce(encoder); + } +} diff --git a/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroDataTest.java b/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroDataTest.java index 2d70b45021..6d66ee9d45 100644 --- a/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroDataTest.java +++ b/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroDataTest.java @@ -76,10 +76,14 @@ import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.complex.writer.BaseWriter; import org.apache.arrow.vector.complex.writer.FieldWriter; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryEncoder; +import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.types.DateUnit; import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.TimeUnit; import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.util.JsonStringArrayList; @@ -2817,4 +2821,81 @@ record = datumReader.read(record, decoder); } } } + + @Test + public void testWriteDictEnumEncoded() throws Exception { + + BufferAllocator allocator = new RootAllocator(); + + // Create a dictionary + FieldType dictionaryField = new FieldType(false, new ArrowType.Utf8(), null); + VarCharVector dictionaryVector = + new VarCharVector(new Field("dictionary", dictionaryField, null), allocator); + + dictionaryVector.allocateNew(3); + dictionaryVector.set(0, "apple".getBytes()); + dictionaryVector.set(1, "banana".getBytes()); + dictionaryVector.set(2, "cherry".getBytes()); + dictionaryVector.setValueCount(3); + + Dictionary dictionary = + new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null)); + DictionaryProvider dictionaries = new DictionaryProvider.MapDictionaryProvider(dictionary); + + // Field definition + FieldType stringField = new FieldType(false, new ArrowType.Utf8(), null); + VarCharVector stringVector = + new VarCharVector(new Field("enumField", stringField, null), allocator); + stringVector.allocateNew(10); + stringVector.setSafe(0, "apple".getBytes()); + stringVector.setSafe(1, "banana".getBytes()); + stringVector.setSafe(2, "cherry".getBytes()); + stringVector.setSafe(3, "cherry".getBytes()); + stringVector.setSafe(4, "apple".getBytes()); + stringVector.setSafe(5, "banana".getBytes()); + stringVector.setSafe(6, "apple".getBytes()); + stringVector.setSafe(7, "cherry".getBytes()); + stringVector.setSafe(8, "banana".getBytes()); + stringVector.setSafe(9, "apple".getBytes()); + stringVector.setValueCount(10); + + IntVector encodedVector = (IntVector) DictionaryEncoder.encode(stringVector, dictionary); + + // Set up VSR + List vectors = Arrays.asList(encodedVector); + int rowCount = 10; + + try (VectorSchemaRoot root = new VectorSchemaRoot(vectors)) { + + File dataFile = new File(TMP, "testWriteEnumEncoded.avro"); + + // Write an AVRO block using the producer classes + try (FileOutputStream fos = new FileOutputStream(dataFile)) { + BinaryEncoder encoder = new EncoderFactory().directBinaryEncoder(fos, null); + CompositeAvroProducer producer = + ArrowToAvroUtils.createCompositeProducer(vectors, dictionaries); + for (int row = 0; row < rowCount; row++) { + producer.produce(encoder); + } + encoder.flush(); + } + + // Set up reading the AVRO block as a GenericRecord + Schema schema = ArrowToAvroUtils.createAvroSchema(root.getSchema().getFields(), dictionaries); + GenericDatumReader datumReader = new GenericDatumReader<>(schema); + + try (InputStream inputStream = new FileInputStream(dataFile)) { + + BinaryDecoder decoder = DecoderFactory.get().binaryDecoder(inputStream, null); + GenericRecord record = null; + + // Read and check values + for (int row = 0; row < rowCount; row++) { + record = datumReader.read(record, decoder); + // Values read from Avro should be the decoded enum values + assertEquals(stringVector.getObject(row).toString(), record.get("enumField").toString()); + } + } + } + } } diff --git a/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroSchemaTest.java b/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroSchemaTest.java index d3e12e763a..d5e0357a8c 100644 --- a/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroSchemaTest.java +++ b/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroSchemaTest.java @@ -20,11 +20,18 @@ import java.util.Arrays; import java.util.List; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.types.DateUnit; import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.TimeUnit; import org.apache.arrow.vector.types.UnionMode; import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.avro.LogicalTypes; @@ -1389,4 +1396,126 @@ public void testConvertUnionTypes() { Schema.Type.STRING, schema.getField("nullableDenseUnionField").schema().getTypes().get(3).getType()); } + + @Test + public void testWriteDictEnumEncoded() { + + BufferAllocator allocator = new RootAllocator(); + + // Create a dictionary + FieldType dictionaryField = new FieldType(false, new ArrowType.Utf8(), null); + VarCharVector dictionaryVector = + new VarCharVector(new Field("dictionary", dictionaryField, null), allocator); + + dictionaryVector.allocateNew(3); + dictionaryVector.set(0, "apple".getBytes()); + dictionaryVector.set(1, "banana".getBytes()); + dictionaryVector.set(2, "cherry".getBytes()); + dictionaryVector.setValueCount(3); + + Dictionary dictionary = + new Dictionary( + dictionaryVector, new DictionaryEncoding(0L, false, new ArrowType.Int(8, true))); + DictionaryProvider dictionaries = new DictionaryProvider.MapDictionaryProvider(dictionary); + + List fields = + Arrays.asList( + new Field( + "enumField", + new FieldType(false, new ArrowType.Int(8, true), dictionary.getEncoding(), null), + null)); + + Schema schema = ArrowToAvroUtils.createAvroSchema(fields, "TestRecord", null, dictionaries); + + assertEquals(Schema.Type.RECORD, schema.getType()); + assertEquals(1, schema.getFields().size()); + + Schema.Field enumField = schema.getField("enumField"); + + assertEquals(Schema.Type.ENUM, enumField.schema().getType()); + assertEquals(3, enumField.schema().getEnumSymbols().size()); + assertEquals("apple", enumField.schema().getEnumSymbols().get(0)); + assertEquals("banana", enumField.schema().getEnumSymbols().get(1)); + assertEquals("cherry", enumField.schema().getEnumSymbols().get(2)); + } + + @Test + public void testWriteDictEnumInvalid() { + + BufferAllocator allocator = new RootAllocator(); + + // Create a dictionary + FieldType dictionaryField = new FieldType(false, new ArrowType.Utf8(), null); + VarCharVector dictionaryVector = + new VarCharVector(new Field("dictionary", dictionaryField, null), allocator); + + dictionaryVector.allocateNew(3); + dictionaryVector.set(0, "passion fruit".getBytes()); + dictionaryVector.set(1, "banana".getBytes()); + dictionaryVector.set(2, "cherry".getBytes()); + dictionaryVector.setValueCount(3); + + Dictionary dictionary = + new Dictionary( + dictionaryVector, new DictionaryEncoding(0L, false, new ArrowType.Int(8, true))); + DictionaryProvider dictionaries = new DictionaryProvider.MapDictionaryProvider(dictionary); + + List fields = + Arrays.asList( + new Field( + "enumField", + new FieldType(false, new ArrowType.Int(8, true), dictionary.getEncoding(), null), + null)); + + // Dictionary field contains values that are not valid enums + // Should be decoded and output as a string field + + Schema schema = ArrowToAvroUtils.createAvroSchema(fields, "TestRecord", null, dictionaries); + + assertEquals(Schema.Type.RECORD, schema.getType()); + assertEquals(1, schema.getFields().size()); + + Schema.Field enumField = schema.getField("enumField"); + assertEquals(Schema.Type.STRING, enumField.schema().getType()); + } + + @Test + public void testWriteDictEnumInvalid2() { + + BufferAllocator allocator = new RootAllocator(); + + // Create a dictionary + FieldType dictionaryField = new FieldType(false, new ArrowType.Int(64, true), null); + BigIntVector dictionaryVector = + new BigIntVector(new Field("dictionary", dictionaryField, null), allocator); + + dictionaryVector.allocateNew(3); + dictionaryVector.set(0, 123L); + dictionaryVector.set(1, 456L); + dictionaryVector.set(2, 789L); + dictionaryVector.setValueCount(3); + + Dictionary dictionary = + new Dictionary( + dictionaryVector, new DictionaryEncoding(0L, false, new ArrowType.Int(8, true))); + DictionaryProvider dictionaries = new DictionaryProvider.MapDictionaryProvider(dictionary); + + List fields = + Arrays.asList( + new Field( + "enumField", + new FieldType(false, new ArrowType.Int(8, true), dictionary.getEncoding(), null), + null)); + + // Dictionary field encodes LONG values rather than STRING + // Should be doecded and output as a LONG field + + Schema schema = ArrowToAvroUtils.createAvroSchema(fields, "TestRecord", null, dictionaries); + + assertEquals(Schema.Type.RECORD, schema.getType()); + assertEquals(1, schema.getFields().size()); + + Schema.Field enumField = schema.getField("enumField"); + assertEquals(Schema.Type.LONG, enumField.schema().getType()); + } } diff --git a/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/RoundTripDataTest.java b/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/RoundTripDataTest.java index 85e6a960b0..ceaf59aa72 100644 --- a/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/RoundTripDataTest.java +++ b/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/RoundTripDataTest.java @@ -52,6 +52,7 @@ import org.apache.arrow.vector.TimeStampMilliVector; import org.apache.arrow.vector.TimeStampNanoTZVector; import org.apache.arrow.vector.TimeStampNanoVector; +import org.apache.arrow.vector.TinyIntVector; import org.apache.arrow.vector.VarBinaryVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; @@ -60,10 +61,14 @@ import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.complex.writer.BaseWriter; import org.apache.arrow.vector.complex.writer.FieldWriter; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryEncoder; +import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.types.DateUnit; import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.TimeUnit; import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.avro.Schema; @@ -78,16 +83,21 @@ public class RoundTripDataTest { @TempDir public static File TMP; - private static AvroToArrowConfig basicConfig(BufferAllocator allocator) { - return new AvroToArrowConfig(allocator, 1000, null, Collections.emptySet(), false); + private static AvroToArrowConfig basicConfig( + BufferAllocator allocator, DictionaryProvider.MapDictionaryProvider dictionaries) { + return new AvroToArrowConfig(allocator, 1000, dictionaries, Collections.emptySet(), false); } private static VectorSchemaRoot readDataFile( - Schema schema, File dataFile, BufferAllocator allocator) throws Exception { + Schema schema, + File dataFile, + BufferAllocator allocator, + DictionaryProvider.MapDictionaryProvider dictionaries) + throws Exception { try (FileInputStream fis = new FileInputStream(dataFile)) { BinaryDecoder decoder = new DecoderFactory().directBinaryDecoder(fis, null); - return AvroToArrow.avroToArrow(schema, decoder, basicConfig(allocator)); + return AvroToArrow.avroToArrow(schema, decoder, basicConfig(allocator, dictionaries)); } } @@ -95,11 +105,22 @@ private static void roundTripTest( VectorSchemaRoot root, BufferAllocator allocator, File dataFile, int rowCount) throws Exception { + roundTripTest(root, allocator, dataFile, rowCount, null); + } + + private static void roundTripTest( + VectorSchemaRoot root, + BufferAllocator allocator, + File dataFile, + int rowCount, + DictionaryProvider dictionaries) + throws Exception { + // Write an AVRO block using the producer classes try (FileOutputStream fos = new FileOutputStream(dataFile)) { BinaryEncoder encoder = new EncoderFactory().directBinaryEncoder(fos, null); CompositeAvroProducer producer = - ArrowToAvroUtils.createCompositeProducer(root.getFieldVectors()); + ArrowToAvroUtils.createCompositeProducer(root.getFieldVectors(), dictionaries); for (int row = 0; row < rowCount; row++) { producer.produce(encoder); } @@ -107,10 +128,14 @@ private static void roundTripTest( } // Generate AVRO schema - Schema schema = ArrowToAvroUtils.createAvroSchema(root.getSchema().getFields()); + Schema schema = ArrowToAvroUtils.createAvroSchema(root.getSchema().getFields(), dictionaries); + + DictionaryProvider.MapDictionaryProvider roundTripDictionaries = + new DictionaryProvider.MapDictionaryProvider(); // Read back in and compare - try (VectorSchemaRoot roundTrip = readDataFile(schema, dataFile, allocator)) { + try (VectorSchemaRoot roundTrip = + readDataFile(schema, dataFile, allocator, roundTripDictionaries)) { assertEquals(root.getSchema(), roundTrip.getSchema()); assertEquals(rowCount, roundTrip.getRowCount()); @@ -119,6 +144,21 @@ private static void roundTripTest( for (int row = 0; row < rowCount; row++) { assertEquals(root.getVector(0).getObject(row), roundTrip.getVector(0).getObject(row)); } + + if (dictionaries != null) { + for (long id : dictionaries.getDictionaryIds()) { + Dictionary originalDictionary = dictionaries.lookup(id); + Dictionary roundTripDictionary = roundTripDictionaries.lookup(id); + assertEquals( + originalDictionary.getVector().getValueCount(), + roundTripDictionary.getVector().getValueCount()); + for (int j = 0; j < originalDictionary.getVector().getValueCount(); j++) { + assertEquals( + originalDictionary.getVector().getObject(j), + roundTripDictionary.getVector().getObject(j)); + } + } + } } } @@ -141,7 +181,7 @@ private static void roundTripByteArrayTest( Schema schema = ArrowToAvroUtils.createAvroSchema(root.getSchema().getFields()); // Read back in and compare - try (VectorSchemaRoot roundTrip = readDataFile(schema, dataFile, allocator)) { + try (VectorSchemaRoot roundTrip = readDataFile(schema, dataFile, allocator, null)) { assertEquals(root.getSchema(), roundTrip.getSchema()); assertEquals(rowCount, roundTrip.getRowCount()); @@ -1603,4 +1643,58 @@ public void testRoundTripNullableStructs() throws Exception { roundTripTest(root, allocator, dataFile, rowCount); } } + + @Test + public void testRoundTripEnum() throws Exception { + + BufferAllocator allocator = new RootAllocator(); + + // Create a dictionary + FieldType dictionaryField = new FieldType(false, new ArrowType.Utf8(), null); + VarCharVector dictionaryVector = + new VarCharVector(new Field("dictionary", dictionaryField, null), allocator); + + dictionaryVector.allocateNew(3); + dictionaryVector.set(0, "apple".getBytes()); + dictionaryVector.set(1, "banana".getBytes()); + dictionaryVector.set(2, "cherry".getBytes()); + dictionaryVector.setValueCount(3); + + // For simplicity, ensure the index type matches what will be decoded during Avro enum decoding + Dictionary dictionary = + new Dictionary( + dictionaryVector, new DictionaryEncoding(0L, false, new ArrowType.Int(8, true))); + DictionaryProvider dictionaries = new DictionaryProvider.MapDictionaryProvider(dictionary); + + // Field definition + FieldType stringField = new FieldType(false, new ArrowType.Utf8(), null); + VarCharVector stringVector = + new VarCharVector(new Field("enumField", stringField, null), allocator); + stringVector.allocateNew(10); + stringVector.setSafe(0, "apple".getBytes()); + stringVector.setSafe(1, "banana".getBytes()); + stringVector.setSafe(2, "cherry".getBytes()); + stringVector.setSafe(3, "cherry".getBytes()); + stringVector.setSafe(4, "apple".getBytes()); + stringVector.setSafe(5, "banana".getBytes()); + stringVector.setSafe(6, "apple".getBytes()); + stringVector.setSafe(7, "cherry".getBytes()); + stringVector.setSafe(8, "banana".getBytes()); + stringVector.setSafe(9, "apple".getBytes()); + stringVector.setValueCount(10); + + TinyIntVector encodedVector = + (TinyIntVector) DictionaryEncoder.encode(stringVector, dictionary); + + // Set up VSR + List vectors = Arrays.asList(encodedVector); + int rowCount = 10; + + try (VectorSchemaRoot root = new VectorSchemaRoot(vectors)) { + + File dataFile = new File(TMP, "testRoundTripEnums.avro"); + + roundTripTest(root, allocator, dataFile, rowCount, dictionaries); + } + } } diff --git a/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/RoundTripSchemaTest.java b/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/RoundTripSchemaTest.java index 864e2c8b59..37c0b4d9fe 100644 --- a/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/RoundTripSchemaTest.java +++ b/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/RoundTripSchemaTest.java @@ -21,27 +21,50 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.types.DateUnit; import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.TimeUnit; import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.avro.Schema; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; public class RoundTripSchemaTest { private void doRoundTripTest(List fields) { + doRoundTripTest(fields, null); + } - AvroToArrowConfig config = new AvroToArrowConfig(null, 1, null, Collections.emptySet(), false); + private void doRoundTripTest(List fields, DictionaryProvider dictionaries) { - Schema avroSchema = ArrowToAvroUtils.createAvroSchema(fields, "TestRecord"); + DictionaryProvider.MapDictionaryProvider decodeDictionaries = + new DictionaryProvider.MapDictionaryProvider(); + AvroToArrowConfig decodeConfig = + new AvroToArrowConfig(null, 1, decodeDictionaries, Collections.emptySet(), false); + + Schema avroSchema = ArrowToAvroUtils.createAvroSchema(fields, "TestRecord", null, dictionaries); org.apache.arrow.vector.types.pojo.Schema arrowSchema = - AvroToArrowUtils.createArrowSchema(avroSchema, config); + AvroToArrowUtils.createArrowSchema(avroSchema, decodeConfig); // Compare string representations - equality not defined for logical types assertEquals(fields, arrowSchema.getFields()); + + for (int i = 0; i < fields.size(); i++) { + Field field = fields.get(i); + Field rtField = arrowSchema.getFields().get(i); + if (field.getDictionary() != null) { + // Dictionary content is not decoded until the data is consumed + Assertions.assertNotNull(rtField.getDictionary()); + } + } } // Schema round trip for primitive types, nullable and non-nullable @@ -440,4 +463,38 @@ public void testRoundTripStructType() { doRoundTripTest(fields); } + + @Test + public void testRoundTripEnumType() { + + BufferAllocator allocator = new RootAllocator(); + + FieldType dictionaryField = new FieldType(false, new ArrowType.Utf8(), null); + VarCharVector dictionaryVector = + new VarCharVector(new Field("dictionary", dictionaryField, null), allocator); + + dictionaryVector.allocateNew(3); + dictionaryVector.set(0, "apple".getBytes()); + dictionaryVector.set(1, "banana".getBytes()); + dictionaryVector.set(2, "cherry".getBytes()); + dictionaryVector.setValueCount(3); + + // For simplicity, ensure the index type matches what will be decoded during Avro enum decoding + Dictionary dictionary = + new Dictionary( + dictionaryVector, new DictionaryEncoding(0L, false, new ArrowType.Int(8, true))); + DictionaryProvider dictionaries = new DictionaryProvider.MapDictionaryProvider(dictionary); + + List fields = + Arrays.asList( + new Field( + "enumField", + new FieldType( + true, + new ArrowType.Int(8, true), + new DictionaryEncoding(0L, false, new ArrowType.Int(8, true))), + null)); + + doRoundTripTest(fields, dictionaries); + } }