From 34131f4e835f16530dc491fbc301258a476bce18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Tue, 21 Sep 2021 19:25:15 +0200 Subject: [PATCH 01/21] Squash --- .../core/src/serde/logical_plan/from_proto.rs | 21 +- datafusion/Cargo.toml | 2 + datafusion/src/execution/context.rs | 9 + datafusion/src/lib.rs | 4 + datafusion/src/logical_plan/builder.rs | 25 +- datafusion/src/logical_plan/expr.rs | 12 +- .../src/optimizer/projection_push_down.rs | 4 +- datafusion/src/pyarrow.rs | 67 +++ datafusion/src/sql/planner.rs | 3 +- python/Cargo.lock | 83 ++-- python/Cargo.toml | 10 +- python/datafusion/__init__.py | 18 + python/{ => datafusion}/tests/__init__.py | 0 python/{ => datafusion}/tests/generic.py | 0 python/datafusion/tests/test_catalog.py | 72 +++ .../tests/test_dataframe.py} | 29 +- .../tests/test_math_functions.py | 1 + .../{ => datafusion}/tests/test_pa_types.py | 0 python/{ => datafusion}/tests/test_sql.py | 28 +- .../tests/test_string_functions.py | 1 + python/{ => datafusion}/tests/test_udaf.py | 1 + python/pyproject.toml | 3 + python/src/catalog.rs | 123 +++++ python/src/context.rs | 151 +++--- python/src/dataframe.rs | 196 +++----- python/src/errors.rs | 19 +- python/src/expression.rs | 203 ++++---- python/src/functions.rs | 442 +++++++++--------- python/src/lib.rs | 17 +- python/src/scalar.rs | 36 -- python/src/to_py.rs | 75 --- python/src/to_rust.rs | 122 ----- python/src/types.rs | 65 --- python/src/udaf.rs | 42 +- python/src/udf.rs | 18 +- 35 files changed, 901 insertions(+), 1001 deletions(-) create mode 100644 datafusion/src/pyarrow.rs create mode 100644 python/datafusion/__init__.py rename python/{ => datafusion}/tests/__init__.py (100%) rename python/{ => datafusion}/tests/generic.py (100%) create mode 100644 python/datafusion/tests/test_catalog.py rename python/{tests/test_df.py => datafusion/tests/test_dataframe.py} (82%) rename python/{ => datafusion}/tests/test_math_functions.py (99%) rename python/{ => datafusion}/tests/test_pa_types.py (100%) rename python/{ => datafusion}/tests/test_sql.py (88%) rename python/{ => datafusion}/tests/test_string_functions.py (99%) rename python/{ => datafusion}/tests/test_udaf.py (99%) create mode 100644 python/src/catalog.rs delete mode 100644 python/src/scalar.rs delete mode 100644 python/src/to_py.rs delete mode 100644 python/src/to_rust.rs delete mode 100644 python/src/types.rs diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 259fcb3482a7..66bec782e609 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -106,14 +106,15 @@ impl TryInto for &protobuf::LogicalPlanNode { } LogicalPlanType::Selection(selection) => { let input: LogicalPlan = convert_box_required!(selection.input)?; + let expr: Expr = selection + .expr + .as_ref() + .ok_or_else(|| { + BallistaError::General("expression required".to_string()) + })? + .try_into()?; LogicalPlanBuilder::from(input) - .filter( - selection - .expr - .as_ref() - .expect("expression required") - .try_into()?, - )? + .filter(expr)? .build() .map_err(|e| e.into()) } @@ -123,7 +124,7 @@ impl TryInto for &protobuf::LogicalPlanNode { .window_expr .iter() .map(|expr| expr.try_into()) - .collect::, _>>()?; + .collect::, _>>()?; LogicalPlanBuilder::from(input) .window(window_expr)? .build() @@ -135,12 +136,12 @@ impl TryInto for &protobuf::LogicalPlanNode { .group_expr .iter() .map(|expr| expr.try_into()) - .collect::, _>>()?; + .collect::, _>>()?; let aggr_expr = aggregate .aggr_expr .iter() .map(|expr| expr.try_into()) - .collect::, _>>()?; + .collect::, _>>()?; LogicalPlanBuilder::from(input) .aggregate(group_expr, aggr_expr)? .build() diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index 793262a031bd..8aac711facc7 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -43,6 +43,7 @@ simd = ["arrow/simd"] crypto_expressions = ["md-5", "sha2", "blake2", "blake3"] regex_expressions = ["regex"] unicode_expressions = ["unicode-segmentation"] +pyarrow = ["pyo3", "arrow/pyarrow"] # Used for testing ONLY: causes all values to hash to the same value (test for collisions) force_hash_collisions = [] # Used to enable the avro format @@ -75,6 +76,7 @@ smallvec = { version = "1.6", features = ["union"] } rand = "0.8" avro-rs = { version = "0.13", features = ["snappy"], optional = true } num-traits = { version = "0.2", optional = true } +pyo3 = { version = "0.14", optional = true } [dev-dependencies] criterion = "0.3" diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 9be5038f47c9..23ebba28d92e 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -329,6 +329,14 @@ impl ExecutionContext { ))) } + /// Creates an empty DataFrame. + pub fn read_empty(&self) -> Result> { + Ok(Arc::new(DataFrameImpl::new( + self.state.clone(), + &LogicalPlanBuilder::empty(true).build()?, + ))) + } + /// Creates a DataFrame for reading a CSV data source. pub async fn read_csv( &mut self, @@ -565,6 +573,7 @@ impl ExecutionContext { /// register_table function. /// /// Returns an error if no table has been registered with the provided reference. + /// NOTE(kszucs): perhaps it should be called dataframe() instead? pub fn table<'a>( &self, table_ref: impl Into>, diff --git a/datafusion/src/lib.rs b/datafusion/src/lib.rs index 2159864d10fd..4f4cd664fd41 100644 --- a/datafusion/src/lib.rs +++ b/datafusion/src/lib.rs @@ -232,6 +232,10 @@ pub use arrow; pub use parquet; pub(crate) mod field_util; + +#[cfg(feature = "pyarrow")] +mod pyarrow; + #[cfg(test)] pub mod test; pub mod test_util; diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index 693bf78fbe0e..dcbddca89cd7 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -426,14 +426,14 @@ impl LogicalPlanBuilder { Ok(plan) } /// Apply a projection without alias. - pub fn project(&self, expr: impl IntoIterator) -> Result { + pub fn project(&self, expr: impl IntoIterator>) -> Result { self.project_with_alias(expr, None) } /// Apply a projection with alias pub fn project_with_alias( &self, - expr: impl IntoIterator, + expr: impl IntoIterator>, alias: Option, ) -> Result { Ok(Self::from(project_with_alias( @@ -444,8 +444,8 @@ impl LogicalPlanBuilder { } /// Apply a filter - pub fn filter(&self, expr: Expr) -> Result { - let expr = normalize_col(expr, &self.plan)?; + pub fn filter(&self, expr: impl Into) -> Result { + let expr = normalize_col(expr.into(), &self.plan)?; Ok(Self::from(LogicalPlan::Filter { predicate: expr, input: Arc::new(self.plan.clone()), @@ -461,7 +461,7 @@ impl LogicalPlanBuilder { } /// Apply a sort - pub fn sort(&self, exprs: impl IntoIterator) -> Result { + pub fn sort(&self, exprs: impl IntoIterator>) -> Result { Ok(Self::from(LogicalPlan::Sort { expr: normalize_cols(exprs, &self.plan)?, input: Arc::new(self.plan.clone()), @@ -629,8 +629,11 @@ impl LogicalPlanBuilder { } /// Apply a window functions to extend the schema - pub fn window(&self, window_expr: impl IntoIterator) -> Result { - let window_expr = window_expr.into_iter().collect::>(); + pub fn window( + &self, + window_expr: impl IntoIterator>, + ) -> Result { + let window_expr = normalize_cols(window_expr, &self.plan)?; let all_expr = window_expr.iter(); validate_unique_names("Windows", all_expr.clone(), self.plan.schema())?; let mut window_fields: Vec = @@ -648,8 +651,8 @@ impl LogicalPlanBuilder { /// value of the `group_expr`; pub fn aggregate( &self, - group_expr: impl IntoIterator, - aggr_expr: impl IntoIterator, + group_expr: impl IntoIterator>, + aggr_expr: impl IntoIterator>, ) -> Result { let group_expr = normalize_cols(group_expr, &self.plan)?; let aggr_expr = normalize_cols(aggr_expr, &self.plan)?; @@ -796,13 +799,13 @@ pub fn union_with_alias( /// * An invalid expression is used (e.g. a `sort` expression) pub fn project_with_alias( plan: LogicalPlan, - expr: impl IntoIterator, + expr: impl IntoIterator>, alias: Option, ) -> Result { let input_schema = plan.schema(); let mut projected_expr = vec![]; for e in expr { - match e { + match e.into() { Expr::Wildcard => { projected_expr.extend(expand_wildcard(input_schema, &plan)?) } diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 499a8c720dba..19e6fe36c7d6 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -540,6 +540,9 @@ impl Expr { /// This function errors when it is impossible to cast the /// expression to the target [arrow::datatypes::DataType]. pub fn cast_to(self, cast_to_type: &DataType, schema: &DFSchema) -> Result { + // TODO(kszucs): most of the operations do not validate the type correctness + // like all of the binary expressions below. Perhaps Expr should track the + // type of the expression? let this_type = self.get_type(schema)?; if this_type == *cast_to_type { Ok(self) @@ -1305,10 +1308,13 @@ fn normalize_col_with_schemas( /// Recursively normalize all Column expressions in a list of expression trees #[inline] pub fn normalize_cols( - exprs: impl IntoIterator, + exprs: impl IntoIterator>, plan: &LogicalPlan, ) -> Result> { - exprs.into_iter().map(|e| normalize_col(e, plan)).collect() + exprs + .into_iter() + .map(|e| normalize_col(e.into(), plan)) + .collect() } /// Recursively 'unnormalize' (remove all qualifiers) from an @@ -1544,6 +1550,8 @@ pub fn approx_distinct(expr: Expr) -> Expr { } } +// TODO(kszucs): this seems buggy, unary_scalar_expr! is used for many +// varying arity functions /// Create an convenience function representing a unary scalar function macro_rules! unary_scalar_expr { ($ENUM:ident, $FUNC:ident) => { diff --git a/datafusion/src/optimizer/projection_push_down.rs b/datafusion/src/optimizer/projection_push_down.rs index 2d66c5321acf..4fabc4f08f09 100644 --- a/datafusion/src/optimizer/projection_push_down.rs +++ b/datafusion/src/optimizer/projection_push_down.rs @@ -475,7 +475,7 @@ mod tests { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .aggregate(vec![], vec![max(col("b"))])? + .aggregate(Vec::::new(), vec![max(col("b"))])? .build()?; let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#test.b)]]\ @@ -508,7 +508,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("c"))? - .aggregate(vec![], vec![max(col("b"))])? + .aggregate(Vec::::new(), vec![max(col("b"))])? .build()?; let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#test.b)]]\ diff --git a/datafusion/src/pyarrow.rs b/datafusion/src/pyarrow.rs new file mode 100644 index 000000000000..da05d63d8c2c --- /dev/null +++ b/datafusion/src/pyarrow.rs @@ -0,0 +1,67 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use pyo3::exceptions::{PyException, PyNotImplementedError}; +use pyo3::prelude::*; +use pyo3::types::PyList; +use pyo3::PyNativeType; + +use crate::arrow::array::ArrayData; +use crate::arrow::pyarrow::PyArrowConvert; +use crate::error::DataFusionError; +use crate::scalar::ScalarValue; + +impl From for PyErr { + fn from(err: DataFusionError) -> PyErr { + PyException::new_err(err.to_string()) + } +} + +impl PyArrowConvert for ScalarValue { + fn from_pyarrow(value: &PyAny) -> PyResult { + let py = value.py(); + let typ = value.getattr("type")?; + let val = value.call_method0("as_py")?; + + // construct pyarrow array from the python value and pyarrow type + let factory = py.import("pyarrow")?.getattr("array")?; + let args = PyList::new(py, &[val]); + let array = factory.call1((args, typ))?; + + // convert the pyarrow array to rust array using C data interface + let array = array.extract::()?; + let scalar = ScalarValue::try_from_array(&array.into(), 0)?; + + Ok(scalar) + } + + fn to_pyarrow(&self, _py: Python) -> PyResult { + Err(PyNotImplementedError::new_err("Not implemented")) + } +} + +impl<'source> FromPyObject<'source> for ScalarValue { + fn extract(value: &'source PyAny) -> PyResult { + Self::from_pyarrow(value) + } +} + +impl<'a> IntoPy for ScalarValue { + fn into_py(self, py: Python) -> PyObject { + self.to_pyarrow(py).unwrap() + } +} diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 60d2da8be2c7..1653cb5d5ac5 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -18,6 +18,7 @@ //! SQL Query Planner (produces logical plan from SQL AST) use std::collections::HashSet; +use std::iter; use std::str::FromStr; use std::sync::Arc; use std::{convert::TryInto, vec}; @@ -822,7 +823,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let plan = if select.distinct { return LogicalPlanBuilder::from(plan) - .aggregate(select_exprs_post_aggr, vec![])? + .aggregate(select_exprs_post_aggr, iter::empty::())? .build(); } else { plan diff --git a/python/Cargo.lock b/python/Cargo.lock index 6ae27021e61c..fa84a54ced7b 100644 --- a/python/Cargo.lock +++ b/python/Cargo.lock @@ -10,9 +10,9 @@ checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] name = "ahash" -version = "0.7.4" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43bb833f0bf979d8475d38fbf09ed3b8a55e1885fe93ad3f93239fc6a4f17b98" +checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47" dependencies = [ "getrandom 0.2.3", "once_cell", @@ -72,6 +72,7 @@ dependencies = [ "lexical-core", "multiversion", "num", + "pyo3", "rand 0.8.4", "regex", "serde", @@ -121,9 +122,9 @@ dependencies = [ [[package]] name = "blake3" -version = "1.0.0" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcd555c66291d5f836dbb6883b48660ece810fe25a31f3bdfb911945dff2691f" +checksum = "2607a74355ce2e252d0c483b2d8a348e1bba36036e786ccc2dcd777213c86ffd" dependencies = [ "arrayref", "arrayvec", @@ -165,9 +166,9 @@ dependencies = [ [[package]] name = "bstr" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90682c8d613ad3373e66de8c6411e0ae2ab2571e879d2efbf73558cc66f21279" +checksum = "ba3569f383e8f1598449f1a423e72e99569137b47740b1da11ef19af3d5c3223" dependencies = [ "lazy_static", "memchr", @@ -183,9 +184,9 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "cc" -version = "1.0.70" +version = "1.0.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d26a6ce4b6a484fa3edb70f7efa6fc430fd2b87285fe8b84304fd0936faa0dc0" +checksum = "79c2681d6594606957bbb8631c4b90a7fcaaa72cdb714743a437b156d6a7eedd" dependencies = [ "jobserver", ] @@ -296,6 +297,7 @@ dependencies = [ "parquet", "paste 1.0.5", "pin-project-lite", + "pyo3", "rand 0.8.4", "regex", "sha2", @@ -311,7 +313,6 @@ name = "datafusion-python" version = "0.3.0" dependencies = [ "datafusion", - "libc", "pyo3", "rand 0.7.3", "tokio", @@ -340,9 +341,9 @@ dependencies = [ [[package]] name = "flate2" -version = "1.0.21" +version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80edafed416a46fb378521624fab1cfa2eb514784fd8921adbe8a8d8321da811" +checksum = "1e6988e897c1c9c485f43b47a529cef42fde0547f9d8d41a7062518f1d8fc53f" dependencies = [ "cfg-if", "crc32fast", @@ -544,9 +545,9 @@ dependencies = [ [[package]] name = "instant" -version = "0.1.10" +version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bee0328b1209d157ef001c94dd85b4f8f64139adb0eac2659f4b08382b2f474d" +checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" dependencies = [ "cfg-if", ] @@ -644,9 +645,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.101" +version = "0.2.105" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3cb00336871be5ed2c8ed44b60ae9959dc5b9f08539422ed43f09e34ecaeba21" +checksum = "869d572136620d55835903746bcb5cdc54cb2851fd0aeec53220b4bb65ef3013" [[package]] name = "lock_api" @@ -943,9 +944,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "ppv-lite86" -version = "0.2.10" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac74c624d6b2d21f425f752262f42188365d7b8ff1aff74c82e45136510a4857" +checksum = "ed0cfbc8191465bed66e1718596ee0b0b35d5ee1f41c5df2189d0fe8bde535ba" [[package]] name = "proc-macro-hack" @@ -961,9 +962,9 @@ checksum = "bc881b2c22681370c6a780e47af9840ef841837bc98118431d4e1868bd0c1086" [[package]] name = "proc-macro2" -version = "1.0.28" +version = "1.0.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c7ed8b8c7b886ea3ed7dde405212185f423ab44682667c8c6dd14aa1d9f6612" +checksum = "edc3358ebc67bc8b7fa0c007f945b0b18226f78437d61bec735a9eb96b61ee70" dependencies = [ "unicode-xid", ] @@ -1018,9 +1019,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.9" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3d0b9745dc2debf507c8422de05d7226cc1f0644216dfdfead988f9b1ab32a7" +checksum = "38bc8cc6a5f2e3655e0899c1b848643b2562f853f114bfec7be120678e3ace05" dependencies = [ "proc-macro2", ] @@ -1169,9 +1170,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.67" +version = "1.0.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7f9e390c27c3c0ce8bc5d725f6e4d30a29d26659494aa4b17535f7522c5c950" +checksum = "0f690853975602e1bfe1ccbf50504d67174e3bcf340f23b5ea9992e0587a52d8" dependencies = [ "indexmap", "itoa", @@ -1181,9 +1182,9 @@ dependencies = [ [[package]] name = "sha2" -version = "0.9.6" +version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9204c41a1597a8c5af23c82d1c921cb01ec0a4c59e07a9c7306062829a3903f3" +checksum = "b69f9a4c9740d74c5baa3fd2e547f9525fa8088a8a958e0ca2409a514e33f5fa" dependencies = [ "block-buffer", "cfg-if", @@ -1194,15 +1195,15 @@ dependencies = [ [[package]] name = "slab" -version = "0.4.4" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c307a32c1c5c437f38c7fd45d753050587732ba8628319fbdf12a7e289ccc590" +checksum = "9def91fd1e018fe007022791f865d0ccc9b3a0d5001e01aabb8b40e46000afb5" [[package]] name = "smallvec" -version = "1.6.1" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe0f37c9e8f3c5a4a66ad655a93c74daac4ad00c441533bf5c6e7990bb42604e" +checksum = "1ecab6c735a6bb4139c0caafd0cc3635748bbb3acf4550e8138122099251f309" [[package]] name = "snap" @@ -1251,9 +1252,9 @@ checksum = "6bdef32e8150c2a081110b42772ffe7d7c9032b606bc226c8260fd97e0976601" [[package]] name = "syn" -version = "1.0.76" +version = "1.0.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6f107db402c2c2055242dbf4d2af0e69197202e9faacbef9571bbe47f5a1b84" +checksum = "d010a1623fbd906d51d650a9916aaefc05ffa0e4053ff7fe601167f3e715d194" dependencies = [ "proc-macro2", "quote", @@ -1262,18 +1263,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.29" +version = "1.0.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "602eca064b2d83369e2b2f34b09c70b605402801927c65c11071ac911d299b88" +checksum = "854babe52e4df1653706b98fcfc05843010039b406875930a70e4d9644e5c417" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.29" +version = "1.0.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bad553cc2c78e8de258400763a647e80e6d1b31ee237275d756f6836d204494c" +checksum = "aa32fd3f627f367fe16f893e2597ae3c05020f8bba2666a4e6ea73d377e5714b" dependencies = [ "proc-macro2", "quote", @@ -1314,9 +1315,9 @@ dependencies = [ [[package]] name = "tokio" -version = "1.11.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4efe6fc2395938c8155973d7be49fe8d03a843726e285e100a8a383cc0154ce" +checksum = "c2c2416fdedca8443ae44b4527de1ea633af61d8f7169ffa6e72c5b53d24efcc" dependencies = [ "autocfg", "num_cpus", @@ -1326,9 +1327,9 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "1.3.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54473be61f4ebe4efd09cec9bd5d16fa51d70ea0192213d754d2d500457db110" +checksum = "b2dd85aeaba7b68df939bd357c6afb36c87951be9e80bf9c859f2fc3e9fca0fd" dependencies = [ "proc-macro2", "quote", @@ -1360,9 +1361,9 @@ checksum = "8895849a949e7845e06bd6dc1aa51731a103c42707010a5b591c0038fb73385b" [[package]] name = "unicode-width" -version = "0.1.8" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9337591893a19b88d8d87f2cec1e73fad5cdfd10e5a6f349f498ad6ea2ffb1e3" +checksum = "3ed742d4ea2bd1176e236172c8429aaf54486e7ac098db29ffe6529e0ce50973" [[package]] name = "unicode-xid" diff --git a/python/Cargo.toml b/python/Cargo.toml index c0645a152078..d3f13b498916 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -28,17 +28,19 @@ edition = "2021" rust-version = "1.56" [dependencies] -libc = "0.2" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] } rand = "0.7" -pyo3 = { version = "0.14.1", features = ["extension-module", "abi3", "abi3-py36"] } -datafusion = { path = "../datafusion", version = "5.1.0" } +pyo3 = { version = "0.14", features = ["extension-module", "abi3", "abi3-py36"] } +datafusion = { path = "../datafusion", version = "5.1.0", features = ["pyarrow"] } uuid = { version = "0.8", features = ["v4"] } [lib] -name = "datafusion" +name = "internals" crate-type = ["cdylib"] +[package.metadata.maturin] +name = "datafusion.internals" + [profile.release] lto = true codegen-units = 1 diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py new file mode 100644 index 000000000000..761381cb26b3 --- /dev/null +++ b/python/datafusion/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from .internals import * # noqa diff --git a/python/tests/__init__.py b/python/datafusion/tests/__init__.py similarity index 100% rename from python/tests/__init__.py rename to python/datafusion/tests/__init__.py diff --git a/python/tests/generic.py b/python/datafusion/tests/generic.py similarity index 100% rename from python/tests/generic.py rename to python/datafusion/tests/generic.py diff --git a/python/datafusion/tests/test_catalog.py b/python/datafusion/tests/test_catalog.py new file mode 100644 index 000000000000..5ae81d5521e1 --- /dev/null +++ b/python/datafusion/tests/test_catalog.py @@ -0,0 +1,72 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pyarrow as pa +import pytest + +from datafusion import ExecutionContext + + +@pytest.fixture +def ctx(): + return ExecutionContext() + + +@pytest.fixture +def database(ctx, tmp_path): + path = tmp_path / "test.csv" + + table = pa.Table.from_arrays( + [ + [1, 2, 3, 4], + ["a", "b", "c", "d"], + [1.1, 2.2, 3.3, 4.4], + ], + names=["int", "str", "float"], + ) + pa.csv.write_csv(table, path) + + ctx.register_csv("csv", path) + ctx.register_csv("csv1", str(path)) + ctx.register_csv( + "csv2", + path, + has_header=True, + delimiter=",", + schema_infer_max_records=10, + ) + + +def test_basic(ctx, database): + with pytest.raises(KeyError): + ctx.catalog("non-existent") + + default = ctx.catalog() + assert default.names() == ["public"] + + database = default.database("public") + assert database.names() == {"csv1", "csv", "csv2"} + + table = database.table("csv") + assert table.kind == "physical" + assert table.schema == pa.schema( + [ + pa.field("int", pa.int64(), nullable=False), + pa.field("str", pa.string(), nullable=False), + pa.field("float", pa.float64(), nullable=False), + ] + ) diff --git a/python/tests/test_df.py b/python/datafusion/tests/test_dataframe.py similarity index 82% rename from python/tests/test_df.py rename to python/datafusion/tests/test_dataframe.py index 9bbdb5a30077..236cd7c03ae3 100644 --- a/python/tests/test_df.py +++ b/python/datafusion/tests/test_dataframe.py @@ -17,7 +17,8 @@ import pyarrow as pa import pytest -from datafusion import ExecutionContext + +from datafusion import DataFrame, ExecutionContext from datafusion import functions as f @@ -61,7 +62,7 @@ def test_filter(df): def test_sort(df): - df = df.sort([f.col("b").sort(ascending=False)]) + df = df.sort(f.col("b").sort(ascending=False)) table = pa.Table.from_batches(df.collect()) expected = {"a": [3, 2, 1], "b": [6, 5, 4]} @@ -109,9 +110,29 @@ def test_join(): ) df1 = ctx.create_dataframe([[batch]]) - df = df.join(df1, join_keys=(["a"], ["a"]), how="inner") - df = df.sort([f.col("a").sort(ascending=True)]) + df = df.join(df1, on="a", how="inner") + df = df.sort(f.col("a").sort(ascending=True)) table = pa.Table.from_batches(df.collect()) expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]} assert table.to_pydict() == expected + + +def test_get_dataframe(tmp_path): + ctx = ExecutionContext() + + path = tmp_path / "test.csv" + table = pa.Table.from_arrays( + [ + [1, 2, 3, 4], + ["a", "b", "c", "d"], + [1.1, 2.2, 3.3, 4.4], + ], + names=["int", "str", "float"], + ) + pa.csv.write_csv(table, path) + + ctx.register_csv("csv", path) + + df = ctx.table("csv") + assert isinstance(df, DataFrame) diff --git a/python/tests/test_math_functions.py b/python/datafusion/tests/test_math_functions.py similarity index 99% rename from python/tests/test_math_functions.py rename to python/datafusion/tests/test_math_functions.py index 98656b8c4f42..4e473c3de16a 100644 --- a/python/tests/test_math_functions.py +++ b/python/datafusion/tests/test_math_functions.py @@ -18,6 +18,7 @@ import numpy as np import pyarrow as pa import pytest + from datafusion import ExecutionContext from datafusion import functions as f diff --git a/python/tests/test_pa_types.py b/python/datafusion/tests/test_pa_types.py similarity index 100% rename from python/tests/test_pa_types.py rename to python/datafusion/tests/test_pa_types.py diff --git a/python/tests/test_sql.py b/python/datafusion/tests/test_sql.py similarity index 88% rename from python/tests/test_sql.py rename to python/datafusion/tests/test_sql.py index f309a85104b2..e9fb49c0e33d 100644 --- a/python/tests/test_sql.py +++ b/python/datafusion/tests/test_sql.py @@ -68,9 +68,9 @@ def test_register_csv(ctx, tmp_path): assert ctx.tables() == {"csv", "csv1", "csv2", "csv3"} for table in ["csv", "csv1", "csv2"]: - result = ctx.sql(f"SELECT COUNT(int) FROM {table}").collect() + result = ctx.sql(f"SELECT COUNT(int) AS cnt FROM {table}").collect() result = pa.Table.from_batches(result) - assert result.to_pydict() == {f"COUNT({table}.int)": [4]} + assert result.to_pydict() == {"cnt": [4]} result = ctx.sql("SELECT * FROM csv3").collect() result = pa.Table.from_batches(result) @@ -87,9 +87,9 @@ def test_register_parquet(ctx, tmp_path): ctx.register_parquet("t", path) assert ctx.tables() == {"t"} - result = ctx.sql("SELECT COUNT(a) FROM t").collect() + result = ctx.sql("SELECT COUNT(a) AS cnt FROM t").collect() result = pa.Table.from_batches(result) - assert result.to_pydict() == {"COUNT(t.a)": [100]} + assert result.to_pydict() == {"cnt": [100]} def test_execute(ctx, tmp_path): @@ -102,21 +102,21 @@ def test_execute(ctx, tmp_path): assert ctx.tables() == {"t"} # count - result = ctx.sql("SELECT COUNT(a) FROM t").collect() + result = ctx.sql("SELECT COUNT(a) AS cnt FROM t").collect() expected = pa.array([7], pa.uint64()) - expected = [pa.RecordBatch.from_arrays([expected], ["COUNT(a)"])] + expected = [pa.RecordBatch.from_arrays([expected], ["cnt"])] assert result == expected # where expected = pa.array([2], pa.uint64()) - expected = [pa.RecordBatch.from_arrays([expected], ["COUNT(a)"])] - result = ctx.sql("SELECT COUNT(a) FROM t WHERE a > 10").collect() + expected = [pa.RecordBatch.from_arrays([expected], ["cnt"])] + result = ctx.sql("SELECT COUNT(a) AS cnt FROM t WHERE a > 10").collect() assert result == expected # group by results = ctx.sql( - "SELECT CAST(a as int), COUNT(a) FROM t GROUP BY CAST(a as int)" + "SELECT CAST(a as int) AS a, COUNT(a) AS cnt FROM t GROUP BY a" ).collect() # group by returns batches @@ -124,8 +124,8 @@ def test_execute(ctx, tmp_path): result_values = [] for result in results: pydict = result.to_pydict() - result_keys.extend(pydict["CAST(t.a AS Int32)"]) - result_values.extend(pydict["COUNT(t.a)"]) + result_keys.extend(pydict["a"]) + result_values.extend(pydict["cnt"]) result_keys, result_values = ( list(t) for t in zip(*sorted(zip(result_keys, result_values))) @@ -136,14 +136,12 @@ def test_execute(ctx, tmp_path): # order by result = ctx.sql( - "SELECT a, CAST(a AS int) FROM t ORDER BY a DESC LIMIT 2" + "SELECT a, CAST(a AS int) AS a_int FROM t ORDER BY a DESC LIMIT 2" ).collect() expected_a = pa.array([50.0219, 50.0152], pa.float64()) expected_cast = pa.array([50, 50], pa.int32()) expected = [ - pa.RecordBatch.from_arrays( - [expected_a, expected_cast], ["a", "CAST(t.a AS Int32)"] - ) + pa.RecordBatch.from_arrays([expected_a, expected_cast], ["a", "a_int"]) ] np.testing.assert_equal(expected[0].column(1), expected[0].column(1)) diff --git a/python/tests/test_string_functions.py b/python/datafusion/tests/test_string_functions.py similarity index 99% rename from python/tests/test_string_functions.py rename to python/datafusion/tests/test_string_functions.py index 965f08707285..3d6c380c55a6 100644 --- a/python/tests/test_string_functions.py +++ b/python/datafusion/tests/test_string_functions.py @@ -17,6 +17,7 @@ import pyarrow as pa import pytest + from datafusion import ExecutionContext from datafusion import functions as f diff --git a/python/tests/test_udaf.py b/python/datafusion/tests/test_udaf.py similarity index 99% rename from python/tests/test_udaf.py rename to python/datafusion/tests/test_udaf.py index 7ff622330ccc..0eb93f6c9876 100644 --- a/python/tests/test_udaf.py +++ b/python/datafusion/tests/test_udaf.py @@ -20,6 +20,7 @@ import pyarrow as pa import pyarrow.compute as pc import pytest + from datafusion import ExecutionContext from datafusion import functions as f diff --git a/python/pyproject.toml b/python/pyproject.toml index f366aa94ddf4..c6ee363497d7 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -50,3 +50,6 @@ dependencies = [ [project.urls] documentation = "https://arrow.apache.org/datafusion/python" repository = "https://github.com/apache/arrow-datafusion" + +[tool.isort] +profile = "black" diff --git a/python/src/catalog.rs b/python/src/catalog.rs new file mode 100644 index 000000000000..a2d382fbe7b7 --- /dev/null +++ b/python/src/catalog.rs @@ -0,0 +1,123 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashSet; +use std::sync::Arc; + +use pyo3::exceptions::PyKeyError; +use pyo3::prelude::*; +// use pyo3::{PyIterProtocol, PyMappingProtocol}; + +use datafusion::{ + arrow::pyarrow::PyArrowConvert, + catalog::{catalog::CatalogProvider, schema::SchemaProvider}, + datasource::{TableProvider, TableType}, +}; + +#[pyclass(name = "Catalog", subclass)] +pub(crate) struct PyCatalog { + catalog: Arc, +} + +#[pyclass(name = "Database", subclass)] +pub(crate) struct PyDatabase { + database: Arc, +} + +#[pyclass(name = "Table", subclass)] +pub(crate) struct PyTable { + table: Arc, +} + +impl PyCatalog { + pub fn new(catalog: Arc) -> Self { + Self { catalog } + } +} + +impl PyDatabase { + pub fn new(database: Arc) -> Self { + Self { database } + } +} + +impl PyTable { + pub fn new(table: Arc) -> Self { + Self { table } + } +} + +#[pymethods] +impl PyCatalog { + fn names(&self) -> Vec { + self.catalog.schema_names() + } + + fn database(&self, name: &str) -> PyResult { + match self.catalog.schema(name) { + Some(database) => Ok(PyDatabase::new(database)), + None => Err(PyKeyError::new_err(format!( + "Database with name {} doesn't exist.", + name + ))), + } + } +} + +#[pymethods] +impl PyDatabase { + fn names(&self) -> HashSet { + self.database.table_names().into_iter().collect() + } + + fn table(&self, name: &str) -> PyResult { + match self.database.table(name) { + Some(table) => Ok(PyTable::new(table)), + None => Err(PyKeyError::new_err(format!( + "Table with name {} doesn't exist.", + name + ))), + } + } + + // register_table + // deregister_table +} + +#[pymethods] +impl PyTable { + /// Get a reference to the schema for this table + #[getter] + fn schema(&self, py: Python) -> PyResult { + self.table.schema().to_pyarrow(py) + } + + /// Get the type of this table for metadata/catalog purposes. + #[getter] + fn kind(&self) -> &str { + match self.table.table_type() { + TableType::Base => "physical", + TableType::View => "view", + TableType::Temporary => "temporary", + } + } + + // fn scan + // fn statistics + // fn has_exact_statistics + // fn supports_filter_pushdown +} diff --git a/python/src/context.rs b/python/src/context.rs index b813f27a73c9..b8fb0dbd118d 100644 --- a/python/src/context.rs +++ b/python/src/context.rs @@ -22,71 +22,53 @@ use uuid::Uuid; use tokio::runtime::Runtime; -use pyo3::exceptions::PyValueError; +use pyo3::exceptions::{PyKeyError, PyValueError}; use pyo3::prelude::*; +use datafusion::arrow::datatypes::{DataType, Schema}; use datafusion::arrow::record_batch::RecordBatch; use datafusion::datasource::MemTable; -use datafusion::execution::context::ExecutionContext as _ExecutionContext; +use datafusion::execution::context::ExecutionContext; use datafusion::prelude::CsvReadOptions; -use crate::dataframe; -use crate::errors; +use crate::{dataframe, errors, functions}; use crate::functions::{self, PyVolatility}; -use crate::to_rust; use crate::types::PyDataType; +use crate::catalog::PyCatalog; +use crate::dataframe::PyDataFrame; +use crate::errors::DataFusionError; -/// `ExecutionContext` is able to plan and execute DataFusion plans. + +/// `PyExecutionContext` is able to plan and execute DataFusion plans. /// It has a powerful optimizer, a physical planner for local execution, and a /// multi-threaded execution engine to perform the execution. -#[pyclass(unsendable)] -pub(crate) struct ExecutionContext { - ctx: _ExecutionContext, +#[pyclass(name = "ExecutionContext", unsendable)] +pub(crate) struct PyExecutionContext { + ctx: ExecutionContext, } #[pymethods] -impl ExecutionContext { +impl PyExecutionContext { + // TODO(kszucs): should expose the configuration options as keyword arguments #[new] fn new() -> Self { - ExecutionContext { - ctx: _ExecutionContext::new(), + PyExecutionContext { + ctx: ExecutionContext::new(), } } - /// Returns a DataFrame whose plan corresponds to the SQL statement. - fn sql(&mut self, query: &str, py: Python) -> PyResult { - let rt = Runtime::new().unwrap(); - let df = py.allow_threads(|| { - rt.block_on(async { - self.ctx - .sql(query) - .await - .map_err(|e| -> errors::DataFusionError { e.into() }) - }) - })?; - Ok(dataframe::DataFrame::new( - self.ctx.state.clone(), - df.to_logical_plan(), - )) + /// Returns a PyDataFrame whose plan corresponds to the SQL statement. + fn sql(&mut self, query: &str) -> PyResult { + let df = self.ctx.sql(query).map_err(DataFusionError::from)?; + Ok(PyDataFrame::new(df)) } fn create_dataframe( &mut self, - partitions: Vec>, - py: Python, - ) -> PyResult { - let partitions: Vec> = partitions - .iter() - .map(|batches| { - batches - .iter() - .map(|batch| to_rust::to_rust_batch(batch.as_ref(py))) - .collect() - }) - .collect::>()?; - - let table = - errors::wrap(MemTable::try_new(partitions[0][0].schema(), partitions))?; + partitions: Vec>, + ) -> PyResult { + let table = MemTable::try_new(partitions[0][0].schema(), partitions) + .map_err(DataFusionError::from)?; // generate a random (unique) name for this table // table name cannot start with numeric digit @@ -95,43 +77,32 @@ impl ExecutionContext { .to_simple() .encode_lower(&mut Uuid::encode_buffer()); - errors::wrap(self.ctx.register_table(&*name, Arc::new(table)))?; - Ok(dataframe::DataFrame::new( - self.ctx.state.clone(), - errors::wrap(self.ctx.table(&*name))?.to_logical_plan(), - )) + self.ctx + .register_table(&*name, Arc::new(table)) + .map_err(DataFusionError::from)?; + let table = self.ctx.table(&*name).map_err(DataFusionError::from)?; + + let df = PyDataFrame::new(table); + Ok(df) } fn register_record_batches( &mut self, name: &str, - partitions: Vec>, - py: Python, + partitions: Vec>, ) -> PyResult<()> { - let partitions: Vec> = partitions - .iter() - .map(|batches| { - batches - .iter() - .map(|batch| to_rust::to_rust_batch(batch.as_ref(py))) - .collect() - }) - .collect::>()?; - - let table = - errors::wrap(MemTable::try_new(partitions[0][0].schema(), partitions))?; - - errors::wrap(self.ctx.register_table(&*name, Arc::new(table)))?; + let schema = partitions[0][0].schema(); + let table = MemTable::try_new(schema, partitions)?; + self.ctx + .register_table(name, Arc::new(table)) + .map_err(DataFusionError::from)?; Ok(()) } - fn register_parquet(&mut self, name: &str, path: &str, py: Python) -> PyResult<()> { - let rt = Runtime::new().unwrap(); - py.allow_threads(|| { - rt.block_on(async { - errors::wrap(self.ctx.register_parquet(name, path).await) - }) - })?; + fn register_parquet(&mut self, name: &str, path: &str) -> PyResult<()> { + self.ctx + .register_parquet(name, path) + .map_err(DataFusionError::from)?; Ok(()) } @@ -146,7 +117,7 @@ impl ExecutionContext { &mut self, name: &str, path: PathBuf, - schema: Option<&PyAny>, + schema: Option, has_header: bool, delimiter: &str, schema_infer_max_records: usize, @@ -156,10 +127,6 @@ impl ExecutionContext { let path = path .to_str() .ok_or(PyValueError::new_err("Unable to convert path to a string"))?; - let schema = match schema { - Some(s) => Some(to_rust::to_rust_schema(s)?), - None => None, - }; let delimiter = delimiter.as_bytes(); if delimiter.len() != 1 { return Err(PyValueError::new_err( @@ -174,12 +141,9 @@ impl ExecutionContext { .file_extension(file_extension); options.schema = schema.as_ref(); - let rt = Runtime::new().unwrap(); - py.allow_threads(|| { - rt.block_on(async { - errors::wrap(self.ctx.register_csv(name, path, options).await) - }) - })?; + self.ctx + .register_csv(name, path, options) + .map_err(DataFusionError::from)?; Ok(()) } @@ -190,14 +154,33 @@ impl ExecutionContext { args_types: Vec, return_type: PyDataType, volatility: PyVolatility, - ) { + ) -> PyResult<()> { let function = - functions::create_udf(func, args_types, return_type, volatility, name); - + functions::create_udf(func, args_types, return_type, volatility, name)?; self.ctx.register_udf(function.function); + Ok(()) + } + + #[args(name = "\"datafusion\"")] + fn catalog(&self, name: &str) -> PyResult { + match self.ctx.catalog(name) { + Some(catalog) => Ok(PyCatalog::new(catalog)), + None => Err(PyKeyError::new_err(format!( + "Catalog with name {} doesn't exist.", + &name + ))), + } } fn tables(&self) -> HashSet { self.ctx.tables().unwrap() } + + fn table(&self, name: &str) -> PyResult { + Ok(PyDataFrame::new(self.ctx.table(name)?)) + } + + fn empty_table(&self) -> PyResult { + Ok(PyDataFrame::new(self.ctx.read_empty()?)) + } } diff --git a/python/src/dataframe.rs b/python/src/dataframe.rs index 48da234fc23e..236abfdf9909 100644 --- a/python/src/dataframe.rs +++ b/python/src/dataframe.rs @@ -15,174 +15,94 @@ // specific language governing permissions and limitations // under the License. -use std::sync::{Arc, Mutex}; +use std::sync::Arc; -use logical_plan::LogicalPlan; -use pyo3::{prelude::*, types::PyTuple}; +use pyo3::prelude::*; use tokio::runtime::Runtime; -use datafusion::execution::context::ExecutionContext as _ExecutionContext; -use datafusion::logical_plan::{JoinType, LogicalPlanBuilder}; -use datafusion::physical_plan::collect; -use datafusion::{execution::context::ExecutionContextState, logical_plan}; - -use crate::{errors, to_py}; -use crate::{errors::DataFusionError, expression}; +use datafusion::arrow::datatypes::Schema; +use datafusion::arrow::pyarrow::PyArrowConvert; use datafusion::arrow::util::pretty; +use datafusion::dataframe::DataFrame; +use datafusion::logical_plan::JoinType; + +use crate::{errors::DataFusionError, expression::PyExpr}; -/// A DataFrame is a representation of a logical plan and an API to compose statements. +/// A PyDataFrame is a representation of a logical plan and an API to compose statements. /// Use it to build a plan and `.collect()` to execute the plan and collect the result. /// The actual execution of a plan runs natively on Rust and Arrow on a multi-threaded environment. -#[pyclass] -pub(crate) struct DataFrame { - ctx_state: Arc>, - plan: LogicalPlan, +#[pyclass(name = "DataFrame")] +#[derive(Clone)] +pub(crate) struct PyDataFrame { + df: Arc, } -impl DataFrame { - /// creates a new DataFrame - pub fn new(ctx_state: Arc>, plan: LogicalPlan) -> Self { - Self { ctx_state, plan } +impl PyDataFrame { + /// creates a new PyDataFrame + pub fn new(df: Arc) -> Self { + Self { df } } } #[pymethods] -impl DataFrame { - /// Select `expressions` from the existing DataFrame. - #[args(args = "*")] - fn select(&self, args: &PyTuple) -> PyResult { - let expressions = expression::from_tuple(args)?; - let builder = LogicalPlanBuilder::from(self.plan.clone()); - let builder = - errors::wrap(builder.project(expressions.into_iter().map(|e| e.expr)))?; - let plan = errors::wrap(builder.build())?; - - Ok(DataFrame { - ctx_state: self.ctx_state.clone(), - plan, - }) +impl PyDataFrame { + /// Returns the schema from the logical plan + fn schema(&self) -> Schema { + self.df.schema().into() } - /// Filter according to the `predicate` expression - fn filter(&self, predicate: expression::Expression) -> PyResult { - let builder = LogicalPlanBuilder::from(self.plan.clone()); - let builder = errors::wrap(builder.filter(predicate.expr))?; - let plan = errors::wrap(builder.build())?; + #[args(args = "*")] + fn select(&self, args: Vec) -> PyResult { + let expr = args.into_iter().map(|e| e.into()).collect(); + let df = self.df.select(expr)?; + Ok(Self::new(df)) + } - Ok(DataFrame { - ctx_state: self.ctx_state.clone(), - plan, - }) + fn filter(&self, predicate: PyExpr) -> PyResult { + let df = self.df.filter(predicate.into())?; + Ok(Self::new(df)) } - /// Aggregates using expressions - fn aggregate( - &self, - group_by: Vec, - aggs: Vec, - ) -> PyResult { - let builder = LogicalPlanBuilder::from(self.plan.clone()); - let builder = errors::wrap(builder.aggregate( - group_by.into_iter().map(|e| e.expr), - aggs.into_iter().map(|e| e.expr), - ))?; - let plan = errors::wrap(builder.build())?; - - Ok(DataFrame { - ctx_state: self.ctx_state.clone(), - plan, - }) + fn aggregate(&self, group_by: Vec, aggs: Vec) -> PyResult { + let group_by = group_by.into_iter().map(|e| e.into()).collect(); + let aggs = aggs.into_iter().map(|e| e.into()).collect(); + let df = self.df.aggregate(group_by, aggs)?; + Ok(Self::new(df)) } - /// Sort by specified sorting expressions - fn sort(&self, exprs: Vec) -> PyResult { - let exprs = exprs.into_iter().map(|e| e.expr); - let builder = LogicalPlanBuilder::from(self.plan.clone()); - let builder = errors::wrap(builder.sort(exprs))?; - let plan = errors::wrap(builder.build())?; - Ok(DataFrame { - ctx_state: self.ctx_state.clone(), - plan, - }) + #[args(exprs = "*")] + fn sort(&self, exprs: Vec) -> PyResult { + let exprs = exprs.into_iter().map(|e| e.into()).collect(); + let df = self.df.sort(exprs)?; + Ok(Self::new(df)) } - /// Limits the plan to return at most `count` rows fn limit(&self, count: usize) -> PyResult { - let builder = LogicalPlanBuilder::from(self.plan.clone()); - let builder = errors::wrap(builder.limit(count))?; - let plan = errors::wrap(builder.build())?; - - Ok(DataFrame { - ctx_state: self.ctx_state.clone(), - plan, - }) + let df = self.df.limit(count)?; + Ok(Self::new(df)) } /// Executes the plan, returning a list of `RecordBatch`es. - /// Unless some order is specified in the plan, there is no guarantee of the order of the result - fn collect(&self, py: Python) -> PyResult { - let ctx = _ExecutionContext::from(self.ctx_state.clone()); + /// Unless some order is specified in the plan, there is no + /// guarantee of the order of the result. + fn collect(&self, py: Python) -> PyResult> { let rt = Runtime::new().unwrap(); - let plan = ctx - .optimize(&self.plan) - .map_err(|e| -> errors::DataFusionError { e.into() })?; - - let plan = py.allow_threads(|| { - rt.block_on(async { - ctx.create_physical_plan(&plan) - .await - .map_err(|e| -> errors::DataFusionError { e.into() }) - }) - })?; - - let batches = py.allow_threads(|| { - rt.block_on(async { - collect(plan) - .await - .map_err(|e| -> errors::DataFusionError { e.into() }) - }) - })?; - to_py::to_py(&batches) + let batches = py.allow_threads(|| rt.block_on(self.df.collect()))?; + // cannot use PyResult> return type due to + // https://github.com/PyO3/pyo3/issues/1813 + batches.into_iter().map(|rb| rb.to_pyarrow(py)).collect() } /// Print the result, 20 lines by default #[args(num = "20")] fn show(&self, py: Python, num: usize) -> PyResult<()> { - let ctx = _ExecutionContext::from(self.ctx_state.clone()); let rt = Runtime::new().unwrap(); - let plan = py.allow_threads(|| { - rt.block_on(async { - let l_plan = ctx - .optimize(&self.limit(num)?.plan) - .map_err(|e| -> errors::DataFusionError { e.into() })?; - let p_plan = ctx - .create_physical_plan(&l_plan) - .await - .map_err(|e| -> errors::DataFusionError { e.into() })?; - Ok::<_, PyErr>(p_plan) - }) - })?; - - let batches = py.allow_threads(|| { - rt.block_on(async { - collect(plan) - .await - .map_err(|e| -> errors::DataFusionError { e.into() }) - }) - })?; - - Ok(pretty::print_batches(&batches).unwrap()) + let df = self.df.limit(num)?; + let batches = py.allow_threads(|| rt.block_on(df.collect()))?; + Ok(pretty::print_batches(&batches)?) } - /// Returns the join of two DataFrames `on`. - fn join( - &self, - right: &DataFrame, - join_keys: (Vec<&str>, Vec<&str>), - how: &str, - ) -> PyResult { - let builder = LogicalPlanBuilder::from(self.plan.clone()); - + fn join(&self, right: PyDataFrame, on: Vec<&str>, how: &str) -> PyResult { let join_type = match how { "inner" => JoinType::Inner, "left" => JoinType::Left, @@ -199,13 +119,7 @@ impl DataFrame { } }; - let builder = errors::wrap(builder.join(&right.plan, join_type, join_keys))?; - - let plan = errors::wrap(builder.build())?; - - Ok(DataFrame { - ctx_state: self.ctx_state.clone(), - plan, - }) + let df = self.df.join(right.df, join_type, &on, &on)?; + Ok(Self::new(df)) } } diff --git a/python/src/errors.rs b/python/src/errors.rs index fbe98037a030..85b4a77cf4db 100644 --- a/python/src/errors.rs +++ b/python/src/errors.rs @@ -16,10 +16,11 @@ // under the License. use core::fmt; +//use std::result::Result; use datafusion::arrow::error::ArrowError; use datafusion::error::DataFusionError as InnerDataFusionError; -use pyo3::{exceptions, PyErr}; +use pyo3::{exceptions::PyException, PyErr}; #[derive(Debug)] pub enum DataFusionError { @@ -38,9 +39,9 @@ impl fmt::Display for DataFusionError { } } -impl From for PyErr { - fn from(err: DataFusionError) -> PyErr { - exceptions::PyException::new_err(err.to_string()) +impl From for DataFusionError { + fn from(err: ArrowError) -> DataFusionError { + DataFusionError::ArrowError(err) } } @@ -50,12 +51,8 @@ impl From for DataFusionError { } } -impl From for DataFusionError { - fn from(err: ArrowError) -> DataFusionError { - DataFusionError::ArrowError(err) +impl From for PyErr { + fn from(err: DataFusionError) -> PyErr { + PyException::new_err(err.to_string()) } } - -pub(crate) fn wrap(a: Result) -> Result { - Ok(a?) -} diff --git a/python/src/expression.rs b/python/src/expression.rs index 4320b1d14c8b..63e9d7b1d665 100644 --- a/python/src/expression.rs +++ b/python/src/expression.rs @@ -15,156 +15,161 @@ // specific language governing permissions and limitations // under the License. -use pyo3::{ - basic::CompareOp, prelude::*, types::PyTuple, PyNumberProtocol, PyObjectProtocol, -}; +use pyo3::{basic::CompareOp, prelude::*, PyNumberProtocol, PyObjectProtocol}; +use std::convert::{From, Into}; +use std::vec::Vec; -use datafusion::logical_plan::Expr as _Expr; -use datafusion::physical_plan::udaf::AggregateUDF as _AggregateUDF; -use datafusion::physical_plan::udf::ScalarUDF as _ScalarUDF; +use datafusion::arrow::datatypes::DataType; +use datafusion::logical_plan::{col, lit, Expr}; +use datafusion::physical_plan::{udaf::AggregateUDF, udf::ScalarUDF}; +use datafusion::scalar::ScalarValue; -/// An expression that can be used on a DataFrame -#[pyclass] +/// An PyExpr that can be used on a DataFrame +#[pyclass(name = "Expr")] #[derive(Debug, Clone)] -pub(crate) struct Expression { - pub(crate) expr: _Expr, +pub(crate) struct PyExpr { + pub(crate) expr: Expr, } -/// converts a tuple of expressions into a vector of Expressions -pub(crate) fn from_tuple(value: &PyTuple) -> PyResult> { - value - .iter() - .map(|e| e.extract::()) - .collect::>() +impl From for Expr { + fn from(expr: PyExpr) -> Expr { + expr.expr + } +} + +impl Into for Expr { + fn into(self) -> PyExpr { + PyExpr { expr: self } + } } #[pyproto] -impl PyNumberProtocol for Expression { - fn __add__(lhs: Expression, rhs: Expression) -> PyResult { - Ok(Expression { - expr: lhs.expr + rhs.expr, - }) +impl PyNumberProtocol for PyExpr { + fn __add__(lhs: PyExpr, rhs: PyExpr) -> PyResult { + Ok((lhs.expr + rhs.expr).into()) } - fn __sub__(lhs: Expression, rhs: Expression) -> PyResult { - Ok(Expression { - expr: lhs.expr - rhs.expr, - }) + fn __sub__(lhs: PyExpr, rhs: PyExpr) -> PyResult { + Ok((lhs.expr - rhs.expr).into()) } - fn __truediv__(lhs: Expression, rhs: Expression) -> PyResult { - Ok(Expression { - expr: lhs.expr / rhs.expr, - }) + fn __truediv__(lhs: PyExpr, rhs: PyExpr) -> PyResult { + Ok((lhs.expr / rhs.expr).into()) } - fn __mul__(lhs: Expression, rhs: Expression) -> PyResult { - Ok(Expression { - expr: lhs.expr * rhs.expr, - }) + fn __mul__(lhs: PyExpr, rhs: PyExpr) -> PyResult { + Ok((lhs.expr * rhs.expr).into()) } - fn __and__(lhs: Expression, rhs: Expression) -> PyResult { - Ok(Expression { - expr: lhs.expr.and(rhs.expr), - }) + fn __mod__(lhs: PyExpr, rhs: PyExpr) -> PyResult { + Ok(lhs.expr.clone().modulus(rhs.expr).into()) } - fn __or__(lhs: Expression, rhs: Expression) -> PyResult { - Ok(Expression { - expr: lhs.expr.or(rhs.expr), - }) + fn __and__(lhs: PyExpr, rhs: PyExpr) -> PyResult { + Ok(lhs.expr.clone().and(rhs.expr).into()) } - fn __invert__(&self) -> PyResult { - Ok(Expression { - expr: self.expr.clone().not(), - }) + fn __or__(lhs: PyExpr, rhs: PyExpr) -> PyResult { + Ok(lhs.expr.clone().or(rhs.expr).into()) + } + + fn __invert__(&self) -> PyResult { + Ok(self.expr.clone().not().into()) } } #[pyproto] -impl PyObjectProtocol for Expression { - fn __richcmp__(&self, other: Expression, op: CompareOp) -> Expression { - match op { - CompareOp::Lt => Expression { - expr: self.expr.clone().lt(other.expr), - }, - CompareOp::Le => Expression { - expr: self.expr.clone().lt_eq(other.expr), - }, - CompareOp::Eq => Expression { - expr: self.expr.clone().eq(other.expr), - }, - CompareOp::Ne => Expression { - expr: self.expr.clone().not_eq(other.expr), - }, - CompareOp::Gt => Expression { - expr: self.expr.clone().gt(other.expr), - }, - CompareOp::Ge => Expression { - expr: self.expr.clone().gt_eq(other.expr), - }, - } +impl PyObjectProtocol for PyExpr { + fn __richcmp__(&self, other: PyExpr, op: CompareOp) -> PyExpr { + let expr = match op { + CompareOp::Lt => self.expr.clone().lt(other.expr), + CompareOp::Le => self.expr.clone().lt_eq(other.expr), + CompareOp::Eq => self.expr.clone().eq(other.expr), + CompareOp::Ne => self.expr.clone().not_eq(other.expr), + CompareOp::Gt => self.expr.clone().gt(other.expr), + CompareOp::Ge => self.expr.clone().gt_eq(other.expr), + }; + expr.into() } } #[pymethods] -impl Expression { - /// assign a name to the expression - pub fn alias(&self, name: &str) -> PyResult { - Ok(Expression { - expr: self.expr.clone().alias(name), - }) +impl PyExpr { + /// assign a name to the PyExpr + pub fn alias(&self, name: &str) -> PyExpr { + self.expr.clone().alias(name).into() } - /// Create a sort expression from an existing expression. + /// Create a sort PyExpr from an existing PyExpr. #[args(ascending = true, nulls_first = true)] - pub fn sort(&self, ascending: bool, nulls_first: bool) -> PyResult { - Ok(Expression { - expr: self.expr.clone().sort(ascending, nulls_first), - }) + pub fn sort(&self, ascending: bool, nulls_first: bool) -> PyExpr { + self.expr.clone().sort(ascending, nulls_first).into() + } + + pub fn is_null(&self) -> PyExpr { + self.expr.clone().is_null().into() + } + + pub fn cast(&self, to: DataType) -> PyExpr { + // self.expr.cast_to() requires DFSchema to validate that the cast + // is supported, omit that for now + let expr = Expr::Cast { + expr: Box::new(self.expr.clone()), + data_type: to, + }; + expr.into() } } -/// Represents a ScalarUDF +/// Represents a PyScalarUDF #[pyclass] #[derive(Debug, Clone)] -pub struct ScalarUDF { - pub(crate) function: _ScalarUDF, +pub struct PyScalarUDF { + pub(crate) function: ScalarUDF, } #[pymethods] -impl ScalarUDF { - /// creates a new expression with the call of the udf +impl PyScalarUDF { + /// creates a new PyExpr with the call of the udf #[call] #[args(args = "*")] - fn __call__(&self, args: &PyTuple) -> PyResult { - let args = from_tuple(args)?.iter().map(|e| e.expr.clone()).collect(); - - Ok(Expression { - expr: self.function.call(args), - }) + fn __call__(&self, args: Vec) -> PyResult { + let args = args.iter().map(|e| e.expr.clone()).collect(); + Ok(self.function.call(args).into()) } } /// Represents a AggregateUDF #[pyclass] #[derive(Debug, Clone)] -pub struct AggregateUDF { - pub(crate) function: _AggregateUDF, +pub struct PyAggregateUDF { + pub(crate) function: AggregateUDF, } #[pymethods] -impl AggregateUDF { - /// creates a new expression with the call of the udf +impl PyAggregateUDF { + /// creates a new PyExpr with the call of the udf #[call] #[args(args = "*")] - fn __call__(&self, args: &PyTuple) -> PyResult { - let args = from_tuple(args)?.iter().map(|e| e.expr.clone()).collect(); - - Ok(Expression { - expr: self.function.call(args), - }) + fn __call__(&self, args: Vec) -> PyResult { + let args = args.iter().map(|e| e.expr.clone()).collect(); + Ok(self.function.call(args).into()) } } + +#[pyfunction] +fn literal(value: ScalarValue) -> PyExpr { + lit(value).into() +} + +#[pyfunction] +fn column(value: &str) -> PyExpr { + col(value).into() +} + +pub fn init(m: &PyModule) -> PyResult<()> { + m.add_class::()?; + m.add_wrapped(wrap_pyfunction!(literal))?; + m.add_wrapped(wrap_pyfunction!(column))?; + Ok(()) +} diff --git a/python/src/functions.rs b/python/src/functions.rs index e7d141d61171..282d47a627e3 100644 --- a/python/src/functions.rs +++ b/python/src/functions.rs @@ -15,87 +15,64 @@ // specific language governing permissions and limitations // under the License. -use crate::udaf; -use crate::udf; -use crate::{expression, types::PyDataType}; +use std::sync::Arc; + +use pyo3::{prelude::*, wrap_pyfunction}; + use datafusion::arrow::datatypes::DataType; use datafusion::logical_plan::{self, Literal}; use datafusion::physical_plan::functions::Volatility; use pyo3::{prelude::*, types::PyTuple, wrap_pyfunction, Python}; use std::sync::Arc; +use datafusion::logical_plan; +use datafusion::logical_plan::Expr; +use datafusion::physical_plan::{ + aggregates::AggregateFunction, functions::BuiltinScalarFunction, +}; -/// Expression representing a column on the existing plan. +use crate::{ + expression::{PyAggregateUDF, PyExpr, PyScalarUDF}, + udaf, udf, +}; + +/// PyExpr representing a column on the existing plan. +/// TODO(kszucs): remove col and lit #[pyfunction] #[pyo3(text_signature = "(name)")] -fn col(name: &str) -> expression::Expression { - expression::Expression { +fn col(name: &str) -> PyExpr { + PyExpr { expr: logical_plan::col(name), } } -/// # A bridge type that converts PyAny data into datafusion literal -/// -/// Note that the ordering here matters because it has to be from -/// narrow to wider values because Python has duck typing so putting -/// Int before Boolean results in a premature match. -#[derive(FromPyObject)] -enum PythonLiteral<'a> { - Boolean(bool), - Int(i64), - UInt(u64), - Float(f64), - Str(&'a str), - Binary(&'a [u8]), -} - -impl<'a> Literal for PythonLiteral<'a> { - fn lit(&self) -> logical_plan::Expr { - match self { - PythonLiteral::Boolean(val) => val.lit(), - PythonLiteral::Int(val) => val.lit(), - PythonLiteral::UInt(val) => val.lit(), - PythonLiteral::Float(val) => val.lit(), - PythonLiteral::Str(val) => val.lit(), - PythonLiteral::Binary(val) => val.lit(), - } - } -} - -/// Expression representing a constant value +/// PyExpr representing a constant value #[pyfunction] #[pyo3(text_signature = "(value)")] -fn lit(value: &PyAny) -> PyResult { - let py_lit = value.extract::()?; - let expr = py_lit.lit(); - Ok(expression::Expression { expr }) +fn lit(value: i32) -> PyExpr { + logical_plan::lit(value).into() } #[pyfunction] -fn array(value: Vec) -> expression::Expression { - expression::Expression { +fn array(value: Vec) -> PyExpr { + PyExpr { expr: logical_plan::array(value.into_iter().map(|x| x.expr).collect::>()), } } #[pyfunction] -fn in_list( - expr: expression::Expression, - value: Vec, - negated: bool, -) -> expression::Expression { - expression::Expression { - expr: logical_plan::in_list( - expr.expr, - value.into_iter().map(|x| x.expr).collect::>(), - negated, - ), - } +fn in_list(expr: PyExpr, value: Vec, negated: bool) -> PyExpr { + logical_plan::in_list( + expr.expr, + value.into_iter().map(|x| x.expr).collect::>(), + negated, + ) + .into() } /// Current date and time #[pyfunction] -fn now() -> expression::Expression { - expression::Expression { +fn now() -> PyExpr { + PyExpr { // here lit(0) is a stub for conform to arity expr: logical_plan::now(logical_plan::lit(0)), } @@ -103,8 +80,8 @@ fn now() -> expression::Expression { /// Returns a random value in the range 0.0 <= x < 1.0 #[pyfunction] -fn random() -> expression::Expression { - expression::Expression { +fn random() -> PyExpr { + PyExpr { expr: logical_plan::random(), } } @@ -124,123 +101,147 @@ fn digest( /// Concatenates the text representations of all the arguments. /// NULL arguments are ignored. #[pyfunction(args = "*")] -fn concat(args: &PyTuple) -> PyResult { - let expressions = expression::from_tuple(args)?; - let args = expressions.into_iter().map(|e| e.expr).collect::>(); - Ok(expression::Expression { - expr: logical_plan::concat(&args), - }) +fn concat(args: Vec) -> PyResult { + let args = args.into_iter().map(|e| e.expr).collect::>(); + Ok(logical_plan::concat(&args).into()) } /// Concatenates all but the first argument, with separators. /// The first argument is used as the separator string, and should not be NULL. /// Other NULL arguments are ignored. #[pyfunction(sep, args = "*")] -fn concat_ws(sep: String, args: &PyTuple) -> PyResult { - let expressions = expression::from_tuple(args)?; - let args = expressions.into_iter().map(|e| e.expr).collect::>(); - Ok(expression::Expression { - expr: logical_plan::concat_ws(sep, &args), - }) +fn concat_ws(sep: String, args: Vec) -> PyResult { + let args = args.into_iter().map(|e| e.expr).collect::>(); + Ok(logical_plan::concat_ws(sep, &args).into()) } -macro_rules! define_unary_function { - ($NAME: ident) => { - #[doc = "This function is not documented yet"] - #[pyfunction] - fn $NAME(value: expression::Expression) -> expression::Expression { - expression::Expression { - expr: logical_plan::$NAME(value.expr), - } - } +macro_rules! scalar_function { + ($NAME: ident, $FUNC: ident) => { + scalar_function!($NAME, $FUNC, stringify!($NAME)); }; - ($NAME: ident, $DOC: expr) => { + ($NAME: ident, $FUNC: ident, $DOC: expr) => { #[doc = $DOC] - #[pyfunction] - fn $NAME(value: expression::Expression) -> expression::Expression { - expression::Expression { - expr: logical_plan::$NAME(value.expr), - } + #[pyfunction(args = "*")] + fn $NAME(args: Vec) -> PyExpr { + let expr = Expr::ScalarFunction { + fun: BuiltinScalarFunction::$FUNC, + args: args.into_iter().map(|e| e.into()).collect(), + }; + expr.into() } }; } -define_unary_function!(sqrt, "sqrt"); -define_unary_function!(sin, "sin"); -define_unary_function!(cos, "cos"); -define_unary_function!(tan, "tan"); -define_unary_function!(asin, "asin"); -define_unary_function!(acos, "acos"); -define_unary_function!(atan, "atan"); -define_unary_function!(floor, "floor"); -define_unary_function!(ceil, "ceil"); -define_unary_function!(round, "round"); -define_unary_function!(trunc, "trunc"); -define_unary_function!(abs, "abs"); -define_unary_function!(signum, "signum"); -define_unary_function!(exp, "exp"); -define_unary_function!(ln, "ln"); -define_unary_function!(log2, "log2"); -define_unary_function!(log10, "log10"); +macro_rules! aggregate_function { + ($NAME: ident, $FUNC: ident) => { + aggregate_function!($NAME, $FUNC, stringify!($NAME)); + }; + ($NAME: ident, $FUNC: ident, $DOC: expr) => { + #[doc = $DOC] + #[pyfunction(args = "*", distinct = "false")] + fn $NAME(args: Vec, distinct: bool) -> PyExpr { + let expr = Expr::AggregateFunction { + fun: AggregateFunction::$FUNC, + args: args.into_iter().map(|e| e.into()).collect(), + distinct, + }; + expr.into() + } + }; +} -define_unary_function!(ascii, "Returns the numeric code of the first character of the argument. In UTF8 encoding, returns the Unicode code point of the character. In other multibyte encodings, the argument must be an ASCII character."); -define_unary_function!(sum); -define_unary_function!( +scalar_function!(abs, Abs); +scalar_function!(acos, Acos); +scalar_function!(ascii, Ascii, "Returns the numeric code of the first character of the argument. In UTF8 encoding, returns the Unicode code point of the character. In other multibyte encodings, the argument must be an ASCII character."); +scalar_function!(asin, Asin); +scalar_function!(atan, Atan); +scalar_function!( bit_length, + BitLength, "Returns number of bits in the string (8 times the octet_length)." ); -define_unary_function!(btrim, "Removes the longest string containing only characters in characters (a space by default) from the start and end of string."); -define_unary_function!( +scalar_function!(btrim, Btrim, "Removes the longest string containing only characters in characters (a space by default) from the start and end of string."); +scalar_function!(ceil, Ceil); +scalar_function!( character_length, + CharacterLength, "Returns number of characters in the string." ); -define_unary_function!(chr, "Returns the character with the given code."); -define_unary_function!(initcap, "Converts the first letter of each word to upper case and the rest to lower case. Words are sequences of alphanumeric characters separated by non-alphanumeric characters."); -define_unary_function!(left, "Returns first n characters in the string, or when n is negative, returns all but last |n| characters."); -define_unary_function!(lower, "Converts the string to all lower case"); -define_unary_function!(lpad, "Extends the string to length length by prepending the characters fill (a space by default). If the string is already longer than length then it is truncated (on the right)."); -define_unary_function!(ltrim, "Removes the longest string containing only characters in characters (a space by default) from the start of string."); -define_unary_function!( +scalar_function!(chr, Chr, "Returns the character with the given code."); +scalar_function!(cos, Cos); +scalar_function!(exp, Exp); +scalar_function!(floor, Floor); +scalar_function!(initcap, InitCap, "Converts the first letter of each word to upper case and the rest to lower case. Words are sequences of alphanumeric characters separated by non-alphanumeric characters."); +scalar_function!(left, Left, "Returns first n characters in the string, or when n is negative, returns all but last |n| characters."); +scalar_function!(ln, Ln); +scalar_function!(log10, Log10); +scalar_function!(log2, Log2); +scalar_function!(lower, Lower, "Converts the string to all lower case"); +scalar_function!(lpad, Lpad, "Extends the string to length length by prepending the characters fill (a space by default). If the string is already longer than length then it is truncated (on the right)."); +scalar_function!(ltrim, Ltrim, "Removes the longest string containing only characters in characters (a space by default) from the start of string."); +scalar_function!( md5, + MD5, "Computes the MD5 hash of the argument, with the result written in hexadecimal." ); -define_unary_function!(octet_length, "Returns number of bytes in the string. Since this version of the function accepts type character directly, it will not strip trailing spaces."); -define_unary_function!( - replace, - "Replaces all occurrences in string of substring from with substring to." -); -define_unary_function!(repeat, "Repeats string the specified number of times."); -define_unary_function!( +scalar_function!(octet_length, OctetLength, "Returns number of bytes in the string. Since this version of the function accepts type character directly, it will not strip trailing spaces."); +scalar_function!(regexp_match, RegexpMatch); +scalar_function!( regexp_replace, + RegexpReplace, "Replaces substring(s) matching a POSIX regular expression" ); -define_unary_function!( +scalar_function!( + repeat, + Repeat, + "Repeats string the specified number of times." +); +scalar_function!( + replace, + Replace, + "Replaces all occurrences in string of substring from with substring to." +); +scalar_function!( reverse, + Reverse, "Reverses the order of the characters in the string." ); -define_unary_function!(right, "Returns last n characters in the string, or when n is negative, returns all but first |n| characters."); -define_unary_function!(rpad, "Extends the string to length length by appending the characters fill (a space by default). If the string is already longer than length then it is truncated."); -define_unary_function!(rtrim, "Removes the longest string containing only characters in characters (a space by default) from the end of string."); -define_unary_function!(sha224); -define_unary_function!(sha256); -define_unary_function!(sha384); -define_unary_function!(sha512); -define_unary_function!(split_part, "Splits string at occurrences of delimiter and returns the n'th field (counting from one)."); -define_unary_function!(starts_with, "Returns true if string starts with prefix."); -define_unary_function!(strpos,"Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.)"); -define_unary_function!(substr); -define_unary_function!( +scalar_function!(right, Right, "Returns last n characters in the string, or when n is negative, returns all but first |n| characters."); +scalar_function!(round, Round); +scalar_function!(rpad, Rpad, "Extends the string to length length by appending the characters fill (a space by default). If the string is already longer than length then it is truncated."); +scalar_function!(rtrim, Rtrim, "Removes the longest string containing only characters in characters (a space by default) from the end of string."); +scalar_function!(sha224, SHA224); +scalar_function!(sha256, SHA256); +scalar_function!(sha384, SHA384); +scalar_function!(sha512, SHA512); +scalar_function!(signum, Signum); +scalar_function!(sin, Sin); +scalar_function!(split_part, SplitPart, "Splits string at occurrences of delimiter and returns the n'th field (counting from one)."); +scalar_function!(sqrt, Sqrt); +scalar_function!( + starts_with, + StartsWith, + "Returns true if string starts with prefix." +); +scalar_function!(strpos, Strpos, "Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.)"); +scalar_function!(substr, Substr); +scalar_function!(tan, Tan); +scalar_function!( to_hex, + ToHex, "Converts the number to its equivalent hexadecimal representation." ); -define_unary_function!(translate, "Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted."); -define_unary_function!(trim, "Removes the longest string containing only characters in characters (a space by default) from the start, end, or both ends (BOTH is the default) of string."); -define_unary_function!(upper, "Converts the string to all upper case."); -define_unary_function!(avg); -define_unary_function!(min); -define_unary_function!(max); -define_unary_function!(count); -define_unary_function!(approx_distinct); +scalar_function!(to_timestamp, ToTimestamp); +scalar_function!(translate, Translate, "Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted."); +scalar_function!(trim, Trim, "Removes the longest string containing only characters in characters (a space by default) from the start, end, or both ends (BOTH is the default) of string."); +scalar_function!(trunc, Trunc); +scalar_function!(upper, Upper, "Converts the string to all upper case."); + +aggregate_function!(avg, Avg); +aggregate_function!(count, Count); +aggregate_function!(max, Max); +aggregate_function!(min, Min); +aggregate_function!(sum, Sum); #[pyclass(name = "Volatility", module = "datafusion.functions")] #[derive(Clone)] @@ -276,12 +277,8 @@ pub(crate) fn create_udf( return_type: PyDataType, volatility: PyVolatility, name: &str, -) -> expression::ScalarUDF { - let input_types: Vec = - input_types.iter().map(|d| d.data_type.clone()).collect(); - let return_type = Arc::new(return_type.data_type); - - expression::ScalarUDF { +) -> PyResult { + Ok(PyScalarUDF { function: logical_plan::create_udf( name, input_types, @@ -289,7 +286,7 @@ pub(crate) fn create_udf( volatility.volatility, udf::array_udf(fun), ), - } + }) } /// Creates a new UDF (User Defined Function). @@ -300,10 +297,9 @@ fn udf( return_type: PyDataType, volatility: PyVolatility, py: Python, -) -> PyResult { +) -> PyResult { let name = fun.getattr(py, "__qualname__")?.extract::(py)?; - - Ok(create_udf(fun, input_types, return_type, volatility, &name)) + create_udf(fun, input_types, return_type, volatility, &name) } /// Creates a new UDAF (User Defined Aggregate Function). @@ -315,94 +311,88 @@ fn udaf( state_type: Vec, volatility: PyVolatility, py: Python, -) -> PyResult { +) -> PyResult { let name = accumulator .getattr(py, "__qualname__")? .extract::(py)?; - let input_type = input_type.data_type; - let return_type = Arc::new(return_type.data_type); - let state_type = Arc::new(state_type.into_iter().map(|t| t.data_type).collect()); - - Ok(expression::AggregateUDF { + Ok(PyAggregateUDF { function: logical_plan::create_udaf( &name, input_type, return_type, volatility.volatility, udaf::array_udaf(accumulator), - state_type, + Arc::new(state_type), ), }) } -pub fn init(module: &PyModule) -> PyResult<()> { - module.add_class::()?; - module.add_function(wrap_pyfunction!(abs, module)?)?; - module.add_function(wrap_pyfunction!(acos, module)?)?; - module.add_function(wrap_pyfunction!(approx_distinct, module)?)?; - module.add_function(wrap_pyfunction!(array, module)?)?; - module.add_function(wrap_pyfunction!(ascii, module)?)?; - module.add_function(wrap_pyfunction!(asin, module)?)?; - module.add_function(wrap_pyfunction!(atan, module)?)?; - module.add_function(wrap_pyfunction!(avg, module)?)?; - module.add_function(wrap_pyfunction!(bit_length, module)?)?; - module.add_function(wrap_pyfunction!(btrim, module)?)?; - module.add_function(wrap_pyfunction!(ceil, module)?)?; - module.add_function(wrap_pyfunction!(character_length, module)?)?; - module.add_function(wrap_pyfunction!(chr, module)?)?; - module.add_function(wrap_pyfunction!(col, module)?)?; - module.add_function(wrap_pyfunction!(concat_ws, module)?)?; - module.add_function(wrap_pyfunction!(concat, module)?)?; - module.add_function(wrap_pyfunction!(cos, module)?)?; - module.add_function(wrap_pyfunction!(count, module)?)?; - module.add_function(wrap_pyfunction!(exp, module)?)?; - module.add_function(wrap_pyfunction!(floor, module)?)?; - module.add_function(wrap_pyfunction!(in_list, module)?)?; - module.add_function(wrap_pyfunction!(initcap, module)?)?; - module.add_function(wrap_pyfunction!(left, module)?)?; - module.add_function(wrap_pyfunction!(lit, module)?)?; - module.add_function(wrap_pyfunction!(ln, module)?)?; - module.add_function(wrap_pyfunction!(log10, module)?)?; - module.add_function(wrap_pyfunction!(log2, module)?)?; - module.add_function(wrap_pyfunction!(lower, module)?)?; - module.add_function(wrap_pyfunction!(lpad, module)?)?; - module.add_function(wrap_pyfunction!(ltrim, module)?)?; - module.add_function(wrap_pyfunction!(max, module)?)?; - module.add_function(wrap_pyfunction!(md5, module)?)?; - module.add_function(wrap_pyfunction!(digest, module)?)?; - module.add_function(wrap_pyfunction!(min, module)?)?; - module.add_function(wrap_pyfunction!(now, module)?)?; - module.add_function(wrap_pyfunction!(octet_length, module)?)?; - module.add_function(wrap_pyfunction!(random, module)?)?; - module.add_function(wrap_pyfunction!(regexp_replace, module)?)?; - module.add_function(wrap_pyfunction!(repeat, module)?)?; - module.add_function(wrap_pyfunction!(replace, module)?)?; - module.add_function(wrap_pyfunction!(reverse, module)?)?; - module.add_function(wrap_pyfunction!(right, module)?)?; - module.add_function(wrap_pyfunction!(round, module)?)?; - module.add_function(wrap_pyfunction!(rpad, module)?)?; - module.add_function(wrap_pyfunction!(rtrim, module)?)?; - module.add_function(wrap_pyfunction!(sha224, module)?)?; - module.add_function(wrap_pyfunction!(sha256, module)?)?; - module.add_function(wrap_pyfunction!(sha384, module)?)?; - module.add_function(wrap_pyfunction!(sha512, module)?)?; - module.add_function(wrap_pyfunction!(signum, module)?)?; - module.add_function(wrap_pyfunction!(sin, module)?)?; - module.add_function(wrap_pyfunction!(split_part, module)?)?; - module.add_function(wrap_pyfunction!(sqrt, module)?)?; - module.add_function(wrap_pyfunction!(starts_with, module)?)?; - module.add_function(wrap_pyfunction!(strpos, module)?)?; - module.add_function(wrap_pyfunction!(substr, module)?)?; - module.add_function(wrap_pyfunction!(sum, module)?)?; - module.add_function(wrap_pyfunction!(tan, module)?)?; - module.add_function(wrap_pyfunction!(to_hex, module)?)?; - module.add_function(wrap_pyfunction!(translate, module)?)?; - module.add_function(wrap_pyfunction!(trim, module)?)?; - module.add_function(wrap_pyfunction!(trunc, module)?)?; - module.add_function(wrap_pyfunction!(udaf, module)?)?; - module.add_function(wrap_pyfunction!(udf, module)?)?; - module.add_function(wrap_pyfunction!(upper, module)?)?; - +pub fn init(m: &PyModule) -> PyResult<()> { + m.add_wrapped(wrap_pyfunction!(abs))?; + m.add_wrapped(wrap_pyfunction!(acos))?; + m.add_wrapped(wrap_pyfunction!(array))?; + m.add_wrapped(wrap_pyfunction!(ascii))?; + m.add_wrapped(wrap_pyfunction!(asin))?; + m.add_wrapped(wrap_pyfunction!(atan))?; + m.add_wrapped(wrap_pyfunction!(avg))?; + m.add_wrapped(wrap_pyfunction!(bit_length))?; + m.add_wrapped(wrap_pyfunction!(btrim))?; + m.add_wrapped(wrap_pyfunction!(ceil))?; + m.add_wrapped(wrap_pyfunction!(character_length))?; + m.add_wrapped(wrap_pyfunction!(chr))?; + m.add_wrapped(wrap_pyfunction!(col))?; + m.add_wrapped(wrap_pyfunction!(concat_ws))?; + m.add_wrapped(wrap_pyfunction!(concat))?; + m.add_wrapped(wrap_pyfunction!(cos))?; + m.add_wrapped(wrap_pyfunction!(count))?; + m.add_wrapped(wrap_pyfunction!(exp))?; + m.add_wrapped(wrap_pyfunction!(floor))?; + m.add_wrapped(wrap_pyfunction!(in_list))?; + m.add_wrapped(wrap_pyfunction!(initcap))?; + m.add_wrapped(wrap_pyfunction!(left))?; + m.add_wrapped(wrap_pyfunction!(lit))?; + m.add_wrapped(wrap_pyfunction!(ln))?; + m.add_wrapped(wrap_pyfunction!(log10))?; + m.add_wrapped(wrap_pyfunction!(log2))?; + m.add_wrapped(wrap_pyfunction!(lower))?; + m.add_wrapped(wrap_pyfunction!(lpad))?; + m.add_wrapped(wrap_pyfunction!(ltrim))?; + m.add_wrapped(wrap_pyfunction!(max))?; + m.add_wrapped(wrap_pyfunction!(md5))?; + m.add_wrapped(wrap_pyfunction!(min))?; + m.add_wrapped(wrap_pyfunction!(now))?; + m.add_wrapped(wrap_pyfunction!(octet_length))?; + m.add_wrapped(wrap_pyfunction!(random))?; + m.add_wrapped(wrap_pyfunction!(regexp_match))?; + m.add_wrapped(wrap_pyfunction!(regexp_replace))?; + m.add_wrapped(wrap_pyfunction!(repeat))?; + m.add_wrapped(wrap_pyfunction!(replace))?; + m.add_wrapped(wrap_pyfunction!(reverse))?; + m.add_wrapped(wrap_pyfunction!(right))?; + m.add_wrapped(wrap_pyfunction!(round))?; + m.add_wrapped(wrap_pyfunction!(rpad))?; + m.add_wrapped(wrap_pyfunction!(rtrim))?; + m.add_wrapped(wrap_pyfunction!(sha224))?; + m.add_wrapped(wrap_pyfunction!(sha256))?; + m.add_wrapped(wrap_pyfunction!(sha384))?; + m.add_wrapped(wrap_pyfunction!(sha512))?; + m.add_wrapped(wrap_pyfunction!(signum))?; + m.add_wrapped(wrap_pyfunction!(sin))?; + m.add_wrapped(wrap_pyfunction!(split_part))?; + m.add_wrapped(wrap_pyfunction!(sqrt))?; + m.add_wrapped(wrap_pyfunction!(starts_with))?; + m.add_wrapped(wrap_pyfunction!(strpos))?; + m.add_wrapped(wrap_pyfunction!(substr))?; + m.add_wrapped(wrap_pyfunction!(sum))?; + m.add_wrapped(wrap_pyfunction!(tan))?; + m.add_wrapped(wrap_pyfunction!(to_hex))?; + m.add_wrapped(wrap_pyfunction!(to_timestamp))?; + m.add_wrapped(wrap_pyfunction!(translate))?; + m.add_wrapped(wrap_pyfunction!(trim))?; + m.add_wrapped(wrap_pyfunction!(trunc))?; + m.add_wrapped(wrap_pyfunction!(udaf))?; + m.add_wrapped(wrap_pyfunction!(udf))?; + m.add_wrapped(wrap_pyfunction!(upper))?; Ok(()) } diff --git a/python/src/lib.rs b/python/src/lib.rs index 4436781bec36..2a725aaaa804 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -17,15 +17,12 @@ use pyo3::prelude::*; +mod catalog; mod context; mod dataframe; mod errors; mod expression; mod functions; -mod scalar; -mod to_py; -mod to_rust; -mod types; mod udaf; mod udf; @@ -44,15 +41,19 @@ fn register_module_package(py: Python, package_name: &str, module: &PyModule) { /// DataFusion. #[pymodule] -fn datafusion(py: Python, m: &PyModule) -> PyResult<()> { - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; +fn internals(py: Python, m: &PyModule) -> PyResult<()> { + expression::init(m)?; let functions = PyModule::new(py, "functions")?; functions::init(functions)?; register_module_package(py, "datafusion.functions", functions); m.add_submodule(functions)?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + Ok(()) } diff --git a/python/src/scalar.rs b/python/src/scalar.rs deleted file mode 100644 index 0c562a940361..000000000000 --- a/python/src/scalar.rs +++ /dev/null @@ -1,36 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use pyo3::prelude::*; - -use datafusion::scalar::ScalarValue as _Scalar; - -use crate::to_rust::to_rust_scalar; - -/// An expression that can be used on a DataFrame -#[derive(Debug, Clone)] -pub(crate) struct Scalar { - pub(crate) scalar: _Scalar, -} - -impl<'source> FromPyObject<'source> for Scalar { - fn extract(ob: &'source PyAny) -> PyResult { - Ok(Self { - scalar: to_rust_scalar(ob)?, - }) - } -} diff --git a/python/src/to_py.rs b/python/src/to_py.rs deleted file mode 100644 index 6bc0581c8c70..000000000000 --- a/python/src/to_py.rs +++ /dev/null @@ -1,75 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::arrow::array::ArrayRef; -use datafusion::arrow::record_batch::RecordBatch; -use libc::uintptr_t; -use pyo3::prelude::*; -use pyo3::types::PyList; -use pyo3::PyErr; -use std::convert::From; - -use crate::errors; - -pub fn to_py_array(array: &ArrayRef, py: Python) -> PyResult { - let (array_pointer, schema_pointer) = - array.to_raw().map_err(errors::DataFusionError::from)?; - - let pa = py.import("pyarrow")?; - - let array = pa.getattr("Array")?.call_method1( - "_import_from_c", - (array_pointer as uintptr_t, schema_pointer as uintptr_t), - )?; - Ok(array.to_object(py)) -} - -fn to_py_batch<'a>( - batch: &RecordBatch, - py: Python, - pyarrow: &'a PyModule, -) -> Result { - let mut py_arrays = vec![]; - let mut py_names = vec![]; - - let schema = batch.schema(); - for (array, field) in batch.columns().iter().zip(schema.fields().iter()) { - let array = to_py_array(array, py)?; - - py_arrays.push(array); - py_names.push(field.name()); - } - - let record = pyarrow - .getattr("RecordBatch")? - .call_method1("from_arrays", (py_arrays, py_names))?; - - Ok(PyObject::from(record)) -} - -/// Converts a &[RecordBatch] into a Vec represented in PyArrow -pub fn to_py(batches: &[RecordBatch]) -> PyResult { - Python::with_gil(|py| { - let pyarrow = PyModule::import(py, "pyarrow")?; - let mut py_batches = vec![]; - for batch in batches { - py_batches.push(to_py_batch(batch, py, pyarrow)?); - } - let list = PyList::new(py, py_batches); - Ok(PyObject::from(list)) - }) -} diff --git a/python/src/to_rust.rs b/python/src/to_rust.rs deleted file mode 100644 index 7977fe4ff8ce..000000000000 --- a/python/src/to_rust.rs +++ /dev/null @@ -1,122 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::convert::TryFrom; -use std::sync::Arc; - -use datafusion::arrow::{ - array::{make_array_from_raw, ArrayRef}, - datatypes::Field, - datatypes::Schema, - ffi, - record_batch::RecordBatch, -}; -use datafusion::scalar::ScalarValue; -use libc::uintptr_t; -use pyo3::prelude::*; - -use crate::{errors, types::PyDataType}; - -/// converts a pyarrow Array into a Rust Array -pub fn to_rust(ob: &PyAny) -> PyResult { - // prepare a pointer to receive the Array struct - let (array_pointer, schema_pointer) = - ffi::ArrowArray::into_raw(unsafe { ffi::ArrowArray::empty() }); - - // make the conversion through PyArrow's private API - // this changes the pointer's memory and is thus unsafe. In particular, `_export_to_c` can go out of bounds - ob.call_method1( - "_export_to_c", - (array_pointer as uintptr_t, schema_pointer as uintptr_t), - )?; - - let array = unsafe { make_array_from_raw(array_pointer, schema_pointer) } - .map_err(errors::DataFusionError::from)?; - Ok(array) -} - -/// converts a pyarrow batch into a RecordBatch -pub fn to_rust_batch(batch: &PyAny) -> PyResult { - let schema = batch.getattr("schema")?; - let names = schema.getattr("names")?.extract::>()?; - - let fields = names - .iter() - .enumerate() - .map(|(i, name)| { - let field = schema.call_method1("field", (i,))?; - let nullable = field.getattr("nullable")?.extract::()?; - let py_data_type = field.getattr("type")?; - let data_type = py_data_type.extract::()?.data_type; - Ok(Field::new(name, data_type, nullable)) - }) - .collect::>()?; - - let schema = Arc::new(Schema::new(fields)); - - let arrays = (0..names.len()) - .map(|i| { - let array = batch.call_method1("column", (i,))?; - to_rust(array) - }) - .collect::>()?; - - let batch = - RecordBatch::try_new(schema, arrays).map_err(errors::DataFusionError::from)?; - Ok(batch) -} - -/// converts a pyarrow Scalar into a Rust Scalar -pub fn to_rust_scalar(ob: &PyAny) -> PyResult { - let t = ob - .getattr("__class__")? - .getattr("__name__")? - .extract::<&str>()?; - - let p = ob.call_method0("as_py")?; - - Ok(match t { - "Int8Scalar" => ScalarValue::Int8(Some(p.extract::()?)), - "Int16Scalar" => ScalarValue::Int16(Some(p.extract::()?)), - "Int32Scalar" => ScalarValue::Int32(Some(p.extract::()?)), - "Int64Scalar" => ScalarValue::Int64(Some(p.extract::()?)), - "UInt8Scalar" => ScalarValue::UInt8(Some(p.extract::()?)), - "UInt16Scalar" => ScalarValue::UInt16(Some(p.extract::()?)), - "UInt32Scalar" => ScalarValue::UInt32(Some(p.extract::()?)), - "UInt64Scalar" => ScalarValue::UInt64(Some(p.extract::()?)), - "FloatScalar" => ScalarValue::Float32(Some(p.extract::()?)), - "DoubleScalar" => ScalarValue::Float64(Some(p.extract::()?)), - "BooleanScalar" => ScalarValue::Boolean(Some(p.extract::()?)), - "StringScalar" => ScalarValue::Utf8(Some(p.extract::()?)), - "LargeStringScalar" => ScalarValue::LargeUtf8(Some(p.extract::()?)), - other => { - return Err(errors::DataFusionError::Common(format!( - "Type \"{}\"not yet implemented", - other - )) - .into()) - } - }) -} - -pub fn to_rust_schema(ob: &PyAny) -> PyResult { - let c_schema = ffi::FFI_ArrowSchema::empty(); - let c_schema_ptr = &c_schema as *const ffi::FFI_ArrowSchema; - ob.call_method1("_export_to_c", (c_schema_ptr as uintptr_t,))?; - let schema = Schema::try_from(&c_schema).map_err(errors::DataFusionError::from)?; - Ok(schema) -} diff --git a/python/src/types.rs b/python/src/types.rs deleted file mode 100644 index bd6ef0d376e6..000000000000 --- a/python/src/types.rs +++ /dev/null @@ -1,65 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::arrow::datatypes::DataType; -use pyo3::{FromPyObject, PyAny, PyResult}; - -use crate::errors; - -/// utility struct to convert PyObj to native DataType -#[derive(Debug, Clone)] -pub struct PyDataType { - pub data_type: DataType, -} - -impl<'source> FromPyObject<'source> for PyDataType { - fn extract(ob: &'source PyAny) -> PyResult { - let id = ob.getattr("id")?.extract::()?; - let data_type = data_type_id(&id)?; - Ok(PyDataType { data_type }) - } -} - -fn data_type_id(id: &i32) -> Result { - // see https://github.com/apache/arrow/blob/3694794bdfd0677b95b8c95681e392512f1c9237/python/pyarrow/includes/libarrow.pxd - // this is not ideal as it does not generalize for non-basic types - // Find a way to get a unique name from the pyarrow.DataType - Ok(match id { - 1 => DataType::Boolean, - 2 => DataType::UInt8, - 3 => DataType::Int8, - 4 => DataType::UInt16, - 5 => DataType::Int16, - 6 => DataType::UInt32, - 7 => DataType::Int32, - 8 => DataType::UInt64, - 9 => DataType::Int64, - 10 => DataType::Float16, - 11 => DataType::Float32, - 12 => DataType::Float64, - 13 => DataType::Utf8, - 14 => DataType::Binary, - 34 => DataType::LargeUtf8, - 35 => DataType::LargeBinary, - other => { - return Err(errors::DataFusionError::Common(format!( - "The type {} is not valid", - other - ))) - } - }) -} diff --git a/python/src/udaf.rs b/python/src/udaf.rs index 83e8be05db60..756afe68c31e 100644 --- a/python/src/udaf.rs +++ b/python/src/udaf.rs @@ -20,6 +20,7 @@ use std::sync::Arc; use pyo3::{prelude::*, types::PyTuple}; use datafusion::arrow::array::ArrayRef; +use datafusion::arrow::pyarrow::PyArrowConvert; use datafusion::error::Result; use datafusion::{ @@ -27,10 +28,6 @@ use datafusion::{ scalar::ScalarValue, }; -use crate::scalar::Scalar; -use crate::to_py::to_py_array; -use crate::to_rust::to_rust_scalar; - #[derive(Debug)] struct PyAccumulator { accum: PyObject, @@ -43,18 +40,9 @@ impl PyAccumulator { } impl Accumulator for PyAccumulator { - fn state(&self) -> Result> { - Python::with_gil(|py| { - let state = self - .accum - .as_ref(py) - .call_method0("to_scalars") - .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))? - .extract::>() - .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))?; - - Ok(state.into_iter().map(|v| v.scalar).collect::>()) - }) + fn state(&self) -> Result> { + Python::with_gil(|py| self.accum.as_ref(py).call_method0("to_scalars")?.extract()) + .map_err(|e| InnerDataFusionError::Execution(format!("{}", e))) } fn update(&mut self, _values: &[ScalarValue]) -> Result<()> { @@ -67,17 +55,9 @@ impl Accumulator for PyAccumulator { todo!() } - fn evaluate(&self) -> Result { - Python::with_gil(|py| { - let value = self - .accum - .as_ref(py) - .call_method0("evaluate") - .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))?; - - to_rust_scalar(value) - .map_err(|e| InnerDataFusionError::Execution(format!("{}", e))) - }) + fn evaluate(&self) -> Result { + Python::with_gil(|py| self.accum.as_ref(py).call_method0("evaluate")?.extract()) + .map_err(|e| InnerDataFusionError::Execution(format!("{}", e))) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { @@ -88,10 +68,7 @@ impl Accumulator for PyAccumulator { // 1. let py_args = values .iter() - .map(|arg| { - // remove unwrap - to_py_array(arg, py).unwrap() - }) + .map(|arg| arg.data().to_owned().to_pyarrow(py).unwrap()) .collect::>(); let py_args = PyTuple::new(py, py_args); @@ -111,7 +88,8 @@ impl Accumulator for PyAccumulator { // 2. merge let state = &states[0]; - let state = to_py_array(state, py) + let state = state + .to_pyarrow(py) .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))?; // 2. diff --git a/python/src/udf.rs b/python/src/udf.rs index 49a18d993241..fa77e4ab3257 100644 --- a/python/src/udf.rs +++ b/python/src/udf.rs @@ -15,15 +15,12 @@ // specific language governing permissions and limitations // under the License. -use pyo3::{prelude::*, types::PyTuple}; - -use datafusion::{arrow::array, physical_plan::functions::make_scalar_function}; - +use datafusion::arrow::array::ArrayRef; +use datafusion::arrow::pyarrow::PyArrowConvert; use datafusion::error::DataFusionError; use datafusion::physical_plan::functions::ScalarFunctionImplementation; - -use crate::to_py::to_py_array; -use crate::to_rust::to_rust; +use datafusion::{arrow::array, physical_plan::functions::make_scalar_function}; +use pyo3::{prelude::*, types::PyTuple}; /// creates a DataFusion's UDF implementation from a python function that expects pyarrow arrays /// This is more efficient as it performs a zero-copy of the contents. @@ -38,10 +35,7 @@ pub fn array_udf(func: PyObject) -> ScalarFunctionImplementation { // 1. let py_args = args .iter() - .map(|arg| { - // remove unwrap - to_py_array(arg, py).unwrap() - }) + .map(|arg| arg.data().to_owned().to_pyarrow(py).unwrap()) .collect::>(); let py_args = PyTuple::new(py, py_args); @@ -52,7 +46,7 @@ pub fn array_udf(func: PyObject) -> ScalarFunctionImplementation { Err(error) => Err(DataFusionError::Execution(format!("{:?}", error))), }?; - let array = to_rust(value).unwrap(); + let array = ArrayRef::from_pyarrow(value).unwrap(); Ok(array) }) }, From d1052535947699835236565f1508d3948ffd1374 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Tue, 21 Sep 2021 20:48:50 +0200 Subject: [PATCH 02/21] Deps --- datafusion-cli/Cargo.toml | 4 ++ .../src/optimizer/common_subexpr_eliminate.rs | 5 +- python/datafusion/tests/test_pa_types.py | 50 ------------------- 3 files changed, 7 insertions(+), 52 deletions(-) delete mode 100644 python/datafusion/tests/test_pa_types.py diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index b424f498ac5f..b9ec3bdc2fd7 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -31,5 +31,9 @@ clap = "2.33" rustyline = "9.0" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] } datafusion = { path = "../datafusion", version = "5.1.0" } +<<<<<<< HEAD arrow = { version = "6.0.0" } +======= +arrow = { git = "https://github.com/apache/arrow-rs" } +>>>>>>> 96dde89f (Deps) ballista = { path = "../ballista/rust/client", version = "0.6.0" } diff --git a/datafusion/src/optimizer/common_subexpr_eliminate.rs b/datafusion/src/optimizer/common_subexpr_eliminate.rs index ea60286b902f..0e97663b5fef 100644 --- a/datafusion/src/optimizer/common_subexpr_eliminate.rs +++ b/datafusion/src/optimizer/common_subexpr_eliminate.rs @@ -631,6 +631,7 @@ mod test { avg, binary_expr, col, lit, sum, LogicalPlanBuilder, Operator, }; use crate::test::*; + use std::iter; fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { let optimizer = CommonSubexprEliminate {}; @@ -688,7 +689,7 @@ mod test { let plan = LogicalPlanBuilder::from(table_scan) .aggregate( - vec![], + iter::empty::(), vec![ sum(binary_expr( col("a"), @@ -723,7 +724,7 @@ mod test { let plan = LogicalPlanBuilder::from(table_scan) .aggregate( - vec![], + iter::empty::(), vec![ binary_expr(lit(1), Operator::Plus, avg(col("a"))), binary_expr(lit(1), Operator::Minus, avg(col("a"))), diff --git a/python/datafusion/tests/test_pa_types.py b/python/datafusion/tests/test_pa_types.py deleted file mode 100644 index 04f6110e3a42..000000000000 --- a/python/datafusion/tests/test_pa_types.py +++ /dev/null @@ -1,50 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import pyarrow as pa - - -def test_type_ids(): - # Having this fixed is very important because internally we rely on this id - # to parse from python - for idx, arrow_type in [ - (0, pa.null()), - (1, pa.bool_()), - (2, pa.uint8()), - (3, pa.int8()), - (4, pa.uint16()), - (5, pa.int16()), - (6, pa.uint32()), - (7, pa.int32()), - (8, pa.uint64()), - (9, pa.int64()), - (10, pa.float16()), - (11, pa.float32()), - (12, pa.float64()), - (13, pa.string()), - (13, pa.utf8()), - (14, pa.binary()), - (16, pa.date32()), - (17, pa.date64()), - (18, pa.timestamp("us")), - (19, pa.time32("s")), - (20, pa.time64("us")), - (23, pa.decimal128(8, 1)), - (34, pa.large_utf8()), - (35, pa.large_binary()), - ]: - assert idx == arrow_type.id From 85f7be5d5ecf2a91fcace262bcc4e594f8a4fc64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Mon, 25 Oct 2021 17:26:08 +0200 Subject: [PATCH 03/21] Resolve post-rebase issues --- datafusion-cli/Cargo.toml | 6 +--- datafusion/src/logical_plan/builder.rs | 11 +++++-- python/src/context.rs | 14 +++------ python/src/functions.rs | 41 ++++++++++++-------------- 4 files changed, 32 insertions(+), 40 deletions(-) diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index b9ec3bdc2fd7..360e873c6ed7 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -31,9 +31,5 @@ clap = "2.33" rustyline = "9.0" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] } datafusion = { path = "../datafusion", version = "5.1.0" } -<<<<<<< HEAD -arrow = { version = "6.0.0" } -======= -arrow = { git = "https://github.com/apache/arrow-rs" } ->>>>>>> 96dde89f (Deps) +arrow = { version = "6.0.0" } ballista = { path = "../ballista/rust/client", version = "0.6.0" } diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index dcbddca89cd7..a9d814f66eb0 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -33,6 +33,7 @@ use arrow::{ record_batch::RecordBatch, }; use std::convert::TryFrom; +use std::iter; use std::{ collections::{HashMap, HashSet}, sync::Arc, @@ -426,7 +427,10 @@ impl LogicalPlanBuilder { Ok(plan) } /// Apply a projection without alias. - pub fn project(&self, expr: impl IntoIterator>) -> Result { + pub fn project( + &self, + expr: impl IntoIterator>, + ) -> Result { self.project_with_alias(expr, None) } @@ -477,7 +481,7 @@ impl LogicalPlanBuilder { pub fn distinct(&self) -> Result { let projection_expr = expand_wildcard(self.plan.schema(), &self.plan)?; let plan = LogicalPlanBuilder::from(self.plan.clone()) - .aggregate(projection_expr, vec![])? + .aggregate(projection_expr, iter::empty::())? .build()?; Self::from(plan).project(vec![Expr::Wildcard]) } @@ -805,7 +809,8 @@ pub fn project_with_alias( let input_schema = plan.schema(); let mut projected_expr = vec![]; for e in expr { - match e.into() { + let e = e.into(); + match e { Expr::Wildcard => { projected_expr.extend(expand_wildcard(input_schema, &plan)?) } diff --git a/python/src/context.rs b/python/src/context.rs index b8fb0dbd118d..ed73a5353c4c 100644 --- a/python/src/context.rs +++ b/python/src/context.rs @@ -20,8 +20,6 @@ use std::{collections::HashSet, sync::Arc}; use uuid::Uuid; -use tokio::runtime::Runtime; - use pyo3::exceptions::{PyKeyError, PyValueError}; use pyo3::prelude::*; @@ -31,13 +29,10 @@ use datafusion::datasource::MemTable; use datafusion::execution::context::ExecutionContext; use datafusion::prelude::CsvReadOptions; -use crate::{dataframe, errors, functions}; -use crate::functions::{self, PyVolatility}; -use crate::types::PyDataType; use crate::catalog::PyCatalog; use crate::dataframe::PyDataFrame; use crate::errors::DataFusionError; - +use crate::functions::{PyVolatility, create_udf}; /// `PyExecutionContext` is able to plan and execute DataFusion plans. /// It has a powerful optimizer, a physical planner for local execution, and a @@ -151,12 +146,11 @@ impl PyExecutionContext { &mut self, name: &str, func: PyObject, - args_types: Vec, - return_type: PyDataType, + args_types: Vec, + return_type: DataType, volatility: PyVolatility, ) -> PyResult<()> { - let function = - functions::create_udf(func, args_types, return_type, volatility, name)?; + let function = create_udf(func, args_types, return_type, volatility, name)?; self.ctx.register_udf(function.function); Ok(()) } diff --git a/python/src/functions.rs b/python/src/functions.rs index 282d47a627e3..e74579ff27fd 100644 --- a/python/src/functions.rs +++ b/python/src/functions.rs @@ -17,15 +17,12 @@ use std::sync::Arc; -use pyo3::{prelude::*, wrap_pyfunction}; +use pyo3::{prelude::*, wrap_pyfunction, Python}; use datafusion::arrow::datatypes::DataType; -use datafusion::logical_plan::{self, Literal}; -use datafusion::physical_plan::functions::Volatility; -use pyo3::{prelude::*, types::PyTuple, wrap_pyfunction, Python}; -use std::sync::Arc; use datafusion::logical_plan; -use datafusion::logical_plan::Expr; +//use datafusion::logical_plan::Expr; +use datafusion::physical_plan::functions::Volatility; use datafusion::physical_plan::{ aggregates::AggregateFunction, functions::BuiltinScalarFunction, }; @@ -90,11 +87,11 @@ fn random() -> PyExpr { /// Standard algorithms are md5, sha224, sha256, sha384, sha512, blake2s, blake2b, and blake3. #[pyfunction(value, method)] fn digest( - value: expression::Expression, - method: expression::Expression, -) -> expression::Expression { - expression::Expression { - expr: logical_plan::digest(value.expr, method.expr), + value: PyExpr, + method: PyExpr, +) -> PyExpr { + PyExpr { + expr: logical_plan::digest(value.expr, method.expr) } } @@ -123,7 +120,7 @@ macro_rules! scalar_function { #[doc = $DOC] #[pyfunction(args = "*")] fn $NAME(args: Vec) -> PyExpr { - let expr = Expr::ScalarFunction { + let expr = logical_plan::Expr::ScalarFunction { fun: BuiltinScalarFunction::$FUNC, args: args.into_iter().map(|e| e.into()).collect(), }; @@ -140,7 +137,7 @@ macro_rules! aggregate_function { #[doc = $DOC] #[pyfunction(args = "*", distinct = "false")] fn $NAME(args: Vec, distinct: bool) -> PyExpr { - let expr = Expr::AggregateFunction { + let expr = logical_plan::Expr::AggregateFunction { fun: AggregateFunction::$FUNC, args: args.into_iter().map(|e| e.into()).collect(), distinct, @@ -273,8 +270,8 @@ impl PyVolatility { pub(crate) fn create_udf( fun: PyObject, - input_types: Vec, - return_type: PyDataType, + input_types: Vec, + return_type: DataType, volatility: PyVolatility, name: &str, ) -> PyResult { @@ -282,7 +279,7 @@ pub(crate) fn create_udf( function: logical_plan::create_udf( name, input_types, - return_type, + Arc::new(return_type), volatility.volatility, udf::array_udf(fun), ), @@ -293,8 +290,8 @@ pub(crate) fn create_udf( #[pyfunction] fn udf( fun: PyObject, - input_types: Vec, - return_type: PyDataType, + input_types: Vec, + return_type: DataType, volatility: PyVolatility, py: Python, ) -> PyResult { @@ -306,9 +303,9 @@ fn udf( #[pyfunction] fn udaf( accumulator: PyObject, - input_type: PyDataType, - return_type: PyDataType, - state_type: Vec, + input_type: DataType, + return_type: DataType, + state_type: Vec, volatility: PyVolatility, py: Python, ) -> PyResult { @@ -320,7 +317,7 @@ fn udaf( function: logical_plan::create_udaf( &name, input_type, - return_type, + Arc::new(return_type), volatility.volatility, udaf::array_udaf(accumulator), Arc::new(state_type), From fd21e6faf0ed3f985a09ba1faa556d32372b9c9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Mon, 25 Oct 2021 18:07:13 +0200 Subject: [PATCH 04/21] Resolve post-rebase issues --- python/src/context.rs | 31 +++++++++++++++++++++---------- python/src/functions.rs | 8 +++----- python/src/lib.rs | 26 ++++++++++++++------------ 3 files changed, 38 insertions(+), 27 deletions(-) diff --git a/python/src/context.rs b/python/src/context.rs index ed73a5353c4c..f8ce7ed690fa 100644 --- a/python/src/context.rs +++ b/python/src/context.rs @@ -15,9 +15,11 @@ // specific language governing permissions and limitations // under the License. +use std::future::Future; use std::path::PathBuf; use std::{collections::HashSet, sync::Arc}; +use tokio::runtime::Runtime; use uuid::Uuid; use pyo3::exceptions::{PyKeyError, PyValueError}; @@ -32,7 +34,16 @@ use datafusion::prelude::CsvReadOptions; use crate::catalog::PyCatalog; use crate::dataframe::PyDataFrame; use crate::errors::DataFusionError; -use crate::functions::{PyVolatility, create_udf}; +use crate::functions::{create_udf, PyVolatility}; + +fn wait_for_future(py: Python, f: F) -> F::Output +where + F: Send, + F::Output: Send, +{ + let rt = Runtime::new().unwrap(); + py.allow_threads(|| rt.block_on(f)) +} /// `PyExecutionContext` is able to plan and execute DataFusion plans. /// It has a powerful optimizer, a physical planner for local execution, and a @@ -53,8 +64,9 @@ impl PyExecutionContext { } /// Returns a PyDataFrame whose plan corresponds to the SQL statement. - fn sql(&mut self, query: &str) -> PyResult { - let df = self.ctx.sql(query).map_err(DataFusionError::from)?; + fn sql(&mut self, query: &str, py: Python) -> PyResult { + let result = self.ctx.sql(query); + let df = wait_for_future(py, result).map_err(DataFusionError::from)?; Ok(PyDataFrame::new(df)) } @@ -94,10 +106,9 @@ impl PyExecutionContext { Ok(()) } - fn register_parquet(&mut self, name: &str, path: &str) -> PyResult<()> { - self.ctx - .register_parquet(name, path) - .map_err(DataFusionError::from)?; + fn register_parquet(&mut self, name: &str, path: &str, py: Python) -> PyResult<()> { + let result = self.ctx.register_parquet(name, path); + wait_for_future(py, result).map_err(DataFusionError::from)?; Ok(()) } @@ -136,9 +147,9 @@ impl PyExecutionContext { .file_extension(file_extension); options.schema = schema.as_ref(); - self.ctx - .register_csv(name, path, options) - .map_err(DataFusionError::from)?; + let result = self.ctx.register_csv(name, path, options); + wait_for_future(py, result).map_err(DataFusionError::from)?; + Ok(()) } diff --git a/python/src/functions.rs b/python/src/functions.rs index e74579ff27fd..edd94324c3f5 100644 --- a/python/src/functions.rs +++ b/python/src/functions.rs @@ -86,12 +86,9 @@ fn random() -> PyExpr { /// Computes a binary hash of the given data. type is the algorithm to use. /// Standard algorithms are md5, sha224, sha256, sha384, sha512, blake2s, blake2b, and blake3. #[pyfunction(value, method)] -fn digest( - value: PyExpr, - method: PyExpr, -) -> PyExpr { +fn digest(value: PyExpr, method: PyExpr) -> PyExpr { PyExpr { - expr: logical_plan::digest(value.expr, method.expr) + expr: logical_plan::digest(value.expr, method.expr), } } @@ -343,6 +340,7 @@ pub fn init(m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(concat))?; m.add_wrapped(wrap_pyfunction!(cos))?; m.add_wrapped(wrap_pyfunction!(count))?; + m.add_wrapped(wrap_pyfunction!(digest))?; m.add_wrapped(wrap_pyfunction!(exp))?; m.add_wrapped(wrap_pyfunction!(floor))?; m.add_wrapped(wrap_pyfunction!(in_list))?; diff --git a/python/src/lib.rs b/python/src/lib.rs index 2a725aaaa804..aab1695e2992 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -26,27 +26,29 @@ mod functions; mod udaf; mod udf; +// TODO(kszucs): remvoe // taken from https://github.com/PyO3/pyo3/issues/471 -fn register_module_package(py: Python, package_name: &str, module: &PyModule) { - py.import("sys") - .expect("failed to import python sys module") - .dict() - .get_item("modules") - .expect("failed to get python modules dictionary") - .downcast::() - .expect("failed to turn sys.modules into a PyDict") - .set_item(package_name, module) - .expect("failed to inject module"); -} +// fn register_module_package(py: Python, package_name: &str, module: &PyModule) { +// py.import("sys") +// .expect("failed to import python sys module") +// .dict() +// .get_item("modules") +// .expect("failed to get python modules dictionary") +// .downcast::() +// .expect("failed to turn sys.modules into a PyDict") +// .set_item(package_name, module) +// .expect("failed to inject module"); +// } /// DataFusion. #[pymodule] fn internals(py: Python, m: &PyModule) -> PyResult<()> { expression::init(m)?; + //register_module_package(py, "datafusion.functions", functions); + let functions = PyModule::new(py, "functions")?; functions::init(functions)?; - register_module_package(py, "datafusion.functions", functions); m.add_submodule(functions)?; m.add_class::()?; From 51a272711d6e0e886de5b55da327f65d3b67e30d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Mon, 25 Oct 2021 19:01:03 +0200 Subject: [PATCH 05/21] Resolve post-rebase issues --- python/datafusion/__init__.py | 9 ++++++++ python/datafusion/tests/test_dataframe.py | 18 ++++++++-------- .../datafusion/tests/test_math_functions.py | 10 ++++----- .../datafusion/tests/test_string_functions.py | 15 +++++++------ python/datafusion/tests/test_udaf.py | 6 +++--- python/src/functions.rs | 21 ++----------------- 6 files changed, 35 insertions(+), 44 deletions(-) diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index 761381cb26b3..7924b9f8d3fa 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -15,4 +15,13 @@ # specific language governing permissions and limitations # under the License. +import pyarrow as pa + from .internals import * # noqa +from .internals import literal as _literal + + +def literal(value): + if not isinstance(value, pa.Scalar): + value = pa.scalar(value) + return _literal(value) diff --git a/python/datafusion/tests/test_dataframe.py b/python/datafusion/tests/test_dataframe.py index 236cd7c03ae3..6e1ac7cfb8ee 100644 --- a/python/datafusion/tests/test_dataframe.py +++ b/python/datafusion/tests/test_dataframe.py @@ -18,7 +18,7 @@ import pyarrow as pa import pytest -from datafusion import DataFrame, ExecutionContext +from datafusion import DataFrame, ExecutionContext, literal, column from datafusion import functions as f @@ -37,8 +37,8 @@ def df(): def test_select(df): df = df.select( - f.col("a") + f.col("b"), - f.col("a") - f.col("b"), + column("a") + column("b"), + column("a") - column("b"), ) # execute and collect the first (and only) batch @@ -50,9 +50,9 @@ def test_select(df): def test_filter(df): df = df.select( - f.col("a") + f.col("b"), - f.col("a") - f.col("b"), - ).filter(f.col("a") > f.lit(2)) + column("a") + column("b"), + column("a") - column("b"), + ).filter(column("a") > literal(2)) # execute and collect the first (and only) batch result = df.collect()[0] @@ -62,7 +62,7 @@ def test_filter(df): def test_sort(df): - df = df.sort(f.col("b").sort(ascending=False)) + df = df.sort(column("b").sort(ascending=False)) table = pa.Table.from_batches(df.collect()) expected = {"a": [3, 2, 1], "b": [6, 5, 4]} @@ -89,7 +89,7 @@ def test_udf(df): f.Volatility.immutable(), ) - df = df.select(udf(f.col("a"))) + df = df.select(udf(column("a"))) result = df.collect()[0].column(0) assert result == pa.array([False, False, False]) @@ -111,7 +111,7 @@ def test_join(): df1 = ctx.create_dataframe([[batch]]) df = df.join(df1, on="a", how="inner") - df = df.sort(f.col("a").sort(ascending=True)) + df = df.sort(column("a").sort(ascending=True)) table = pa.Table.from_batches(df.collect()) expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]} diff --git a/python/datafusion/tests/test_math_functions.py b/python/datafusion/tests/test_math_functions.py index 4e473c3de16a..5bd1cd223cc3 100644 --- a/python/datafusion/tests/test_math_functions.py +++ b/python/datafusion/tests/test_math_functions.py @@ -19,7 +19,7 @@ import pyarrow as pa import pytest -from datafusion import ExecutionContext +from datafusion import ExecutionContext, literal, column from datafusion import functions as f @@ -35,7 +35,7 @@ def df(): def test_math_functions(df): values = np.array([0.1, -0.7, 0.55]) - col_v = f.col("value") + col_v = column("value") df = df.select( f.abs(col_v), f.sin(col_v), @@ -44,9 +44,9 @@ def test_math_functions(df): f.asin(col_v), f.acos(col_v), f.exp(col_v), - f.ln(col_v + f.lit(1)), - f.log2(col_v + f.lit(1)), - f.log10(col_v + f.lit(1)), + f.ln(col_v + literal(pa.scalar(1))), + f.log2(col_v + literal(pa.scalar(1))), + f.log10(col_v + literal(pa.scalar(1))), f.random(), ) result = df.collect() diff --git a/python/datafusion/tests/test_string_functions.py b/python/datafusion/tests/test_string_functions.py index 3d6c380c55a6..4eef222768c4 100644 --- a/python/datafusion/tests/test_string_functions.py +++ b/python/datafusion/tests/test_string_functions.py @@ -18,7 +18,7 @@ import pyarrow as pa import pytest -from datafusion import ExecutionContext +from datafusion import ExecutionContext, column, literal from datafusion import functions as f @@ -36,7 +36,7 @@ def df(): def test_string_functions(df): - df = df.select(f.md5(f.col("a")), f.lower(f.col("a"))) + df = df.select(f.md5(column("a")), f.lower(column("a"))) result = df.collect() assert len(result) == 1 result = result[0] @@ -51,12 +51,11 @@ def test_string_functions(df): def test_hash_functions(df): - df = df.select( - *[ - f.digest(f.col("a"), f.lit(m)) - for m in ("md5", "sha256", "sha512", "blake2s", "blake3") - ] - ) + exprs = [ + f.digest(column("a"), literal(m)) + for m in ("md5", "sha256", "sha512", "blake2s", "blake3") + ] + df = df.select(*exprs) result = df.collect() assert len(result) == 1 result = result[0] diff --git a/python/datafusion/tests/test_udaf.py b/python/datafusion/tests/test_udaf.py index 0eb93f6c9876..70de702535c4 100644 --- a/python/datafusion/tests/test_udaf.py +++ b/python/datafusion/tests/test_udaf.py @@ -21,7 +21,7 @@ import pyarrow.compute as pc import pytest -from datafusion import ExecutionContext +from datafusion import ExecutionContext, column from datafusion import functions as f @@ -71,7 +71,7 @@ def test_aggregate(df): f.Volatility.immutable(), ) - df = df.aggregate([], [udaf(f.col("a"))]) + df = df.aggregate([], [udaf(column("a"))]) # execute and collect the first (and only) batch result = df.collect()[0] @@ -88,7 +88,7 @@ def test_group_by(df): f.Volatility.immutable(), ) - df = df.aggregate([f.col("b")], [udaf(f.col("a"))]) + df = df.aggregate([column("b")], [udaf(column("a"))]) batches = df.collect() arrays = [batch.column(1) for batch in batches] diff --git a/python/src/functions.rs b/python/src/functions.rs index edd94324c3f5..a79142f852f5 100644 --- a/python/src/functions.rs +++ b/python/src/functions.rs @@ -32,23 +32,6 @@ use crate::{ udaf, udf, }; -/// PyExpr representing a column on the existing plan. -/// TODO(kszucs): remove col and lit -#[pyfunction] -#[pyo3(text_signature = "(name)")] -fn col(name: &str) -> PyExpr { - PyExpr { - expr: logical_plan::col(name), - } -} - -/// PyExpr representing a constant value -#[pyfunction] -#[pyo3(text_signature = "(value)")] -fn lit(value: i32) -> PyExpr { - logical_plan::lit(value).into() -} - #[pyfunction] fn array(value: Vec) -> PyExpr { PyExpr { @@ -323,6 +306,8 @@ fn udaf( } pub fn init(m: &PyModule) -> PyResult<()> { + // TODO(kszucs): implement FromPyObject to PyVolatility + m.add_class::()?; m.add_wrapped(wrap_pyfunction!(abs))?; m.add_wrapped(wrap_pyfunction!(acos))?; m.add_wrapped(wrap_pyfunction!(array))?; @@ -335,7 +320,6 @@ pub fn init(m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(ceil))?; m.add_wrapped(wrap_pyfunction!(character_length))?; m.add_wrapped(wrap_pyfunction!(chr))?; - m.add_wrapped(wrap_pyfunction!(col))?; m.add_wrapped(wrap_pyfunction!(concat_ws))?; m.add_wrapped(wrap_pyfunction!(concat))?; m.add_wrapped(wrap_pyfunction!(cos))?; @@ -346,7 +330,6 @@ pub fn init(m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(in_list))?; m.add_wrapped(wrap_pyfunction!(initcap))?; m.add_wrapped(wrap_pyfunction!(left))?; - m.add_wrapped(wrap_pyfunction!(lit))?; m.add_wrapped(wrap_pyfunction!(ln))?; m.add_wrapped(wrap_pyfunction!(log10))?; m.add_wrapped(wrap_pyfunction!(log2))?; From 436e63a2445d3051393daf5e860ffa717354c417 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Fri, 29 Oct 2021 17:10:19 +0200 Subject: [PATCH 06/21] Address review comments --- datafusion/src/execution/context.rs | 1 - python/src/dataframe.rs | 9 +++++++-- python/src/errors.rs | 1 - python/src/functions.rs | 2 ++ 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 23ebba28d92e..81fd78518f0f 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -573,7 +573,6 @@ impl ExecutionContext { /// register_table function. /// /// Returns an error if no table has been registered with the provided reference. - /// NOTE(kszucs): perhaps it should be called dataframe() instead? pub fn table<'a>( &self, table_ref: impl Into>, diff --git a/python/src/dataframe.rs b/python/src/dataframe.rs index 236abfdf9909..7372dec49564 100644 --- a/python/src/dataframe.rs +++ b/python/src/dataframe.rs @@ -102,7 +102,12 @@ impl PyDataFrame { Ok(pretty::print_batches(&batches)?) } - fn join(&self, right: PyDataFrame, on: Vec<&str>, how: &str) -> PyResult { + fn join( + &self, + right: PyDataFrame, + join_keys: (Vec<&str>, Vec<&str>), + how: &str, + ) -> PyResult { let join_type = match how { "inner" => JoinType::Inner, "left" => JoinType::Left, @@ -119,7 +124,7 @@ impl PyDataFrame { } }; - let df = self.df.join(right.df, join_type, &on, &on)?; + let df = self.df.join(right.df, join_type, &join_keys.0, &join_keys.1)?; Ok(Self::new(df)) } } diff --git a/python/src/errors.rs b/python/src/errors.rs index 85b4a77cf4db..655ed8441cb4 100644 --- a/python/src/errors.rs +++ b/python/src/errors.rs @@ -16,7 +16,6 @@ // under the License. use core::fmt; -//use std::result::Result; use datafusion::arrow::error::ArrowError; use datafusion::error::DataFusionError as InnerDataFusionError; diff --git a/python/src/functions.rs b/python/src/functions.rs index a79142f852f5..9529094a546a 100644 --- a/python/src/functions.rs +++ b/python/src/functions.rs @@ -219,6 +219,7 @@ aggregate_function!(count, Count); aggregate_function!(max, Max); aggregate_function!(min, Min); aggregate_function!(sum, Sum); +aggregate_function!(approx_distinct, ApproxDistinct); #[pyclass(name = "Volatility", module = "datafusion.functions")] #[derive(Clone)] @@ -310,6 +311,7 @@ pub fn init(m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_wrapped(wrap_pyfunction!(abs))?; m.add_wrapped(wrap_pyfunction!(acos))?; + m.add_wrapped(wrap_pyfunction!(approx_distinct))?; m.add_wrapped(wrap_pyfunction!(array))?; m.add_wrapped(wrap_pyfunction!(ascii))?; m.add_wrapped(wrap_pyfunction!(asin))?; From 48a47791066843b8b65b519e56ba937f337734b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Fri, 29 Oct 2021 17:18:10 +0200 Subject: [PATCH 07/21] Reuse wait_for_future from dataframe.rs --- python/src/catalog.rs | 1 - python/src/context.rs | 12 +----------- python/src/dataframe.rs | 12 ++++++------ python/src/lib.rs | 1 + python/src/utils.rs | 31 +++++++++++++++++++++++++++++++ 5 files changed, 39 insertions(+), 18 deletions(-) create mode 100644 python/src/utils.rs diff --git a/python/src/catalog.rs b/python/src/catalog.rs index a2d382fbe7b7..826ac7827ca6 100644 --- a/python/src/catalog.rs +++ b/python/src/catalog.rs @@ -20,7 +20,6 @@ use std::sync::Arc; use pyo3::exceptions::PyKeyError; use pyo3::prelude::*; -// use pyo3::{PyIterProtocol, PyMappingProtocol}; use datafusion::{ arrow::pyarrow::PyArrowConvert, diff --git a/python/src/context.rs b/python/src/context.rs index f8ce7ed690fa..8e6ab4a66637 100644 --- a/python/src/context.rs +++ b/python/src/context.rs @@ -15,11 +15,9 @@ // specific language governing permissions and limitations // under the License. -use std::future::Future; use std::path::PathBuf; use std::{collections::HashSet, sync::Arc}; -use tokio::runtime::Runtime; use uuid::Uuid; use pyo3::exceptions::{PyKeyError, PyValueError}; @@ -35,15 +33,7 @@ use crate::catalog::PyCatalog; use crate::dataframe::PyDataFrame; use crate::errors::DataFusionError; use crate::functions::{create_udf, PyVolatility}; - -fn wait_for_future(py: Python, f: F) -> F::Output -where - F: Send, - F::Output: Send, -{ - let rt = Runtime::new().unwrap(); - py.allow_threads(|| rt.block_on(f)) -} +use crate::utils::wait_for_future; /// `PyExecutionContext` is able to plan and execute DataFusion plans. /// It has a powerful optimizer, a physical planner for local execution, and a diff --git a/python/src/dataframe.rs b/python/src/dataframe.rs index 7372dec49564..8b1c8f0e3e45 100644 --- a/python/src/dataframe.rs +++ b/python/src/dataframe.rs @@ -18,7 +18,6 @@ use std::sync::Arc; use pyo3::prelude::*; -use tokio::runtime::Runtime; use datafusion::arrow::datatypes::Schema; use datafusion::arrow::pyarrow::PyArrowConvert; @@ -26,6 +25,7 @@ use datafusion::arrow::util::pretty; use datafusion::dataframe::DataFrame; use datafusion::logical_plan::JoinType; +use crate::utils::wait_for_future; use crate::{errors::DataFusionError, expression::PyExpr}; /// A PyDataFrame is a representation of a logical plan and an API to compose statements. @@ -86,8 +86,7 @@ impl PyDataFrame { /// Unless some order is specified in the plan, there is no /// guarantee of the order of the result. fn collect(&self, py: Python) -> PyResult> { - let rt = Runtime::new().unwrap(); - let batches = py.allow_threads(|| rt.block_on(self.df.collect()))?; + let batches = wait_for_future(py, self.df.collect())?; // cannot use PyResult> return type due to // https://github.com/PyO3/pyo3/issues/1813 batches.into_iter().map(|rb| rb.to_pyarrow(py)).collect() @@ -96,9 +95,8 @@ impl PyDataFrame { /// Print the result, 20 lines by default #[args(num = "20")] fn show(&self, py: Python, num: usize) -> PyResult<()> { - let rt = Runtime::new().unwrap(); let df = self.df.limit(num)?; - let batches = py.allow_threads(|| rt.block_on(df.collect()))?; + let batches = wait_for_future(py, df.collect())?; Ok(pretty::print_batches(&batches)?) } @@ -124,7 +122,9 @@ impl PyDataFrame { } }; - let df = self.df.join(right.df, join_type, &join_keys.0, &join_keys.1)?; + let df = self + .df + .join(right.df, join_type, &join_keys.0, &join_keys.1)?; Ok(Self::new(df)) } } diff --git a/python/src/lib.rs b/python/src/lib.rs index aab1695e2992..2bfeec06cf5d 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -25,6 +25,7 @@ mod expression; mod functions; mod udaf; mod udf; +mod utils; // TODO(kszucs): remvoe // taken from https://github.com/PyO3/pyo3/issues/471 diff --git a/python/src/utils.rs b/python/src/utils.rs new file mode 100644 index 000000000000..c2d924adfcea --- /dev/null +++ b/python/src/utils.rs @@ -0,0 +1,31 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::future::Future; + +use pyo3::prelude::*; +use tokio::runtime::Runtime; + +/// Utility to collect rust futures with GIL released +pub(crate) fn wait_for_future(py: Python, f: F) -> F::Output +where + F: Send, + F::Output: Send, +{ + let rt = Runtime::new().unwrap(); + py.allow_threads(|| rt.block_on(f)) +} From b96e87157a97b047398cf1da5137ef2582081400 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Fri, 29 Oct 2021 17:37:11 +0200 Subject: [PATCH 08/21] Improve module organization --- python/src/expression.rs | 20 +------------------- python/src/lib.rs | 32 ++++++++++++++++++++++++++++---- 2 files changed, 29 insertions(+), 23 deletions(-) diff --git a/python/src/expression.rs b/python/src/expression.rs index 63e9d7b1d665..a6405de66d40 100644 --- a/python/src/expression.rs +++ b/python/src/expression.rs @@ -20,9 +20,8 @@ use std::convert::{From, Into}; use std::vec::Vec; use datafusion::arrow::datatypes::DataType; -use datafusion::logical_plan::{col, lit, Expr}; +use datafusion::logical_plan::Expr; use datafusion::physical_plan::{udaf::AggregateUDF, udf::ScalarUDF}; -use datafusion::scalar::ScalarValue; /// An PyExpr that can be used on a DataFrame #[pyclass(name = "Expr")] @@ -156,20 +155,3 @@ impl PyAggregateUDF { Ok(self.function.call(args).into()) } } - -#[pyfunction] -fn literal(value: ScalarValue) -> PyExpr { - lit(value).into() -} - -#[pyfunction] -fn column(value: &str) -> PyExpr { - col(value).into() -} - -pub fn init(m: &PyModule) -> PyResult<()> { - m.add_class::()?; - m.add_wrapped(wrap_pyfunction!(literal))?; - m.add_wrapped(wrap_pyfunction!(column))?; - Ok(()) -} diff --git a/python/src/lib.rs b/python/src/lib.rs index 2bfeec06cf5d..c27b682f4a3b 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -17,6 +17,13 @@ use pyo3::prelude::*; +use datafusion::logical_plan; +use datafusion::scalar::ScalarValue; + +use context::PyExecutionContext; +use dataframe::PyDataFrame; +use expression::PyExpr; + mod catalog; mod context; mod dataframe; @@ -27,6 +34,21 @@ mod udaf; mod udf; mod utils; +// wrap_pyfunction!() doesn't work from other modules, so +// define the simple API functions here. See organizing +// modules section of pyo3: +// https://pyo3.rs/v0.14.5/module.html#organizing-your-module-registration-code + +#[pyfunction] +fn literal(value: ScalarValue) -> PyExpr { + logical_plan::lit(value).into() +} + +#[pyfunction] +fn column(value: &str) -> PyExpr { + logical_plan::col(value).into() +} + // TODO(kszucs): remvoe // taken from https://github.com/PyO3/pyo3/issues/471 // fn register_module_package(py: Python, package_name: &str, module: &PyModule) { @@ -44,8 +66,6 @@ mod utils; /// DataFusion. #[pymodule] fn internals(py: Python, m: &PyModule) -> PyResult<()> { - expression::init(m)?; - //register_module_package(py, "datafusion.functions", functions); let functions = PyModule::new(py, "functions")?; @@ -55,8 +75,12 @@ fn internals(py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; - m.add_class::()?; - m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + + m.add_wrapped(wrap_pyfunction!(literal))?; + m.add_wrapped(wrap_pyfunction!(column))?; Ok(()) } From e512147eaccba5e26ce9373277b42f7a137fb927 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Fri, 29 Oct 2021 17:48:30 +0200 Subject: [PATCH 09/21] Reorganize tests --- python/datafusion/functions.py | 0 .../tests/test_aggregation.py | 6 +- .../tests/test_context.py} | 0 python/datafusion/tests/test_dataframe.py | 2 +- ..._string_functions.py => test_functions.py} | 85 ++++++++++++++++++- .../datafusion/tests/test_math_functions.py | 71 ---------------- python/tests/test_functions.py | 63 -------------- 7 files changed, 86 insertions(+), 141 deletions(-) create mode 100644 python/datafusion/functions.py rename python/{ => datafusion}/tests/test_aggregation.py (94%) rename python/{tests/test_df_sql.py => datafusion/tests/test_context.py} (100%) rename python/datafusion/tests/{test_string_functions.py => test_functions.py} (59%) delete mode 100644 python/datafusion/tests/test_math_functions.py delete mode 100644 python/tests/test_functions.py diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/tests/test_aggregation.py b/python/datafusion/tests/test_aggregation.py similarity index 94% rename from python/tests/test_aggregation.py rename to python/datafusion/tests/test_aggregation.py index f0996f9e06d9..8d3ec9a3854e 100644 --- a/python/tests/test_aggregation.py +++ b/python/datafusion/tests/test_aggregation.py @@ -17,7 +17,7 @@ import pyarrow as pa import pytest -from datafusion import ExecutionContext +from datafusion import ExecutionContext, column from datafusion import functions as f @@ -34,8 +34,8 @@ def df(): def test_built_in_aggregation(df): - col_a = f.col("a") - col_b = f.col("b") + col_a = column("a") + col_b = column("b") df = df.aggregate( [], [f.max(col_a), f.min(col_a), f.count(col_a), f.approx_distinct(col_b)], diff --git a/python/tests/test_df_sql.py b/python/datafusion/tests/test_context.py similarity index 100% rename from python/tests/test_df_sql.py rename to python/datafusion/tests/test_context.py diff --git a/python/datafusion/tests/test_dataframe.py b/python/datafusion/tests/test_dataframe.py index 6e1ac7cfb8ee..af2b2e8db309 100644 --- a/python/datafusion/tests/test_dataframe.py +++ b/python/datafusion/tests/test_dataframe.py @@ -110,7 +110,7 @@ def test_join(): ) df1 = ctx.create_dataframe([[batch]]) - df = df.join(df1, on="a", how="inner") + df = df.join(df1, join_keys=(["a"], ["a"]), how="inner") df = df.sort(column("a").sort(ascending=True)) table = pa.Table.from_batches(df.collect()) diff --git a/python/datafusion/tests/test_string_functions.py b/python/datafusion/tests/test_functions.py similarity index 59% rename from python/datafusion/tests/test_string_functions.py rename to python/datafusion/tests/test_functions.py index 4eef222768c4..345d46fc5b83 100644 --- a/python/datafusion/tests/test_string_functions.py +++ b/python/datafusion/tests/test_functions.py @@ -15,26 +15,105 @@ # specific language governing permissions and limitations # under the License. +import numpy as np import pyarrow as pa import pytest -from datafusion import ExecutionContext, column, literal +from datafusion import ExecutionContext, literal, column from datafusion import functions as f @pytest.fixture def df(): ctx = ExecutionContext() - # create a RecordBatch and a new DataFrame from it batch = pa.RecordBatch.from_arrays( [pa.array(["Hello", "World", "!"]), pa.array([4, 5, 6])], names=["a", "b"], ) - return ctx.create_dataframe([[batch]]) +def test_literal(df): + df = df.select( + literal(1), + literal("1"), + literal("OK"), + literal(3.14), + literal(True), + literal(b"hello world"), + ) + result = df.collect() + assert len(result) == 1 + result = result[0] + assert result.column(0) == pa.array([1] * 3) + assert result.column(1) == pa.array(["1"] * 3) + assert result.column(2) == pa.array(["OK"] * 3) + assert result.column(3) == pa.array([3.14] * 3) + assert result.column(4) == pa.array([True] * 3) + assert result.column(5) == pa.array([b"hello world"] * 3) + + +def test_lit_arith(df): + """ + Test literals with arithmetic operations + """ + df = df.select(literal(1) + column("b"), f.concat(column("a"), literal("!"))) + result = df.collect() + assert len(result) == 1 + result = result[0] + assert result.column(0) == pa.array([5, 6, 7]) + assert result.column(1) == pa.array(["Hello!", "World!", "!!"]) + + +def test_math_functions(): + ctx = ExecutionContext() + # create a RecordBatch and a new DataFrame from it + batch = pa.RecordBatch.from_arrays( + [pa.array([0.1, -0.7, 0.55])], names=["value"] + ) + df = ctx.create_dataframe([[batch]]) + + values = np.array([0.1, -0.7, 0.55]) + col_v = column("value") + df = df.select( + f.abs(col_v), + f.sin(col_v), + f.cos(col_v), + f.tan(col_v), + f.asin(col_v), + f.acos(col_v), + f.exp(col_v), + f.ln(col_v + literal(pa.scalar(1))), + f.log2(col_v + literal(pa.scalar(1))), + f.log10(col_v + literal(pa.scalar(1))), + f.random(), + ) + batches = df.collect() + assert len(batches) == 1 + result = batches[0] + + np.testing.assert_array_almost_equal(result.column(0), np.abs(values)) + np.testing.assert_array_almost_equal(result.column(1), np.sin(values)) + np.testing.assert_array_almost_equal(result.column(2), np.cos(values)) + np.testing.assert_array_almost_equal(result.column(3), np.tan(values)) + np.testing.assert_array_almost_equal(result.column(4), np.arcsin(values)) + np.testing.assert_array_almost_equal(result.column(5), np.arccos(values)) + np.testing.assert_array_almost_equal(result.column(6), np.exp(values)) + np.testing.assert_array_almost_equal( + result.column(7), np.log(values + 1.0) + ) + np.testing.assert_array_almost_equal( + result.column(8), np.log2(values + 1.0) + ) + np.testing.assert_array_almost_equal( + result.column(9), np.log10(values + 1.0) + ) + np.testing.assert_array_less(result.column(10), np.ones_like(values)) + + + + def test_string_functions(df): df = df.select(f.md5(column("a")), f.lower(column("a"))) result = df.collect() diff --git a/python/datafusion/tests/test_math_functions.py b/python/datafusion/tests/test_math_functions.py deleted file mode 100644 index 5bd1cd223cc3..000000000000 --- a/python/datafusion/tests/test_math_functions.py +++ /dev/null @@ -1,71 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import numpy as np -import pyarrow as pa -import pytest - -from datafusion import ExecutionContext, literal, column -from datafusion import functions as f - - -@pytest.fixture -def df(): - ctx = ExecutionContext() - # create a RecordBatch and a new DataFrame from it - batch = pa.RecordBatch.from_arrays( - [pa.array([0.1, -0.7, 0.55])], names=["value"] - ) - return ctx.create_dataframe([[batch]]) - - -def test_math_functions(df): - values = np.array([0.1, -0.7, 0.55]) - col_v = column("value") - df = df.select( - f.abs(col_v), - f.sin(col_v), - f.cos(col_v), - f.tan(col_v), - f.asin(col_v), - f.acos(col_v), - f.exp(col_v), - f.ln(col_v + literal(pa.scalar(1))), - f.log2(col_v + literal(pa.scalar(1))), - f.log10(col_v + literal(pa.scalar(1))), - f.random(), - ) - result = df.collect() - assert len(result) == 1 - result = result[0] - np.testing.assert_array_almost_equal(result.column(0), np.abs(values)) - np.testing.assert_array_almost_equal(result.column(1), np.sin(values)) - np.testing.assert_array_almost_equal(result.column(2), np.cos(values)) - np.testing.assert_array_almost_equal(result.column(3), np.tan(values)) - np.testing.assert_array_almost_equal(result.column(4), np.arcsin(values)) - np.testing.assert_array_almost_equal(result.column(5), np.arccos(values)) - np.testing.assert_array_almost_equal(result.column(6), np.exp(values)) - np.testing.assert_array_almost_equal( - result.column(7), np.log(values + 1.0) - ) - np.testing.assert_array_almost_equal( - result.column(8), np.log2(values + 1.0) - ) - np.testing.assert_array_almost_equal( - result.column(9), np.log10(values + 1.0) - ) - np.testing.assert_array_less(result.column(10), np.ones_like(values)) diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py deleted file mode 100644 index 67cf502c445e..000000000000 --- a/python/tests/test_functions.py +++ /dev/null @@ -1,63 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import pyarrow as pa -import pytest -from datafusion import ExecutionContext -from datafusion import functions as f - - -@pytest.fixture -def df(): - ctx = ExecutionContext() - # create a RecordBatch and a new DataFrame from it - batch = pa.RecordBatch.from_arrays( - [pa.array(["Hello", "World", "!"]), pa.array([4, 5, 6])], - names=["a", "b"], - ) - return ctx.create_dataframe([[batch]]) - - -def test_lit(df): - """test lit function""" - df = df.select( - f.lit(1), - f.lit("1"), - f.lit("OK"), - f.lit(3.14), - f.lit(True), - f.lit(b"hello world"), - ) - result = df.collect() - assert len(result) == 1 - result = result[0] - assert result.column(0) == pa.array([1] * 3) - assert result.column(1) == pa.array(["1"] * 3) - assert result.column(2) == pa.array(["OK"] * 3) - assert result.column(3) == pa.array([3.14] * 3) - assert result.column(4) == pa.array([True] * 3) - assert result.column(5) == pa.array([b"hello world"] * 3) - - -def test_lit_arith(df): - """test lit function within arithmatics""" - df = df.select(f.lit(1) + f.col("b"), f.concat(f.col("a"), f.lit("!"))) - result = df.collect() - assert len(result) == 1 - result = result[0] - assert result.column(0) == pa.array([5, 6, 7]) - assert result.column(1) == pa.array(["Hello!", "World!", "!!"]) From ad184e8505b7231aba37a30779849eb73dfd86d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Fri, 29 Oct 2021 18:27:32 +0200 Subject: [PATCH 10/21] Support Binary and LargeBinary arrays for ScalarValue::try_from_array --- datafusion/src/scalar.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index 00586bf5549e..d5656f282121 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -1093,6 +1093,8 @@ impl ScalarValue { DataType::Int32 => typed_cast!(array, index, Int32Array, Int32), DataType::Int16 => typed_cast!(array, index, Int16Array, Int16), DataType::Int8 => typed_cast!(array, index, Int8Array, Int8), + DataType::Binary => typed_cast!(array, index, BinaryArray, Binary), + DataType::LargeBinary => typed_cast!(array, index, LargeBinaryArray, LargeBinary), DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8), DataType::LargeUtf8 => typed_cast!(array, index, LargeStringArray, LargeUtf8), DataType::List(nested_type) => { From ae52d17f002b654fcd4909007f9cf91f4fcaf2f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Fri, 29 Oct 2021 18:28:56 +0200 Subject: [PATCH 11/21] Cargo format --- datafusion/src/scalar.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index d5656f282121..33bc9dd10486 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -1094,7 +1094,9 @@ impl ScalarValue { DataType::Int16 => typed_cast!(array, index, Int16Array, Int16), DataType::Int8 => typed_cast!(array, index, Int8Array, Int8), DataType::Binary => typed_cast!(array, index, BinaryArray, Binary), - DataType::LargeBinary => typed_cast!(array, index, LargeBinaryArray, LargeBinary), + DataType::LargeBinary => { + typed_cast!(array, index, LargeBinaryArray, LargeBinary) + } DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8), DataType::LargeUtf8 => typed_cast!(array, index, LargeStringArray, LargeUtf8), DataType::List(nested_type) => { From 84e7fc03876c60a2640b4d87cd6c0316cc7aaa9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Fri, 29 Oct 2021 18:30:30 +0200 Subject: [PATCH 12/21] Define column and literal functions in python --- python/datafusion/__init__.py | 13 +++++++--- python/datafusion/functions.py | 5 ++++ python/datafusion/tests/test_functions.py | 6 ++--- python/src/expression.rs | 18 ++++++++++++- python/src/lib.rs | 31 +++-------------------- 5 files changed, 38 insertions(+), 35 deletions(-) diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index 7924b9f8d3fa..534cad927a2b 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -17,11 +17,18 @@ import pyarrow as pa -from .internals import * # noqa -from .internals import literal as _literal +from .internals import ( + DataFrame, + ExecutionContext, + Expression, +) + + +def column(value): + return Expression.column(value) def literal(value): if not isinstance(value, pa.Scalar): value = pa.scalar(value) - return _literal(value) + return Expression.literal(value) diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index e69de29bb2d1..acc20814a825 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -0,0 +1,5 @@ +from .internals import functions + + +def __getattr__(name): + return getattr(functions, name) diff --git a/python/datafusion/tests/test_functions.py b/python/datafusion/tests/test_functions.py index 345d46fc5b83..17fcc625edea 100644 --- a/python/datafusion/tests/test_functions.py +++ b/python/datafusion/tests/test_functions.py @@ -58,7 +58,9 @@ def test_lit_arith(df): """ Test literals with arithmetic operations """ - df = df.select(literal(1) + column("b"), f.concat(column("a"), literal("!"))) + df = df.select( + literal(1) + column("b"), f.concat(column("a"), literal("!")) + ) result = df.collect() assert len(result) == 1 result = result[0] @@ -112,8 +114,6 @@ def test_math_functions(): np.testing.assert_array_less(result.column(10), np.ones_like(values)) - - def test_string_functions(df): df = df.select(f.md5(column("a")), f.lower(column("a"))) result = df.collect() diff --git a/python/src/expression.rs b/python/src/expression.rs index a6405de66d40..629c3e38155c 100644 --- a/python/src/expression.rs +++ b/python/src/expression.rs @@ -23,8 +23,13 @@ use datafusion::arrow::datatypes::DataType; use datafusion::logical_plan::Expr; use datafusion::physical_plan::{udaf::AggregateUDF, udf::ScalarUDF}; + +use datafusion::logical_plan::{lit, col}; +use datafusion::scalar::ScalarValue; + + /// An PyExpr that can be used on a DataFrame -#[pyclass(name = "Expr")] +#[pyclass(name = "Expression")] #[derive(Debug, Clone)] pub(crate) struct PyExpr { pub(crate) expr: Expr, @@ -94,6 +99,17 @@ impl PyObjectProtocol for PyExpr { #[pymethods] impl PyExpr { + + #[staticmethod] + pub fn literal(value: ScalarValue) -> PyExpr { + lit(value).into() + } + + #[staticmethod] + pub fn column(value: &str) -> PyExpr { + col(value).into() + } + /// assign a name to the PyExpr pub fn alias(&self, name: &str) -> PyExpr { self.expr.clone().alias(name).into() diff --git a/python/src/lib.rs b/python/src/lib.rs index c27b682f4a3b..3ad88df461a2 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -17,13 +17,6 @@ use pyo3::prelude::*; -use datafusion::logical_plan; -use datafusion::scalar::ScalarValue; - -use context::PyExecutionContext; -use dataframe::PyDataFrame; -use expression::PyExpr; - mod catalog; mod context; mod dataframe; @@ -34,21 +27,6 @@ mod udaf; mod udf; mod utils; -// wrap_pyfunction!() doesn't work from other modules, so -// define the simple API functions here. See organizing -// modules section of pyo3: -// https://pyo3.rs/v0.14.5/module.html#organizing-your-module-registration-code - -#[pyfunction] -fn literal(value: ScalarValue) -> PyExpr { - logical_plan::lit(value).into() -} - -#[pyfunction] -fn column(value: &str) -> PyExpr { - logical_plan::col(value).into() -} - // TODO(kszucs): remvoe // taken from https://github.com/PyO3/pyo3/issues/471 // fn register_module_package(py: Python, package_name: &str, module: &PyModule) { @@ -75,12 +53,9 @@ fn internals(py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - - m.add_wrapped(wrap_pyfunction!(literal))?; - m.add_wrapped(wrap_pyfunction!(column))?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; Ok(()) } From 583153682d0990ca07c1966616284343842b10e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Fri, 29 Oct 2021 18:31:14 +0200 Subject: [PATCH 13/21] Apply isort --- python/datafusion/__init__.py | 6 +----- python/datafusion/tests/test_aggregation.py | 1 + python/datafusion/tests/test_context.py | 1 + python/datafusion/tests/test_dataframe.py | 3 ++- python/datafusion/tests/test_functions.py | 3 ++- python/datafusion/tests/test_sql.py | 1 + 6 files changed, 8 insertions(+), 7 deletions(-) diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index 534cad927a2b..52c3dfaa0bb7 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -17,11 +17,7 @@ import pyarrow as pa -from .internals import ( - DataFrame, - ExecutionContext, - Expression, -) +from .internals import DataFrame, ExecutionContext, Expression def column(value): diff --git a/python/datafusion/tests/test_aggregation.py b/python/datafusion/tests/test_aggregation.py index 8d3ec9a3854e..d539c44585a6 100644 --- a/python/datafusion/tests/test_aggregation.py +++ b/python/datafusion/tests/test_aggregation.py @@ -17,6 +17,7 @@ import pyarrow as pa import pytest + from datafusion import ExecutionContext, column from datafusion import functions as f diff --git a/python/datafusion/tests/test_context.py b/python/datafusion/tests/test_context.py index c6eac6bb2ffc..60beea4a01be 100644 --- a/python/datafusion/tests/test_context.py +++ b/python/datafusion/tests/test_context.py @@ -17,6 +17,7 @@ import pyarrow as pa import pytest + from datafusion import ExecutionContext diff --git a/python/datafusion/tests/test_dataframe.py b/python/datafusion/tests/test_dataframe.py index af2b2e8db309..db246a61e3c2 100644 --- a/python/datafusion/tests/test_dataframe.py +++ b/python/datafusion/tests/test_dataframe.py @@ -18,8 +18,9 @@ import pyarrow as pa import pytest -from datafusion import DataFrame, ExecutionContext, literal, column +from datafusion import DataFrame, ExecutionContext, column from datafusion import functions as f +from datafusion import literal @pytest.fixture diff --git a/python/datafusion/tests/test_functions.py b/python/datafusion/tests/test_functions.py index 17fcc625edea..08ddbaec345a 100644 --- a/python/datafusion/tests/test_functions.py +++ b/python/datafusion/tests/test_functions.py @@ -19,8 +19,9 @@ import pyarrow as pa import pytest -from datafusion import ExecutionContext, literal, column +from datafusion import ExecutionContext, column from datafusion import functions as f +from datafusion import literal @pytest.fixture diff --git a/python/datafusion/tests/test_sql.py b/python/datafusion/tests/test_sql.py index e9fb49c0e33d..89ebbb0bc3d9 100644 --- a/python/datafusion/tests/test_sql.py +++ b/python/datafusion/tests/test_sql.py @@ -21,6 +21,7 @@ from datafusion import ExecutionContext from datafusion import functions as f + from . import generic as helpers From 69e257d46cdd1348013308ec1637e4dbc9c4a891 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Fri, 29 Oct 2021 18:53:15 +0200 Subject: [PATCH 14/21] Add test for importing from datafusion.functions --- python/datafusion/__init__.py | 2 + python/datafusion/tests/test_functions.py | 46 +++++++++++++++++------ python/src/expression.rs | 5 +-- python/src/functions.rs | 2 +- python/src/lib.rs | 31 +++++---------- 5 files changed, 48 insertions(+), 38 deletions(-) diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index 52c3dfaa0bb7..724898da9d75 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -19,6 +19,8 @@ from .internals import DataFrame, ExecutionContext, Expression +__all__ = ["DataFrame", "ExecutionContext", "Expression", "column", "literal"] + def column(value): return Expression.column(value) diff --git a/python/datafusion/tests/test_functions.py b/python/datafusion/tests/test_functions.py index 08ddbaec345a..6b63e9bb1645 100644 --- a/python/datafusion/tests/test_functions.py +++ b/python/datafusion/tests/test_functions.py @@ -35,6 +35,10 @@ def df(): return ctx.create_dataframe([[batch]]) +def test_import_from_submodule(): + from datafusion.functions import abs, sin # noqa + + def test_literal(df): df = df.select( literal(1), @@ -150,52 +154,70 @@ def test_hash_functions(df): assert result.column(1) == pa.array( [ b( - "185F8DB32271FE25F561A6FC938B2E264306EC304EDA518007D1764826381969" + "185F8DB32271FE25F561A6FC938B2E26" + "4306EC304EDA518007D1764826381969" ), b( - "78AE647DC5544D227130A0682A51E30BC7777FBB6D8A8F17007463A3ECD1D524" + "78AE647DC5544D227130A0682A51E30B" + "C7777FBB6D8A8F17007463A3ECD1D524" ), b( - "BB7208BC9B5D7C04F1236A82A0093A5E33F40423D5BA8D4266F7092C3BA43B62" + "BB7208BC9B5D7C04F1236A82A0093A5E" + "33F40423D5BA8D4266F7092C3BA43B62" ), ] ) assert result.column(2) == pa.array( [ b( - "3615F80C9D293ED7402687F94B22D58E529B8CC7916F8FAC7FDDF7FBD5AF4CF777D3D795A7A00A16BF7E7F3FB9561EE9BAAE480DA9FE7A18769E71886B03F315" + "3615F80C9D293ED7402687F94B22D58E" + "529B8CC7916F8FAC7FDDF7FBD5AF4CF7" + "77D3D795A7A00A16BF7E7F3FB9561EE9" + "BAAE480DA9FE7A18769E71886B03F315" ), b( - "8EA77393A42AB8FA92500FB077A9509CC32BC95E72712EFA116EDAF2EDFAE34FBB682EFDD6C5DD13C117E08BD4AAEF71291D8AACE2F890273081D0677C16DF0F" + "8EA77393A42AB8FA92500FB077A9509C" + "C32BC95E72712EFA116EDAF2EDFAE34F" + "BB682EFDD6C5DD13C117E08BD4AAEF71" + "291D8AACE2F890273081D0677C16DF0F" ), b( - "3831A6A6155E509DEE59A7F451EB35324D8F8F2DF6E3708894740F98FDEE23889F4DE5ADB0C5010DFB555CDA77C8AB5DC902094C52DE3278F35A75EBC25F093A" + "3831A6A6155E509DEE59A7F451EB3532" + "4D8F8F2DF6E3708894740F98FDEE2388" + "9F4DE5ADB0C5010DFB555CDA77C8AB5D" + "C902094C52DE3278F35A75EBC25F093A" ), ] ) assert result.column(3) == pa.array( [ b( - "F73A5FBF881F89B814871F46E26AD3FA37CB2921C5E8561618639015B3CCBB71" + "F73A5FBF881F89B814871F46E26AD3FA" + "37CB2921C5E8561618639015B3CCBB71" ), b( - "B792A0383FB9E7A189EC150686579532854E44B71AC394831DAED169BA85CCC5" + "B792A0383FB9E7A189EC150686579532" + "854E44B71AC394831DAED169BA85CCC5" ), b( - "27988A0E51812297C77A433F635233346AEE29A829DCF4F46E0F58F402C6CFCB" + "27988A0E51812297C77A433F63523334" + "6AEE29A829DCF4F46E0F58F402C6CFCB" ), ] ) assert result.column(4) == pa.array( [ b( - "FBC2B0516EE8744D293B980779178A3508850FDCFE965985782C39601B65794F" + "FBC2B0516EE8744D293B980779178A35" + "08850FDCFE965985782C39601B65794F" ), b( - "BF73D18575A736E4037D45F9E316085B86C19BE6363DE6AA789E13DEAACC1C4E" + "BF73D18575A736E4037D45F9E316085B" + "86C19BE6363DE6AA789E13DEAACC1C4E" ), b( - "C8D11B9F7237E4034ADBCD2005735F9BC4C597C75AD89F4492BEC8F77D15F7EB" + "C8D11B9F7237E4034ADBCD2005735F9B" + "C4C597C75AD89F4492BEC8F77D15F7EB" ), ] ) diff --git a/python/src/expression.rs b/python/src/expression.rs index 629c3e38155c..fd4b407db938 100644 --- a/python/src/expression.rs +++ b/python/src/expression.rs @@ -23,11 +23,9 @@ use datafusion::arrow::datatypes::DataType; use datafusion::logical_plan::Expr; use datafusion::physical_plan::{udaf::AggregateUDF, udf::ScalarUDF}; - -use datafusion::logical_plan::{lit, col}; +use datafusion::logical_plan::{col, lit}; use datafusion::scalar::ScalarValue; - /// An PyExpr that can be used on a DataFrame #[pyclass(name = "Expression")] #[derive(Debug, Clone)] @@ -99,7 +97,6 @@ impl PyObjectProtocol for PyExpr { #[pymethods] impl PyExpr { - #[staticmethod] pub fn literal(value: ScalarValue) -> PyExpr { lit(value).into() diff --git a/python/src/functions.rs b/python/src/functions.rs index 9529094a546a..b075e34a4073 100644 --- a/python/src/functions.rs +++ b/python/src/functions.rs @@ -306,7 +306,7 @@ fn udaf( }) } -pub fn init(m: &PyModule) -> PyResult<()> { +pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { // TODO(kszucs): implement FromPyObject to PyVolatility m.add_class::()?; m.add_wrapped(wrap_pyfunction!(abs))?; diff --git a/python/src/lib.rs b/python/src/lib.rs index 3ad88df461a2..ff33d94f0851 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -27,29 +27,13 @@ mod udaf; mod udf; mod utils; -// TODO(kszucs): remvoe -// taken from https://github.com/PyO3/pyo3/issues/471 -// fn register_module_package(py: Python, package_name: &str, module: &PyModule) { -// py.import("sys") -// .expect("failed to import python sys module") -// .dict() -// .get_item("modules") -// .expect("failed to get python modules dictionary") -// .downcast::() -// .expect("failed to turn sys.modules into a PyDict") -// .set_item(package_name, module) -// .expect("failed to inject module"); -// } - -/// DataFusion. +/// Low-level DataFusion internal package. +/// +/// The higher-level public API is defined in pure python files under the +/// datafusion directory. #[pymodule] fn internals(py: Python, m: &PyModule) -> PyResult<()> { - //register_module_package(py, "datafusion.functions", functions); - - let functions = PyModule::new(py, "functions")?; - functions::init(functions)?; - m.add_submodule(functions)?; - + // Register the python classes m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -57,5 +41,10 @@ fn internals(py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + // Register the functions as a submodule + let funcs = PyModule::new(py, "functions")?; + functions::init_module(funcs)?; + m.add_submodule(funcs)?; + Ok(()) } From 0e1901de619ff8e4333a9c8c16227de8fcd472e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Fri, 29 Oct 2021 18:56:53 +0200 Subject: [PATCH 15/21] Rename internals to _internal --- python/Cargo.toml | 4 ++-- python/datafusion/__init__.py | 2 +- python/datafusion/functions.py | 20 +++++++++++++++++++- python/src/lib.rs | 2 +- 4 files changed, 23 insertions(+), 5 deletions(-) diff --git a/python/Cargo.toml b/python/Cargo.toml index d3f13b498916..3d3ebfa34540 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -35,11 +35,11 @@ datafusion = { path = "../datafusion", version = "5.1.0", features = ["pyarrow"] uuid = { version = "0.8", features = ["v4"] } [lib] -name = "internals" +name = "_internal" crate-type = ["cdylib"] [package.metadata.maturin] -name = "datafusion.internals" +name = "datafusion._internal" [profile.release] lto = true diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index 724898da9d75..e6421b95addc 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -17,7 +17,7 @@ import pyarrow as pa -from .internals import DataFrame, ExecutionContext, Expression +from ._internal import DataFrame, ExecutionContext, Expression __all__ = ["DataFrame", "ExecutionContext", "Expression", "column", "literal"] diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index acc20814a825..782ecba22191 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -1,4 +1,22 @@ -from .internals import functions +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +from ._internal import functions def __getattr__(name): diff --git a/python/src/lib.rs b/python/src/lib.rs index ff33d94f0851..d9d5993d4b4d 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -32,7 +32,7 @@ mod utils; /// The higher-level public API is defined in pure python files under the /// datafusion directory. #[pymodule] -fn internals(py: Python, m: &PyModule) -> PyResult<()> { +fn _internal(py: Python, m: &PyModule) -> PyResult<()> { // Register the python classes m.add_class::()?; m.add_class::()?; From 6bedd79bdf91976ed7dcdae1d3d10125260617ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Fri, 29 Oct 2021 19:13:01 +0200 Subject: [PATCH 16/21] Make classes inheritable; add tests for imports; set module --- python/datafusion/tests/test_functions.py | 4 -- python/datafusion/tests/test_imports.py | 52 +++++++++++++++++++++++ python/src/context.rs | 2 +- python/src/dataframe.rs | 2 +- python/src/expression.rs | 2 +- 5 files changed, 55 insertions(+), 7 deletions(-) create mode 100644 python/datafusion/tests/test_imports.py diff --git a/python/datafusion/tests/test_functions.py b/python/datafusion/tests/test_functions.py index 6b63e9bb1645..84718eaf0ce6 100644 --- a/python/datafusion/tests/test_functions.py +++ b/python/datafusion/tests/test_functions.py @@ -35,10 +35,6 @@ def df(): return ctx.create_dataframe([[batch]]) -def test_import_from_submodule(): - from datafusion.functions import abs, sin # noqa - - def test_literal(df): df = df.select( literal(1), diff --git a/python/datafusion/tests/test_imports.py b/python/datafusion/tests/test_imports.py new file mode 100644 index 000000000000..0accb392d673 --- /dev/null +++ b/python/datafusion/tests/test_imports.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +import datafusion +from datafusion import DataFrame, ExecutionContext, Expression, functions + + +def test_import_datafusion(): + assert datafusion.__name__ == "datafusion" + + +def test_class_module_is_datafusion(): + for klass in [ExecutionContext, Expression, DataFrame]: + assert klass.__module__ == "datafusion" + + +def test_import_from_functions_submodule(): + from datafusion.functions import abs, sin # noqa + + assert functions.abs is abs + assert functions.sin is sin + + msg = "cannot import name 'foobar' from 'datafusion.functions'" + with pytest.raises(ImportError, match=msg): + from datafusion.functions import foobar # noqa + + +def test_classes_are_inheritable(): + class MyExecContext(ExecutionContext): + pass + + class MyExpression(Expression): + pass + + class MyDataFrame(DataFrame): + pass diff --git a/python/src/context.rs b/python/src/context.rs index 8e6ab4a66637..a8e963cfffef 100644 --- a/python/src/context.rs +++ b/python/src/context.rs @@ -38,7 +38,7 @@ use crate::utils::wait_for_future; /// `PyExecutionContext` is able to plan and execute DataFusion plans. /// It has a powerful optimizer, a physical planner for local execution, and a /// multi-threaded execution engine to perform the execution. -#[pyclass(name = "ExecutionContext", unsendable)] +#[pyclass(name = "ExecutionContext", module = "datafusion", subclass, unsendable)] pub(crate) struct PyExecutionContext { ctx: ExecutionContext, } diff --git a/python/src/dataframe.rs b/python/src/dataframe.rs index 8b1c8f0e3e45..9050df92ed26 100644 --- a/python/src/dataframe.rs +++ b/python/src/dataframe.rs @@ -31,7 +31,7 @@ use crate::{errors::DataFusionError, expression::PyExpr}; /// A PyDataFrame is a representation of a logical plan and an API to compose statements. /// Use it to build a plan and `.collect()` to execute the plan and collect the result. /// The actual execution of a plan runs natively on Rust and Arrow on a multi-threaded environment. -#[pyclass(name = "DataFrame")] +#[pyclass(name = "DataFrame", module = "datafusion", subclass)] #[derive(Clone)] pub(crate) struct PyDataFrame { df: Arc, diff --git a/python/src/expression.rs b/python/src/expression.rs index fd4b407db938..cb1e3bd783b7 100644 --- a/python/src/expression.rs +++ b/python/src/expression.rs @@ -27,7 +27,7 @@ use datafusion::logical_plan::{col, lit}; use datafusion::scalar::ScalarValue; /// An PyExpr that can be used on a DataFrame -#[pyclass(name = "Expression")] +#[pyclass(name = "Expression", module = "datafusion", subclass)] #[derive(Debug, Clone)] pub(crate) struct PyExpr { pub(crate) expr: Expr, From a2ef207aea6f445bc4d4415c094ada727216b6d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Fri, 29 Oct 2021 20:41:24 +0200 Subject: [PATCH 17/21] Remove PyVolatility --- python/datafusion/tests/test_dataframe.py | 2 +- python/datafusion/tests/test_sql.py | 3 +- python/datafusion/tests/test_udaf.py | 4 +- python/src/context.rs | 4 +- python/src/expression.rs | 2 + python/src/functions.rs | 63 ++++++++++------------- 6 files changed, 36 insertions(+), 42 deletions(-) diff --git a/python/datafusion/tests/test_dataframe.py b/python/datafusion/tests/test_dataframe.py index db246a61e3c2..4ed221ec81f9 100644 --- a/python/datafusion/tests/test_dataframe.py +++ b/python/datafusion/tests/test_dataframe.py @@ -87,7 +87,7 @@ def test_udf(df): lambda x: x.is_null(), [pa.int64()], pa.bool_(), - f.Volatility.immutable(), + volatility="immutable", ) df = df.select(udf(column("a"))) diff --git a/python/datafusion/tests/test_sql.py b/python/datafusion/tests/test_sql.py index 89ebbb0bc3d9..d845f7565d91 100644 --- a/python/datafusion/tests/test_sql.py +++ b/python/datafusion/tests/test_sql.py @@ -20,7 +20,6 @@ import pytest from datafusion import ExecutionContext -from datafusion import functions as f from . import generic as helpers @@ -199,7 +198,7 @@ def test_udf( ) ctx.register_parquet("t", path) ctx.register_udf( - "udf", fn, input_types, output_type, f.Volatility.immutable() + "udf", fn, input_types, output_type, volatility="immutable" ) batches = ctx.sql("SELECT udf(a) AS tt FROM t").collect() diff --git a/python/datafusion/tests/test_udaf.py b/python/datafusion/tests/test_udaf.py index 70de702535c4..2fbc7bde14b2 100644 --- a/python/datafusion/tests/test_udaf.py +++ b/python/datafusion/tests/test_udaf.py @@ -68,7 +68,7 @@ def test_aggregate(df): pa.float64(), pa.float64(), [pa.float64()], - f.Volatility.immutable(), + volatility="immutable", ) df = df.aggregate([], [udaf(column("a"))]) @@ -85,7 +85,7 @@ def test_group_by(df): pa.float64(), pa.float64(), [pa.float64()], - f.Volatility.immutable(), + volatility="immutable", ) df = df.aggregate([column("b")], [udaf(column("a"))]) diff --git a/python/src/context.rs b/python/src/context.rs index a8e963cfffef..78e9c1e920c3 100644 --- a/python/src/context.rs +++ b/python/src/context.rs @@ -32,7 +32,7 @@ use datafusion::prelude::CsvReadOptions; use crate::catalog::PyCatalog; use crate::dataframe::PyDataFrame; use crate::errors::DataFusionError; -use crate::functions::{create_udf, PyVolatility}; +use crate::functions::create_udf; use crate::utils::wait_for_future; /// `PyExecutionContext` is able to plan and execute DataFusion plans. @@ -149,7 +149,7 @@ impl PyExecutionContext { func: PyObject, args_types: Vec, return_type: DataType, - volatility: PyVolatility, + volatility: &str, ) -> PyResult<()> { let function = create_udf(func, args_types, return_type, volatility, name)?; self.ctx.register_udf(function.function); diff --git a/python/src/expression.rs b/python/src/expression.rs index cb1e3bd783b7..2c8b2735d09f 100644 --- a/python/src/expression.rs +++ b/python/src/expression.rs @@ -142,6 +142,8 @@ pub struct PyScalarUDF { #[pymethods] impl PyScalarUDF { + // ADD NEW() + /// creates a new PyExpr with the call of the udf #[call] #[args(args = "*")] diff --git a/python/src/functions.rs b/python/src/functions.rs index b075e34a4073..1537a9f7858e 100644 --- a/python/src/functions.rs +++ b/python/src/functions.rs @@ -28,6 +28,7 @@ use datafusion::physical_plan::{ }; use crate::{ + errors::DataFusionError, expression::{PyAggregateUDF, PyExpr, PyScalarUDF}, udaf, udf, }; @@ -221,47 +222,30 @@ aggregate_function!(min, Min); aggregate_function!(sum, Sum); aggregate_function!(approx_distinct, ApproxDistinct); -#[pyclass(name = "Volatility", module = "datafusion.functions")] -#[derive(Clone)] -pub struct PyVolatility { - pub(crate) volatility: Volatility, -} - -#[pymethods] -impl PyVolatility { - #[staticmethod] - fn immutable() -> Self { - Self { - volatility: Volatility::Immutable, - } - } - #[staticmethod] - fn stable() -> Self { - Self { - volatility: Volatility::Stable, - } - } - #[staticmethod] - fn volatile() -> Self { - Self { - volatility: Volatility::Volatile, - } - } -} - pub(crate) fn create_udf( fun: PyObject, input_types: Vec, return_type: DataType, - volatility: PyVolatility, + volatility: &str, name: &str, ) -> PyResult { + let volatility = match volatility { + "immutable" => Volatility::Immutable, + "stable" => Volatility::Stable, + "volatile" => Volatility::Volatile, + value => { + return Err(DataFusionError::Common(format!( + "Unsupportad volatility type: `{}`, supported values are: immutable, stable and volatile.", + value + )).into()) + } + }; Ok(PyScalarUDF { function: logical_plan::create_udf( name, input_types, Arc::new(return_type), - volatility.volatility, + volatility, udf::array_udf(fun), ), }) @@ -273,7 +257,7 @@ fn udf( fun: PyObject, input_types: Vec, return_type: DataType, - volatility: PyVolatility, + volatility: &str, py: Python, ) -> PyResult { let name = fun.getattr(py, "__qualname__")?.extract::(py)?; @@ -287,19 +271,30 @@ fn udaf( input_type: DataType, return_type: DataType, state_type: Vec, - volatility: PyVolatility, + volatility: &str, py: Python, ) -> PyResult { let name = accumulator .getattr(py, "__qualname__")? .extract::(py)?; + let volatility = match volatility { + "immutable" => Volatility::Immutable, + "stable" => Volatility::Stable, + "volatile" => Volatility::Volatile, + value => { + return Err(DataFusionError::Common( + format!("Unsupportad volatility type: `{}`, supported values are: immutable, stable and volatile.", value) + ).into()) + } + }; + Ok(PyAggregateUDF { function: logical_plan::create_udaf( &name, input_type, Arc::new(return_type), - volatility.volatility, + volatility, udaf::array_udaf(accumulator), Arc::new(state_type), ), @@ -307,8 +302,6 @@ fn udaf( } pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { - // TODO(kszucs): implement FromPyObject to PyVolatility - m.add_class::()?; m.add_wrapped(wrap_pyfunction!(abs))?; m.add_wrapped(wrap_pyfunction!(acos))?; m.add_wrapped(wrap_pyfunction!(approx_distinct))?; From c7f384a3f928536a2174af771ef97d6bc5ac936c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Fri, 29 Oct 2021 21:08:34 +0200 Subject: [PATCH 18/21] Move ScalarUdf to udf.rs --- python/datafusion/__init__.py | 6 ++++ python/src/context.rs | 16 +++------- python/src/expression.rs | 26 ++------------- python/src/functions.rs | 51 +++-------------------------- python/src/udf.rs | 60 ++++++++++++++++++++++++++++++++++- 5 files changed, 75 insertions(+), 84 deletions(-) diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index e6421b95addc..09f64e8ce4e9 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -30,3 +30,9 @@ def literal(value): if not isinstance(value, pa.Scalar): value = pa.scalar(value) return Expression.literal(value) + + +# def udf(): +# """Create a new User Defined Function""" +# let name = fun.getattr(py, "__qualname__")?.extract::(py)?; +# create_udf(fun, input_types, return_type, volatility, &name) diff --git a/python/src/context.rs b/python/src/context.rs index 78e9c1e920c3..7f386bac398d 100644 --- a/python/src/context.rs +++ b/python/src/context.rs @@ -23,7 +23,7 @@ use uuid::Uuid; use pyo3::exceptions::{PyKeyError, PyValueError}; use pyo3::prelude::*; -use datafusion::arrow::datatypes::{DataType, Schema}; +use datafusion::arrow::datatypes::Schema; use datafusion::arrow::record_batch::RecordBatch; use datafusion::datasource::MemTable; use datafusion::execution::context::ExecutionContext; @@ -32,7 +32,7 @@ use datafusion::prelude::CsvReadOptions; use crate::catalog::PyCatalog; use crate::dataframe::PyDataFrame; use crate::errors::DataFusionError; -use crate::functions::create_udf; +use crate::udf::PyScalarUDF; use crate::utils::wait_for_future; /// `PyExecutionContext` is able to plan and execute DataFusion plans. @@ -143,16 +143,8 @@ impl PyExecutionContext { Ok(()) } - fn register_udf( - &mut self, - name: &str, - func: PyObject, - args_types: Vec, - return_type: DataType, - volatility: &str, - ) -> PyResult<()> { - let function = create_udf(func, args_types, return_type, volatility, name)?; - self.ctx.register_udf(function.function); + fn register_udf(&mut self, udf: PyScalarUDF) -> PyResult<()> { + self.ctx.register_udf(udf.function); Ok(()) } diff --git a/python/src/expression.rs b/python/src/expression.rs index 2c8b2735d09f..e1ef4d665cab 100644 --- a/python/src/expression.rs +++ b/python/src/expression.rs @@ -20,10 +20,8 @@ use std::convert::{From, Into}; use std::vec::Vec; use datafusion::arrow::datatypes::DataType; -use datafusion::logical_plan::Expr; -use datafusion::physical_plan::{udaf::AggregateUDF, udf::ScalarUDF}; - -use datafusion::logical_plan::{col, lit}; +use datafusion::logical_plan::{col, lit, Expr}; +use datafusion::physical_plan::udaf::AggregateUDF; use datafusion::scalar::ScalarValue; /// An PyExpr that can be used on a DataFrame @@ -133,26 +131,6 @@ impl PyExpr { } } -/// Represents a PyScalarUDF -#[pyclass] -#[derive(Debug, Clone)] -pub struct PyScalarUDF { - pub(crate) function: ScalarUDF, -} - -#[pymethods] -impl PyScalarUDF { - // ADD NEW() - - /// creates a new PyExpr with the call of the udf - #[call] - #[args(args = "*")] - fn __call__(&self, args: Vec) -> PyResult { - let args = args.iter().map(|e| e.expr.clone()).collect(); - Ok(self.function.call(args).into()) - } -} - /// Represents a AggregateUDF #[pyclass] #[derive(Debug, Clone)] diff --git a/python/src/functions.rs b/python/src/functions.rs index 1537a9f7858e..d6056a2e3011 100644 --- a/python/src/functions.rs +++ b/python/src/functions.rs @@ -21,7 +21,7 @@ use pyo3::{prelude::*, wrap_pyfunction, Python}; use datafusion::arrow::datatypes::DataType; use datafusion::logical_plan; -//use datafusion::logical_plan::Expr; + use datafusion::physical_plan::functions::Volatility; use datafusion::physical_plan::{ aggregates::AggregateFunction, functions::BuiltinScalarFunction, @@ -29,8 +29,8 @@ use datafusion::physical_plan::{ use crate::{ errors::DataFusionError, - expression::{PyAggregateUDF, PyExpr, PyScalarUDF}, - udaf, udf, + expression::{PyAggregateUDF, PyExpr}, + udaf, }; #[pyfunction] @@ -222,49 +222,7 @@ aggregate_function!(min, Min); aggregate_function!(sum, Sum); aggregate_function!(approx_distinct, ApproxDistinct); -pub(crate) fn create_udf( - fun: PyObject, - input_types: Vec, - return_type: DataType, - volatility: &str, - name: &str, -) -> PyResult { - let volatility = match volatility { - "immutable" => Volatility::Immutable, - "stable" => Volatility::Stable, - "volatile" => Volatility::Volatile, - value => { - return Err(DataFusionError::Common(format!( - "Unsupportad volatility type: `{}`, supported values are: immutable, stable and volatile.", - value - )).into()) - } - }; - Ok(PyScalarUDF { - function: logical_plan::create_udf( - name, - input_types, - Arc::new(return_type), - volatility, - udf::array_udf(fun), - ), - }) -} - -/// Creates a new UDF (User Defined Function). -#[pyfunction] -fn udf( - fun: PyObject, - input_types: Vec, - return_type: DataType, - volatility: &str, - py: Python, -) -> PyResult { - let name = fun.getattr(py, "__qualname__")?.extract::(py)?; - create_udf(fun, input_types, return_type, volatility, &name) -} - -/// Creates a new UDAF (User Defined Aggregate Function). +/// Creates a new udf. #[pyfunction] fn udaf( accumulator: PyObject, @@ -365,7 +323,6 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(trim))?; m.add_wrapped(wrap_pyfunction!(trunc))?; m.add_wrapped(wrap_pyfunction!(udaf))?; - m.add_wrapped(wrap_pyfunction!(udf))?; m.add_wrapped(wrap_pyfunction!(upper))?; Ok(()) } diff --git a/python/src/udf.rs b/python/src/udf.rs index fa77e4ab3257..2178f793a89d 100644 --- a/python/src/udf.rs +++ b/python/src/udf.rs @@ -15,12 +15,23 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + +use pyo3::{prelude::*, types::PyTuple}; + use datafusion::arrow::array::ArrayRef; +use datafusion::arrow::datatypes::DataType; use datafusion::arrow::pyarrow::PyArrowConvert; use datafusion::error::DataFusionError; +use datafusion::logical_plan; use datafusion::physical_plan::functions::ScalarFunctionImplementation; +use datafusion::physical_plan::{ + functions::Volatility, udaf::AggregateUDF, udf::ScalarUDF, +}; use datafusion::{arrow::array, physical_plan::functions::make_scalar_function}; -use pyo3::{prelude::*, types::PyTuple}; + +use crate::errors; +use crate::expression::PyExpr; /// creates a DataFusion's UDF implementation from a python function that expects pyarrow arrays /// This is more efficient as it performs a zero-copy of the contents. @@ -52,3 +63,50 @@ pub fn array_udf(func: PyObject) -> ScalarFunctionImplementation { }, ) } + +/// Represents a PyScalarUDF +#[pyclass] +#[derive(Debug, Clone)] +pub struct PyScalarUDF { + pub(crate) function: ScalarUDF, +} + +#[pymethods] +impl PyScalarUDF { + //#args(args) + fn new( + fun: PyObject, + input_types: Vec, + return_type: DataType, + volatility: &str, + name: &str, + ) -> PyResult { + let volatility = match volatility { + "immutable" => Volatility::Immutable, + "stable" => Volatility::Stable, + "volatile" => Volatility::Volatile, + value => { + return Err(errors::DataFusionError::Common(format!( + "Unsupportad volatility type: `{}`, supported values are: immutable, stable and volatile.", + value + )).into()) + } + }; + let function = logical_plan::create_udf( + name, + input_types, + Arc::new(return_type), + volatility, + array_udf(fun), + ); + Ok(PyScalarUDF { function }) + } + + /// creates a new PyExpr with the call of the udf + #[call] + #[args(args = "*")] + fn __call__(&self, args: Vec) -> PyResult { + let args = args.iter().map(|e| e.expr.clone()).collect(); + Ok(self.function.call(args).into()) + } +} From 837c1ba3693ed41d8d9babae208ffec680a5b5f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Fri, 29 Oct 2021 22:02:09 +0200 Subject: [PATCH 19/21] Factor out PyScalarUDF and PyAggregateUDF --- python/datafusion/__init__.py | 6 +++ python/src/expression.rs | 18 --------- python/src/functions.rs | 45 +--------------------- python/src/udaf.rs | 71 +++++++++++++++++++++++++++-------- python/src/udf.rs | 44 ++++++++-------------- python/src/utils.rs | 19 ++++++++++ 6 files changed, 98 insertions(+), 105 deletions(-) diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index 09f64e8ce4e9..20f058d7af84 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -36,3 +36,9 @@ def literal(value): # """Create a new User Defined Function""" # let name = fun.getattr(py, "__qualname__")?.extract::(py)?; # create_udf(fun, input_types, return_type, volatility, &name) + + +# udaf(): +# // let name = accumulator +# // .getattr(py, "__qualname__")? +# // .extract::(py)?; diff --git a/python/src/expression.rs b/python/src/expression.rs index e1ef4d665cab..28a19c67c0db 100644 --- a/python/src/expression.rs +++ b/python/src/expression.rs @@ -130,21 +130,3 @@ impl PyExpr { expr.into() } } - -/// Represents a AggregateUDF -#[pyclass] -#[derive(Debug, Clone)] -pub struct PyAggregateUDF { - pub(crate) function: AggregateUDF, -} - -#[pymethods] -impl PyAggregateUDF { - /// creates a new PyExpr with the call of the udf - #[call] - #[args(args = "*")] - fn __call__(&self, args: Vec) -> PyResult { - let args = args.iter().map(|e| e.expr.clone()).collect(); - Ok(self.function.call(args).into()) - } -} diff --git a/python/src/functions.rs b/python/src/functions.rs index d6056a2e3011..3cd971fb4cde 100644 --- a/python/src/functions.rs +++ b/python/src/functions.rs @@ -27,11 +27,8 @@ use datafusion::physical_plan::{ aggregates::AggregateFunction, functions::BuiltinScalarFunction, }; -use crate::{ - errors::DataFusionError, - expression::{PyAggregateUDF, PyExpr}, - udaf, -}; +use crate::errors::DataFusionError; +use crate::expression::PyExpr; #[pyfunction] fn array(value: Vec) -> PyExpr { @@ -222,43 +219,6 @@ aggregate_function!(min, Min); aggregate_function!(sum, Sum); aggregate_function!(approx_distinct, ApproxDistinct); -/// Creates a new udf. -#[pyfunction] -fn udaf( - accumulator: PyObject, - input_type: DataType, - return_type: DataType, - state_type: Vec, - volatility: &str, - py: Python, -) -> PyResult { - let name = accumulator - .getattr(py, "__qualname__")? - .extract::(py)?; - - let volatility = match volatility { - "immutable" => Volatility::Immutable, - "stable" => Volatility::Stable, - "volatile" => Volatility::Volatile, - value => { - return Err(DataFusionError::Common( - format!("Unsupportad volatility type: `{}`, supported values are: immutable, stable and volatile.", value) - ).into()) - } - }; - - Ok(PyAggregateUDF { - function: logical_plan::create_udaf( - &name, - input_type, - Arc::new(return_type), - volatility, - udaf::array_udaf(accumulator), - Arc::new(state_type), - ), - }) -} - pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(abs))?; m.add_wrapped(wrap_pyfunction!(acos))?; @@ -322,7 +282,6 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(translate))?; m.add_wrapped(wrap_pyfunction!(trim))?; m.add_wrapped(wrap_pyfunction!(trunc))?; - m.add_wrapped(wrap_pyfunction!(udaf))?; m.add_wrapped(wrap_pyfunction!(upper))?; Ok(()) } diff --git a/python/src/udaf.rs b/python/src/udaf.rs index 756afe68c31e..a6b9edf312dd 100644 --- a/python/src/udaf.rs +++ b/python/src/udaf.rs @@ -20,13 +20,16 @@ use std::sync::Arc; use pyo3::{prelude::*, types::PyTuple}; use datafusion::arrow::array::ArrayRef; +use datafusion::arrow::datatypes::DataType; use datafusion::arrow::pyarrow::PyArrowConvert; +use datafusion::error::{DataFusionError, Result}; +use datafusion::logical_plan; +use datafusion::physical_plan::udaf::AggregateUDF; +use datafusion::physical_plan::Accumulator; +use datafusion::scalar::ScalarValue; -use datafusion::error::Result; -use datafusion::{ - error::DataFusionError as InnerDataFusionError, physical_plan::Accumulator, - scalar::ScalarValue, -}; +use crate::expression::PyExpr; +use crate::utils::parse_volatility; #[derive(Debug)] struct PyAccumulator { @@ -42,7 +45,7 @@ impl PyAccumulator { impl Accumulator for PyAccumulator { fn state(&self) -> Result> { Python::with_gil(|py| self.accum.as_ref(py).call_method0("to_scalars")?.extract()) - .map_err(|e| InnerDataFusionError::Execution(format!("{}", e))) + .map_err(|e| DataFusionError::Execution(format!("{}", e))) } fn update(&mut self, _values: &[ScalarValue]) -> Result<()> { @@ -57,7 +60,7 @@ impl Accumulator for PyAccumulator { fn evaluate(&self) -> Result { Python::with_gil(|py| self.accum.as_ref(py).call_method0("evaluate")?.extract()) - .map_err(|e| InnerDataFusionError::Execution(format!("{}", e))) + .map_err(|e| DataFusionError::Execution(format!("{}", e))) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { @@ -76,7 +79,7 @@ impl Accumulator for PyAccumulator { self.accum .as_ref(py) .call_method1("update", py_args) - .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))?; + .map_err(|e| DataFusionError::Execution(format!("{}", e)))?; Ok(()) }) @@ -84,34 +87,72 @@ impl Accumulator for PyAccumulator { fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { Python::with_gil(|py| { - // 1. cast states to Pyarrow array - // 2. merge let state = &states[0]; + // 1. cast states to Pyarrow array let state = state .to_pyarrow(py) - .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))?; + .map_err(|e| DataFusionError::Execution(format!("{}", e)))?; - // 2. + // 2. merge self.accum .as_ref(py) .call_method1("merge", (state,)) - .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))?; + .map_err(|e| DataFusionError::Execution(format!("{}", e)))?; Ok(()) }) } } -pub fn array_udaf( +pub fn to_rust_accumulator( accumulator: PyObject, ) -> Arc Result> + Send + Sync> { Arc::new(move || -> Result> { let accumulator = Python::with_gil(|py| { accumulator .call0(py) - .map_err(|e| InnerDataFusionError::Execution(format!("{}", e))) + .map_err(|e| DataFusionError::Execution(format!("{}", e))) })?; Ok(Box::new(PyAccumulator::new(accumulator))) }) } + +/// Represents a AggregateUDF +#[pyclass] +#[derive(Debug, Clone)] +pub struct PyAggregateUDF { + pub(crate) function: AggregateUDF, +} + +#[pymethods] +impl PyAggregateUDF { + #[new] + fn new( + name: &str, + accumulator: PyObject, + input_type: DataType, + return_type: DataType, + state_type: Vec, + volatility: &str, + ) -> PyResult { + Ok(Self { + function: logical_plan::create_udaf( + &name, + input_type, + Arc::new(return_type), + parse_volatility(volatility)?, + to_rust_accumulator(accumulator), + Arc::new(state_type), + ), + }) + } + + /// creates a new PyExpr with the call of the udf + #[call] + #[args(args = "*")] + fn __call__(&self, args: Vec) -> PyResult { + let args = args.iter().map(|e| e.expr.clone()).collect(); + Ok(self.function.call(args).into()) + } +} diff --git a/python/src/udf.rs b/python/src/udf.rs index 2178f793a89d..df2888c086a7 100644 --- a/python/src/udf.rs +++ b/python/src/udf.rs @@ -24,39 +24,36 @@ use datafusion::arrow::datatypes::DataType; use datafusion::arrow::pyarrow::PyArrowConvert; use datafusion::error::DataFusionError; use datafusion::logical_plan; -use datafusion::physical_plan::functions::ScalarFunctionImplementation; -use datafusion::physical_plan::{ - functions::Volatility, udaf::AggregateUDF, udf::ScalarUDF, +use datafusion::physical_plan::functions::{ + make_scalar_function, ScalarFunctionImplementation, Volatility, }; -use datafusion::{arrow::array, physical_plan::functions::make_scalar_function}; +use datafusion::physical_plan::udf::ScalarUDF; -use crate::errors; use crate::expression::PyExpr; +use crate::utils::parse_volatility; -/// creates a DataFusion's UDF implementation from a python function that expects pyarrow arrays -/// This is more efficient as it performs a zero-copy of the contents. -pub fn array_udf(func: PyObject) -> ScalarFunctionImplementation { +/// Create a DataFusion's UDF implementation from a python function +/// that expects pyarrow arrays. This is more efficient as it performs +/// a zero-copy of the contents. +fn to_rust_function(func: PyObject) -> ScalarFunctionImplementation { make_scalar_function( - move |args: &[array::ArrayRef]| -> Result { + move |args: &[ArrayRef]| -> Result { Python::with_gil(|py| { // 1. cast args to Pyarrow arrays - // 2. call function - // 3. cast to arrow::array::Array - - // 1. let py_args = args .iter() .map(|arg| arg.data().to_owned().to_pyarrow(py).unwrap()) .collect::>(); let py_args = PyTuple::new(py, py_args); - // 2. + // 2. call function let value = func.as_ref(py).call(py_args, None); let value = match value { Ok(n) => Ok(n), Err(error) => Err(DataFusionError::Execution(format!("{:?}", error))), }?; + // 3. cast to arrow::array::Array let array = ArrayRef::from_pyarrow(value).unwrap(); Ok(array) }) @@ -73,31 +70,20 @@ pub struct PyScalarUDF { #[pymethods] impl PyScalarUDF { - //#args(args) + #[new] fn new( + name: &str, fun: PyObject, input_types: Vec, return_type: DataType, volatility: &str, - name: &str, ) -> PyResult { - let volatility = match volatility { - "immutable" => Volatility::Immutable, - "stable" => Volatility::Stable, - "volatile" => Volatility::Volatile, - value => { - return Err(errors::DataFusionError::Common(format!( - "Unsupportad volatility type: `{}`, supported values are: immutable, stable and volatile.", - value - )).into()) - } - }; let function = logical_plan::create_udf( name, input_types, Arc::new(return_type), - volatility, - array_udf(fun), + parse_volatility(volatility)?, + to_rust_function(fun), ); Ok(PyScalarUDF { function }) } diff --git a/python/src/utils.rs b/python/src/utils.rs index c2d924adfcea..c8e1c63b1d0f 100644 --- a/python/src/utils.rs +++ b/python/src/utils.rs @@ -20,6 +20,10 @@ use std::future::Future; use pyo3::prelude::*; use tokio::runtime::Runtime; +use datafusion::physical_plan::functions::Volatility; + +use crate::errors::DataFusionError; + /// Utility to collect rust futures with GIL released pub(crate) fn wait_for_future(py: Python, f: F) -> F::Output where @@ -29,3 +33,18 @@ where let rt = Runtime::new().unwrap(); py.allow_threads(|| rt.block_on(f)) } + +pub(crate) fn parse_volatility(value: &str) -> Result { + Ok(match value { + "immutable" => Volatility::Immutable, + "stable" => Volatility::Stable, + "volatile" => Volatility::Volatile, + value => { + return Err(DataFusionError::Common(format!( + "Unsupportad volatility type: `{}`, supported \ + values are: immutable, stable and volatile.", + value + ))) + } + }) +} From b8047723c8db948926666ffbd3bc1377b504ac7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Fri, 29 Oct 2021 23:23:44 +0200 Subject: [PATCH 20/21] Refactor UDF and UDAF construction --- python/datafusion/__init__.py | 80 ++++++++++++++++++++--- python/datafusion/tests/test_dataframe.py | 8 +-- python/datafusion/tests/test_imports.py | 17 ++++- python/datafusion/tests/test_sql.py | 10 +-- python/datafusion/tests/test_udaf.py | 59 ++++++++++++++--- python/src/expression.rs | 3 +- python/src/functions.rs | 7 +- python/src/lib.rs | 2 + python/src/udaf.rs | 49 +++++++------- python/src/udf.rs | 12 ++-- 10 files changed, 176 insertions(+), 71 deletions(-) diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index 20f058d7af84..4f9082e7e402 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -15,11 +15,46 @@ # specific language governing permissions and limitations # under the License. +from abc import ABCMeta, abstractmethod +from typing import List + import pyarrow as pa -from ._internal import DataFrame, ExecutionContext, Expression +from ._internal import ( + AggregateUDF, + DataFrame, + ExecutionContext, + Expression, + ScalarUDF, +) + +__all__ = [ + "DataFrame", + "ExecutionContext", + "Expression", + "AggregateUDF", + "ScalarUDF", + "column", + "literal", +] + + +class Accumulator(metaclass=ABCMeta): + @abstractmethod + def state(self) -> List[pa.Scalar]: + pass + + @abstractmethod + def update(self, values: pa.Array) -> None: + pass + + @abstractmethod + def merge(self, states: pa.Array) -> None: + pass -__all__ = ["DataFrame", "ExecutionContext", "Expression", "column", "literal"] + @abstractmethod + def evaluate(self) -> pa.Scalar: + pass def column(value): @@ -32,13 +67,38 @@ def literal(value): return Expression.literal(value) -# def udf(): -# """Create a new User Defined Function""" -# let name = fun.getattr(py, "__qualname__")?.extract::(py)?; -# create_udf(fun, input_types, return_type, volatility, &name) +def udf(func, input_types, return_type, volatility, name=None): + """ + Create a new User Defined Function + """ + if not callable(func): + raise TypeError("`func` argument must be callable") + if name is None: + name = func.__qualname__ + return ScalarUDF( + name=name, + func=func, + input_types=input_types, + return_type=return_type, + volatility=volatility, + ) -# udaf(): -# // let name = accumulator -# // .getattr(py, "__qualname__")? -# // .extract::(py)?; +def udaf(accum, input_type, return_type, state_type, volatility, name=None): + """ + Create a new User Defined Aggregate Function + """ + if not issubclass(accum, Accumulator): + raise TypeError( + "`accum` must implement the abstract base class Accumulator" + ) + if name is None: + name = accum.__qualname__ + return AggregateUDF( + name=name, + accumulator=accum, + input_type=input_type, + return_type=return_type, + state_type=state_type, + volatility=volatility, + ) diff --git a/python/datafusion/tests/test_dataframe.py b/python/datafusion/tests/test_dataframe.py index 4ed221ec81f9..0eb970a69e83 100644 --- a/python/datafusion/tests/test_dataframe.py +++ b/python/datafusion/tests/test_dataframe.py @@ -18,9 +18,7 @@ import pyarrow as pa import pytest -from datafusion import DataFrame, ExecutionContext, column -from datafusion import functions as f -from datafusion import literal +from datafusion import DataFrame, ExecutionContext, column, literal, udf @pytest.fixture @@ -83,14 +81,14 @@ def test_limit(df): def test_udf(df): # is_null is a pa function over arrays - udf = f.udf( + is_null = udf( lambda x: x.is_null(), [pa.int64()], pa.bool_(), volatility="immutable", ) - df = df.select(udf(column("a"))) + df = df.select(is_null(column("a"))) result = df.collect()[0].column(0) assert result == pa.array([False, False, False]) diff --git a/python/datafusion/tests/test_imports.py b/python/datafusion/tests/test_imports.py index 0accb392d673..423800248a5c 100644 --- a/python/datafusion/tests/test_imports.py +++ b/python/datafusion/tests/test_imports.py @@ -18,7 +18,14 @@ import pytest import datafusion -from datafusion import DataFrame, ExecutionContext, Expression, functions +from datafusion import ( + AggregateUDF, + DataFrame, + ExecutionContext, + Expression, + ScalarUDF, + functions, +) def test_import_datafusion(): @@ -26,7 +33,13 @@ def test_import_datafusion(): def test_class_module_is_datafusion(): - for klass in [ExecutionContext, Expression, DataFrame]: + for klass in [ + ExecutionContext, + Expression, + DataFrame, + ScalarUDF, + AggregateUDF, + ]: assert klass.__module__ == "datafusion" diff --git a/python/datafusion/tests/test_sql.py b/python/datafusion/tests/test_sql.py index d845f7565d91..23f20079f0da 100644 --- a/python/datafusion/tests/test_sql.py +++ b/python/datafusion/tests/test_sql.py @@ -19,7 +19,7 @@ import pyarrow as pa import pytest -from datafusion import ExecutionContext +from datafusion import ExecutionContext, udf from . import generic as helpers @@ -197,11 +197,13 @@ def test_udf( tmp_path / "a.parquet", pa.array(input_values) ) ctx.register_parquet("t", path) - ctx.register_udf( - "udf", fn, input_types, output_type, volatility="immutable" + + func = udf( + fn, input_types, output_type, name="func", volatility="immutable" ) + ctx.register_udf(func) - batches = ctx.sql("SELECT udf(a) AS tt FROM t").collect() + batches = ctx.sql("SELECT func(a) AS tt FROM t").collect() result = batches[0].column(0) assert result == pa.array(expected_values) diff --git a/python/datafusion/tests/test_udaf.py b/python/datafusion/tests/test_udaf.py index 2fbc7bde14b2..2f286ba105dd 100644 --- a/python/datafusion/tests/test_udaf.py +++ b/python/datafusion/tests/test_udaf.py @@ -21,11 +21,10 @@ import pyarrow.compute as pc import pytest -from datafusion import ExecutionContext, column -from datafusion import functions as f +from datafusion import Accumulator, ExecutionContext, column, udaf -class Accumulator: +class Summarize(Accumulator): """ Interface of a user-defined accumulation. """ @@ -33,7 +32,7 @@ class Accumulator: def __init__(self): self._sum = pa.scalar(0.0) - def to_scalars(self) -> List[pa.Scalar]: + def state(self) -> List[pa.Scalar]: return [self._sum] def update(self, values: pa.Array) -> None: @@ -50,6 +49,18 @@ def evaluate(self) -> pa.Scalar: return self._sum +class NotSubclassOfAccumulator: + pass + + +class MissingMethods(Accumulator): + def __init__(self): + self._sum = pa.scalar(0) + + def state(self) -> List[pa.Scalar]: + return [self._sum] + + @pytest.fixture def df(): ctx = ExecutionContext() @@ -62,16 +73,43 @@ def df(): return ctx.create_dataframe([[batch]]) +def test_errors(df): + with pytest.raises(TypeError): + udaf( + NotSubclassOfAccumulator, + pa.float64(), + pa.float64(), + [pa.float64()], + volatility="immutable", + ) + + accum = udaf( + MissingMethods, + pa.int64(), + pa.int64(), + [pa.int64()], + volatility="immutable", + ) + df = df.aggregate([], [accum(column("a"))]) + + msg = ( + "Can't instantiate abstract class MissingMethods with abstract " + "methods evaluate, merge, update" + ) + with pytest.raises(Exception, match=msg): + df.collect() + + def test_aggregate(df): - udaf = f.udaf( - Accumulator, + summarize = udaf( + Summarize, pa.float64(), pa.float64(), [pa.float64()], volatility="immutable", ) - df = df.aggregate([], [udaf(column("a"))]) + df = df.aggregate([], [summarize(column("a"))]) # execute and collect the first (and only) batch result = df.collect()[0] @@ -80,17 +118,18 @@ def test_aggregate(df): def test_group_by(df): - udaf = f.udaf( - Accumulator, + summarize = udaf( + Summarize, pa.float64(), pa.float64(), [pa.float64()], volatility="immutable", ) - df = df.aggregate([column("b")], [udaf(column("a"))]) + df = df.aggregate([column("b")], [summarize(column("a"))]) batches = df.collect() + arrays = [batch.column(1) for batch in batches] joined = pa.concat_arrays(arrays) assert joined == pa.array([1.0 + 2.0, 3.0]) diff --git a/python/src/expression.rs b/python/src/expression.rs index 28a19c67c0db..21cecaa1ccce 100644 --- a/python/src/expression.rs +++ b/python/src/expression.rs @@ -17,11 +17,10 @@ use pyo3::{basic::CompareOp, prelude::*, PyNumberProtocol, PyObjectProtocol}; use std::convert::{From, Into}; -use std::vec::Vec; use datafusion::arrow::datatypes::DataType; use datafusion::logical_plan::{col, lit, Expr}; -use datafusion::physical_plan::udaf::AggregateUDF; + use datafusion::scalar::ScalarValue; /// An PyExpr that can be used on a DataFrame diff --git a/python/src/functions.rs b/python/src/functions.rs index 3cd971fb4cde..a2862202602f 100644 --- a/python/src/functions.rs +++ b/python/src/functions.rs @@ -15,19 +15,14 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; +use pyo3::{prelude::*, wrap_pyfunction}; -use pyo3::{prelude::*, wrap_pyfunction, Python}; - -use datafusion::arrow::datatypes::DataType; use datafusion::logical_plan; -use datafusion::physical_plan::functions::Volatility; use datafusion::physical_plan::{ aggregates::AggregateFunction, functions::BuiltinScalarFunction, }; -use crate::errors::DataFusionError; use crate::expression::PyExpr; #[pyfunction] diff --git a/python/src/lib.rs b/python/src/lib.rs index d9d5993d4b4d..d40bae251c86 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -40,6 +40,8 @@ fn _internal(py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; // Register the functions as a submodule let funcs = PyModule::new(py, "functions")?; diff --git a/python/src/udaf.rs b/python/src/udaf.rs index a6b9edf312dd..294a32d7f062 100644 --- a/python/src/udaf.rs +++ b/python/src/udaf.rs @@ -26,25 +26,26 @@ use datafusion::error::{DataFusionError, Result}; use datafusion::logical_plan; use datafusion::physical_plan::udaf::AggregateUDF; use datafusion::physical_plan::Accumulator; +use datafusion::physical_plan::aggregates::AccumulatorFunctionImplementation; use datafusion::scalar::ScalarValue; use crate::expression::PyExpr; use crate::utils::parse_volatility; #[derive(Debug)] -struct PyAccumulator { +struct RustAccumulator { accum: PyObject, } -impl PyAccumulator { +impl RustAccumulator { fn new(accum: PyObject) -> Self { Self { accum } } } -impl Accumulator for PyAccumulator { +impl Accumulator for RustAccumulator { fn state(&self) -> Result> { - Python::with_gil(|py| self.accum.as_ref(py).call_method0("to_scalars")?.extract()) + Python::with_gil(|py| self.accum.as_ref(py).call_method0("state")?.extract()) .map_err(|e| DataFusionError::Execution(format!("{}", e))) } @@ -66,16 +67,13 @@ impl Accumulator for PyAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { Python::with_gil(|py| { // 1. cast args to Pyarrow array - // 2. call function - - // 1. let py_args = values .iter() .map(|arg| arg.data().to_owned().to_pyarrow(py).unwrap()) .collect::>(); let py_args = PyTuple::new(py, py_args); - // update accumulator + // 2. call function self.accum .as_ref(py) .call_method1("update", py_args) @@ -94,7 +92,7 @@ impl Accumulator for PyAccumulator { .to_pyarrow(py) .map_err(|e| DataFusionError::Execution(format!("{}", e)))?; - // 2. merge + // 2. call merge self.accum .as_ref(py) .call_method1("merge", (state,)) @@ -106,20 +104,20 @@ impl Accumulator for PyAccumulator { } pub fn to_rust_accumulator( - accumulator: PyObject, -) -> Arc Result> + Send + Sync> { + accum: PyObject, +) -> AccumulatorFunctionImplementation { Arc::new(move || -> Result> { - let accumulator = Python::with_gil(|py| { - accumulator + let accum = Python::with_gil(|py| { + accum .call0(py) .map_err(|e| DataFusionError::Execution(format!("{}", e))) })?; - Ok(Box::new(PyAccumulator::new(accumulator))) + Ok(Box::new(RustAccumulator::new(accum))) }) } /// Represents a AggregateUDF -#[pyclass] +#[pyclass(name = "AggregateUDF", module = "datafusion", subclass)] #[derive(Debug, Clone)] pub struct PyAggregateUDF { pub(crate) function: AggregateUDF, @@ -127,7 +125,7 @@ pub struct PyAggregateUDF { #[pymethods] impl PyAggregateUDF { - #[new] + #[new(name, accumulator, input_type, return_type, state_type, volatility)] fn new( name: &str, accumulator: PyObject, @@ -136,16 +134,15 @@ impl PyAggregateUDF { state_type: Vec, volatility: &str, ) -> PyResult { - Ok(Self { - function: logical_plan::create_udaf( - &name, - input_type, - Arc::new(return_type), - parse_volatility(volatility)?, - to_rust_accumulator(accumulator), - Arc::new(state_type), - ), - }) + let function = logical_plan::create_udaf( + &name, + input_type, + Arc::new(return_type), + parse_volatility(volatility)?, + to_rust_accumulator(accumulator), + Arc::new(state_type), + ); + Ok(Self { function }) } /// creates a new PyExpr with the call of the udf diff --git a/python/src/udf.rs b/python/src/udf.rs index df2888c086a7..379c449870b2 100644 --- a/python/src/udf.rs +++ b/python/src/udf.rs @@ -25,7 +25,7 @@ use datafusion::arrow::pyarrow::PyArrowConvert; use datafusion::error::DataFusionError; use datafusion::logical_plan; use datafusion::physical_plan::functions::{ - make_scalar_function, ScalarFunctionImplementation, Volatility, + make_scalar_function, ScalarFunctionImplementation, }; use datafusion::physical_plan::udf::ScalarUDF; @@ -62,7 +62,7 @@ fn to_rust_function(func: PyObject) -> ScalarFunctionImplementation { } /// Represents a PyScalarUDF -#[pyclass] +#[pyclass(name = "ScalarUDF", module = "datafusion", subclass)] #[derive(Debug, Clone)] pub struct PyScalarUDF { pub(crate) function: ScalarUDF, @@ -70,10 +70,10 @@ pub struct PyScalarUDF { #[pymethods] impl PyScalarUDF { - #[new] + #[new(name, func, input_types, return_type, volatility)] fn new( name: &str, - fun: PyObject, + func: PyObject, input_types: Vec, return_type: DataType, volatility: &str, @@ -83,9 +83,9 @@ impl PyScalarUDF { input_types, Arc::new(return_type), parse_volatility(volatility)?, - to_rust_function(fun), + to_rust_function(func), ); - Ok(PyScalarUDF { function }) + Ok(Self { function }) } /// creates a new PyExpr with the call of the udf From 35ad333080e3e5aca4370dfb874beaa0840756f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Fri, 29 Oct 2021 23:32:53 +0200 Subject: [PATCH 21/21] Set public as the default database for the catalog --- python/datafusion/tests/test_catalog.py | 4 ++-- python/src/catalog.rs | 7 ++++--- python/src/udaf.rs | 6 ++---- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/python/datafusion/tests/test_catalog.py b/python/datafusion/tests/test_catalog.py index 5ae81d5521e1..2e64a810a718 100644 --- a/python/datafusion/tests/test_catalog.py +++ b/python/datafusion/tests/test_catalog.py @@ -58,8 +58,8 @@ def test_basic(ctx, database): default = ctx.catalog() assert default.names() == ["public"] - database = default.database("public") - assert database.names() == {"csv1", "csv", "csv2"} + for database in [default.database("public"), default.database()]: + assert database.names() == {"csv1", "csv", "csv2"} table = database.table("csv") assert table.kind == "physical" diff --git a/python/src/catalog.rs b/python/src/catalog.rs index 826ac7827ca6..f93c795ec34c 100644 --- a/python/src/catalog.rs +++ b/python/src/catalog.rs @@ -27,17 +27,17 @@ use datafusion::{ datasource::{TableProvider, TableType}, }; -#[pyclass(name = "Catalog", subclass)] +#[pyclass(name = "Catalog", module = "datafusion", subclass)] pub(crate) struct PyCatalog { catalog: Arc, } -#[pyclass(name = "Database", subclass)] +#[pyclass(name = "Database", module = "datafusion", subclass)] pub(crate) struct PyDatabase { database: Arc, } -#[pyclass(name = "Table", subclass)] +#[pyclass(name = "Table", module = "datafusion", subclass)] pub(crate) struct PyTable { table: Arc, } @@ -66,6 +66,7 @@ impl PyCatalog { self.catalog.schema_names() } + #[args(name = "\"public\"")] fn database(&self, name: &str) -> PyResult { match self.catalog.schema(name) { Some(database) => Ok(PyDatabase::new(database)), diff --git a/python/src/udaf.rs b/python/src/udaf.rs index 294a32d7f062..1de6e63205ed 100644 --- a/python/src/udaf.rs +++ b/python/src/udaf.rs @@ -24,9 +24,9 @@ use datafusion::arrow::datatypes::DataType; use datafusion::arrow::pyarrow::PyArrowConvert; use datafusion::error::{DataFusionError, Result}; use datafusion::logical_plan; +use datafusion::physical_plan::aggregates::AccumulatorFunctionImplementation; use datafusion::physical_plan::udaf::AggregateUDF; use datafusion::physical_plan::Accumulator; -use datafusion::physical_plan::aggregates::AccumulatorFunctionImplementation; use datafusion::scalar::ScalarValue; use crate::expression::PyExpr; @@ -103,9 +103,7 @@ impl Accumulator for RustAccumulator { } } -pub fn to_rust_accumulator( - accum: PyObject, -) -> AccumulatorFunctionImplementation { +pub fn to_rust_accumulator(accum: PyObject) -> AccumulatorFunctionImplementation { Arc::new(move || -> Result> { let accum = Python::with_gil(|py| { accum