diff --git a/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/ArrowToAvroUtils.java b/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/ArrowToAvroUtils.java
index 87b594af9e..e09b99f670 100644
--- a/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/ArrowToAvroUtils.java
+++ b/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/ArrowToAvroUtils.java
@@ -17,10 +17,14 @@
package org.apache.arrow.adapter.avro;
import java.util.ArrayList;
+import java.util.HashSet;
import java.util.List;
+import java.util.Set;
+import java.util.regex.Pattern;
import org.apache.arrow.adapter.avro.producers.AvroBigIntProducer;
import org.apache.arrow.adapter.avro.producers.AvroBooleanProducer;
import org.apache.arrow.adapter.avro.producers.AvroBytesProducer;
+import org.apache.arrow.adapter.avro.producers.AvroEnumProducer;
import org.apache.arrow.adapter.avro.producers.AvroFixedSizeBinaryProducer;
import org.apache.arrow.adapter.avro.producers.AvroFixedSizeListProducer;
import org.apache.arrow.adapter.avro.producers.AvroFloat2Producer;
@@ -41,6 +45,7 @@
import org.apache.arrow.adapter.avro.producers.AvroUint8Producer;
import org.apache.arrow.adapter.avro.producers.BaseAvroProducer;
import org.apache.arrow.adapter.avro.producers.CompositeAvroProducer;
+import org.apache.arrow.adapter.avro.producers.DictionaryDecodingProducer;
import org.apache.arrow.adapter.avro.producers.Producer;
import org.apache.arrow.adapter.avro.producers.logical.AvroDateDayProducer;
import org.apache.arrow.adapter.avro.producers.logical.AvroDateMilliProducer;
@@ -59,6 +64,7 @@
import org.apache.arrow.adapter.avro.producers.logical.AvroTimestampSecProducer;
import org.apache.arrow.adapter.avro.producers.logical.AvroTimestampSecTzProducer;
import org.apache.arrow.util.Preconditions;
+import org.apache.arrow.vector.BaseIntVector;
import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.BitVector;
import org.apache.arrow.vector.DateDayVector;
@@ -96,11 +102,14 @@
import org.apache.arrow.vector.complex.ListVector;
import org.apache.arrow.vector.complex.MapVector;
import org.apache.arrow.vector.complex.StructVector;
+import org.apache.arrow.vector.dictionary.Dictionary;
+import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.TimeUnit;
import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.util.Text;
import org.apache.avro.LogicalType;
import org.apache.avro.LogicalTypes;
import org.apache.avro.Schema;
@@ -162,17 +171,29 @@ public class ArrowToAvroUtils {
* may be nullable. Record types must contain at least one child field and cannot contain multiple
* fields with the same name
*
+ *
String fields that are dictionary-encoded will be represented as an Avro enum, so long as
+ * all the values meet the restrictions on Avro enums (non-null, valid identifiers). Other data
+ * types that are dictionary encoded, or string fields that do not meet the avro requirements,
+ * will be output as their decoded type.
+ *
* @param arrowFields The arrow fields used to generate the Avro schema
* @param typeName Name of the top level Avro record type
* @param namespace Namespace of the top level Avro record type
+ * @param dictionaries A dictionary provider is required if any fields use dictionary encoding
* @return An Avro record schema for the given list of fields, with the specified name and
* namespace
*/
public static Schema createAvroSchema(
- List arrowFields, String typeName, String namespace) {
+ List arrowFields, String typeName, String namespace, DictionaryProvider dictionaries) {
SchemaBuilder.RecordBuilder assembler =
SchemaBuilder.record(typeName).namespace(namespace);
- return buildRecordSchema(assembler, arrowFields, namespace);
+ return buildRecordSchema(assembler, arrowFields, namespace, dictionaries);
+ }
+
+ /** Overload provided for convenience, sets dictionaries = null. */
+ public static Schema createAvroSchema(
+ List arrowFields, String typeName, String namespace) {
+ return createAvroSchema(arrowFields, typeName, namespace, null);
}
/** Overload provided for convenience, sets namespace = null. */
@@ -185,61 +206,83 @@ public static Schema createAvroSchema(List arrowFields) {
return createAvroSchema(arrowFields, GENERIC_RECORD_TYPE_NAME);
}
+ /**
+ * Overload provided for convenience, sets name = GENERIC_RECORD_TYPE_NAME and namespace = null.
+ */
+ public static Schema createAvroSchema(List arrowFields, DictionaryProvider dictionaries) {
+ return createAvroSchema(arrowFields, GENERIC_RECORD_TYPE_NAME, null, dictionaries);
+ }
+
private static T buildRecordSchema(
- SchemaBuilder.RecordBuilder builder, List fields, String namespace) {
+ SchemaBuilder.RecordBuilder builder,
+ List fields,
+ String namespace,
+ DictionaryProvider dictionaries) {
if (fields.isEmpty()) {
throw new IllegalArgumentException("Record field must have at least one child field");
}
SchemaBuilder.FieldAssembler assembler = builder.namespace(namespace).fields();
for (Field field : fields) {
- assembler = buildFieldSchema(assembler, field, namespace);
+ assembler = buildFieldSchema(assembler, field, namespace, dictionaries);
}
return assembler.endRecord();
}
private static SchemaBuilder.FieldAssembler buildFieldSchema(
- SchemaBuilder.FieldAssembler assembler, Field field, String namespace) {
+ SchemaBuilder.FieldAssembler assembler,
+ Field field,
+ String namespace,
+ DictionaryProvider dictionaries) {
return assembler
.name(field.getName())
- .type(buildTypeSchema(SchemaBuilder.builder(), field, namespace))
+ .type(buildTypeSchema(SchemaBuilder.builder(), field, namespace, dictionaries))
.noDefault();
}
private static T buildTypeSchema(
- SchemaBuilder.TypeBuilder builder, Field field, String namespace) {
+ SchemaBuilder.TypeBuilder builder,
+ Field field,
+ String namespace,
+ DictionaryProvider dictionaries) {
// Nullable unions need special handling, since union types cannot be directly nested
if (field.getType().getTypeID() == ArrowType.ArrowTypeID.Union) {
boolean unionNullable = field.getChildren().stream().anyMatch(Field::isNullable);
if (unionNullable) {
SchemaBuilder.UnionAccumulator union = builder.unionOf().nullType();
- return addTypesToUnion(union, field.getChildren(), namespace);
+ return addTypesToUnion(union, field.getChildren(), namespace, dictionaries);
} else {
Field headType = field.getChildren().get(0);
List tailTypes = field.getChildren().subList(1, field.getChildren().size());
SchemaBuilder.UnionAccumulator union =
- buildBaseTypeSchema(builder.unionOf(), headType, namespace);
- return addTypesToUnion(union, tailTypes, namespace);
+ buildBaseTypeSchema(builder.unionOf(), headType, namespace, dictionaries);
+ return addTypesToUnion(union, tailTypes, namespace, dictionaries);
}
} else if (field.isNullable()) {
- return buildBaseTypeSchema(builder.nullable(), field, namespace);
+ return buildBaseTypeSchema(builder.nullable(), field, namespace, dictionaries);
} else {
- return buildBaseTypeSchema(builder, field, namespace);
+ return buildBaseTypeSchema(builder, field, namespace, dictionaries);
}
}
private static T buildArraySchema(
- SchemaBuilder.ArrayBuilder builder, Field listField, String namespace) {
+ SchemaBuilder.ArrayBuilder builder,
+ Field listField,
+ String namespace,
+ DictionaryProvider dictionaries) {
if (listField.getChildren().size() != 1) {
throw new IllegalArgumentException("List field must have exactly one child field");
}
Field itemField = listField.getChildren().get(0);
- return buildTypeSchema(builder.items(), itemField, namespace);
+ return buildTypeSchema(builder.items(), itemField, namespace, dictionaries);
}
private static T buildMapSchema(
- SchemaBuilder.MapBuilder builder, Field mapField, String namespace) {
+ SchemaBuilder.MapBuilder builder,
+ Field mapField,
+ String namespace,
+ DictionaryProvider dictionaries) {
if (mapField.getChildren().size() != 1) {
throw new IllegalArgumentException("Map field must have exactly one child field");
}
@@ -253,11 +296,14 @@ private static T buildMapSchema(
throw new IllegalArgumentException(
"Map keys must be of type string and cannot be nullable for conversion to Avro");
}
- return buildTypeSchema(builder.values(), valueField, namespace);
+ return buildTypeSchema(builder.values(), valueField, namespace, dictionaries);
}
private static T buildBaseTypeSchema(
- SchemaBuilder.BaseTypeBuilder builder, Field field, String namespace) {
+ SchemaBuilder.BaseTypeBuilder builder,
+ Field field,
+ String namespace,
+ DictionaryProvider dictionaries) {
ArrowType.ArrowTypeID typeID = field.getType().getTypeID();
@@ -269,6 +315,33 @@ private static T buildBaseTypeSchema(
return builder.booleanType();
case Int:
+ if (field.getDictionary() != null) {
+ if (dictionaries == null) {
+ throw new IllegalArgumentException(
+ "Field references a dictionary but no dictionaries were provided: "
+ + field.getName());
+ }
+ Dictionary dictionary = dictionaries.lookup(field.getDictionary().getId());
+ if (dictionary == null) {
+ throw new IllegalArgumentException(
+ "Field references a dictionary that does not exist: "
+ + field.getName()
+ + ", dictionary ID = "
+ + field.getDictionary().getId());
+ }
+ if (dictionaryIsValidEnum(dictionary)) {
+ String[] symbols = dictionarySymbols(dictionary);
+ return builder.enumeration(field.getName()).symbols(symbols);
+ } else {
+ Field decodedField =
+ new Field(
+ field.getName(),
+ dictionary.getVector().getField().getFieldType(),
+ dictionary.getVector().getField().getChildren());
+ return buildBaseTypeSchema(builder, decodedField, namespace, dictionaries);
+ }
+ }
+
ArrowType.Int intType = (ArrowType.Int) field.getType();
if (intType.getBitWidth() > 32 || (intType.getBitWidth() == 32 && !intType.getIsSigned())) {
return builder.longType();
@@ -328,7 +401,7 @@ private static T buildBaseTypeSchema(
String childNamespace =
namespace == null ? field.getName() : namespace + "." + field.getName();
return buildRecordSchema(
- builder.record(field.getName()), field.getChildren(), childNamespace);
+ builder.record(field.getName()), field.getChildren(), childNamespace, dictionaries);
case List:
case FixedSizeList:
@@ -339,13 +412,13 @@ private static T buildBaseTypeSchema(
new Field("item", itemField.getFieldType(), itemField.getChildren());
Field safeListField =
new Field(field.getName(), field.getFieldType(), List.of(safeItemField));
- return buildArraySchema(builder.array(), safeListField, namespace);
+ return buildArraySchema(builder.array(), safeListField, namespace, dictionaries);
} else {
- return buildArraySchema(builder.array(), field, namespace);
+ return buildArraySchema(builder.array(), field, namespace, dictionaries);
}
case Map:
- return buildMapSchema(builder.map(), field, namespace);
+ return buildMapSchema(builder.map(), field, namespace, dictionaries);
default:
throw new IllegalArgumentException(
@@ -354,9 +427,12 @@ private static T buildBaseTypeSchema(
}
private static T addTypesToUnion(
- SchemaBuilder.UnionAccumulator accumulator, List unionFields, String namespace) {
+ SchemaBuilder.UnionAccumulator accumulator,
+ List unionFields,
+ String namespace,
+ DictionaryProvider dictionaries) {
for (var field : unionFields) {
- accumulator = buildBaseTypeSchema(accumulator.and(), field, namespace);
+ accumulator = buildBaseTypeSchema(accumulator.and(), field, namespace, dictionaries);
}
return accumulator.endUnion();
}
@@ -373,30 +449,88 @@ private static LogicalType timestampLogicalType(ArrowType.Timestamp timestampTyp
}
}
+ private static boolean dictionaryIsValidEnum(Dictionary dictionary) {
+
+ if (dictionary.getVectorType().getTypeID() != ArrowType.ArrowTypeID.Utf8) {
+ return false;
+ }
+
+ VarCharVector vector = (VarCharVector) dictionary.getVector();
+ Set symbols = new HashSet<>();
+
+ for (int i = 0; i < vector.getValueCount(); i++) {
+ if (vector.isNull(i)) {
+ return false;
+ }
+ Text text = vector.getObject(i);
+ if (text == null) {
+ return false;
+ }
+ String symbol = text.toString();
+ if (!ENUM_REGEX.matcher(symbol).matches()) {
+ return false;
+ }
+ if (symbols.contains(symbol)) {
+ return false;
+ }
+ symbols.add(symbol);
+ }
+
+ return true;
+ }
+
+ private static String[] dictionarySymbols(Dictionary dictionary) {
+
+ VarCharVector vector = (VarCharVector) dictionary.getVector();
+ String[] symbols = new String[vector.getValueCount()];
+
+ for (int i = 0; i < vector.getValueCount(); i++) {
+ Text text = vector.getObject(i);
+ // This should never happen if dictionaryIsValidEnum() succeeded
+ if (text == null) {
+ throw new IllegalArgumentException("Illegal null value in enum");
+ }
+ symbols[i] = text.toString();
+ }
+
+ return symbols;
+ }
+
+ private static final Pattern ENUM_REGEX = Pattern.compile("^[A-Za-z_][A-Za-z0-9_]*$");
+
/**
* Create a composite Avro producer for a set of field vectors (typically the root set of a VSR).
*
* @param vectors The vectors that will be used to produce Avro data
* @return The resulting composite Avro producer
*/
- public static CompositeAvroProducer createCompositeProducer(List vectors) {
+ public static CompositeAvroProducer createCompositeProducer(
+ List vectors, DictionaryProvider dictionaries) {
List> producers = new ArrayList<>(vectors.size());
for (FieldVector vector : vectors) {
- BaseAvroProducer extends FieldVector> producer = createProducer(vector);
+ BaseAvroProducer extends FieldVector> producer = createProducer(vector, dictionaries);
producers.add(producer);
}
return new CompositeAvroProducer(producers);
}
- private static BaseAvroProducer> createProducer(FieldVector vector) {
+ /** Overload provided for convenience, sets dictionaries = null. */
+ public static CompositeAvroProducer createCompositeProducer(List vectors) {
+
+ return createCompositeProducer(vectors, null);
+ }
+
+ private static BaseAvroProducer> createProducer(
+ FieldVector vector, DictionaryProvider dictionaries) {
boolean nullable = vector.getField().isNullable();
- return createProducer(vector, nullable);
+ return createProducer(vector, nullable, dictionaries);
}
- private static BaseAvroProducer> createProducer(FieldVector vector, boolean nullable) {
+ private static BaseAvroProducer> createProducer(
+ FieldVector vector, boolean nullable, DictionaryProvider dictionaries) {
Preconditions.checkNotNull(vector, "Arrow vector object can't be null");
@@ -405,10 +539,34 @@ private static BaseAvroProducer> createProducer(FieldVector vector, boolean nu
// Avro understands nullable types as a union of type | null
// Most nullable fields in a VSR will not be unions, so provide a special wrapper
if (nullable && minorType != Types.MinorType.UNION) {
- final BaseAvroProducer> innerProducer = createProducer(vector, false);
+ final BaseAvroProducer> innerProducer = createProducer(vector, false, dictionaries);
return new AvroNullableProducer<>(innerProducer);
}
+ if (vector.getField().getDictionary() != null) {
+ if (dictionaries == null) {
+ throw new IllegalArgumentException(
+ "Field references a dictionary but no dictionaries were provided: "
+ + vector.getField().getName());
+ }
+ Dictionary dictionary = dictionaries.lookup(vector.getField().getDictionary().getId());
+ if (dictionary == null) {
+ throw new IllegalArgumentException(
+ "Field references a dictionary that does not exist: "
+ + vector.getField().getName()
+ + ", dictionary ID = "
+ + vector.getField().getDictionary().getId());
+ }
+ // If a field is dictionary-encoded but cannot be represented as an Avro enum,
+ // then decode it before writing
+ if (dictionaryIsValidEnum(dictionary)) {
+ return new AvroEnumProducer((BaseIntVector) vector);
+ } else {
+ BaseAvroProducer> dictProducer = createProducer(dictionary.getVector(), false, null);
+ return new DictionaryDecodingProducer<>((BaseIntVector) vector, dictProducer);
+ }
+ }
+
switch (minorType) {
case NULL:
return new AvroNullProducer((NullVector) vector);
@@ -486,21 +644,23 @@ private static BaseAvroProducer> createProducer(FieldVector vector, boolean nu
Producer>[] childProducers = new Producer>[childVectors.size()];
for (int i = 0; i < childVectors.size(); i++) {
FieldVector childVector = childVectors.get(i);
- childProducers[i] = createProducer(childVector, childVector.getField().isNullable());
+ childProducers[i] =
+ createProducer(childVector, childVector.getField().isNullable(), dictionaries);
}
return new AvroStructProducer(structVector, childProducers);
case LIST:
ListVector listVector = (ListVector) vector;
FieldVector itemVector = listVector.getDataVector();
- Producer> itemProducer = createProducer(itemVector, itemVector.getField().isNullable());
+ Producer> itemProducer =
+ createProducer(itemVector, itemVector.getField().isNullable(), dictionaries);
return new AvroListProducer(listVector, itemProducer);
case FIXED_SIZE_LIST:
FixedSizeListVector fixedListVector = (FixedSizeListVector) vector;
FieldVector fixedItemVector = fixedListVector.getDataVector();
Producer> fixedItemProducer =
- createProducer(fixedItemVector, fixedItemVector.getField().isNullable());
+ createProducer(fixedItemVector, fixedItemVector.getField().isNullable(), dictionaries);
return new AvroFixedSizeListProducer(fixedListVector, fixedItemProducer);
case MAP:
@@ -514,7 +674,7 @@ private static BaseAvroProducer> createProducer(FieldVector vector, boolean nu
FieldVector valueVector = entryVector.getChildrenFromFields().get(1);
Producer> keyProducer = new AvroStringProducer(keyVector);
Producer> valueProducer =
- createProducer(valueVector, valueVector.getField().isNullable());
+ createProducer(valueVector, valueVector.getField().isNullable(), dictionaries);
Producer> entryProducer =
new AvroStructProducer(entryVector, new Producer>[] {keyProducer, valueProducer});
return new AvroMapProducer(mapVector, entryProducer);
diff --git a/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/producers/AvroEnumProducer.java b/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/producers/AvroEnumProducer.java
index 068566493e..eebfb7d241 100644
--- a/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/producers/AvroEnumProducer.java
+++ b/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/producers/AvroEnumProducer.java
@@ -17,22 +17,22 @@
package org.apache.arrow.adapter.avro.producers;
import java.io.IOException;
-import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.BaseIntVector;
import org.apache.avro.io.Encoder;
/**
- * Producer that produces enum values from a dictionary-encoded {@link IntVector}, writes data to an
- * Avro encoder.
+ * Producer that produces enum values from a dictionary-encoded {@link BaseIntVector}, writes data
+ * to an Avro encoder.
*/
-public class AvroEnumProducer extends BaseAvroProducer {
+public class AvroEnumProducer extends BaseAvroProducer {
/** Instantiate an AvroEnumProducer. */
- public AvroEnumProducer(IntVector vector) {
+ public AvroEnumProducer(BaseIntVector vector) {
super(vector);
}
@Override
public void produce(Encoder encoder) throws IOException {
- encoder.writeEnum(vector.get(currentIndex++));
+ encoder.writeEnum((int) vector.getValueAsLong(currentIndex++));
}
}
diff --git a/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/producers/DictionaryDecodingProducer.java b/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/producers/DictionaryDecodingProducer.java
new file mode 100644
index 0000000000..afeba08511
--- /dev/null
+++ b/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/producers/DictionaryDecodingProducer.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.arrow.adapter.avro.producers;
+
+import java.io.IOException;
+import org.apache.arrow.vector.BaseIntVector;
+import org.apache.arrow.vector.FieldVector;
+import org.apache.avro.io.Encoder;
+
+/**
+ * Producer that decodes values from a dictionary-encoded {@link FieldVector}, writes the resulting
+ * values to an Avro encoder.
+ *
+ * @param Type of the underlying dictionary vector
+ */
+public class DictionaryDecodingProducer
+ extends BaseAvroProducer {
+
+ private final Producer dictProducer;
+
+ /** Instantiate a DictionaryDecodingProducer. */
+ public DictionaryDecodingProducer(BaseIntVector indexVector, Producer dictProducer) {
+ super(indexVector);
+ this.dictProducer = dictProducer;
+ }
+
+ @Override
+ public void produce(Encoder encoder) throws IOException {
+ int dicIndex = (int) vector.getValueAsLong(currentIndex++);
+ dictProducer.setPosition(dicIndex);
+ dictProducer.produce(encoder);
+ }
+}
diff --git a/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroDataTest.java b/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroDataTest.java
index 2d70b45021..6d66ee9d45 100644
--- a/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroDataTest.java
+++ b/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroDataTest.java
@@ -76,10 +76,14 @@
import org.apache.arrow.vector.complex.StructVector;
import org.apache.arrow.vector.complex.writer.BaseWriter;
import org.apache.arrow.vector.complex.writer.FieldWriter;
+import org.apache.arrow.vector.dictionary.Dictionary;
+import org.apache.arrow.vector.dictionary.DictionaryEncoder;
+import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.types.DateUnit;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.TimeUnit;
import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.util.JsonStringArrayList;
@@ -2817,4 +2821,81 @@ record = datumReader.read(record, decoder);
}
}
}
+
+ @Test
+ public void testWriteDictEnumEncoded() throws Exception {
+
+ BufferAllocator allocator = new RootAllocator();
+
+ // Create a dictionary
+ FieldType dictionaryField = new FieldType(false, new ArrowType.Utf8(), null);
+ VarCharVector dictionaryVector =
+ new VarCharVector(new Field("dictionary", dictionaryField, null), allocator);
+
+ dictionaryVector.allocateNew(3);
+ dictionaryVector.set(0, "apple".getBytes());
+ dictionaryVector.set(1, "banana".getBytes());
+ dictionaryVector.set(2, "cherry".getBytes());
+ dictionaryVector.setValueCount(3);
+
+ Dictionary dictionary =
+ new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null));
+ DictionaryProvider dictionaries = new DictionaryProvider.MapDictionaryProvider(dictionary);
+
+ // Field definition
+ FieldType stringField = new FieldType(false, new ArrowType.Utf8(), null);
+ VarCharVector stringVector =
+ new VarCharVector(new Field("enumField", stringField, null), allocator);
+ stringVector.allocateNew(10);
+ stringVector.setSafe(0, "apple".getBytes());
+ stringVector.setSafe(1, "banana".getBytes());
+ stringVector.setSafe(2, "cherry".getBytes());
+ stringVector.setSafe(3, "cherry".getBytes());
+ stringVector.setSafe(4, "apple".getBytes());
+ stringVector.setSafe(5, "banana".getBytes());
+ stringVector.setSafe(6, "apple".getBytes());
+ stringVector.setSafe(7, "cherry".getBytes());
+ stringVector.setSafe(8, "banana".getBytes());
+ stringVector.setSafe(9, "apple".getBytes());
+ stringVector.setValueCount(10);
+
+ IntVector encodedVector = (IntVector) DictionaryEncoder.encode(stringVector, dictionary);
+
+ // Set up VSR
+ List vectors = Arrays.asList(encodedVector);
+ int rowCount = 10;
+
+ try (VectorSchemaRoot root = new VectorSchemaRoot(vectors)) {
+
+ File dataFile = new File(TMP, "testWriteEnumEncoded.avro");
+
+ // Write an AVRO block using the producer classes
+ try (FileOutputStream fos = new FileOutputStream(dataFile)) {
+ BinaryEncoder encoder = new EncoderFactory().directBinaryEncoder(fos, null);
+ CompositeAvroProducer producer =
+ ArrowToAvroUtils.createCompositeProducer(vectors, dictionaries);
+ for (int row = 0; row < rowCount; row++) {
+ producer.produce(encoder);
+ }
+ encoder.flush();
+ }
+
+ // Set up reading the AVRO block as a GenericRecord
+ Schema schema = ArrowToAvroUtils.createAvroSchema(root.getSchema().getFields(), dictionaries);
+ GenericDatumReader datumReader = new GenericDatumReader<>(schema);
+
+ try (InputStream inputStream = new FileInputStream(dataFile)) {
+
+ BinaryDecoder decoder = DecoderFactory.get().binaryDecoder(inputStream, null);
+ GenericRecord record = null;
+
+ // Read and check values
+ for (int row = 0; row < rowCount; row++) {
+ record = datumReader.read(record, decoder);
+ // Values read from Avro should be the decoded enum values
+ assertEquals(stringVector.getObject(row).toString(), record.get("enumField").toString());
+ }
+ }
+ }
+ }
}
diff --git a/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroSchemaTest.java b/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroSchemaTest.java
index d3e12e763a..d5e0357a8c 100644
--- a/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroSchemaTest.java
+++ b/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroSchemaTest.java
@@ -20,11 +20,18 @@
import java.util.Arrays;
import java.util.List;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.BigIntVector;
+import org.apache.arrow.vector.VarCharVector;
+import org.apache.arrow.vector.dictionary.Dictionary;
+import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.types.DateUnit;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.TimeUnit;
import org.apache.arrow.vector.types.UnionMode;
import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.avro.LogicalTypes;
@@ -1389,4 +1396,126 @@ public void testConvertUnionTypes() {
Schema.Type.STRING,
schema.getField("nullableDenseUnionField").schema().getTypes().get(3).getType());
}
+
+ @Test
+ public void testWriteDictEnumEncoded() {
+
+ BufferAllocator allocator = new RootAllocator();
+
+ // Create a dictionary
+ FieldType dictionaryField = new FieldType(false, new ArrowType.Utf8(), null);
+ VarCharVector dictionaryVector =
+ new VarCharVector(new Field("dictionary", dictionaryField, null), allocator);
+
+ dictionaryVector.allocateNew(3);
+ dictionaryVector.set(0, "apple".getBytes());
+ dictionaryVector.set(1, "banana".getBytes());
+ dictionaryVector.set(2, "cherry".getBytes());
+ dictionaryVector.setValueCount(3);
+
+ Dictionary dictionary =
+ new Dictionary(
+ dictionaryVector, new DictionaryEncoding(0L, false, new ArrowType.Int(8, true)));
+ DictionaryProvider dictionaries = new DictionaryProvider.MapDictionaryProvider(dictionary);
+
+ List fields =
+ Arrays.asList(
+ new Field(
+ "enumField",
+ new FieldType(false, new ArrowType.Int(8, true), dictionary.getEncoding(), null),
+ null));
+
+ Schema schema = ArrowToAvroUtils.createAvroSchema(fields, "TestRecord", null, dictionaries);
+
+ assertEquals(Schema.Type.RECORD, schema.getType());
+ assertEquals(1, schema.getFields().size());
+
+ Schema.Field enumField = schema.getField("enumField");
+
+ assertEquals(Schema.Type.ENUM, enumField.schema().getType());
+ assertEquals(3, enumField.schema().getEnumSymbols().size());
+ assertEquals("apple", enumField.schema().getEnumSymbols().get(0));
+ assertEquals("banana", enumField.schema().getEnumSymbols().get(1));
+ assertEquals("cherry", enumField.schema().getEnumSymbols().get(2));
+ }
+
+ @Test
+ public void testWriteDictEnumInvalid() {
+
+ BufferAllocator allocator = new RootAllocator();
+
+ // Create a dictionary
+ FieldType dictionaryField = new FieldType(false, new ArrowType.Utf8(), null);
+ VarCharVector dictionaryVector =
+ new VarCharVector(new Field("dictionary", dictionaryField, null), allocator);
+
+ dictionaryVector.allocateNew(3);
+ dictionaryVector.set(0, "passion fruit".getBytes());
+ dictionaryVector.set(1, "banana".getBytes());
+ dictionaryVector.set(2, "cherry".getBytes());
+ dictionaryVector.setValueCount(3);
+
+ Dictionary dictionary =
+ new Dictionary(
+ dictionaryVector, new DictionaryEncoding(0L, false, new ArrowType.Int(8, true)));
+ DictionaryProvider dictionaries = new DictionaryProvider.MapDictionaryProvider(dictionary);
+
+ List fields =
+ Arrays.asList(
+ new Field(
+ "enumField",
+ new FieldType(false, new ArrowType.Int(8, true), dictionary.getEncoding(), null),
+ null));
+
+ // Dictionary field contains values that are not valid enums
+ // Should be decoded and output as a string field
+
+ Schema schema = ArrowToAvroUtils.createAvroSchema(fields, "TestRecord", null, dictionaries);
+
+ assertEquals(Schema.Type.RECORD, schema.getType());
+ assertEquals(1, schema.getFields().size());
+
+ Schema.Field enumField = schema.getField("enumField");
+ assertEquals(Schema.Type.STRING, enumField.schema().getType());
+ }
+
+ @Test
+ public void testWriteDictEnumInvalid2() {
+
+ BufferAllocator allocator = new RootAllocator();
+
+ // Create a dictionary
+ FieldType dictionaryField = new FieldType(false, new ArrowType.Int(64, true), null);
+ BigIntVector dictionaryVector =
+ new BigIntVector(new Field("dictionary", dictionaryField, null), allocator);
+
+ dictionaryVector.allocateNew(3);
+ dictionaryVector.set(0, 123L);
+ dictionaryVector.set(1, 456L);
+ dictionaryVector.set(2, 789L);
+ dictionaryVector.setValueCount(3);
+
+ Dictionary dictionary =
+ new Dictionary(
+ dictionaryVector, new DictionaryEncoding(0L, false, new ArrowType.Int(8, true)));
+ DictionaryProvider dictionaries = new DictionaryProvider.MapDictionaryProvider(dictionary);
+
+ List fields =
+ Arrays.asList(
+ new Field(
+ "enumField",
+ new FieldType(false, new ArrowType.Int(8, true), dictionary.getEncoding(), null),
+ null));
+
+ // Dictionary field encodes LONG values rather than STRING
+ // Should be doecded and output as a LONG field
+
+ Schema schema = ArrowToAvroUtils.createAvroSchema(fields, "TestRecord", null, dictionaries);
+
+ assertEquals(Schema.Type.RECORD, schema.getType());
+ assertEquals(1, schema.getFields().size());
+
+ Schema.Field enumField = schema.getField("enumField");
+ assertEquals(Schema.Type.LONG, enumField.schema().getType());
+ }
}
diff --git a/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/RoundTripDataTest.java b/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/RoundTripDataTest.java
index 85e6a960b0..ceaf59aa72 100644
--- a/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/RoundTripDataTest.java
+++ b/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/RoundTripDataTest.java
@@ -52,6 +52,7 @@
import org.apache.arrow.vector.TimeStampMilliVector;
import org.apache.arrow.vector.TimeStampNanoTZVector;
import org.apache.arrow.vector.TimeStampNanoVector;
+import org.apache.arrow.vector.TinyIntVector;
import org.apache.arrow.vector.VarBinaryVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
@@ -60,10 +61,14 @@
import org.apache.arrow.vector.complex.StructVector;
import org.apache.arrow.vector.complex.writer.BaseWriter;
import org.apache.arrow.vector.complex.writer.FieldWriter;
+import org.apache.arrow.vector.dictionary.Dictionary;
+import org.apache.arrow.vector.dictionary.DictionaryEncoder;
+import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.types.DateUnit;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.TimeUnit;
import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.avro.Schema;
@@ -78,16 +83,21 @@ public class RoundTripDataTest {
@TempDir public static File TMP;
- private static AvroToArrowConfig basicConfig(BufferAllocator allocator) {
- return new AvroToArrowConfig(allocator, 1000, null, Collections.emptySet(), false);
+ private static AvroToArrowConfig basicConfig(
+ BufferAllocator allocator, DictionaryProvider.MapDictionaryProvider dictionaries) {
+ return new AvroToArrowConfig(allocator, 1000, dictionaries, Collections.emptySet(), false);
}
private static VectorSchemaRoot readDataFile(
- Schema schema, File dataFile, BufferAllocator allocator) throws Exception {
+ Schema schema,
+ File dataFile,
+ BufferAllocator allocator,
+ DictionaryProvider.MapDictionaryProvider dictionaries)
+ throws Exception {
try (FileInputStream fis = new FileInputStream(dataFile)) {
BinaryDecoder decoder = new DecoderFactory().directBinaryDecoder(fis, null);
- return AvroToArrow.avroToArrow(schema, decoder, basicConfig(allocator));
+ return AvroToArrow.avroToArrow(schema, decoder, basicConfig(allocator, dictionaries));
}
}
@@ -95,11 +105,22 @@ private static void roundTripTest(
VectorSchemaRoot root, BufferAllocator allocator, File dataFile, int rowCount)
throws Exception {
+ roundTripTest(root, allocator, dataFile, rowCount, null);
+ }
+
+ private static void roundTripTest(
+ VectorSchemaRoot root,
+ BufferAllocator allocator,
+ File dataFile,
+ int rowCount,
+ DictionaryProvider dictionaries)
+ throws Exception {
+
// Write an AVRO block using the producer classes
try (FileOutputStream fos = new FileOutputStream(dataFile)) {
BinaryEncoder encoder = new EncoderFactory().directBinaryEncoder(fos, null);
CompositeAvroProducer producer =
- ArrowToAvroUtils.createCompositeProducer(root.getFieldVectors());
+ ArrowToAvroUtils.createCompositeProducer(root.getFieldVectors(), dictionaries);
for (int row = 0; row < rowCount; row++) {
producer.produce(encoder);
}
@@ -107,10 +128,14 @@ private static void roundTripTest(
}
// Generate AVRO schema
- Schema schema = ArrowToAvroUtils.createAvroSchema(root.getSchema().getFields());
+ Schema schema = ArrowToAvroUtils.createAvroSchema(root.getSchema().getFields(), dictionaries);
+
+ DictionaryProvider.MapDictionaryProvider roundTripDictionaries =
+ new DictionaryProvider.MapDictionaryProvider();
// Read back in and compare
- try (VectorSchemaRoot roundTrip = readDataFile(schema, dataFile, allocator)) {
+ try (VectorSchemaRoot roundTrip =
+ readDataFile(schema, dataFile, allocator, roundTripDictionaries)) {
assertEquals(root.getSchema(), roundTrip.getSchema());
assertEquals(rowCount, roundTrip.getRowCount());
@@ -119,6 +144,21 @@ private static void roundTripTest(
for (int row = 0; row < rowCount; row++) {
assertEquals(root.getVector(0).getObject(row), roundTrip.getVector(0).getObject(row));
}
+
+ if (dictionaries != null) {
+ for (long id : dictionaries.getDictionaryIds()) {
+ Dictionary originalDictionary = dictionaries.lookup(id);
+ Dictionary roundTripDictionary = roundTripDictionaries.lookup(id);
+ assertEquals(
+ originalDictionary.getVector().getValueCount(),
+ roundTripDictionary.getVector().getValueCount());
+ for (int j = 0; j < originalDictionary.getVector().getValueCount(); j++) {
+ assertEquals(
+ originalDictionary.getVector().getObject(j),
+ roundTripDictionary.getVector().getObject(j));
+ }
+ }
+ }
}
}
@@ -141,7 +181,7 @@ private static void roundTripByteArrayTest(
Schema schema = ArrowToAvroUtils.createAvroSchema(root.getSchema().getFields());
// Read back in and compare
- try (VectorSchemaRoot roundTrip = readDataFile(schema, dataFile, allocator)) {
+ try (VectorSchemaRoot roundTrip = readDataFile(schema, dataFile, allocator, null)) {
assertEquals(root.getSchema(), roundTrip.getSchema());
assertEquals(rowCount, roundTrip.getRowCount());
@@ -1603,4 +1643,58 @@ public void testRoundTripNullableStructs() throws Exception {
roundTripTest(root, allocator, dataFile, rowCount);
}
}
+
+ @Test
+ public void testRoundTripEnum() throws Exception {
+
+ BufferAllocator allocator = new RootAllocator();
+
+ // Create a dictionary
+ FieldType dictionaryField = new FieldType(false, new ArrowType.Utf8(), null);
+ VarCharVector dictionaryVector =
+ new VarCharVector(new Field("dictionary", dictionaryField, null), allocator);
+
+ dictionaryVector.allocateNew(3);
+ dictionaryVector.set(0, "apple".getBytes());
+ dictionaryVector.set(1, "banana".getBytes());
+ dictionaryVector.set(2, "cherry".getBytes());
+ dictionaryVector.setValueCount(3);
+
+ // For simplicity, ensure the index type matches what will be decoded during Avro enum decoding
+ Dictionary dictionary =
+ new Dictionary(
+ dictionaryVector, new DictionaryEncoding(0L, false, new ArrowType.Int(8, true)));
+ DictionaryProvider dictionaries = new DictionaryProvider.MapDictionaryProvider(dictionary);
+
+ // Field definition
+ FieldType stringField = new FieldType(false, new ArrowType.Utf8(), null);
+ VarCharVector stringVector =
+ new VarCharVector(new Field("enumField", stringField, null), allocator);
+ stringVector.allocateNew(10);
+ stringVector.setSafe(0, "apple".getBytes());
+ stringVector.setSafe(1, "banana".getBytes());
+ stringVector.setSafe(2, "cherry".getBytes());
+ stringVector.setSafe(3, "cherry".getBytes());
+ stringVector.setSafe(4, "apple".getBytes());
+ stringVector.setSafe(5, "banana".getBytes());
+ stringVector.setSafe(6, "apple".getBytes());
+ stringVector.setSafe(7, "cherry".getBytes());
+ stringVector.setSafe(8, "banana".getBytes());
+ stringVector.setSafe(9, "apple".getBytes());
+ stringVector.setValueCount(10);
+
+ TinyIntVector encodedVector =
+ (TinyIntVector) DictionaryEncoder.encode(stringVector, dictionary);
+
+ // Set up VSR
+ List vectors = Arrays.asList(encodedVector);
+ int rowCount = 10;
+
+ try (VectorSchemaRoot root = new VectorSchemaRoot(vectors)) {
+
+ File dataFile = new File(TMP, "testRoundTripEnums.avro");
+
+ roundTripTest(root, allocator, dataFile, rowCount, dictionaries);
+ }
+ }
}
diff --git a/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/RoundTripSchemaTest.java b/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/RoundTripSchemaTest.java
index 864e2c8b59..37c0b4d9fe 100644
--- a/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/RoundTripSchemaTest.java
+++ b/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/RoundTripSchemaTest.java
@@ -21,27 +21,50 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.VarCharVector;
+import org.apache.arrow.vector.dictionary.Dictionary;
+import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.types.DateUnit;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.TimeUnit;
import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.avro.Schema;
+import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
public class RoundTripSchemaTest {
private void doRoundTripTest(List fields) {
+ doRoundTripTest(fields, null);
+ }
- AvroToArrowConfig config = new AvroToArrowConfig(null, 1, null, Collections.emptySet(), false);
+ private void doRoundTripTest(List fields, DictionaryProvider dictionaries) {
- Schema avroSchema = ArrowToAvroUtils.createAvroSchema(fields, "TestRecord");
+ DictionaryProvider.MapDictionaryProvider decodeDictionaries =
+ new DictionaryProvider.MapDictionaryProvider();
+ AvroToArrowConfig decodeConfig =
+ new AvroToArrowConfig(null, 1, decodeDictionaries, Collections.emptySet(), false);
+
+ Schema avroSchema = ArrowToAvroUtils.createAvroSchema(fields, "TestRecord", null, dictionaries);
org.apache.arrow.vector.types.pojo.Schema arrowSchema =
- AvroToArrowUtils.createArrowSchema(avroSchema, config);
+ AvroToArrowUtils.createArrowSchema(avroSchema, decodeConfig);
// Compare string representations - equality not defined for logical types
assertEquals(fields, arrowSchema.getFields());
+
+ for (int i = 0; i < fields.size(); i++) {
+ Field field = fields.get(i);
+ Field rtField = arrowSchema.getFields().get(i);
+ if (field.getDictionary() != null) {
+ // Dictionary content is not decoded until the data is consumed
+ Assertions.assertNotNull(rtField.getDictionary());
+ }
+ }
}
// Schema round trip for primitive types, nullable and non-nullable
@@ -440,4 +463,38 @@ public void testRoundTripStructType() {
doRoundTripTest(fields);
}
+
+ @Test
+ public void testRoundTripEnumType() {
+
+ BufferAllocator allocator = new RootAllocator();
+
+ FieldType dictionaryField = new FieldType(false, new ArrowType.Utf8(), null);
+ VarCharVector dictionaryVector =
+ new VarCharVector(new Field("dictionary", dictionaryField, null), allocator);
+
+ dictionaryVector.allocateNew(3);
+ dictionaryVector.set(0, "apple".getBytes());
+ dictionaryVector.set(1, "banana".getBytes());
+ dictionaryVector.set(2, "cherry".getBytes());
+ dictionaryVector.setValueCount(3);
+
+ // For simplicity, ensure the index type matches what will be decoded during Avro enum decoding
+ Dictionary dictionary =
+ new Dictionary(
+ dictionaryVector, new DictionaryEncoding(0L, false, new ArrowType.Int(8, true)));
+ DictionaryProvider dictionaries = new DictionaryProvider.MapDictionaryProvider(dictionary);
+
+ List fields =
+ Arrays.asList(
+ new Field(
+ "enumField",
+ new FieldType(
+ true,
+ new ArrowType.Int(8, true),
+ new DictionaryEncoding(0L, false, new ArrowType.Int(8, true))),
+ null));
+
+ doRoundTripTest(fields, dictionaries);
+ }
}