diff --git a/arrow-pyarrow-integration-testing/src/lib.rs b/arrow-pyarrow-integration-testing/src/lib.rs index 98bf5a1b62fd..13dee1d691a3 100644 --- a/arrow-pyarrow-integration-testing/src/lib.rs +++ b/arrow-pyarrow-integration-testing/src/lib.rs @@ -29,7 +29,7 @@ use pyo3::{libc::uintptr_t, prelude::*}; use arrow::array::{make_array_from_raw, ArrayRef, Int64Array}; use arrow::compute::kernels; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, Schema}; use arrow::error::ArrowError; use arrow::ffi; use arrow::ffi::FFI_ArrowSchema; @@ -76,6 +76,16 @@ struct PyDataType { inner: DataType, } +#[pyclass] +struct PyField { + inner: Field, +} + +#[pyclass] +struct PySchema { + inner: Schema, +} + #[pymethods] impl PyDataType { #[staticmethod] @@ -84,7 +94,7 @@ impl PyDataType { let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; value.call_method1("_export_to_c", (c_schema_ptr as uintptr_t,))?; let dtype = DataType::try_from(&c_schema).map_err(PyO3ArrowError::from)?; - Ok(PyDataType { inner: dtype }) + Ok(Self { inner: dtype }) } fn to_pyarrow(&self, py: Python) -> PyResult { @@ -98,11 +108,6 @@ impl PyDataType { } } -#[pyclass] -struct PyField { - inner: Field, -} - #[pymethods] impl PyField { #[staticmethod] @@ -111,7 +116,7 @@ impl PyField { let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; value.call_method1("_export_to_c", (c_schema_ptr as uintptr_t,))?; let field = Field::try_from(&c_schema).map_err(PyO3ArrowError::from)?; - Ok(PyField { inner: field }) + Ok(Self { inner: field }) } fn to_pyarrow(&self, py: Python) -> PyResult { @@ -125,6 +130,29 @@ impl PyField { } } +#[pymethods] +impl PySchema { + #[staticmethod] + fn from_pyarrow(value: &PyAny) -> PyResult { + let c_schema = FFI_ArrowSchema::empty(); + let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; + value.call_method1("_export_to_c", (c_schema_ptr as uintptr_t,))?; + let schema = Schema::try_from(&c_schema).map_err(PyO3ArrowError::from)?; + Ok(Self { inner: schema }) + } + + fn to_pyarrow(&self, py: Python) -> PyResult { + let c_schema = + FFI_ArrowSchema::try_from(&self.inner).map_err(PyO3ArrowError::from)?; + let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; + let module = py.import("pyarrow")?; + let class = module.getattr("Schema")?; + let schema = + class.call_method1("_import_from_c", (c_schema_ptr as uintptr_t,))?; + Ok(schema.into()) + } +} + impl<'source> FromPyObject<'source> for PyDataType { fn extract(value: &'source PyAny) -> PyResult { PyDataType::from_pyarrow(value) @@ -137,10 +165,11 @@ impl<'source> FromPyObject<'source> for PyField { } } -// struct PyField(Field); -// struct PySchema(Schema); - -// fn type_to_rust(ob: PyObject, py: Python) -> PyResult { +impl<'source> FromPyObject<'source> for PySchema { + fn extract(value: &'source PyAny) -> PyResult { + PySchema::from_pyarrow(value) + } +} fn array_to_rust(ob: PyObject, py: Python) -> PyResult { // prepare a pointer to receive the Array struct @@ -156,13 +185,12 @@ fn array_to_rust(ob: PyObject, py: Python) -> PyResult { )?; let array = unsafe { make_array_from_raw(array_pointer, schema_pointer) } - .map_err(|e| PyO3ArrowError::from(e))?; + .map_err(PyO3ArrowError::from)?; Ok(array) } fn array_to_py(array: ArrayRef, py: Python) -> PyResult { - let (array_pointer, schema_pointer) = - array.to_raw().map_err(|e| PyO3ArrowError::from(e))?; + let (array_pointer, schema_pointer) = array.to_raw().map_err(PyO3ArrowError::from)?; let pa = py.import("pyarrow")?; @@ -183,15 +211,10 @@ fn double(array: PyObject, py: Python) -> PyResult { let array = array_to_rust(array, py)?; // perform some operation - let array = - array - .as_any() - .downcast_ref::() - .ok_or(PyO3ArrowError::ArrowError(ArrowError::ParseError( - "Expects an int64".to_string(), - )))?; - let array = - kernels::arithmetic::add(&array, &array).map_err(|e| PyO3ArrowError::from(e))?; + let array = array.as_any().downcast_ref::().ok_or_else(|| { + PyO3ArrowError::ArrowError(ArrowError::ParseError("Expects an int64".to_string())) + })?; + let array = kernels::arithmetic::add(&array, &array).map_err(PyO3ArrowError::from)?; let array = Arc::new(array); // export @@ -222,7 +245,7 @@ fn substring(array: PyObject, start: i64, py: Python) -> PyResult { // substring let array = kernels::substring::substring(array.as_ref(), start, &None) - .map_err(|e| PyO3ArrowError::from(e))?; + .map_err(PyO3ArrowError::from)?; // export array_to_py(array, py) @@ -236,7 +259,7 @@ fn concatenate(array: PyObject, py: Python) -> PyResult { // concat let array = kernels::concat::concat(&[array.as_ref(), array.as_ref()]) - .map_err(|e| PyO3ArrowError::from(e))?; + .map_err(PyO3ArrowError::from)?; // export array_to_py(array, py) @@ -256,6 +279,7 @@ fn round_trip(pyarray: PyObject, py: Python) -> PyResult { fn arrow_pyarrow_integration_testing(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_wrapped(wrap_pyfunction!(double))?; m.add_wrapped(wrap_pyfunction!(double_py))?; m.add_wrapped(wrap_pyfunction!(substring))?; diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py b/arrow-pyarrow-integration-testing/tests/test_sql.py index 02f08f7a7502..2bebc18108e2 100644 --- a/arrow-pyarrow-integration-testing/tests/test_sql.py +++ b/arrow-pyarrow-integration-testing/tests/test_sql.py @@ -17,11 +17,12 @@ # under the License. import contextlib +import string import pytest import pyarrow as pa -from arrow_pyarrow_integration_testing import PyDataType, PyField +from arrow_pyarrow_integration_testing import PyDataType, PyField, PySchema import arrow_pyarrow_integration_testing as rust @@ -128,13 +129,24 @@ def test_dictionary_type_roundtrip(): assert ty.to_pyarrow() == pa.int32() -# Missing implementation in pyarrow -# @pytest.mark.parametrize('pyarrow_type', _supported_pyarrow_types, ids=str) -# def test_field_roundtrip(pyarrow_type): -# for nullable in [True, False]: -# pyarrow_field = pa.field("test", pyarrow_type, nullable=nullable) -# field = PyField.from_pyarrow(pyarrow_field) -# assert field.to_pyarrow() == pyarrow_field +@pytest.mark.parametrize('pyarrow_type', _supported_pyarrow_types, ids=str) +def test_field_roundtrip(pyarrow_type): + pyarrow_field = pa.field("test", pyarrow_type, nullable=True) + field = PyField.from_pyarrow(pyarrow_field) + assert field.to_pyarrow() == pyarrow_field + + if pyarrow_type != pa.null(): + # A null type field may not be non-nullable + pyarrow_field = pa.field("test", pyarrow_type, nullable=False) + field = PyField.from_pyarrow(pyarrow_field) + assert field.to_pyarrow() == pyarrow_field + + +def test_schema_roundtrip(): + pyarrow_fields = zip(string.ascii_lowercase, _supported_pyarrow_types) + pyarrow_schema = pa.schema(pyarrow_fields) + schema = PySchema.from_pyarrow(pyarrow_schema) + assert schema.to_pyarrow() == pyarrow_schema def test_primitive_python(): diff --git a/arrow/src/array/ffi.rs b/arrow/src/array/ffi.rs index a404186a8f18..847649ce1264 100644 --- a/arrow/src/array/ffi.rs +++ b/arrow/src/array/ffi.rs @@ -22,7 +22,7 @@ use std::convert::TryFrom; use crate::{ error::{ArrowError, Result}, ffi, - ffi::ArrowArrayRef + ffi::ArrowArrayRef, }; use super::ArrayData; diff --git a/arrow/src/datatypes/ffi.rs b/arrow/src/datatypes/ffi.rs index 46d11c1dbf52..abe23b6f14fc 100644 --- a/arrow/src/datatypes/ffi.rs +++ b/arrow/src/datatypes/ffi.rs @@ -96,9 +96,9 @@ impl TryFrom<&FFI_ArrowSchema> for Schema { if let DataType::Struct(fields) = dtype { Ok(Schema::new(fields)) } else { - Err(ArrowError::CDataInterface(format!( - "Unable to interpret C data struct as a Schema" - ))) + Err(ArrowError::CDataInterface( + "Unable to interpret C data struct as a Schema".to_string(), + )) } } } @@ -234,10 +234,6 @@ mod tests { Ok(()) } - // fn roundtrip>(expected: T) { - // let c_schema: FFI_ArrowSchema = expected.try_into().unwrap(); - // } - #[test] fn test_type() -> Result<()> { round_trip_type(DataType::Int64)?; @@ -290,6 +286,4 @@ mod tests { assert_eq!(result.is_err(), true); Ok(()) } - - // TODO }