diff --git a/arrow/src/array/array_decimal.rs b/arrow/src/array/array_decimal.rs index 186d0a2f678a..5ab9097e1ca7 100644 --- a/arrow/src/array/array_decimal.rs +++ b/arrow/src/array/array_decimal.rs @@ -17,6 +17,7 @@ use crate::array::{ArrayAccessor, Decimal128Iter, Decimal256Iter}; use num::BigInt; +use num::FromPrimitive; use std::borrow::Borrow; use std::convert::From; use std::fmt; @@ -30,7 +31,8 @@ use super::{BooleanBufferBuilder, FixedSizeBinaryArray}; pub use crate::array::DecimalIter; use crate::buffer::{Buffer, MutableBuffer}; use crate::datatypes::{ - validate_decimal_precision, DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE, + validate_decimal256_precision, validate_decimal_precision, DECIMAL256_MAX_PRECISION, + DECIMAL256_MAX_SCALE, DECIMAL_DEFAULT_SCALE, }; use crate::datatypes::{DataType, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE}; use crate::error::{ArrowError, Result}; @@ -95,6 +97,8 @@ pub trait BasicDecimalArray>: { const VALUE_LENGTH: i32; const DEFAULT_TYPE: DataType; + const MAX_PRECISION: usize; + const MAX_SCALE: usize; fn data(&self) -> &ArrayData; @@ -246,12 +250,72 @@ pub trait BasicDecimalArray>: fn default_type() -> DataType { Self::DEFAULT_TYPE } + + /// Returns a Decimal array with the same data as self, with the + /// specified precision. + /// + /// Returns an Error if: + /// 1. `precision` is larger than [`Self::MAX_PRECISION`] + /// 2. `scale` is larger than [`Self::MAX_SCALE`]; + /// 3. `scale` is > `precision` + fn with_precision_and_scale(self, precision: usize, scale: usize) -> Result + where + Self: Sized, + { + if precision > Self::MAX_PRECISION { + return Err(ArrowError::InvalidArgumentError(format!( + "precision {} is greater than max {}", + precision, + Self::MAX_PRECISION + ))); + } + if scale > Self::MAX_SCALE { + return Err(ArrowError::InvalidArgumentError(format!( + "scale {} is greater than max {}", + scale, + Self::MAX_SCALE + ))); + } + if scale > precision { + return Err(ArrowError::InvalidArgumentError(format!( + "scale {} is greater than precision {}", + scale, precision + ))); + } + + // Ensure that all values are within the requested + // precision. For performance, only check if the precision is + // decreased + self.validate_decimal_precision(precision)?; + + let data_type = if Self::VALUE_LENGTH == 16 { + DataType::Decimal128(self.precision(), self.scale()) + } else { + DataType::Decimal256(self.precision(), self.scale()) + }; + assert_eq!(self.data().data_type(), &data_type); + + // safety: self.data is valid DataType::Decimal as checked above + let new_data_type = if Self::VALUE_LENGTH == 16 { + DataType::Decimal128(precision, scale) + } else { + DataType::Decimal256(precision, scale) + }; + + Ok(self.data().clone().with_data_type(new_data_type).into()) + } + + /// Validates decimal values in this array can be properly interpreted + /// with the specified precision. + fn validate_decimal_precision(&self, precision: usize) -> Result<()>; } impl BasicDecimalArray for Decimal128Array { const VALUE_LENGTH: i32 = 16; const DEFAULT_TYPE: DataType = DataType::Decimal128(DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE); + const MAX_PRECISION: usize = DECIMAL128_MAX_PRECISION; + const MAX_SCALE: usize = DECIMAL128_MAX_SCALE; fn data(&self) -> &ArrayData { &self.data @@ -264,12 +328,23 @@ impl BasicDecimalArray for Decimal128Array { fn scale(&self) -> usize { self.scale } + + fn validate_decimal_precision(&self, precision: usize) -> Result<()> { + if precision < self.precision { + for v in self.iter().flatten() { + validate_decimal_precision(v.as_i128(), precision)?; + } + } + Ok(()) + } } impl BasicDecimalArray for Decimal256Array { const VALUE_LENGTH: i32 = 32; const DEFAULT_TYPE: DataType = DataType::Decimal256(DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE); + const MAX_PRECISION: usize = DECIMAL256_MAX_PRECISION; + const MAX_SCALE: usize = DECIMAL256_MAX_SCALE; fn data(&self) -> &ArrayData { &self.data @@ -282,6 +357,15 @@ impl BasicDecimalArray for Decimal256Array { fn scale(&self) -> usize { self.scale } + + fn validate_decimal_precision(&self, precision: usize) -> Result<()> { + if precision < self.precision { + for v in self.iter().flatten() { + validate_decimal256_precision(&v.to_string(), precision)?; + } + } + Ok(()) + } } impl Decimal128Array { @@ -302,59 +386,6 @@ impl Decimal128Array { }; Decimal128Array::from(data) } - - /// Returns a Decimal128Array with the same data as self, with the - /// specified precision. - /// - /// Returns an Error if: - /// 1. `precision` is larger than [`DECIMAL128_MAX_PRECISION`] - /// 2. `scale` is larger than [`DECIMAL128_MAX_SCALE`]; - /// 3. `scale` is > `precision` - pub fn with_precision_and_scale( - mut self, - precision: usize, - scale: usize, - ) -> Result { - if precision > DECIMAL128_MAX_PRECISION { - return Err(ArrowError::InvalidArgumentError(format!( - "precision {} is greater than max {}", - precision, DECIMAL128_MAX_PRECISION - ))); - } - if scale > DECIMAL128_MAX_SCALE { - return Err(ArrowError::InvalidArgumentError(format!( - "scale {} is greater than max {}", - scale, DECIMAL128_MAX_SCALE - ))); - } - if scale > precision { - return Err(ArrowError::InvalidArgumentError(format!( - "scale {} is greater than precision {}", - scale, precision - ))); - } - - // Ensure that all values are within the requested - // precision. For performance, only check if the precision is - // decreased - if precision < self.precision { - for v in self.iter().flatten() { - validate_decimal_precision(v.as_i128(), precision)?; - } - } - - assert_eq!( - self.data.data_type(), - &DataType::Decimal128(self.precision, self.scale) - ); - - // safety: self.data is valid DataType::Decimal as checked above - let new_data_type = DataType::Decimal128(precision, scale); - self.precision = precision; - self.scale = scale; - self.data = self.data.with_data_type(new_data_type); - Ok(self) - } } impl From for Decimal128Array { @@ -438,6 +469,13 @@ where U::from(data) } +/// Useful converter for usages like cast kernel +impl From for Decimal256 { + fn from(integer: i128) -> Self { + Decimal256::from(BigInt::from_i128(integer).unwrap()) + } +} + impl> FromIterator> for Decimal256Array { fn from_iter>>(iter: I) -> Self { let iter = iter.into_iter(); diff --git a/arrow/src/array/equal/mod.rs b/arrow/src/array/equal/mod.rs index 6fdc06f837c0..3387e2842264 100644 --- a/arrow/src/array/equal/mod.rs +++ b/arrow/src/array/equal/mod.rs @@ -262,6 +262,7 @@ mod tests { use std::convert::TryFrom; use std::sync::Arc; + use crate::array::BasicDecimalArray; use crate::array::{ array::Array, ArrayData, ArrayDataBuilder, ArrayRef, BooleanArray, FixedSizeBinaryBuilder, FixedSizeListBuilder, GenericBinaryArray, Int32Builder, diff --git a/arrow/src/array/transform/mod.rs b/arrow/src/array/transform/mod.rs index 3664a2055210..564ef444a1dd 100644 --- a/arrow/src/array/transform/mod.rs +++ b/arrow/src/array/transform/mod.rs @@ -675,6 +675,7 @@ mod tests { use super::*; + use crate::array::BasicDecimalArray; use crate::array::Decimal128Array; use crate::{ array::{ @@ -708,7 +709,7 @@ mod tests { fn test_decimal() { let decimal_array = create_decimal_array(&[Some(1), Some(2), None, Some(3)], 10, 3); - let arrays = vec![decimal_array.data()]; + let arrays = vec![Array::data(&decimal_array)]; let mut a = MutableArrayData::new(arrays, true, 3); a.extend(0, 0, 3); a.extend(0, 2, 3); diff --git a/arrow/src/ffi.rs b/arrow/src/ffi.rs index 2d529317801d..4073e7183175 100644 --- a/arrow/src/ffi.rs +++ b/arrow/src/ffi.rs @@ -908,6 +908,7 @@ impl<'a> ArrowArrayChild<'a> { #[cfg(test)] mod tests { use super::*; + use crate::array::BasicDecimalArray; use crate::array::{ export_array_into_raw, make_array, Array, ArrayData, BooleanArray, Decimal128Array, DictionaryArray, DurationSecondArray, FixedSizeBinaryArray, @@ -953,7 +954,7 @@ mod tests { .unwrap(); // export it - let array = ArrowArray::try_from(original_array.data().clone())?; + let array = ArrowArray::try_from(Array::data(&original_array).clone())?; // (simulate consumer) import it let data = ArrayData::try_from(array)?; diff --git a/arrow/src/util/pretty.rs b/arrow/src/util/pretty.rs index e92b0366ae1e..84d445e9a1f8 100644 --- a/arrow/src/util/pretty.rs +++ b/arrow/src/util/pretty.rs @@ -19,9 +19,8 @@ //! available unless `feature = "prettyprint"` is enabled. use crate::{array::ArrayRef, record_batch::RecordBatch}; -use std::fmt::Display; - use comfy_table::{Cell, Table}; +use std::fmt::Display; use crate::error::Result; @@ -108,7 +107,7 @@ fn create_column(field: &str, columns: &[ArrayRef]) -> Result { mod tests { use crate::{ array::{ - self, new_null_array, Array, Date32Array, Date64Array, + self, new_null_array, Array, BasicDecimalArray, Date32Array, Date64Array, FixedSizeBinaryBuilder, Float16Array, Int32Array, PrimitiveBuilder, StringArray, StringBuilder, StringDictionaryBuilder, StructArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, diff --git a/parquet/src/arrow/array_reader/primitive_array.rs b/parquet/src/arrow/array_reader/primitive_array.rs index 89f2ce51bef8..aa5d564eee4a 100644 --- a/parquet/src/arrow/array_reader/primitive_array.rs +++ b/parquet/src/arrow/array_reader/primitive_array.rs @@ -25,8 +25,8 @@ use crate::data_type::DataType; use crate::errors::{ParquetError, Result}; use crate::schema::types::ColumnDescPtr; use arrow::array::{ - ArrayDataBuilder, ArrayRef, BooleanArray, BooleanBufferBuilder, Decimal128Array, - Float32Array, Float64Array, Int32Array, Int64Array, + ArrayDataBuilder, ArrayRef, BasicDecimalArray, BooleanArray, BooleanBufferBuilder, + Decimal128Array, Float32Array, Float64Array, Int32Array, Int64Array, }; use arrow::buffer::Buffer; use arrow::datatypes::DataType as ArrowType; diff --git a/parquet/src/arrow/buffer/converter.rs b/parquet/src/arrow/buffer/converter.rs index 93609308d2ba..4cd0589424fc 100644 --- a/parquet/src/arrow/buffer/converter.rs +++ b/parquet/src/arrow/buffer/converter.rs @@ -17,11 +17,11 @@ use crate::data_type::{ByteArray, FixedLenByteArray, Int96}; use arrow::array::{ - Array, ArrayRef, BinaryArray, BinaryBuilder, Decimal128Array, FixedSizeBinaryArray, - FixedSizeBinaryBuilder, IntervalDayTimeArray, IntervalDayTimeBuilder, - IntervalYearMonthArray, IntervalYearMonthBuilder, LargeBinaryArray, - LargeBinaryBuilder, LargeStringArray, LargeStringBuilder, StringArray, StringBuilder, - TimestampNanosecondArray, + Array, ArrayRef, BasicDecimalArray, BinaryArray, BinaryBuilder, Decimal128Array, + FixedSizeBinaryArray, FixedSizeBinaryBuilder, IntervalDayTimeArray, + IntervalDayTimeBuilder, IntervalYearMonthArray, IntervalYearMonthBuilder, + LargeBinaryArray, LargeBinaryBuilder, LargeStringArray, LargeStringBuilder, + StringArray, StringBuilder, TimestampNanosecondArray, }; use std::convert::{From, TryInto}; use std::sync::Arc;