diff --git a/README.md b/README.md index d465ebcd0..65f6ef3e0 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ from having to lock the GIL when running those operations. Its query engine, DataFusion, is written in [Rust](https://www.rust-lang.org/), which makes strong assumptions about thread safety and lack of memory leaks. -There is also experimental support for executing SQL against other DataFrame libraries, such as Polars, Pandas, and any +There is also experimental support for executing SQL against other DataFrame libraries, such as Polars, Pandas, and any drop-in replacements for Pandas. Technically, zero-copy is achieved via the [c data interface](https://arrow.apache.org/docs/format/CDataInterface.html). @@ -70,17 +70,11 @@ df = ctx.sql("select passenger_count, count(*) " "group by passenger_count " "order by passenger_count") -# collect as list of pyarrow.RecordBatch -results = df.collect() - -# get first batch -batch = results[0] - # convert to Pandas -df = batch.to_pandas() +pandas_df = df.to_pandas() # create a chart -fig = df.plot(kind="bar", title="Trip Count by Number of Passengers").get_figure() +fig = pandas_df.plot(kind="bar", title="Trip Count by Number of Passengers").get_figure() fig.savefig('chart.png') ``` diff --git a/datafusion/tests/test_dataframe.py b/datafusion/tests/test_dataframe.py index 18946888f..292a4b00c 100644 --- a/datafusion/tests/test_dataframe.py +++ b/datafusion/tests/test_dataframe.py @@ -533,3 +533,14 @@ def test_cache(df): def test_count(df): # Get number of rows assert df.count() == 3 + + +def test_to_pandas(df): + # Skip test if pandas is not installed + pd = pytest.importorskip("pandas") + + # Convert datafusion dataframe to pandas dataframe + pandas_df = df.to_pandas() + assert type(pandas_df) == pd.DataFrame + assert pandas_df.shape == (3, 3) + assert set(pandas_df.columns) == {"a", "b", "c"} diff --git a/examples/sql-to-pandas.py b/examples/sql-to-pandas.py index 3569e6d8c..3e99b22de 100644 --- a/examples/sql-to-pandas.py +++ b/examples/sql-to-pandas.py @@ -33,17 +33,11 @@ "order by passenger_count" ) -# collect as list of pyarrow.RecordBatch -results = df.collect() - -# get first batch -batch = results[0] - # convert to Pandas -df = batch.to_pandas() +pandas_df = df.to_pandas() # create a chart -fig = df.plot( +fig = pandas_df.plot( kind="bar", title="Trip Count by Number of Passengers" ).get_figure() fig.savefig("chart.png") diff --git a/src/dataframe.rs b/src/dataframe.rs index 4b9fbca6c..a1c68dd1c 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -313,6 +313,24 @@ impl PyDataFrame { Ok(()) } + /// Convert to pandas dataframe with pyarrow + /// Collect the batches, pass to Arrow Table & then convert to Pandas DataFrame + fn to_pandas(&self, py: Python) -> PyResult { + let batches = self.collect(py); + + Python::with_gil(|py| { + // Instantiate pyarrow Table object and use its from_batches method + let table_class = py.import("pyarrow")?.getattr("Table")?; + let args = PyTuple::new(py, batches); + let table: PyObject = table_class.call_method1("from_batches", args)?.into(); + + // Use Table.to_pandas() method to convert batches to pandas dataframe + // See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pandas + let result = table.call_method0(py, "to_pandas")?; + Ok(result) + }) + } + // Executes this DataFrame to get the total number of rows. fn count(&self, py: Python) -> PyResult { Ok(wait_for_future(py, self.df.as_ref().clone().count())?)