From 5dafa4d0267f52a0cfa7b9384a7f77e53cd2cd6e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 5 Jan 2023 00:08:06 -0800 Subject: [PATCH] Support Decimal256 in ffi (#3453) --- arrow-pyarrow-integration-testing/tests/test_sql.py | 2 +- arrow/src/datatypes/ffi.rs | 13 ++++++++++--- arrow/src/ffi.rs | 3 +++ 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py b/arrow-pyarrow-integration-testing/tests/test_sql.py index c97dad77ea1d..98564408d937 100644 --- a/arrow-pyarrow-integration-testing/tests/test_sql.py +++ b/arrow-pyarrow-integration-testing/tests/test_sql.py @@ -63,6 +63,7 @@ def assert_pyarrow_leak(): pa.float32(), pa.float64(), pa.decimal128(19, 4), + pa.decimal256(76, 38), pa.string(), pa.binary(), pa.binary(10), @@ -110,7 +111,6 @@ def assert_pyarrow_leak(): ] _unsupported_pyarrow_types = [ - pa.decimal256(76, 38), ] diff --git a/arrow/src/datatypes/ffi.rs b/arrow/src/datatypes/ffi.rs index 58fc8858ad75..37fa85fcf5dd 100644 --- a/arrow/src/datatypes/ffi.rs +++ b/arrow/src/datatypes/ffi.rs @@ -112,8 +112,8 @@ impl TryFrom<&FFI_ArrowSchema> for DataType { DataType::Decimal128(parsed_precision, parsed_scale) }, [precision, scale, bits] => { - if *bits != "128" { - return Err(ArrowError::CDataInterface("Only 128 bit wide decimal is supported in the Rust implementation".to_string())); + if *bits != "128" && *bits != "256" { + return Err(ArrowError::CDataInterface("Only 128/256 bit wide decimal is supported in the Rust implementation".to_string())); } let parsed_precision = precision.parse::().map_err(|_| { ArrowError::CDataInterface( @@ -125,7 +125,11 @@ impl TryFrom<&FFI_ArrowSchema> for DataType { "The decimal type requires an integer scale".to_string(), ) })?; - DataType::Decimal128(parsed_precision, parsed_scale) + if *bits == "128" { + DataType::Decimal128(parsed_precision, parsed_scale) + } else { + DataType::Decimal256(parsed_precision, parsed_scale) + } } _ => { return Err(ArrowError::CDataInterface(format!( @@ -305,6 +309,9 @@ fn get_format_string(dtype: &DataType) -> Result { DataType::Decimal128(precision, scale) => { Ok(format!("d:{},{}", precision, scale)) } + DataType::Decimal256(precision, scale) => { + Ok(format!("d:{},{},256", precision, scale)) + } DataType::Date32 => Ok("tdD".to_string()), DataType::Date64 => Ok("tdm".to_string()), DataType::Time32(TimeUnit::Second) => Ok("tts".to_string()), diff --git a/arrow/src/ffi.rs b/arrow/src/ffi.rs index 5e9b01b5c6b0..4111b858d050 100644 --- a/arrow/src/ffi.rs +++ b/arrow/src/ffi.rs @@ -120,6 +120,7 @@ use std::{ sync::Arc, }; +use arrow_buffer::i256; use arrow_schema::UnionMode; use bitflags::bitflags; @@ -324,6 +325,7 @@ fn bit_width(data_type: &DataType, i: usize) -> Result { (DataType::Float32, 1) => size_of::() * 8, (DataType::Float64, 1) => size_of::() * 8, (DataType::Decimal128(..), 1) => size_of::() * 8, + (DataType::Decimal256(..), 1) => size_of::() * 8, (DataType::Timestamp(..), 1) => size_of::() * 8, (DataType::Duration(..), 1) => size_of::() * 8, // primitive types have a single buffer @@ -339,6 +341,7 @@ fn bit_width(data_type: &DataType, i: usize) -> Result { (DataType::Float32, _) | (DataType::Float64, _) | (DataType::Decimal128(..), _) | + (DataType::Decimal256(..), _) | (DataType::Timestamp(..), _) | (DataType::Duration(..), _) => { return Err(ArrowError::CDataInterface(format!(