Skip to content

Commit 774ea70

Browse files
authored
Implement to_pandas() (#197)
* Implement to_pandas() * Update documentation * Write unit test
1 parent bb004ee commit 774ea70

File tree

4 files changed

+34
-17
lines changed

4 files changed

+34
-17
lines changed

README.md

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ from having to lock the GIL when running those operations.
3636
Its query engine, DataFusion, is written in [Rust](https://www.rust-lang.org/), which makes strong assumptions
3737
about thread safety and lack of memory leaks.
3838

39-
There is also experimental support for executing SQL against other DataFrame libraries, such as Polars, Pandas, and any
39+
There is also experimental support for executing SQL against other DataFrame libraries, such as Polars, Pandas, and any
4040
drop-in replacements for Pandas.
4141

4242
Technically, zero-copy is achieved via the [c data interface](https://arrow.apache.org/docs/format/CDataInterface.html).
@@ -70,17 +70,11 @@ df = ctx.sql("select passenger_count, count(*) "
7070
"group by passenger_count "
7171
"order by passenger_count")
7272

73-
# collect as list of pyarrow.RecordBatch
74-
results = df.collect()
75-
76-
# get first batch
77-
batch = results[0]
78-
7973
# convert to Pandas
80-
df = batch.to_pandas()
74+
pandas_df = df.to_pandas()
8175

8276
# create a chart
83-
fig = df.plot(kind="bar", title="Trip Count by Number of Passengers").get_figure()
77+
fig = pandas_df.plot(kind="bar", title="Trip Count by Number of Passengers").get_figure()
8478
fig.savefig('chart.png')
8579
```
8680

datafusion/tests/test_dataframe.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,3 +533,14 @@ def test_cache(df):
533533
def test_count(df):
534534
# Get number of rows
535535
assert df.count() == 3
536+
537+
538+
def test_to_pandas(df):
539+
# Skip test if pandas is not installed
540+
pd = pytest.importorskip("pandas")
541+
542+
# Convert datafusion dataframe to pandas dataframe
543+
pandas_df = df.to_pandas()
544+
assert type(pandas_df) == pd.DataFrame
545+
assert pandas_df.shape == (3, 3)
546+
assert set(pandas_df.columns) == {"a", "b", "c"}

examples/sql-to-pandas.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,11 @@
3333
"order by passenger_count"
3434
)
3535

36-
# collect as list of pyarrow.RecordBatch
37-
results = df.collect()
38-
39-
# get first batch
40-
batch = results[0]
41-
4236
# convert to Pandas
43-
df = batch.to_pandas()
37+
pandas_df = df.to_pandas()
4438

4539
# create a chart
46-
fig = df.plot(
40+
fig = pandas_df.plot(
4741
kind="bar", title="Trip Count by Number of Passengers"
4842
).get_figure()
4943
fig.savefig("chart.png")

src/dataframe.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,24 @@ impl PyDataFrame {
313313
Ok(())
314314
}
315315

316+
/// Convert to pandas dataframe with pyarrow
317+
/// Collect the batches, pass to Arrow Table & then convert to Pandas DataFrame
318+
fn to_pandas(&self, py: Python) -> PyResult<PyObject> {
319+
let batches = self.collect(py);
320+
321+
Python::with_gil(|py| {
322+
// Instantiate pyarrow Table object and use its from_batches method
323+
let table_class = py.import("pyarrow")?.getattr("Table")?;
324+
let args = PyTuple::new(py, batches);
325+
let table: PyObject = table_class.call_method1("from_batches", args)?.into();
326+
327+
// Use Table.to_pandas() method to convert batches to pandas dataframe
328+
// See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pandas
329+
let result = table.call_method0(py, "to_pandas")?;
330+
Ok(result)
331+
})
332+
}
333+
316334
// Executes this DataFrame to get the total number of rows.
317335
fn count(&self, py: Python) -> PyResult<usize> {
318336
Ok(wait_for_future(py, self.df.as_ref().clone().count())?)

0 commit comments

Comments
 (0)