From 8fbaa3aba0437128a03826997db515f89be43ce1 Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Fri, 3 Sep 2021 22:03:47 -0700 Subject: [PATCH] update datafusion to 5.1.0 for python binding --- python/Cargo.toml | 5 +++-- python/src/dataframe.rs | 10 +++++++--- python/tests/test_df.py | 2 +- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/python/Cargo.toml b/python/Cargo.toml index 8dba538ae0c7..c2a2957be94b 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -16,7 +16,7 @@ # under the License. [package] -name = "datafusion" +name = "datafusion-python" version = "0.3.0" homepage = "https://github.com/apache/arrow" repository = "https://github.com/apache/arrow" @@ -31,7 +31,8 @@ libc = "0.2" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] } rand = "0.7" pyo3 = { version = "0.14.1", features = ["extension-module", "abi3", "abi3-py36"] } -datafusion = { git = "https://github.com/apache/arrow-datafusion.git", rev = "4d61196dee8526998aee7e7bb10ea88422e5f9e1" } +datafusion = { path = "../datafusion", version = "5.1.0" } +proc-macro2 = { version = "=1.0.28" } [lib] name = "datafusion" diff --git a/python/src/dataframe.rs b/python/src/dataframe.rs index 8e5657ba2f8a..0885ae367a8e 100644 --- a/python/src/dataframe.rs +++ b/python/src/dataframe.rs @@ -161,9 +161,13 @@ impl DataFrame { Ok(pretty::print_batches(&batches).unwrap()) } - /// Returns the join of two DataFrames `on`. - fn join(&self, right: &DataFrame, on: Vec<&str>, how: &str) -> PyResult { + fn join( + &self, + right: &DataFrame, + join_keys: (Vec<&str>, Vec<&str>), + how: &str, + ) -> PyResult { let builder = LogicalPlanBuilder::from(self.plan.clone()); let join_type = match how { @@ -182,7 +186,7 @@ impl DataFrame { } }; - let builder = errors::wrap(builder.join(&right.plan, join_type, on.clone(), on))?; + let builder = errors::wrap(builder.join(&right.plan, join_type, join_keys))?; let plan = errors::wrap(builder.build())?; diff --git a/python/tests/test_df.py b/python/tests/test_df.py index 5b6cbddbd74b..0b19da48160e 100644 --- a/python/tests/test_df.py +++ b/python/tests/test_df.py @@ -104,7 +104,7 @@ def test_join(): ) df1 = ctx.create_dataframe([[batch]]) - df = df.join(df1, on="a", how="inner") + df = df.join(df1, on=("a", "a"), how="inner") df = df.sort([f.col("a").sort(ascending=True)]) table = pa.Table.from_batches(df.collect())