Skip to content

Commit

Permalink
Python tests for Field and Schema
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Jun 17, 2021
1 parent 36268ee commit 7a1c845
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 44 deletions.
76 changes: 50 additions & 26 deletions arrow-pyarrow-integration-testing/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use pyo3::{libc::uintptr_t, prelude::*};

use arrow::array::{make_array_from_raw, ArrayRef, Int64Array};
use arrow::compute::kernels;
use arrow::datatypes::{DataType, Field};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::error::ArrowError;
use arrow::ffi;
use arrow::ffi::FFI_ArrowSchema;
Expand Down Expand Up @@ -76,6 +76,16 @@ struct PyDataType {
inner: DataType,
}

#[pyclass]
struct PyField {
inner: Field,
}

#[pyclass]
struct PySchema {
inner: Schema,
}

#[pymethods]
impl PyDataType {
#[staticmethod]
Expand All @@ -84,7 +94,7 @@ impl PyDataType {
let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
value.call_method1("_export_to_c", (c_schema_ptr as uintptr_t,))?;
let dtype = DataType::try_from(&c_schema).map_err(PyO3ArrowError::from)?;
Ok(PyDataType { inner: dtype })
Ok(Self { inner: dtype })
}

fn to_pyarrow(&self, py: Python) -> PyResult<PyObject> {
Expand All @@ -98,11 +108,6 @@ impl PyDataType {
}
}

#[pyclass]
struct PyField {
inner: Field,
}

#[pymethods]
impl PyField {
#[staticmethod]
Expand All @@ -111,7 +116,7 @@ impl PyField {
let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
value.call_method1("_export_to_c", (c_schema_ptr as uintptr_t,))?;
let field = Field::try_from(&c_schema).map_err(PyO3ArrowError::from)?;
Ok(PyField { inner: field })
Ok(Self { inner: field })
}

fn to_pyarrow(&self, py: Python) -> PyResult<PyObject> {
Expand All @@ -125,6 +130,29 @@ impl PyField {
}
}

#[pymethods]
impl PySchema {
#[staticmethod]
fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
let c_schema = FFI_ArrowSchema::empty();
let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
value.call_method1("_export_to_c", (c_schema_ptr as uintptr_t,))?;
let schema = Schema::try_from(&c_schema).map_err(PyO3ArrowError::from)?;
Ok(Self { inner: schema })
}

fn to_pyarrow(&self, py: Python) -> PyResult<PyObject> {
let c_schema =
FFI_ArrowSchema::try_from(&self.inner).map_err(PyO3ArrowError::from)?;
let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
let module = py.import("pyarrow")?;
let class = module.getattr("Schema")?;
let schema =
class.call_method1("_import_from_c", (c_schema_ptr as uintptr_t,))?;
Ok(schema.into())
}
}

impl<'source> FromPyObject<'source> for PyDataType {
fn extract(value: &'source PyAny) -> PyResult<Self> {
PyDataType::from_pyarrow(value)
Expand All @@ -137,10 +165,11 @@ impl<'source> FromPyObject<'source> for PyField {
}
}

// struct PyField(Field);
// struct PySchema(Schema);

// fn type_to_rust(ob: PyObject, py: Python) -> PyResult<ArrayRef> {
impl<'source> FromPyObject<'source> for PySchema {
fn extract(value: &'source PyAny) -> PyResult<Self> {
PySchema::from_pyarrow(value)
}
}

fn array_to_rust(ob: PyObject, py: Python) -> PyResult<ArrayRef> {
// prepare a pointer to receive the Array struct
Expand All @@ -156,13 +185,12 @@ fn array_to_rust(ob: PyObject, py: Python) -> PyResult<ArrayRef> {
)?;

let array = unsafe { make_array_from_raw(array_pointer, schema_pointer) }
.map_err(|e| PyO3ArrowError::from(e))?;
.map_err(PyO3ArrowError::from)?;
Ok(array)
}

fn array_to_py(array: ArrayRef, py: Python) -> PyResult<PyObject> {
let (array_pointer, schema_pointer) =
array.to_raw().map_err(|e| PyO3ArrowError::from(e))?;
let (array_pointer, schema_pointer) = array.to_raw().map_err(PyO3ArrowError::from)?;

let pa = py.import("pyarrow")?;

Expand All @@ -183,15 +211,10 @@ fn double(array: PyObject, py: Python) -> PyResult<PyObject> {
let array = array_to_rust(array, py)?;

// perform some operation
let array =
array
.as_any()
.downcast_ref::<Int64Array>()
.ok_or(PyO3ArrowError::ArrowError(ArrowError::ParseError(
"Expects an int64".to_string(),
)))?;
let array =
kernels::arithmetic::add(&array, &array).map_err(|e| PyO3ArrowError::from(e))?;
let array = array.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
PyO3ArrowError::ArrowError(ArrowError::ParseError("Expects an int64".to_string()))
})?;
let array = kernels::arithmetic::add(&array, &array).map_err(PyO3ArrowError::from)?;
let array = Arc::new(array);

// export
Expand Down Expand Up @@ -222,7 +245,7 @@ fn substring(array: PyObject, start: i64, py: Python) -> PyResult<PyObject> {

// substring
let array = kernels::substring::substring(array.as_ref(), start, &None)
.map_err(|e| PyO3ArrowError::from(e))?;
.map_err(PyO3ArrowError::from)?;

// export
array_to_py(array, py)
Expand All @@ -236,7 +259,7 @@ fn concatenate(array: PyObject, py: Python) -> PyResult<PyObject> {

// concat
let array = kernels::concat::concat(&[array.as_ref(), array.as_ref()])
.map_err(|e| PyO3ArrowError::from(e))?;
.map_err(PyO3ArrowError::from)?;

// export
array_to_py(array, py)
Expand All @@ -256,6 +279,7 @@ fn round_trip(pyarray: PyObject, py: Python) -> PyResult<PyObject> {
fn arrow_pyarrow_integration_testing(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<PyDataType>()?;
m.add_class::<PyField>()?;
m.add_class::<PySchema>()?;
m.add_wrapped(wrap_pyfunction!(double))?;
m.add_wrapped(wrap_pyfunction!(double_py))?;
m.add_wrapped(wrap_pyfunction!(substring))?;
Expand Down
28 changes: 20 additions & 8 deletions arrow-pyarrow-integration-testing/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
# under the License.

import contextlib
import string

import pytest
import pyarrow as pa

from arrow_pyarrow_integration_testing import PyDataType, PyField
from arrow_pyarrow_integration_testing import PyDataType, PyField, PySchema
import arrow_pyarrow_integration_testing as rust


Expand Down Expand Up @@ -128,13 +129,24 @@ def test_dictionary_type_roundtrip():
assert ty.to_pyarrow() == pa.int32()


# Missing implementation in pyarrow
# @pytest.mark.parametrize('pyarrow_type', _supported_pyarrow_types, ids=str)
# def test_field_roundtrip(pyarrow_type):
# for nullable in [True, False]:
# pyarrow_field = pa.field("test", pyarrow_type, nullable=nullable)
# field = PyField.from_pyarrow(pyarrow_field)
# assert field.to_pyarrow() == pyarrow_field
@pytest.mark.parametrize('pyarrow_type', _supported_pyarrow_types, ids=str)
def test_field_roundtrip(pyarrow_type):
pyarrow_field = pa.field("test", pyarrow_type, nullable=True)
field = PyField.from_pyarrow(pyarrow_field)
assert field.to_pyarrow() == pyarrow_field

if pyarrow_type != pa.null():
# A null type field may not be non-nullable
pyarrow_field = pa.field("test", pyarrow_type, nullable=False)
field = PyField.from_pyarrow(pyarrow_field)
assert field.to_pyarrow() == pyarrow_field


def test_schema_roundtrip():
pyarrow_fields = zip(string.ascii_lowercase, _supported_pyarrow_types)
pyarrow_schema = pa.schema(pyarrow_fields)
schema = PySchema.from_pyarrow(pyarrow_schema)
assert schema.to_pyarrow() == pyarrow_schema


def test_primitive_python():
Expand Down
2 changes: 1 addition & 1 deletion arrow/src/array/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use std::convert::TryFrom;
use crate::{
error::{ArrowError, Result},
ffi,
ffi::ArrowArrayRef
ffi::ArrowArrayRef,
};

use super::ArrayData;
Expand Down
12 changes: 3 additions & 9 deletions arrow/src/datatypes/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ impl TryFrom<&FFI_ArrowSchema> for Schema {
if let DataType::Struct(fields) = dtype {
Ok(Schema::new(fields))
} else {
Err(ArrowError::CDataInterface(format!(
"Unable to interpret C data struct as a Schema"
)))
Err(ArrowError::CDataInterface(
"Unable to interpret C data struct as a Schema".to_string(),
))
}
}
}
Expand Down Expand Up @@ -234,10 +234,6 @@ mod tests {
Ok(())
}

// fn roundtrip<T: TryInto<FFI_ArrowSchema>>(expected: T) {
// let c_schema: FFI_ArrowSchema = expected.try_into().unwrap();
// }

#[test]
fn test_type() -> Result<()> {
round_trip_type(DataType::Int64)?;
Expand Down Expand Up @@ -290,6 +286,4 @@ mod tests {
assert_eq!(result.is_err(), true);
Ok(())
}

// TODO
}

0 comments on commit 7a1c845

Please sign in to comment.