From ce06e17b87148916cf89c7d4731013476d59f507 Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Tue, 27 Oct 2020 06:32:49 +0200 Subject: [PATCH] ARROW-8426: [Rust] [Parquet] - Add more support for converting Dicts This adds more support for: - When converting Arrow -> Parquet containing an Arrow Dictionary, materialize the Dictionary values and send to Parquet to be encoded with a dictionary or not according to the Parquet settings (deliberately not supporting converting an Arrow Dictionary directly to Parquet DictEncoding, and right now this only supports String dictionaries) - When converting Parquet -> Arrow, noticing that the Arrow schema metadata in a Parquet file has a Dictionary type and converting the data to an Arrow dictionary (right now this only supports String dictionaries) I'm not sure if this is in a good enough state to merge or not yet, please let me know @nevi-me ! Closes #8402 from carols10cents/dict Lead-authored-by: Carol (Nichols || Goulding) Co-authored-by: Neville Dipale Co-authored-by: Jake Goulding Signed-off-by: Neville Dipale --- rust/arrow/src/array/data.rs | 57 +++- rust/arrow/src/ipc/convert.rs | 30 +- rust/parquet/src/arrow/array_reader.rs | 269 ++++++++++-------- rust/parquet/src/arrow/arrow_writer.rs | 256 ++++++++++++++++- rust/parquet/src/arrow/converter.rs | 359 ++++++++---------------- rust/parquet/src/arrow/record_reader.rs | 20 -- 6 files changed, 593 insertions(+), 398 deletions(-) diff --git a/rust/arrow/src/array/data.rs b/rust/arrow/src/array/data.rs index 9589f73caf8ed..a1426a6fb8866 100644 --- a/rust/arrow/src/array/data.rs +++ b/rust/arrow/src/array/data.rs @@ -29,7 +29,7 @@ use crate::util::bit_util; /// An generic representation of Arrow array data which encapsulates common attributes and /// operations for Arrow array. Specific operations for different arrays types (e.g., /// primitive, list, struct) are implemented in `Array`. -#[derive(PartialEq, Debug, Clone)] +#[derive(Debug, Clone)] pub struct ArrayData { /// The data type for this array data data_type: DataType, @@ -209,6 +209,61 @@ impl ArrayData { } } +impl PartialEq for ArrayData { + fn eq(&self, other: &Self) -> bool { + assert_eq!( + self.data_type(), + other.data_type(), + "Data types not the same" + ); + assert_eq!(self.len(), other.len(), "Lengths not the same"); + // TODO: when adding tests for this, test that we can compare with arrays that have offsets + assert_eq!(self.offset(), other.offset(), "Offsets not the same"); + assert_eq!(self.null_count(), other.null_count()); + // compare buffers excluding padding + let self_buffers = self.buffers(); + let other_buffers = other.buffers(); + assert_eq!(self_buffers.len(), other_buffers.len()); + self_buffers.iter().zip(other_buffers).for_each(|(s, o)| { + compare_buffer_regions( + s, + self.offset(), // TODO mul by data length + o, + other.offset(), // TODO mul by data len + ); + }); + // assert_eq!(self.buffers(), other.buffers()); + + assert_eq!(self.child_data(), other.child_data()); + // null arrays can skip the null bitmap, thus only compare if there are no nulls + if self.null_count() != 0 || other.null_count() != 0 { + compare_buffer_regions( + self.null_buffer().unwrap(), + self.offset(), + other.null_buffer().unwrap(), + other.offset(), + ) + } + true + } +} + +/// A helper to compare buffer regions of 2 buffers. +/// Compares the length of the shorter buffer. +fn compare_buffer_regions( + left: &Buffer, + left_offset: usize, + right: &Buffer, + right_offset: usize, +) { + // for convenience, we assume that the buffer lengths are only unequal if one has padding, + // so we take the shorter length so we can discard the padding from the longer length + let shorter_len = left.len().min(right.len()); + let s_sliced = left.bit_slice(left_offset, shorter_len); + let o_sliced = right.bit_slice(right_offset, shorter_len); + assert_eq!(s_sliced, o_sliced); +} + /// Builder for `ArrayData` type #[derive(Debug)] pub struct ArrayDataBuilder { diff --git a/rust/arrow/src/ipc/convert.rs b/rust/arrow/src/ipc/convert.rs index a02b6c44dd999..63d55f043c6e9 100644 --- a/rust/arrow/src/ipc/convert.rs +++ b/rust/arrow/src/ipc/convert.rs @@ -641,17 +641,23 @@ pub(crate) fn get_fb_dictionary<'a: 'b, 'b>( fbb: &mut FlatBufferBuilder<'a>, ) -> WIPOffset> { // We assume that the dictionary index type (as an integer) has already been - // validated elsewhere, and can safely assume we are dealing with signed - // integers + // validated elsewhere, and can safely assume we are dealing with integers let mut index_builder = ipc::IntBuilder::new(fbb); - index_builder.add_is_signed(true); + match *index_type { - Int8 => index_builder.add_bitWidth(8), - Int16 => index_builder.add_bitWidth(16), - Int32 => index_builder.add_bitWidth(32), - Int64 => index_builder.add_bitWidth(64), + Int8 | Int16 | Int32 | Int64 => index_builder.add_is_signed(true), + UInt8 | UInt16 | UInt32 | UInt64 => index_builder.add_is_signed(false), _ => {} } + + match *index_type { + Int8 | UInt8 => index_builder.add_bitWidth(8), + Int16 | UInt16 => index_builder.add_bitWidth(16), + Int32 | UInt32 => index_builder.add_bitWidth(32), + Int64 | UInt64 => index_builder.add_bitWidth(64), + _ => {} + } + let index_builder = index_builder.finish(); let mut builder = ipc::DictionaryEncodingBuilder::new(fbb); @@ -773,6 +779,16 @@ mod tests { 123, true, ), + Field::new_dict( + "dictionary", + DataType::Dictionary( + Box::new(DataType::UInt8), + Box::new(DataType::UInt32), + ), + true, + 123, + true, + ), ], md, ); diff --git a/rust/parquet/src/arrow/array_reader.rs b/rust/parquet/src/arrow/array_reader.rs index ad2b84a3923cc..76b672bb301c7 100644 --- a/rust/parquet/src/arrow/array_reader.rs +++ b/rust/parquet/src/arrow/array_reader.rs @@ -33,20 +33,20 @@ use arrow::array::{ }; use arrow::buffer::{Buffer, MutableBuffer}; use arrow::datatypes::{ - BooleanType as ArrowBooleanType, DataType as ArrowType, - Date32Type as ArrowDate32Type, Date64Type as ArrowDate64Type, DateUnit, + ArrowPrimitiveType, BooleanType as ArrowBooleanType, DataType as ArrowType, + Date32Type as ArrowDate32Type, Date64Type as ArrowDate64Type, DurationMicrosecondType as ArrowDurationMicrosecondType, DurationMillisecondType as ArrowDurationMillisecondType, DurationNanosecondType as ArrowDurationNanosecondType, DurationSecondType as ArrowDurationSecondType, Field, Float32Type as ArrowFloat32Type, Float64Type as ArrowFloat64Type, Int16Type as ArrowInt16Type, Int32Type as ArrowInt32Type, - Int64Type as ArrowInt64Type, Int8Type as ArrowInt8Type, IntervalUnit, Schema, + Int64Type as ArrowInt64Type, Int8Type as ArrowInt8Type, Schema, Time32MillisecondType as ArrowTime32MillisecondType, Time32SecondType as ArrowTime32SecondType, Time64MicrosecondType as ArrowTime64MicrosecondType, - Time64NanosecondType as ArrowTime64NanosecondType, TimeUnit, - TimeUnit as ArrowTimeUnit, TimestampMicrosecondType as ArrowTimestampMicrosecondType, + Time64NanosecondType as ArrowTime64NanosecondType, TimeUnit as ArrowTimeUnit, + TimestampMicrosecondType as ArrowTimestampMicrosecondType, TimestampMillisecondType as ArrowTimestampMillisecondType, TimestampNanosecondType as ArrowTimestampNanosecondType, TimestampSecondType as ArrowTimestampSecondType, ToByteSlice, @@ -56,15 +56,10 @@ use arrow::datatypes::{ use arrow::util::bit_util; use crate::arrow::converter::{ - BinaryArrayConverter, BinaryConverter, BoolConverter, BooleanArrayConverter, - Converter, Date32Converter, FixedLenBinaryConverter, FixedSizeArrayConverter, - Float32Converter, Float64Converter, Int16Converter, Int32Converter, Int64Converter, - Int8Converter, Int96ArrayConverter, Int96Converter, LargeBinaryArrayConverter, - LargeBinaryConverter, LargeUtf8ArrayConverter, LargeUtf8Converter, - Time32MillisecondConverter, Time32SecondConverter, Time64MicrosecondConverter, - Time64NanosecondConverter, TimestampMicrosecondConverter, - TimestampMillisecondConverter, UInt16Converter, UInt32Converter, UInt64Converter, - UInt8Converter, Utf8ArrayConverter, Utf8Converter, + BinaryArrayConverter, BinaryConverter, Converter, FixedLenBinaryConverter, + FixedSizeArrayConverter, Int96ArrayConverter, Int96Converter, + LargeBinaryArrayConverter, LargeBinaryConverter, LargeUtf8ArrayConverter, + LargeUtf8Converter, Utf8ArrayConverter, Utf8Converter, }; use crate::arrow::record_reader::RecordReader; use crate::arrow::schema::parquet_to_arrow_field; @@ -212,10 +207,15 @@ impl PrimitiveArrayReader { pub fn new( mut pages: Box, column_desc: ColumnDescPtr, + arrow_type: Option, ) -> Result { - let data_type = parquet_to_arrow_field(column_desc.as_ref())? - .data_type() - .clone(); + // Check if Arrow type is specified, else create it from Parquet type + let data_type = match arrow_type { + Some(t) => t, + None => parquet_to_arrow_field(column_desc.as_ref())? + .data_type() + .clone(), + }; let mut record_reader = RecordReader::::new(column_desc.clone()); if let Some(page_reader) = pages.next() { @@ -267,90 +267,79 @@ impl ArrayReader for PrimitiveArrayReader { } } - // convert to arrays - let array = - match (&self.data_type, T::get_physical_type()) { - (ArrowType::Boolean, PhysicalType::BOOLEAN) => { - BoolConverter::new(BooleanArrayConverter {}) - .convert(self.record_reader.cast::()) - } - (ArrowType::Int8, PhysicalType::INT32) => { - Int8Converter::new().convert(self.record_reader.cast::()) - } - (ArrowType::Int16, PhysicalType::INT32) => { - Int16Converter::new().convert(self.record_reader.cast::()) - } - (ArrowType::Int32, PhysicalType::INT32) => { - Int32Converter::new().convert(self.record_reader.cast::()) - } - (ArrowType::UInt8, PhysicalType::INT32) => { - UInt8Converter::new().convert(self.record_reader.cast::()) - } - (ArrowType::UInt16, PhysicalType::INT32) => { - UInt16Converter::new().convert(self.record_reader.cast::()) - } - (ArrowType::UInt32, PhysicalType::INT32) => { - UInt32Converter::new().convert(self.record_reader.cast::()) - } - (ArrowType::Int64, PhysicalType::INT64) => { - Int64Converter::new().convert(self.record_reader.cast::()) - } - (ArrowType::UInt64, PhysicalType::INT64) => { - UInt64Converter::new().convert(self.record_reader.cast::()) - } - (ArrowType::Float32, PhysicalType::FLOAT) => Float32Converter::new() - .convert(self.record_reader.cast::()), - (ArrowType::Float64, PhysicalType::DOUBLE) => Float64Converter::new() - .convert(self.record_reader.cast::()), - (ArrowType::Timestamp(unit, _), PhysicalType::INT64) => match unit { - TimeUnit::Millisecond => TimestampMillisecondConverter::new() - .convert(self.record_reader.cast::()), - TimeUnit::Microsecond => TimestampMicrosecondConverter::new() - .convert(self.record_reader.cast::()), - _ => Err(general_err!("No conversion from parquet type to arrow type for timestamp with unit {:?}", unit)), - }, - (ArrowType::Date32(unit), PhysicalType::INT32) => match unit { - DateUnit::Day => Date32Converter::new() - .convert(self.record_reader.cast::()), - _ => Err(general_err!("No conversion from parquet type to arrow type for date with unit {:?}", unit)), - } - (ArrowType::Time32(unit), PhysicalType::INT32) => { - match unit { - TimeUnit::Second => { - Time32SecondConverter::new().convert(self.record_reader.cast::()) - } - TimeUnit::Millisecond => { - Time32MillisecondConverter::new().convert(self.record_reader.cast::()) - } - _ => Err(general_err!("Invalid or unsupported arrow array with datatype {:?}", self.get_data_type())) - } - } - (ArrowType::Time64(unit), PhysicalType::INT64) => { - match unit { - TimeUnit::Microsecond => { - Time64MicrosecondConverter::new().convert(self.record_reader.cast::()) - } - TimeUnit::Nanosecond => { - Time64NanosecondConverter::new().convert(self.record_reader.cast::()) - } - _ => Err(general_err!("Invalid or unsupported arrow array with datatype {:?}", self.get_data_type())) - } - } - (ArrowType::Interval(IntervalUnit::YearMonth), PhysicalType::INT32) => { - UInt32Converter::new().convert(self.record_reader.cast::()) - } - (ArrowType::Interval(IntervalUnit::DayTime), PhysicalType::INT64) => { - UInt64Converter::new().convert(self.record_reader.cast::()) - } - (ArrowType::Duration(_), PhysicalType::INT64) => { - UInt64Converter::new().convert(self.record_reader.cast::()) - } - (arrow_type, physical_type) => Err(general_err!( - "Reading {:?} type from parquet {:?} is not supported yet.", - arrow_type, - physical_type - )), - }?; + let arrow_data_type = match T::get_physical_type() { + PhysicalType::BOOLEAN => ArrowBooleanType::DATA_TYPE, + PhysicalType::INT32 => ArrowInt32Type::DATA_TYPE, + PhysicalType::INT64 => ArrowInt64Type::DATA_TYPE, + PhysicalType::FLOAT => ArrowFloat32Type::DATA_TYPE, + PhysicalType::DOUBLE => ArrowFloat64Type::DATA_TYPE, + PhysicalType::INT96 + | PhysicalType::BYTE_ARRAY + | PhysicalType::FIXED_LEN_BYTE_ARRAY => { + unreachable!( + "PrimitiveArrayReaders don't support complex physical types" + ); + } + }; + + // Convert to arrays by using the Parquet phyisical type. + // The physical types are then cast to Arrow types if necessary + + let mut record_data = self.record_reader.consume_record_data()?; + + if T::get_physical_type() == PhysicalType::BOOLEAN { + let mut boolean_buffer = BooleanBufferBuilder::new(record_data.len()); + + for e in record_data.data() { + boolean_buffer.append(*e > 0)?; + } + record_data = boolean_buffer.finish(); + } + + let mut array_data = ArrayDataBuilder::new(arrow_data_type) + .len(self.record_reader.num_values()) + .add_buffer(record_data); + + if let Some(b) = self.record_reader.consume_bitmap_buffer()? { + array_data = array_data.null_bit_buffer(b); + } + + let array = match T::get_physical_type() { + PhysicalType::BOOLEAN => { + Arc::new(PrimitiveArray::::from(array_data.build())) + as ArrayRef + } + PhysicalType::INT32 => { + Arc::new(PrimitiveArray::::from(array_data.build())) + as ArrayRef + } + PhysicalType::INT64 => { + Arc::new(PrimitiveArray::::from(array_data.build())) + as ArrayRef + } + PhysicalType::FLOAT => { + Arc::new(PrimitiveArray::::from(array_data.build())) + as ArrayRef + } + PhysicalType::DOUBLE => { + Arc::new(PrimitiveArray::::from(array_data.build())) + as ArrayRef + } + PhysicalType::INT96 + | PhysicalType::BYTE_ARRAY + | PhysicalType::FIXED_LEN_BYTE_ARRAY => { + unreachable!( + "PrimitiveArrayReaders don't support complex physical types" + ); + } + }; + + // cast to Arrow type + // TODO: we need to check if it's fine for this to be fallible. + // My assumption is that we can't get to an illegal cast as we can only + // generate types that are supported, because we'd have gotten them from + // the metadata which was written to the Parquet sink + let array = arrow::compute::cast(&array, self.get_data_type())?; // save definition and repetition buffers self.def_levels_buffer = self.record_reader.consume_def_levels()?; @@ -503,7 +492,13 @@ where data_buffer.into_iter().map(Some).collect() }; - self.converter.convert(data) + let mut array = self.converter.convert(data)?; + + if let ArrowType::Dictionary(_, _) = self.data_type { + array = arrow::compute::cast(&array, &self.data_type)?; + } + + Ok(array) } fn get_def_levels(&self) -> Option<&[i16]> { @@ -524,10 +519,14 @@ where pages: Box, column_desc: ColumnDescPtr, converter: C, + arrow_type: Option, ) -> Result { - let data_type = parquet_to_arrow_field(column_desc.as_ref())? - .data_type() - .clone(); + let data_type = match arrow_type { + Some(t) => t, + None => parquet_to_arrow_field(column_desc.as_ref())? + .data_type() + .clone(), + }; Ok(Self { data_type, @@ -1437,12 +1436,14 @@ impl<'a> ArrayReaderBuilder { .arrow_schema .field_with_name(cur_type.name()) .ok() - .map(|f| f.data_type()); + .map(|f| f.data_type()) + .cloned(); match cur_type.get_physical_type() { PhysicalType::BOOLEAN => Ok(Box::new(PrimitiveArrayReader::::new( page_iterator, column_desc, + arrow_type, )?)), PhysicalType::INT32 => { if let Some(ArrowType::Null) = arrow_type { @@ -1454,12 +1455,14 @@ impl<'a> ArrayReaderBuilder { Ok(Box::new(PrimitiveArrayReader::::new( page_iterator, column_desc, + arrow_type, )?)) } } PhysicalType::INT64 => Ok(Box::new(PrimitiveArrayReader::::new( page_iterator, column_desc, + arrow_type, )?)), PhysicalType::INT96 => { let converter = Int96Converter::new(Int96ArrayConverter {}); @@ -1467,16 +1470,24 @@ impl<'a> ArrayReaderBuilder { Int96Type, Int96Converter, >::new( - page_iterator, column_desc, converter + page_iterator, + column_desc, + converter, + arrow_type, )?)) } PhysicalType::FLOAT => Ok(Box::new(PrimitiveArrayReader::::new( page_iterator, column_desc, + arrow_type, )?)), - PhysicalType::DOUBLE => Ok(Box::new( - PrimitiveArrayReader::::new(page_iterator, column_desc)?, - )), + PhysicalType::DOUBLE => { + Ok(Box::new(PrimitiveArrayReader::::new( + page_iterator, + column_desc, + arrow_type, + )?)) + } PhysicalType::BYTE_ARRAY => { if cur_type.get_basic_info().logical_type() == LogicalType::UTF8 { if let Some(ArrowType::LargeUtf8) = arrow_type { @@ -1486,7 +1497,10 @@ impl<'a> ArrayReaderBuilder { ByteArrayType, LargeUtf8Converter, >::new( - page_iterator, column_desc, converter + page_iterator, + column_desc, + converter, + arrow_type, )?)) } else { let converter = Utf8Converter::new(Utf8ArrayConverter {}); @@ -1494,7 +1508,10 @@ impl<'a> ArrayReaderBuilder { ByteArrayType, Utf8Converter, >::new( - page_iterator, column_desc, converter + page_iterator, + column_desc, + converter, + arrow_type, )?)) } } else if let Some(ArrowType::LargeBinary) = arrow_type { @@ -1504,7 +1521,10 @@ impl<'a> ArrayReaderBuilder { ByteArrayType, LargeBinaryConverter, >::new( - page_iterator, column_desc, converter + page_iterator, + column_desc, + converter, + arrow_type, )?)) } else { let converter = BinaryConverter::new(BinaryArrayConverter {}); @@ -1512,7 +1532,10 @@ impl<'a> ArrayReaderBuilder { ByteArrayType, BinaryConverter, >::new( - page_iterator, column_desc, converter + page_iterator, + column_desc, + converter, + arrow_type, )?)) } } @@ -1534,7 +1557,10 @@ impl<'a> ArrayReaderBuilder { FixedLenByteArrayType, FixedLenBinaryConverter, >::new( - page_iterator, column_desc, converter + page_iterator, + column_desc, + converter, + arrow_type, )?)) } } @@ -1671,9 +1697,12 @@ mod tests { let column_desc = schema.column(0); let page_iterator = EmptyPageIterator::new(schema); - let mut array_reader = - PrimitiveArrayReader::::new(Box::new(page_iterator), column_desc) - .unwrap(); + let mut array_reader = PrimitiveArrayReader::::new( + Box::new(page_iterator), + column_desc, + None, + ) + .unwrap(); // expect no values to be read let array = array_reader.next_batch(50).unwrap(); @@ -1718,6 +1747,7 @@ mod tests { let mut array_reader = PrimitiveArrayReader::::new( Box::new(page_iterator), column_desc, + None, ) .unwrap(); @@ -1801,6 +1831,7 @@ mod tests { let mut array_reader = PrimitiveArrayReader::<$arrow_parquet_type>::new( Box::new(page_iterator), column_desc.clone(), + None, ) .expect("Unable to get array reader"); @@ -1934,6 +1965,7 @@ mod tests { let mut array_reader = PrimitiveArrayReader::::new( Box::new(page_iterator), column_desc, + None, ) .unwrap(); @@ -2047,6 +2079,7 @@ mod tests { Box::new(page_iterator), column_desc, converter, + None, ) .unwrap(); diff --git a/rust/parquet/src/arrow/arrow_writer.rs b/rust/parquet/src/arrow/arrow_writer.rs index d5e2db40fea27..09f004107e37b 100644 --- a/rust/parquet/src/arrow/arrow_writer.rs +++ b/rust/parquet/src/arrow/arrow_writer.rs @@ -25,7 +25,7 @@ use arrow::record_batch::RecordBatch; use arrow_array::Array; use super::schema::add_encoded_arrow_schema_to_metadata; -use crate::column::writer::ColumnWriter; +use crate::column::writer::{ColumnWriter, ColumnWriterImpl}; use crate::errors::{ParquetError, Result}; use crate::file::properties::WriterProperties; use crate::{ @@ -176,19 +176,175 @@ fn write_leaves( } Ok(()) } + ArrowDataType::Dictionary(key_type, value_type) => { + use arrow_array::{ + Int16DictionaryArray, Int32DictionaryArray, Int64DictionaryArray, + Int8DictionaryArray, PrimitiveArray, StringArray, UInt16DictionaryArray, + UInt32DictionaryArray, UInt64DictionaryArray, UInt8DictionaryArray, + }; + use ArrowDataType::*; + use ColumnWriter::*; + + let array = &**array; + let mut col_writer = get_col_writer(&mut row_group_writer)?; + let levels = levels.pop().expect("Levels exhausted"); + + macro_rules! dispatch_dictionary { + ($($kt: pat, $vt: pat, $w: ident => $kat: ty, $vat: ty,)*) => ( + match (&**key_type, &**value_type, &mut col_writer) { + $(($kt, $vt, $w(writer)) => write_dict::<$kat, $vat, _>(array, writer, levels),)* + (kt, vt, _) => unreachable!("Shouldn't be attempting to write dictionary of <{:?}, {:?}>", kt, vt), + } + ); + } + + match (&**key_type, &**value_type, &mut col_writer) { + (UInt8, UInt32, Int32ColumnWriter(writer)) => { + let typed_array = array + .as_any() + .downcast_ref::() + .expect("Unable to get dictionary array"); + + let keys = typed_array.keys(); + + let value_buffer = typed_array.values(); + let value_array = + arrow::compute::cast(&value_buffer, &ArrowDataType::Int32)?; + + let values = value_array + .as_any() + .downcast_ref::() + .unwrap(); + + use std::convert::TryFrom; + // This removes NULL values from the NullableIter, but + // they're encoded by the levels, so that's fine. + let materialized_values: Vec<_> = keys + .flatten() + .map(|key| { + usize::try_from(key).unwrap_or_else(|k| { + panic!("key {} does not fit in usize", k) + }) + }) + .map(|key| values.value(key)) + .collect(); + + let materialized_primitive_array = + PrimitiveArray::::from( + materialized_values, + ); + + writer.write_batch( + get_numeric_array_slice::( + &materialized_primitive_array, + ) + .as_slice(), + Some(levels.definition.as_slice()), + levels.repetition.as_deref(), + )?; + row_group_writer.close_column(col_writer)?; + + return Ok(()); + } + _ => {} + } + + dispatch_dictionary!( + Int8, Utf8, ByteArrayColumnWriter => Int8DictionaryArray, StringArray, + Int16, Utf8, ByteArrayColumnWriter => Int16DictionaryArray, StringArray, + Int32, Utf8, ByteArrayColumnWriter => Int32DictionaryArray, StringArray, + Int64, Utf8, ByteArrayColumnWriter => Int64DictionaryArray, StringArray, + UInt8, Utf8, ByteArrayColumnWriter => UInt8DictionaryArray, StringArray, + UInt16, Utf8, ByteArrayColumnWriter => UInt16DictionaryArray, StringArray, + UInt32, Utf8, ByteArrayColumnWriter => UInt32DictionaryArray, StringArray, + UInt64, Utf8, ByteArrayColumnWriter => UInt64DictionaryArray, StringArray, + )?; + + row_group_writer.close_column(col_writer)?; + + Ok(()) + } ArrowDataType::Float16 => Err(ParquetError::ArrowError( "Float16 arrays not supported".to_string(), )), ArrowDataType::FixedSizeList(_, _) | ArrowDataType::Boolean | ArrowDataType::FixedSizeBinary(_) - | ArrowDataType::Union(_) - | ArrowDataType::Dictionary(_, _) => Err(ParquetError::NYI( + | ArrowDataType::Union(_) => Err(ParquetError::NYI( "Attempting to write an Arrow type that is not yet implemented".to_string(), )), } } +trait Materialize { + type Output; + + // Materialize the packed dictionary. The writer will later repack it. + fn materialize(&self) -> Vec; +} + +macro_rules! materialize_string { + ($($k:ty,)*) => { + $(impl Materialize<$k, arrow_array::StringArray> for dyn Array { + type Output = ByteArray; + + fn materialize(&self) -> Vec { + use std::convert::TryFrom; + + let typed_array = self.as_any() + .downcast_ref::<$k>() + .expect("Unable to get dictionary array"); + + let keys = typed_array.keys(); + + let value_buffer = typed_array.values(); + let values = value_buffer + .as_any() + .downcast_ref::() + .unwrap(); + + // This removes NULL values from the NullableIter, but + // they're encoded by the levels, so that's fine. + keys + .flatten() + .map(|key| usize::try_from(key).unwrap_or_else(|k| panic!("key {} does not fit in usize", k))) + .map(|key| values.value(key)) + .map(ByteArray::from) + .collect() + } + })* + }; +} + +materialize_string! { + arrow_array::Int8DictionaryArray, + arrow_array::Int16DictionaryArray, + arrow_array::Int32DictionaryArray, + arrow_array::Int64DictionaryArray, + arrow_array::UInt8DictionaryArray, + arrow_array::UInt16DictionaryArray, + arrow_array::UInt32DictionaryArray, + arrow_array::UInt64DictionaryArray, +} + +fn write_dict( + array: &(dyn Array + 'static), + writer: &mut ColumnWriterImpl, + levels: Levels, +) -> Result<()> +where + T: DataType, + dyn Array: Materialize, +{ + writer.write_batch( + &array.materialize(), + Some(levels.definition.as_slice()), + levels.repetition.as_deref(), + )?; + + Ok(()) +} + fn write_leaf( writer: &mut ColumnWriter, column: &arrow_array::ArrayRef, @@ -430,7 +586,15 @@ fn get_levels( struct_levels } ArrowDataType::Union(_) => unimplemented!(), - ArrowDataType::Dictionary(_, _) => unimplemented!(), + ArrowDataType::Dictionary(_, _) => { + // Need to check for these cases not implemented in C++: + // - "Writing DictionaryArray with nested dictionary type not yet supported" + // - "Writing DictionaryArray with null encoded in dictionary type not yet supported" + vec![Levels { + definition: get_primitive_def_levels(array, parent_def_levels), + repetition: None, + }] + } } } @@ -501,7 +665,7 @@ mod tests { use arrow::array::*; use arrow::datatypes::ToByteSlice; - use arrow::datatypes::{DataType, Field, Schema}; + use arrow::datatypes::{DataType, Field, Schema, UInt32Type, UInt8Type}; use arrow::record_batch::RecordBatch; use crate::arrow::{ArrowReader, ParquetFileArrowReader}; @@ -1118,4 +1282,86 @@ mod tests { let values = Arc::new(s); one_column_roundtrip("struct_single_column", values, false); } + + #[test] + fn arrow_writer_string_dictionary() { + // define schema + let schema = Arc::new(Schema::new(vec![Field::new_dict( + "dictionary", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + 42, + true, + )])); + + // create some data + let d: Int32DictionaryArray = [Some("alpha"), None, Some("beta"), Some("alpha")] + .iter() + .copied() + .collect(); + + // build a record batch + let expected_batch = RecordBatch::try_new(schema, vec![Arc::new(d)]).unwrap(); + + roundtrip( + "test_arrow_writer_string_dictionary.parquet", + expected_batch, + ); + } + + #[test] + fn arrow_writer_primitive_dictionary() { + // define schema + let schema = Arc::new(Schema::new(vec![Field::new_dict( + "dictionary", + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::UInt32)), + true, + 42, + true, + )])); + + // create some data + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = PrimitiveBuilder::::new(2); + let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); + builder.append(12345678).unwrap(); + builder.append_null().unwrap(); + builder.append(22345678).unwrap(); + builder.append(12345678).unwrap(); + let d = builder.finish(); + + // build a record batch + let expected_batch = RecordBatch::try_new(schema, vec![Arc::new(d)]).unwrap(); + + roundtrip( + "test_arrow_writer_primitive_dictionary.parquet", + expected_batch, + ); + } + + #[test] + fn arrow_writer_string_dictionary_unsigned_index() { + // define schema + let schema = Arc::new(Schema::new(vec![Field::new_dict( + "dictionary", + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)), + true, + 42, + true, + )])); + + // create some data + let d: UInt8DictionaryArray = [Some("alpha"), None, Some("beta"), Some("alpha")] + .iter() + .copied() + .collect(); + + // build a record batch + let expected_batch = RecordBatch::try_new(schema, vec![Arc::new(d)]).unwrap(); + + roundtrip( + "test_arrow_writer_string_dictionary_unsigned_index.parquet", + expected_batch, + ); + } } diff --git a/rust/parquet/src/arrow/converter.rs b/rust/parquet/src/arrow/converter.rs index 1aceba2d08742..33b29c897e6b7 100644 --- a/rust/parquet/src/arrow/converter.rs +++ b/rust/parquet/src/arrow/converter.rs @@ -15,43 +15,28 @@ // specific language governing permissions and limitations // under the License. -use crate::arrow::record_reader::RecordReader; use crate::data_type::{ByteArray, DataType, Int96}; // TODO: clean up imports (best done when there are few moving parts) -use arrow::{ - array::{ - Array, ArrayRef, BinaryBuilder, BooleanArray, BooleanBufferBuilder, - BufferBuilderTrait, FixedSizeBinaryBuilder, LargeBinaryBuilder, - LargeStringBuilder, StringBuilder, TimestampNanosecondBuilder, - }, - datatypes::Time32MillisecondType, -}; -use arrow::{ - compute::cast, datatypes::Time32SecondType, datatypes::Time64MicrosecondType, - datatypes::Time64NanosecondType, +use arrow::array::{ + Array, ArrayRef, BinaryBuilder, FixedSizeBinaryBuilder, LargeBinaryBuilder, + LargeStringBuilder, PrimitiveBuilder, PrimitiveDictionaryBuilder, StringBuilder, + StringDictionaryBuilder, TimestampNanosecondBuilder, }; +use arrow::compute::cast; use std::convert::From; use std::sync::Arc; use crate::errors::Result; -use arrow::datatypes::{ArrowPrimitiveType, DataType as ArrowDataType}; +use arrow::datatypes::{ArrowDictionaryKeyType, ArrowPrimitiveType}; -use arrow::array::ArrayDataBuilder; use arrow::array::{ - BinaryArray, FixedSizeBinaryArray, LargeBinaryArray, LargeStringArray, - PrimitiveArray, StringArray, TimestampNanosecondArray, + BinaryArray, DictionaryArray, FixedSizeBinaryArray, LargeBinaryArray, + LargeStringArray, PrimitiveArray, StringArray, TimestampNanosecondArray, }; use std::marker::PhantomData; -use crate::data_type::{ - BoolType, DoubleType as ParquetDoubleType, FloatType as ParquetFloatType, - Int32Type as ParquetInt32Type, Int64Type as ParquetInt64Type, -}; -use arrow::datatypes::{ - Date32Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, - TimestampMicrosecondType, TimestampMillisecondType, UInt16Type, UInt32Type, - UInt64Type, UInt8Type, -}; +use crate::data_type::Int32Type as ParquetInt32Type; +use arrow::datatypes::Int32Type; /// A converter is used to consume record reader's content and convert it to arrow /// primitive array. @@ -62,83 +47,6 @@ pub trait Converter { fn convert(&self, source: S) -> Result; } -/// Cast converter first converts record reader's buffer to arrow's -/// `PrimitiveArray`, then casts it to `PrimitiveArray`. -pub struct CastConverter { - _parquet_marker: PhantomData, - _arrow_source_marker: PhantomData, - _arrow_target_marker: PhantomData, -} - -impl - CastConverter -where - ParquetType: DataType, - ArrowSourceType: ArrowPrimitiveType, - ArrowTargetType: ArrowPrimitiveType, -{ - pub fn new() -> Self { - Self { - _parquet_marker: PhantomData, - _arrow_source_marker: PhantomData, - _arrow_target_marker: PhantomData, - } - } -} - -impl - Converter<&mut RecordReader, ArrayRef> - for CastConverter -where - ParquetType: DataType, - ArrowSourceType: ArrowPrimitiveType, - ArrowTargetType: ArrowPrimitiveType, -{ - fn convert(&self, record_reader: &mut RecordReader) -> Result { - let record_data = record_reader.consume_record_data(); - - let mut array_data = ArrayDataBuilder::new(ArrowSourceType::DATA_TYPE) - .len(record_reader.num_values()) - .add_buffer(record_data?); - - if let Some(b) = record_reader.consume_bitmap_buffer()? { - array_data = array_data.null_bit_buffer(b); - } - - let primitive_array: ArrayRef = - Arc::new(PrimitiveArray::::from(array_data.build())); - - Ok(cast(&primitive_array, &ArrowTargetType::DATA_TYPE)?) - } -} - -pub struct BooleanArrayConverter {} - -impl Converter<&mut RecordReader, BooleanArray> for BooleanArrayConverter { - fn convert( - &self, - record_reader: &mut RecordReader, - ) -> Result { - let record_data = record_reader.consume_record_data()?; - - let mut boolean_buffer = BooleanBufferBuilder::new(record_data.len()); - - for e in record_data.data() { - boolean_buffer.append(*e > 0)?; - } - - let mut array_data = ArrayDataBuilder::new(ArrowDataType::Boolean) - .len(record_data.len()) - .add_buffer(boolean_buffer.finish()); - - if let Some(b) = record_reader.consume_bitmap_buffer()? { - array_data = array_data.null_bit_buffer(b); - } - - Ok(BooleanArray::from(array_data.build())) - } -} - pub struct FixedSizeArrayConverter { byte_width: i32, } @@ -253,34 +161,92 @@ impl Converter>, LargeBinaryArray> for LargeBinaryArrayCon } } -pub type BoolConverter<'a> = ArrayRefConverter< - &'a mut RecordReader, - BooleanArray, - BooleanArrayConverter, ->; -pub type Int8Converter = CastConverter; -pub type UInt8Converter = CastConverter; -pub type Int16Converter = CastConverter; -pub type UInt16Converter = CastConverter; -pub type Int32Converter = CastConverter; -pub type UInt32Converter = CastConverter; -pub type Int64Converter = CastConverter; -pub type Date32Converter = CastConverter; -pub type TimestampMillisecondConverter = - CastConverter; -pub type TimestampMicrosecondConverter = - CastConverter; -pub type Time32SecondConverter = - CastConverter; -pub type Time32MillisecondConverter = - CastConverter; -pub type Time64MicrosecondConverter = - CastConverter; -pub type Time64NanosecondConverter = - CastConverter; -pub type UInt64Converter = CastConverter; -pub type Float32Converter = CastConverter; -pub type Float64Converter = CastConverter; +pub struct StringDictionaryArrayConverter {} + +impl Converter>, DictionaryArray> + for StringDictionaryArrayConverter +{ + fn convert(&self, source: Vec>) -> Result> { + let data_size = source + .iter() + .map(|x| x.as_ref().map(|b| b.len()).unwrap_or(0)) + .sum(); + + let keys_builder = PrimitiveBuilder::::new(source.len()); + let values_builder = StringBuilder::with_capacity(source.len(), data_size); + + let mut builder = StringDictionaryBuilder::new(keys_builder, values_builder); + for v in source { + match v { + Some(array) => { + let _ = builder.append(array.as_utf8()?)?; + } + None => builder.append_null()?, + } + } + + Ok(builder.finish()) + } +} + +pub struct DictionaryArrayConverter +{ + _dict_value_source_marker: PhantomData, + _dict_value_target_marker: PhantomData, + _parquet_marker: PhantomData, +} + +impl + DictionaryArrayConverter +{ + pub fn new() -> Self { + Self { + _dict_value_source_marker: PhantomData, + _dict_value_target_marker: PhantomData, + _parquet_marker: PhantomData, + } + } +} + +impl + Converter::T>>, DictionaryArray> + for DictionaryArrayConverter +where + K: ArrowPrimitiveType, + DictValueSourceType: ArrowPrimitiveType, + DictValueTargetType: ArrowPrimitiveType, + ParquetType: DataType, + PrimitiveArray: From::T>>>, +{ + fn convert( + &self, + source: Vec::T>>, + ) -> Result> { + let keys_builder = PrimitiveBuilder::::new(source.len()); + let values_builder = PrimitiveBuilder::::new(source.len()); + + let mut builder = PrimitiveDictionaryBuilder::new(keys_builder, values_builder); + + let source_array: Arc = + Arc::new(PrimitiveArray::::from(source)); + let target_array = cast(&source_array, &DictValueTargetType::DATA_TYPE)?; + let target = target_array + .as_any() + .downcast_ref::>() + .unwrap(); + + for i in 0..target.len() { + if target.is_null(i) { + builder.append_null()?; + } else { + let _ = builder.append(target.value(i))?; + } + } + + Ok(builder.finish()) + } +} + pub type Utf8Converter = ArrayRefConverter>, StringArray, Utf8ArrayConverter>; pub type LargeUtf8Converter = @@ -292,6 +258,22 @@ pub type LargeBinaryConverter = ArrayRefConverter< LargeBinaryArray, LargeBinaryArrayConverter, >; +pub type StringDictionaryConverter = ArrayRefConverter< + Vec>, + DictionaryArray, + StringDictionaryArrayConverter, +>; +pub type DictionaryConverter = ArrayRefConverter< + Vec::T>>, + DictionaryArray, + DictionaryArrayConverter, +>; +pub type PrimitiveDictionaryConverter = ArrayRefConverter< + Vec::T>>, + DictionaryArray, + DictionaryArrayConverter, +>; + pub type Int96Converter = ArrayRefConverter>, TimestampNanosecondArray, Int96ArrayConverter>; pub type FixedLenBinaryConverter = ArrayRefConverter< @@ -357,120 +339,3 @@ where .map(|array| Arc::new(array) as ArrayRef) } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::arrow::converter::Int16Converter; - use crate::arrow::record_reader::RecordReader; - use crate::basic::Encoding; - use crate::schema::parser::parse_message_type; - use crate::schema::types::SchemaDescriptor; - use crate::util::test_common::page_util::InMemoryPageReader; - use crate::util::test_common::page_util::{DataPageBuilder, DataPageBuilderImpl}; - use arrow::array::ArrayEqual; - use arrow::array::PrimitiveArray; - use arrow::datatypes::{Int16Type, Int32Type}; - use std::rc::Rc; - - macro_rules! converter_arrow_source_target { - ($raw_data:expr, $physical_type:expr, $result_arrow_type:ty, $converter:ty) => {{ - // Construct record reader - let mut record_reader = { - // Construct column schema - let message_type = &format!( - " - message test_schema {{ - OPTIONAL {} leaf; - }} - ", - $physical_type - ); - - let def_levels = [1i16, 0i16, 1i16, 1i16]; - build_record_reader( - message_type, - &[1, 2, 3], - 0i16, - None, - 1i16, - Some(&def_levels), - 10, - ) - }; - - let array = <$converter>::new().convert(&mut record_reader).unwrap(); - let array = array - .as_any() - .downcast_ref::>() - .unwrap(); - - assert!(array.equals(&PrimitiveArray::<$result_arrow_type>::from($raw_data))); - }}; - } - - #[test] - fn test_converter_arrow_source_i16_target_i32() { - let raw_data = vec![Some(1i16), None, Some(2i16), Some(3i16)]; - converter_arrow_source_target!(raw_data, "INT32", Int16Type, Int16Converter) - } - - #[test] - fn test_converter_arrow_source_i32_target_date32() { - let raw_data = vec![Some(1i32), None, Some(2i32), Some(3i32)]; - converter_arrow_source_target!(raw_data, "INT32", Date32Type, Date32Converter) - } - - #[test] - fn test_converter_arrow_source_i32_target_i32() { - let raw_data = vec![Some(1i32), None, Some(2i32), Some(3i32)]; - converter_arrow_source_target!(raw_data, "INT32", Int32Type, Int32Converter) - } - - fn build_record_reader( - message_type: &str, - values: &[T::T], - max_rep_level: i16, - rep_levels: Option<&[i16]>, - max_def_level: i16, - def_levels: Option<&[i16]>, - num_records: usize, - ) -> RecordReader { - let desc = parse_message_type(message_type) - .map(|t| SchemaDescriptor::new(Rc::new(t))) - .map(|s| s.column(0)) - .unwrap(); - - let mut record_reader = RecordReader::::new(desc.clone()); - - // Prepare record reader - let mut pb = DataPageBuilderImpl::new(desc, 4, true); - if rep_levels.is_some() { - pb.add_rep_levels( - max_rep_level, - match rep_levels { - Some(a) => a, - _ => unreachable!(), - }, - ); - } - if def_levels.is_some() { - pb.add_def_levels( - max_def_level, - match def_levels { - Some(a) => a, - _ => unreachable!(), - }, - ); - } - pb.add_values::(Encoding::PLAIN, &values); - let page = pb.consume(); - - let page_reader = Box::new(InMemoryPageReader::new(vec![page])); - record_reader.set_page_reader(page_reader).unwrap(); - - record_reader.read_records(num_records).unwrap(); - - record_reader - } -} diff --git a/rust/parquet/src/arrow/record_reader.rs b/rust/parquet/src/arrow/record_reader.rs index b30ab7760b275..519bd15fb0c2f 100644 --- a/rust/parquet/src/arrow/record_reader.rs +++ b/rust/parquet/src/arrow/record_reader.rs @@ -124,26 +124,6 @@ impl RecordReader { } } - pub(crate) fn cast(&mut self) -> &mut RecordReader { - trait CastRecordReader { - fn cast(&mut self) -> &mut RecordReader; - } - - impl CastRecordReader for RecordReader { - default fn cast(&mut self) -> &mut RecordReader { - panic!("Attempted to cast RecordReader to the wrong type") - } - } - - impl CastRecordReader for RecordReader { - fn cast(&mut self) -> &mut RecordReader { - self - } - } - - CastRecordReader::::cast(self) - } - /// Set the current page reader. pub fn set_page_reader(&mut self, page_reader: Box) -> Result<()> { self.column_reader =