Skip to content

Commit

Permalink
update datafusion to 5.1.0 for python binding (#967)
Browse files Browse the repository at this point in the history
* update datafusion to 5.1.0 for python binding
  • Loading branch information
QP Hou authored Sep 8, 2021
1 parent 50cce1a commit bb616bf
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 11 deletions.
8 changes: 6 additions & 2 deletions python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.

[package]
name = "datafusion"
name = "datafusion-python"
version = "0.3.0"
homepage = "https://github.com/apache/arrow"
repository = "https://github.com/apache/arrow"
Expand All @@ -31,7 +31,11 @@ libc = "0.2"
tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] }
rand = "0.7"
pyo3 = { version = "0.14.1", features = ["extension-module", "abi3", "abi3-py36"] }
datafusion = { git = "https://github.com/apache/arrow-datafusion.git", rev = "4d61196dee8526998aee7e7bb10ea88422e5f9e1" }
datafusion = { path = "../datafusion", version = "5.1.0" }
# workaround for a bug introduced in
# https://github.com/dtolnay/proc-macro2/pull/286
# TODO: remove this version pin after upstream releases a fix
proc-macro2 = { version = "=1.0.28" }

[lib]
name = "datafusion"
Expand Down
10 changes: 7 additions & 3 deletions python/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,13 @@ impl DataFrame {
Ok(pretty::print_batches(&batches).unwrap())
}


/// Returns the join of two DataFrames `on`.
fn join(&self, right: &DataFrame, on: Vec<&str>, how: &str) -> PyResult<Self> {
fn join(
&self,
right: &DataFrame,
join_keys: (Vec<&str>, Vec<&str>),
how: &str,
) -> PyResult<Self> {
let builder = LogicalPlanBuilder::from(self.plan.clone());

let join_type = match how {
Expand All @@ -182,7 +186,7 @@ impl DataFrame {
}
};

let builder = errors::wrap(builder.join(&right.plan, join_type, on.clone(), on))?;
let builder = errors::wrap(builder.join(&right.plan, join_type, join_keys))?;

let plan = errors::wrap(builder.build())?;

Expand Down
2 changes: 1 addition & 1 deletion python/tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_join():
)
df1 = ctx.create_dataframe([[batch]])

df = df.join(df1, on="a", how="inner")
df = df.join(df1, join_keys=(["a"], ["a"]), how="inner")
df = df.sort([f.col("a").sort(ascending=True)])
table = pa.Table.from_batches(df.collect())

Expand Down
10 changes: 5 additions & 5 deletions python/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_register_csv(ctx, tmp_path):
for table in ["csv", "csv1", "csv2"]:
result = ctx.sql(f"SELECT COUNT(int) FROM {table}").collect()
result = pa.Table.from_batches(result)
assert result.to_pydict() == {"COUNT(int)": [4]}
assert result.to_pydict() == {f"COUNT({table}.int)": [4]}

result = ctx.sql("SELECT * FROM csv3").collect()
result = pa.Table.from_batches(result)
Expand All @@ -88,7 +88,7 @@ def test_register_parquet(ctx, tmp_path):

result = ctx.sql("SELECT COUNT(a) FROM t").collect()
result = pa.Table.from_batches(result)
assert result.to_pydict() == {"COUNT(a)": [100]}
assert result.to_pydict() == {"COUNT(t.a)": [100]}


def test_execute(ctx, tmp_path):
Expand Down Expand Up @@ -123,8 +123,8 @@ def test_execute(ctx, tmp_path):
result_values = []
for result in results:
pydict = result.to_pydict()
result_keys.extend(pydict["CAST(a AS Int32)"])
result_values.extend(pydict["COUNT(a)"])
result_keys.extend(pydict["CAST(t.a AS Int32)"])
result_values.extend(pydict["COUNT(t.a)"])

result_keys, result_values = (
list(t) for t in zip(*sorted(zip(result_keys, result_values)))
Expand All @@ -141,7 +141,7 @@ def test_execute(ctx, tmp_path):
expected_cast = pa.array([50, 50], pa.int32())
expected = [
pa.RecordBatch.from_arrays(
[expected_a, expected_cast], ["a", "CAST(a AS Int32)"]
[expected_a, expected_cast], ["a", "CAST(t.a AS Int32)"]
)
]
np.testing.assert_equal(expected[0].column(1), expected[0].column(1))
Expand Down

0 comments on commit bb616bf

Please sign in to comment.