Skip to content
Merged

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,22 @@
package org.apache.arrow.adapter.avro.producers;

import java.io.IOException;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.BaseIntVector;
import org.apache.avro.io.Encoder;

/**
* Producer that produces enum values from a dictionary-encoded {@link IntVector}, writes data to an
* Avro encoder.
* Producer that produces enum values from a dictionary-encoded {@link BaseIntVector}, writes data
* to an Avro encoder.
*/
public class AvroEnumProducer extends BaseAvroProducer<IntVector> {
public class AvroEnumProducer extends BaseAvroProducer<BaseIntVector> {

/** Instantiate an AvroEnumProducer. */
public AvroEnumProducer(IntVector vector) {
public AvroEnumProducer(BaseIntVector vector) {
super(vector);
}

@Override
public void produce(Encoder encoder) throws IOException {
encoder.writeEnum(vector.get(currentIndex++));
encoder.writeEnum((int) vector.getValueAsLong(currentIndex++));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.arrow.adapter.avro.producers;

import java.io.IOException;
import org.apache.arrow.vector.BaseIntVector;
import org.apache.arrow.vector.FieldVector;
import org.apache.avro.io.Encoder;

/**
* Producer that decodes values from a dictionary-encoded {@link FieldVector}, writes the resulting
* values to an Avro encoder.
*
* @param <T> Type of the underlying dictionary vector
*/
public class DictionaryDecodingProducer<T extends FieldVector>
extends BaseAvroProducer<BaseIntVector> {

private final Producer<T> dictProducer;

/** Instantiate a DictionaryDecodingProducer. */
public DictionaryDecodingProducer(BaseIntVector indexVector, Producer<T> dictProducer) {
super(indexVector);
this.dictProducer = dictProducer;
}

@Override
public void produce(Encoder encoder) throws IOException {
int dicIndex = (int) vector.getValueAsLong(currentIndex++);
dictProducer.setPosition(dicIndex);
dictProducer.produce(encoder);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,14 @@
import org.apache.arrow.vector.complex.StructVector;
import org.apache.arrow.vector.complex.writer.BaseWriter;
import org.apache.arrow.vector.complex.writer.FieldWriter;
import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.dictionary.DictionaryEncoder;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.types.DateUnit;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.TimeUnit;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.util.JsonStringArrayList;
Expand Down Expand Up @@ -2817,4 +2821,81 @@ record = datumReader.read(record, decoder);
}
}
}

@Test
public void testWriteDictEnumEncoded() throws Exception {

BufferAllocator allocator = new RootAllocator();

// Create a dictionary
FieldType dictionaryField = new FieldType(false, new ArrowType.Utf8(), null);
VarCharVector dictionaryVector =
new VarCharVector(new Field("dictionary", dictionaryField, null), allocator);

dictionaryVector.allocateNew(3);
dictionaryVector.set(0, "apple".getBytes());
dictionaryVector.set(1, "banana".getBytes());
dictionaryVector.set(2, "cherry".getBytes());
dictionaryVector.setValueCount(3);

Dictionary dictionary =
new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null));
DictionaryProvider dictionaries = new DictionaryProvider.MapDictionaryProvider(dictionary);

// Field definition
FieldType stringField = new FieldType(false, new ArrowType.Utf8(), null);
VarCharVector stringVector =
new VarCharVector(new Field("enumField", stringField, null), allocator);
stringVector.allocateNew(10);
stringVector.setSafe(0, "apple".getBytes());
stringVector.setSafe(1, "banana".getBytes());
stringVector.setSafe(2, "cherry".getBytes());
stringVector.setSafe(3, "cherry".getBytes());
stringVector.setSafe(4, "apple".getBytes());
stringVector.setSafe(5, "banana".getBytes());
stringVector.setSafe(6, "apple".getBytes());
stringVector.setSafe(7, "cherry".getBytes());
stringVector.setSafe(8, "banana".getBytes());
stringVector.setSafe(9, "apple".getBytes());
stringVector.setValueCount(10);

IntVector encodedVector = (IntVector) DictionaryEncoder.encode(stringVector, dictionary);

// Set up VSR
List<FieldVector> vectors = Arrays.asList(encodedVector);
int rowCount = 10;

try (VectorSchemaRoot root = new VectorSchemaRoot(vectors)) {

File dataFile = new File(TMP, "testWriteEnumEncoded.avro");

// Write an AVRO block using the producer classes
try (FileOutputStream fos = new FileOutputStream(dataFile)) {
BinaryEncoder encoder = new EncoderFactory().directBinaryEncoder(fos, null);
CompositeAvroProducer producer =
ArrowToAvroUtils.createCompositeProducer(vectors, dictionaries);
for (int row = 0; row < rowCount; row++) {
producer.produce(encoder);
}
encoder.flush();
}

// Set up reading the AVRO block as a GenericRecord
Schema schema = ArrowToAvroUtils.createAvroSchema(root.getSchema().getFields(), dictionaries);
GenericDatumReader<GenericRecord> datumReader = new GenericDatumReader<>(schema);

try (InputStream inputStream = new FileInputStream(dataFile)) {

BinaryDecoder decoder = DecoderFactory.get().binaryDecoder(inputStream, null);
GenericRecord record = null;

// Read and check values
for (int row = 0; row < rowCount; row++) {
record = datumReader.read(record, decoder);
// Values read from Avro should be the decoded enum values
assertEquals(stringVector.getObject(row).toString(), record.get("enumField").toString());
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,18 @@

import java.util.Arrays;
import java.util.List;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.types.DateUnit;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.TimeUnit;
import org.apache.arrow.vector.types.UnionMode;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.avro.LogicalTypes;
Expand Down Expand Up @@ -1389,4 +1396,126 @@ public void testConvertUnionTypes() {
Schema.Type.STRING,
schema.getField("nullableDenseUnionField").schema().getTypes().get(3).getType());
}

@Test
public void testWriteDictEnumEncoded() {

BufferAllocator allocator = new RootAllocator();

// Create a dictionary
FieldType dictionaryField = new FieldType(false, new ArrowType.Utf8(), null);
VarCharVector dictionaryVector =
new VarCharVector(new Field("dictionary", dictionaryField, null), allocator);

dictionaryVector.allocateNew(3);
dictionaryVector.set(0, "apple".getBytes());
dictionaryVector.set(1, "banana".getBytes());
dictionaryVector.set(2, "cherry".getBytes());
dictionaryVector.setValueCount(3);

Dictionary dictionary =
new Dictionary(
dictionaryVector, new DictionaryEncoding(0L, false, new ArrowType.Int(8, true)));
DictionaryProvider dictionaries = new DictionaryProvider.MapDictionaryProvider(dictionary);

List<Field> fields =
Arrays.asList(
new Field(
"enumField",
new FieldType(false, new ArrowType.Int(8, true), dictionary.getEncoding(), null),
null));

Schema schema = ArrowToAvroUtils.createAvroSchema(fields, "TestRecord", null, dictionaries);

assertEquals(Schema.Type.RECORD, schema.getType());
assertEquals(1, schema.getFields().size());

Schema.Field enumField = schema.getField("enumField");

assertEquals(Schema.Type.ENUM, enumField.schema().getType());
assertEquals(3, enumField.schema().getEnumSymbols().size());
assertEquals("apple", enumField.schema().getEnumSymbols().get(0));
assertEquals("banana", enumField.schema().getEnumSymbols().get(1));
assertEquals("cherry", enumField.schema().getEnumSymbols().get(2));
}

@Test
public void testWriteDictEnumInvalid() {

BufferAllocator allocator = new RootAllocator();

// Create a dictionary
FieldType dictionaryField = new FieldType(false, new ArrowType.Utf8(), null);
VarCharVector dictionaryVector =
new VarCharVector(new Field("dictionary", dictionaryField, null), allocator);

dictionaryVector.allocateNew(3);
dictionaryVector.set(0, "passion fruit".getBytes());
dictionaryVector.set(1, "banana".getBytes());
dictionaryVector.set(2, "cherry".getBytes());
dictionaryVector.setValueCount(3);

Dictionary dictionary =
new Dictionary(
dictionaryVector, new DictionaryEncoding(0L, false, new ArrowType.Int(8, true)));
DictionaryProvider dictionaries = new DictionaryProvider.MapDictionaryProvider(dictionary);

List<Field> fields =
Arrays.asList(
new Field(
"enumField",
new FieldType(false, new ArrowType.Int(8, true), dictionary.getEncoding(), null),
null));

// Dictionary field contains values that are not valid enums
// Should be decoded and output as a string field

Schema schema = ArrowToAvroUtils.createAvroSchema(fields, "TestRecord", null, dictionaries);

assertEquals(Schema.Type.RECORD, schema.getType());
assertEquals(1, schema.getFields().size());

Schema.Field enumField = schema.getField("enumField");
assertEquals(Schema.Type.STRING, enumField.schema().getType());
}

@Test
public void testWriteDictEnumInvalid2() {

BufferAllocator allocator = new RootAllocator();

// Create a dictionary
FieldType dictionaryField = new FieldType(false, new ArrowType.Int(64, true), null);
BigIntVector dictionaryVector =
new BigIntVector(new Field("dictionary", dictionaryField, null), allocator);

dictionaryVector.allocateNew(3);
dictionaryVector.set(0, 123L);
dictionaryVector.set(1, 456L);
dictionaryVector.set(2, 789L);
dictionaryVector.setValueCount(3);

Dictionary dictionary =
new Dictionary(
dictionaryVector, new DictionaryEncoding(0L, false, new ArrowType.Int(8, true)));
DictionaryProvider dictionaries = new DictionaryProvider.MapDictionaryProvider(dictionary);

List<Field> fields =
Arrays.asList(
new Field(
"enumField",
new FieldType(false, new ArrowType.Int(8, true), dictionary.getEncoding(), null),
null));

// Dictionary field encodes LONG values rather than STRING
// Should be doecded and output as a LONG field

Schema schema = ArrowToAvroUtils.createAvroSchema(fields, "TestRecord", null, dictionaries);

assertEquals(Schema.Type.RECORD, schema.getType());
assertEquals(1, schema.getFields().size());

Schema.Field enumField = schema.getField("enumField");
assertEquals(Schema.Type.LONG, enumField.schema().getType());
}
}
Loading
Loading