diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py b/arrow-pyarrow-integration-testing/tests/test_sql.py index 2da732b746a..373649472f9 100644 --- a/arrow-pyarrow-integration-testing/tests/test_sql.py +++ b/arrow-pyarrow-integration-testing/tests/test_sql.py @@ -17,6 +17,7 @@ # under the License. import unittest +import decimal import pyarrow import arrow_pyarrow_integration_testing @@ -90,6 +91,19 @@ def test_string_roundtrip(self): c = pyarrow.array(["a", None, "ccc"]) self.assertEqual(b, c) + def test_decimal_roundtrip(self): + """ + Python -> Rust -> Python + """ + data = [ + round(decimal.Decimal(722.82), 2), + round(decimal.Decimal(-934.11), 2), + None + ] + a = pyarrow.array(data, pyarrow.decimal128(5, 2)) + b = arrow_pyarrow_integration_testing.round_trip(a) + self.assertEqual(a, b) + def test_string_python(self): """ Python -> Rust -> Python diff --git a/src/array/fixed_size_list/mod.rs b/src/array/fixed_size_list/mod.rs index 2b5c392303e..df5877162db 100644 --- a/src/array/fixed_size_list/mod.rs +++ b/src/array/fixed_size_list/mod.rs @@ -133,4 +133,8 @@ unsafe impl ToFfi for FixedSizeListArray { fn offset(&self) -> usize { self.offset } + + fn children(&self) -> Vec> { + vec![self.values().clone()] + } } diff --git a/src/ffi/array.rs b/src/ffi/array.rs index 65f708cf2eb..9c2e438cbc3 100644 --- a/src/ffi/array.rs +++ b/src/ffi/array.rs @@ -81,6 +81,7 @@ pub fn try_from(array: A) -> Result> { mod tests { use super::*; use crate::array::*; + use crate::datatypes::TimeUnit; use crate::{error::Result, ffi}; use std::sync::Arc; @@ -146,6 +147,15 @@ mod tests { test_round_trip(data) } + #[test] + fn test_timestamp_tz() -> Result<()> { + let data = Int64Array::from(&vec![Some(2), None, None]).to(DataType::Timestamp( + TimeUnit::Second, + Some("UTC".to_string()), + )); + test_round_trip(data) + } + #[test] fn test_large_binary() -> Result<()> { let data = diff --git a/src/ffi/ffi.rs b/src/ffi/ffi.rs index 725dc42cbec..444a0478fb9 100644 --- a/src/ffi/ffi.rs +++ b/src/ffi/ffi.rs @@ -29,7 +29,7 @@ use crate::{ bytes::{Bytes, Deallocation}, Buffer, }, - datatypes::{DataType, Field, TimeUnit}, + datatypes::{DataType, Field, IntervalUnit, TimeUnit}, error::{ArrowError, Result}, types::NativeType, }; @@ -200,6 +200,12 @@ fn to_field(schema: &Ffi_ArrowSchema) -> Result { "ttm" => DataType::Time32(TimeUnit::Millisecond), "ttu" => DataType::Time64(TimeUnit::Microsecond), "ttn" => DataType::Time64(TimeUnit::Nanosecond), + "tDs" => DataType::Duration(TimeUnit::Second), + "tDm" => DataType::Duration(TimeUnit::Millisecond), + "tDu" => DataType::Duration(TimeUnit::Microsecond), + "tDn" => DataType::Duration(TimeUnit::Nanosecond), + "tiM" => DataType::Interval(IntervalUnit::YearMonth), + "tiD" => DataType::Interval(IntervalUnit::DayTime), "+l" => { let child = schema.child(0); DataType::List(Box::new(to_field(child)?)) @@ -215,10 +221,43 @@ fn to_field(schema: &Ffi_ArrowSchema) -> Result { DataType::Struct(children) } other => { - return Err(ArrowError::Ffi(format!( - "The datatype \"{}\" is still not supported in Rust implementation", - other - ))) + let parts = other.split(':').collect::>(); + if parts.len() == 2 && parts[0] == "tss" { + DataType::Timestamp(TimeUnit::Second, Some(parts[1].to_string())) + } else if parts.len() == 2 && parts[0] == "tsm" { + DataType::Timestamp(TimeUnit::Millisecond, Some(parts[1].to_string())) + } else if parts.len() == 2 && parts[0] == "tsu" { + DataType::Timestamp(TimeUnit::Microsecond, Some(parts[1].to_string())) + } else if parts.len() == 2 && parts[0] == "tsn" { + DataType::Timestamp(TimeUnit::Nanosecond, Some(parts[1].to_string())) + } else if parts.len() == 2 && parts[0] == "d" { + let parts = parts[1].split(',').collect::>(); + if parts.len() < 2 || parts.len() > 3 { + return Err(ArrowError::Ffi( + "Decimal must contain 2 or 3 comma-separated values".to_string(), + )); + }; + if parts.len() == 3 { + let bit_width = parts[0].parse::().map_err(|_| { + ArrowError::Ffi("Decimal bit width is not a valid integer".to_string()) + })?; + if bit_width != 128 { + return Err(ArrowError::Ffi("Decimal256 is not supported".to_string())); + } + } + let precision = parts[0].parse::().map_err(|_| { + ArrowError::Ffi("Decimal precision is not a valid integer".to_string()) + })?; + let scale = parts[1].parse::().map_err(|_| { + ArrowError::Ffi("Decimal scale is not a valid integer".to_string()) + })?; + DataType::Decimal(precision, scale) + } else { + return Err(ArrowError::Ffi(format!( + "The datatype \"{}\" is still not supported in Rust implementation", + other + ))); + } } }; Ok(Field::new(schema.name(), data_type, schema.nullable())) @@ -250,15 +289,33 @@ fn to_format(data_type: &DataType) -> Result { DataType::Time32(TimeUnit::Millisecond) => "ttm", DataType::Time64(TimeUnit::Microsecond) => "ttu", DataType::Time64(TimeUnit::Nanosecond) => "ttn", + DataType::Duration(TimeUnit::Second) => "tDs", + DataType::Duration(TimeUnit::Millisecond) => "tDm", + DataType::Duration(TimeUnit::Microsecond) => "tDu", + DataType::Duration(TimeUnit::Nanosecond) => "tDn", + DataType::Interval(IntervalUnit::YearMonth) => "tiM", + DataType::Interval(IntervalUnit::DayTime) => "tiD", + DataType::Timestamp(unit, tz) => { + let unit = match unit { + TimeUnit::Second => "s", + TimeUnit::Millisecond => "m", + TimeUnit::Microsecond => "u", + TimeUnit::Nanosecond => "n", + }; + return Ok(format!( + "ts{}:{}", + unit, + tz.as_ref().map(|x| x.as_ref()).unwrap_or("") + )); + } + DataType::Decimal(precision, scale) => return Ok(format!("d:{},{}", precision, scale)), DataType::List(_) => "+l", DataType::LargeList(_) => "+L", DataType::Struct(_) => "+s", - z => { - return Err(ArrowError::Ffi(format!( - "The datatype \"{:?}\" is still not supported in Rust implementation", - z - ))) - } + DataType::FixedSizeBinary(size) => return Ok(format!("w{}", size)), + DataType::FixedSizeList(_, size) => return Ok(format!("+w:{}", size)), + DataType::Union(_) => todo!(), + _ => todo!(), } .to_string()) }