From e27dfcb0a0a71c901739d2a0a23a8a6db03ac94c Mon Sep 17 00:00:00 2001 From: Kyle Brooks Date: Tue, 19 Jul 2022 13:41:46 -0400 Subject: [PATCH 1/3] Implement PyArrow Dataset TableProvider and register_dataset context functions. --- Cargo.lock | 2 + Cargo.toml | 2 + datafusion/tests/test_context.py | 17 ++ datafusion/tests/test_sql.py | 12 ++ src/context.rs | 13 ++ src/dataset.rs | 115 ++++++++++++++ src/dataset_exec.rs | 256 +++++++++++++++++++++++++++++++ src/errors.rs | 16 +- src/lib.rs | 3 + src/pyarrow_filter_expression.rs | 196 +++++++++++++++++++++++ 10 files changed, 631 insertions(+), 1 deletion(-) create mode 100644 src/dataset.rs create mode 100644 src/dataset_exec.rs create mode 100644 src/pyarrow_filter_expression.rs diff --git a/Cargo.lock b/Cargo.lock index 37f6b15..bde0b7d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -379,9 +379,11 @@ dependencies = [ name = "datafusion-python" version = "0.6.0" dependencies = [ + "async-trait", "datafusion", "datafusion-common", "datafusion-expr", + "futures", "mimalloc", "pyo3", "rand 0.7.3", diff --git a/Cargo.toml b/Cargo.toml index 5673e38..ac26ca6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,8 @@ datafusion-expr = { version = "^8.0.0" } datafusion-common = { version = "^8.0.0", features = ["pyarrow"] } uuid = { version = "0.8", features = ["v4"] } mimalloc = { version = "*", optional = true, default-features = false } +async-trait = "0.1" +futures = "0.3" [lib] name = "datafusion_python" diff --git a/datafusion/tests/test_context.py b/datafusion/tests/test_context.py index 4d4a38c..a32253d 100644 --- a/datafusion/tests/test_context.py +++ b/datafusion/tests/test_context.py @@ -16,6 +16,7 @@ # under the License. import pyarrow as pa +import pyarrow.dataset as ds def test_register_record_batches(ctx): @@ -72,3 +73,19 @@ def test_deregister_table(ctx, database): ctx.deregister_table("csv") assert public.names() == {"csv1", "csv2"} + +def test_register_dataset(ctx): + # create a RecordBatch and register it as a pyarrow.dataset.Dataset + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], + names=["a", "b"], + ) + dataset = ds.dataset([batch]) + ctx.register_dataset("t", dataset) + + assert ctx.tables() == {"t"} + + result = ctx.sql("SELECT a+b, a-b FROM t").collect() + + assert result[0].column(0) == pa.array([5, 7, 9]) + assert result[0].column(1) == pa.array([-3, -3, -3]) diff --git a/datafusion/tests/test_sql.py b/datafusion/tests/test_sql.py index 38f38ab..fae3a29 100644 --- a/datafusion/tests/test_sql.py +++ b/datafusion/tests/test_sql.py @@ -17,6 +17,7 @@ import numpy as np import pyarrow as pa +import pyarrow.dataset as ds import pytest from datafusion import udf @@ -121,6 +122,17 @@ def test_register_parquet_partitioned(ctx, tmp_path): rd = result.to_pydict() assert dict(zip(rd["grp"], rd["cnt"])) == {"a": 3, "b": 1} +def test_register_dataset(ctx, tmp_path): + path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data()) + dataset = ds.dataset(path, format="parquet") + + ctx.register_dataset("t", dataset) + assert ctx.tables() == {"t"} + + result = ctx.sql("SELECT COUNT(a) AS cnt FROM t").collect() + result = pa.Table.from_batches(result) + assert result.to_pydict() == {"cnt": [100]} + def test_execute(ctx, tmp_path): data = [1, 1, 2, 2, 3, 11, 12] diff --git a/src/context.rs b/src/context.rs index 213703f..d2c17ad 100644 --- a/src/context.rs +++ b/src/context.rs @@ -25,12 +25,14 @@ use pyo3::prelude::*; use datafusion::arrow::datatypes::Schema; use datafusion::arrow::record_batch::RecordBatch; +use datafusion::datasource::datasource::TableProvider; use datafusion::datasource::MemTable; use datafusion::execution::context::SessionContext; use datafusion::prelude::{CsvReadOptions, ParquetReadOptions}; use crate::catalog::{PyCatalog, PyTable}; use crate::dataframe::PyDataFrame; +use crate::dataset::Dataset; use crate::errors::DataFusionError; use crate::udf::PyScalarUDF; use crate::utils::wait_for_future; @@ -173,6 +175,17 @@ impl PySessionContext { Ok(()) } + // Registers a PyArrow.Dataset + fn register_dataset(&self, name: &str, dataset: &PyAny, py: Python) -> PyResult<()> { + let table: Arc = Arc::new(Dataset::new(dataset, py)?); + + self.ctx + .register_table(name, table) + .map_err(DataFusionError::from)?; + + Ok(()) + } + fn register_udf(&mut self, udf: PyScalarUDF) -> PyResult<()> { self.ctx.register_udf(udf.function); Ok(()) diff --git a/src/dataset.rs b/src/dataset.rs new file mode 100644 index 0000000..c8be42d --- /dev/null +++ b/src/dataset.rs @@ -0,0 +1,115 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use pyo3::exceptions::PyValueError; +/// Implements a Datafusion TableProvider that delegates to a PyArrow Dataset +/// This allows us to use PyArrow Datasets as Datafusion tables while pushing down projections and filters +use pyo3::prelude::*; +use pyo3::types::PyType; + +use std::any::Any; +use std::sync::Arc; + +use async_trait::async_trait; + +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::datasource::datasource::TableProviderFilterPushDown; +use datafusion::datasource::{TableProvider, TableType}; +use datafusion::error::{DataFusionError, Result as DFResult}; +use datafusion::logical_plan::*; +use datafusion::physical_plan::ExecutionPlan; + +use crate::dataset_exec::DatasetExec; +use crate::pyarrow_filter_expression::PyArrowFilterExpression; + +// Wraps a pyarrow.dataset.Dataset class and implements a Datafusion TableProvider around it +#[derive(Debug, Clone)] +pub(crate) struct Dataset { + dataset: PyObject, +} + +impl Dataset { + // Creates a Python PyArrow.Dataset + pub fn new(dataset: &PyAny, py: Python) -> PyResult { + // Ensure that we were passed an instance of pyarrow.dataset.Dataset + let ds = PyModule::import(py, "pyarrow.dataset")?; + let ds_type: &PyType = ds.getattr("Dataset")?.downcast()?; + match dataset.is_instance(ds_type)? { + true => Ok(Dataset { + dataset: dataset.into(), + }), + false => Err(PyValueError::new_err( + "dataset argument must be a pyarrow.dataset.Dataset object", + )), + } + } +} + +#[async_trait] +impl TableProvider for Dataset { + /// Returns the table provider as [`Any`](std::any::Any) so that it can be + /// downcast to a specific implementation. + fn as_any(&self) -> &dyn Any { + self + } + + /// Get a reference to the schema for this table + fn schema(&self) -> SchemaRef { + Python::with_gil(|py| { + let dataset = self.dataset.as_ref(py); + // This can panic but since we checked that self.dataset is a pyarrow.dataset.Dataset it should never + Arc::new(dataset.getattr("schema").unwrap().extract().unwrap()) + }) + } + + /// Get the type of this table for metadata/catalog purposes. + fn table_type(&self) -> TableType { + TableType::Base + } + + /// Create an ExecutionPlan that will scan the table. + /// The table provider will be usually responsible of grouping + /// the source data into partitions that can be efficiently + /// parallelized or distributed. + async fn scan( + &self, + projection: &Option>, + filters: &[Expr], + // limit can be used to reduce the amount scanned + // from the datasource as a performance optimization. + // If set, it contains the amount of rows needed by the `LogicalPlan`, + // The datasource should return *at least* this number of rows if available. + _limit: Option, + ) -> DFResult> { + Python::with_gil(|py| { + let plan: Arc = Arc::new( + DatasetExec::new(py, self.dataset.as_ref(py), projection.clone(), filters) + .map_err(|err| DataFusionError::External(Box::new(err)))?, + ); + Ok(plan) + }) + } + + /// Tests whether the table provider can make use of a filter expression + /// to optimise data retrieval. + fn supports_filter_pushdown(&self, filter: &Expr) -> DFResult { + match PyArrowFilterExpression::try_from(filter) { + Ok(_) => Ok(TableProviderFilterPushDown::Exact), + _ => Ok(TableProviderFilterPushDown::Unsupported), + } + } +} diff --git a/src/dataset_exec.rs b/src/dataset_exec.rs new file mode 100644 index 0000000..acd3ffd --- /dev/null +++ b/src/dataset_exec.rs @@ -0,0 +1,256 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Implements a Datafusion physical ExecutionPlan that delegates to a PyArrow Dataset +/// This actually performs the projection, filtering and scanning of a Dataset +use pyo3::prelude::*; +use pyo3::types::{PyDict, PyIterator, PyList}; + +use std::any::Any; +use std::sync::Arc; + +use futures::stream; + +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::error::ArrowError; +use datafusion::arrow::error::Result as ArrowResult; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::error::{DataFusionError as InnerDataFusionError, Result as DFResult}; +use datafusion::execution::context::TaskContext; +use datafusion::logical_plan::{combine_filters, Expr}; +use datafusion::physical_expr::PhysicalSortExpr; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use datafusion::physical_plan::{ + DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, +}; + +use crate::errors::DataFusionError; +use crate::pyarrow_filter_expression::PyArrowFilterExpression; + +struct PyArrowBatchesAdapter { + batches: Py, +} + +impl Iterator for PyArrowBatchesAdapter { + type Item = ArrowResult; + + fn next(&mut self) -> Option { + Python::with_gil(|py| { + let mut batches: &PyIterator = self.batches.as_ref(py); + Some( + batches + .next()? + .and_then(|batch| batch.extract()) + .map_err(|err| ArrowError::ExternalError(Box::new(err))), + ) + }) + } +} + +// Wraps a pyarrow.dataset.Dataset class and implements a Datafusion ExecutionPlan around it +#[derive(Debug, Clone)] +pub(crate) struct DatasetExec { + dataset: PyObject, + schema: SchemaRef, + fragments: Py, + columns: Option>, + filter_expr: Option, + projected_statistics: Statistics, +} + +impl DatasetExec { + pub fn new( + py: Python, + dataset: &PyAny, + projection: Option>, + filters: &[Expr], + ) -> Result { + let columns: Option, DataFusionError>> = projection.map(|p| { + p.iter() + .map(|index| { + let name: String = dataset + .getattr("schema")? + .call_method1("field", (*index,))? + .getattr("name")? + .extract()?; + Ok(name) + }) + .collect() + }); + let columns: Option> = columns.transpose()?; + let filter_expr: Option = combine_filters(filters) + .map(|filters| { + PyArrowFilterExpression::try_from(&filters) + .map(|filter_expr| filter_expr.inner().clone_ref(py)) + }) + .transpose()?; + + let kwargs = PyDict::new(py); + + kwargs.set_item("columns", columns.clone())?; + kwargs.set_item( + "filter", + filter_expr.as_ref().map(|expr| expr.clone_ref(py)), + )?; + + let scanner = dataset.call_method("scanner", (), Some(kwargs))?; + + let schema = Arc::new(scanner.getattr("projected_schema")?.extract()?); + + let builtins = Python::import(py, "builtins")?; + let pylist = builtins.getattr("list")?; + + // Get the fragments or partitions of the dataset + let fragments_iterator: &PyAny = dataset.call_method1( + "get_fragments", + (filter_expr.as_ref().map(|expr| expr.clone_ref(py)),), + )?; + + let fragments: &PyList = pylist + .call1((fragments_iterator,))? + .downcast() + .map_err(PyErr::from)?; + + Ok(DatasetExec { + dataset: dataset.into(), + schema, + fragments: fragments.into(), + columns, + filter_expr, + projected_statistics: Default::default(), + }) + } +} + +impl ExecutionPlan for DatasetExec { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + /// Get the schema for this execution plan + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + /// Get the output partitioning of this plan + fn output_partitioning(&self) -> Partitioning { + Python::with_gil(|py| { + let fragments = self.fragments.as_ref(py); + Partitioning::UnknownPartitioning(fragments.len()) + }) + } + + fn relies_on_input_order(&self) -> bool { + false + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + None + } + + fn children(&self) -> Vec> { + // this is a leaf node and has no children + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> DFResult> { + Ok(self) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> DFResult { + let batch_size = context.session_config().batch_size; + Python::with_gil(|py| { + let dataset = self.dataset.as_ref(py); + let fragments = self.fragments.as_ref(py); + let fragment = fragments + .get_item(partition) + .map_err(|err| InnerDataFusionError::External(Box::new(err)))?; + + // We need to pass the dataset schema to unify the fragment and dataset schema per PyArrow docs + let dataset_schema = dataset + .getattr("schema") + .map_err(|err| InnerDataFusionError::External(Box::new(err)))?; + let kwargs = PyDict::new(py); + kwargs + .set_item("columns", self.columns.clone()) + .map_err(|err| InnerDataFusionError::External(Box::new(err)))?; + kwargs + .set_item( + "filter", + self.filter_expr.as_ref().map(|expr| expr.clone_ref(py)), + ) + .map_err(|err| InnerDataFusionError::External(Box::new(err)))?; + kwargs + .set_item("batch_size", batch_size) + .map_err(|err| InnerDataFusionError::External(Box::new(err)))?; + let scanner = fragment + .call_method("scanner", (dataset_schema,), Some(kwargs)) + .map_err(|err| InnerDataFusionError::External(Box::new(err)))?; + let schema: SchemaRef = Arc::new( + scanner + .getattr("projected_schema") + .and_then(|schema| schema.extract()) + .map_err(|err| InnerDataFusionError::External(Box::new(err)))?, + ); + let record_batches: &PyIterator = scanner + .call_method0("to_batches") + .map_err(|err| InnerDataFusionError::External(Box::new(err)))? + .iter() + .map_err(|err| InnerDataFusionError::External(Box::new(err)))?; + + let record_batches = PyArrowBatchesAdapter { + batches: record_batches.into(), + }; + + let record_batch_stream = stream::iter(record_batches); + let record_batch_stream: SendableRecordBatchStream = + Box::pin(RecordBatchStreamAdapter::new(schema, record_batch_stream)); + Ok(record_batch_stream) + }) + } + + fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { + Python::with_gil(|py| { + let fragments = self.fragments.as_ref(py); + let files: Result, PyErr> = fragments + .iter() + .map(|fragment| -> Result { fragment.extract() }) + .collect(); + match t { + DisplayFormatType::Default => { + write!( + f, + "DatasetExec: files={:?}, projection={:?}", + files, self.columns, + ) + } + } + }) + } + + fn statistics(&self) -> Statistics { + self.projected_statistics.clone() + } +} diff --git a/src/errors.rs b/src/errors.rs index 655ed84..29d3e8f 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -16,6 +16,7 @@ // under the License. use core::fmt; +use std::error::Error; use datafusion::arrow::error::ArrowError; use datafusion::error::DataFusionError as InnerDataFusionError; @@ -26,6 +27,7 @@ pub enum DataFusionError { ExecutionError(InnerDataFusionError), ArrowError(ArrowError), Common(String), + PythonError(PyErr), } impl fmt::Display for DataFusionError { @@ -33,6 +35,7 @@ impl fmt::Display for DataFusionError { match self { DataFusionError::ExecutionError(e) => write!(f, "DataFusion error: {:?}", e), DataFusionError::ArrowError(e) => write!(f, "Arrow error: {:?}", e), + DataFusionError::PythonError(e) => write!(f, "Python error {:?}", e), DataFusionError::Common(e) => write!(f, "{}", e), } } @@ -50,8 +53,19 @@ impl From for DataFusionError { } } +impl From for DataFusionError { + fn from(err: PyErr) -> DataFusionError { + DataFusionError::PythonError(err) + } +} + impl From for PyErr { fn from(err: DataFusionError) -> PyErr { - PyException::new_err(err.to_string()) + match err { + DataFusionError::PythonError(py_err) => py_err, + _ => PyException::new_err(err.to_string()), + } } } + +impl Error for DataFusionError {} diff --git a/src/lib.rs b/src/lib.rs index 25b63e8..c6ab58e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,9 +22,12 @@ use pyo3::prelude::*; pub mod catalog; mod context; mod dataframe; +mod dataset; +mod dataset_exec; pub mod errors; mod expression; mod functions; +mod pyarrow_filter_expression; mod udaf; mod udf; pub mod utils; diff --git a/src/pyarrow_filter_expression.rs b/src/pyarrow_filter_expression.rs new file mode 100644 index 0000000..69eafa2 --- /dev/null +++ b/src/pyarrow_filter_expression.rs @@ -0,0 +1,196 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Converts a Datafusion logical plan expression (Expr) into a PyArrow compute expression +use pyo3::prelude::*; + +use std::convert::TryFrom; +use std::result::Result; + +use datafusion::logical_plan::*; +use datafusion_common::ScalarValue; + +use crate::errors::DataFusionError; + +#[derive(Debug, Clone)] +#[repr(transparent)] +pub(crate) struct PyArrowFilterExpression(PyObject); + +fn operator_to_py<'py>( + operator: &Operator, + op: &'py PyModule, +) -> Result<&'py PyAny, DataFusionError> { + let py_op: &PyAny = match operator { + Operator::Eq => op.getattr("eq")?, + Operator::NotEq => op.getattr("ne")?, + Operator::Lt => op.getattr("lt")?, + Operator::LtEq => op.getattr("le")?, + Operator::Gt => op.getattr("gt")?, + Operator::GtEq => op.getattr("ge")?, + Operator::And => op.getattr("and_")?, + Operator::Or => op.getattr("or_")?, + _ => { + return Err(DataFusionError::Common(format!( + "Unsupported operator {:?}", + operator + ))) + } + }; + Ok(py_op) +} + +fn extract_scalar_list(exprs: &[Expr], py: Python) -> Result, DataFusionError> { + let ret: Result, DataFusionError> = exprs + .iter() + .map(|expr| match expr { + Expr::Literal(v) => match v { + ScalarValue::Boolean(Some(b)) => Ok(b.into_py(py)), + ScalarValue::Int8(Some(i)) => Ok(i.into_py(py)), + ScalarValue::Int16(Some(i)) => Ok(i.into_py(py)), + ScalarValue::Int32(Some(i)) => Ok(i.into_py(py)), + ScalarValue::Int64(Some(i)) => Ok(i.into_py(py)), + ScalarValue::UInt8(Some(i)) => Ok(i.into_py(py)), + ScalarValue::UInt16(Some(i)) => Ok(i.into_py(py)), + ScalarValue::UInt32(Some(i)) => Ok(i.into_py(py)), + ScalarValue::UInt64(Some(i)) => Ok(i.into_py(py)), + ScalarValue::Float32(Some(f)) => Ok(f.into_py(py)), + ScalarValue::Float64(Some(f)) => Ok(f.into_py(py)), + ScalarValue::Utf8(Some(s)) => Ok(s.into_py(py)), + _ => Err(DataFusionError::Common(format!( + "PyArrow can't handle ScalarValue: {:?}", + v + ))), + }, + _ => Err(DataFusionError::Common(format!( + "Only a list of Literals are supported got {:?}", + expr + ))), + }) + .collect(); + ret +} + +impl PyArrowFilterExpression { + pub fn inner(&self) -> &PyObject { + &self.0 + } +} + +impl TryFrom<&Expr> for PyArrowFilterExpression { + type Error = DataFusionError; + + // Converts a Datafusion filter Expr into an expression string that can be evaluated by Python + // Note that pyarrow.compute.{field,scalar} are put into Python globals() when evaluated + // isin, is_null, and is_valid (~is_null) are methods of pyarrow.dataset.Expression + // https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Expression.html#pyarrow-dataset-expression + fn try_from(expr: &Expr) -> Result { + Python::with_gil(|py| { + let pc = Python::import(py, "pyarrow.compute")?; + let op_module = Python::import(py, "operator")?; + let pc_expr: Result<&PyAny, DataFusionError> = match expr { + Expr::Column(Column { name, .. }) => Ok(pc.getattr("field")?.call1((name,))?), + Expr::Literal(v) => match v { + ScalarValue::Boolean(Some(b)) => Ok(pc.getattr("scalar")?.call1((*b,))?), + ScalarValue::Int8(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?), + ScalarValue::Int16(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?), + ScalarValue::Int32(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?), + ScalarValue::Int64(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?), + ScalarValue::UInt8(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?), + ScalarValue::UInt16(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?), + ScalarValue::UInt32(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?), + ScalarValue::UInt64(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?), + ScalarValue::Float32(Some(f)) => Ok(pc.getattr("scalar")?.call1((*f,))?), + ScalarValue::Float64(Some(f)) => Ok(pc.getattr("scalar")?.call1((*f,))?), + ScalarValue::Utf8(Some(s)) => Ok(pc.getattr("scalar")?.call1((s,))?), + _ => Err(DataFusionError::Common(format!( + "PyArrow can't handle ScalarValue: {:?}", + v + ))), + }, + Expr::BinaryExpr { left, op, right } => { + let operator = operator_to_py(op, op_module)?; + let left = PyArrowFilterExpression::try_from(left.as_ref())?.0; + let right = PyArrowFilterExpression::try_from(right.as_ref())?.0; + Ok(operator.call1((left, right))?) + } + Expr::Not(expr) => { + let operator = op_module.getattr("invert")?; + let py_expr = PyArrowFilterExpression::try_from(expr.as_ref())?.0; + Ok(operator.call1((py_expr,))?) + } + Expr::IsNotNull(expr) => { + let py_expr = PyArrowFilterExpression::try_from(expr.as_ref())? + .0 + .into_ref(py); + Ok(py_expr.call_method0("is_valid")?) + } + Expr::IsNull(expr) => { + let expr = PyArrowFilterExpression::try_from(expr.as_ref())? + .0 + .into_ref(py); + Ok(expr.call_method1("is_null", (expr,))?) + } + Expr::Between { + expr, + negated, + low, + high, + } => { + let expr = PyArrowFilterExpression::try_from(expr.as_ref())?.0; + let low = PyArrowFilterExpression::try_from(low.as_ref())?.0; + let high = PyArrowFilterExpression::try_from(high.as_ref())?.0; + let and = op_module.getattr("and_")?; + let le = op_module.getattr("le")?; + let invert = op_module.getattr("invert")?; + + // You can't do scalar <= field() <= scalar in PyArrow, no idea why + let ret = and.call1(( + le.call1((low, expr.clone_ref(py)))?, + le.call1((expr, high))?, + ))?; + + Ok(match negated { + true => invert.call1((ret,))?, + false => ret, + }) + } + Expr::InList { + expr, + list, + negated, + } => { + let expr = PyArrowFilterExpression::try_from(expr.as_ref())? + .0 + .into_ref(py); + let scalars = extract_scalar_list(list, py)?; + let ret = expr.call_method1("isin", (scalars,))?; + let invert = op_module.getattr("invert")?; + + Ok(match negated { + true => invert.call1((ret,))?, + false => ret, + }) + } + _ => Err(DataFusionError::Common(format!( + "Unsupported Datafusion expression {:?}", + expr + ))), + }; + Ok(PyArrowFilterExpression(pc_expr?.into())) + }) + } +} From 128f47830640cf302ca27a02f2643d1e771463b9 Mon Sep 17 00:00:00 2001 From: Kyle Brooks Date: Wed, 20 Jul 2022 15:23:29 -0400 Subject: [PATCH 2/3] Add dataset filter test. --- datafusion/tests/test_context.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/datafusion/tests/test_context.py b/datafusion/tests/test_context.py index a32253d..8f9f45a 100644 --- a/datafusion/tests/test_context.py +++ b/datafusion/tests/test_context.py @@ -89,3 +89,19 @@ def test_register_dataset(ctx): assert result[0].column(0) == pa.array([5, 7, 9]) assert result[0].column(1) == pa.array([-3, -3, -3]) + +def test_dataset_filter(ctx): + # create a RecordBatch and register it as a pyarrow.dataset.Dataset + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], + names=["a", "b"], + ) + dataset = ds.dataset([batch]) + ctx.register_dataset("t", dataset) + + assert ctx.tables() == {"t"} + + result = ctx.sql("SELECT a+b, a-b FROM t WHERE a BETWEEN 2 and 3 AND b > 5").collect() + + assert result[0].column(0) == pa.array([9]) + assert result[0].column(1) == pa.array([-3]) From 007e4d2c91b2c05e11c52d87bdb0879ad6621179 Mon Sep 17 00:00:00 2001 From: Kyle Brooks Date: Thu, 21 Jul 2022 11:18:36 -0400 Subject: [PATCH 3/3] Change match on booleans to if else. --- src/dataset.rs | 11 ++++++----- src/pyarrow_filter_expression.rs | 12 +++--------- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/src/dataset.rs b/src/dataset.rs index c8be42d..f03bc7b 100644 --- a/src/dataset.rs +++ b/src/dataset.rs @@ -48,13 +48,14 @@ impl Dataset { // Ensure that we were passed an instance of pyarrow.dataset.Dataset let ds = PyModule::import(py, "pyarrow.dataset")?; let ds_type: &PyType = ds.getattr("Dataset")?.downcast()?; - match dataset.is_instance(ds_type)? { - true => Ok(Dataset { + if dataset.is_instance(ds_type)? { + Ok(Dataset { dataset: dataset.into(), - }), - false => Err(PyValueError::new_err( + }) + } else { + Err(PyValueError::new_err( "dataset argument must be a pyarrow.dataset.Dataset object", - )), + )) } } } diff --git a/src/pyarrow_filter_expression.rs b/src/pyarrow_filter_expression.rs index 69eafa2..3807553 100644 --- a/src/pyarrow_filter_expression.rs +++ b/src/pyarrow_filter_expression.rs @@ -157,16 +157,13 @@ impl TryFrom<&Expr> for PyArrowFilterExpression { let le = op_module.getattr("le")?; let invert = op_module.getattr("invert")?; - // You can't do scalar <= field() <= scalar in PyArrow, no idea why + // scalar <= field() returns a boolean expression so we need to use and to combine these let ret = and.call1(( le.call1((low, expr.clone_ref(py)))?, le.call1((expr, high))?, ))?; - Ok(match negated { - true => invert.call1((ret,))?, - false => ret, - }) + Ok(if *negated { invert.call1((ret,))? } else { ret }) } Expr::InList { expr, @@ -180,10 +177,7 @@ impl TryFrom<&Expr> for PyArrowFilterExpression { let ret = expr.call_method1("isin", (scalars,))?; let invert = op_module.getattr("invert")?; - Ok(match negated { - true => invert.call1((ret,))?, - false => ret, - }) + Ok(if *negated { invert.call1((ret,))? } else { ret }) } _ => Err(DataFusionError::Common(format!( "Unsupported Datafusion expression {:?}",