Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DataFrame count method #5071

Merged
merged 5 commits into from
Jan 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion-examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Run `git submodule update --init` to init test files.

- [`avro_sql.rs`](examples/avro_sql.rs): Build and run a query plan from a SQL statement against a local AVRO file
- [`csv_sql.rs`](examples/csv_sql.rs): Build and run a query plan from a SQL statement against a local CSV file
- [`custom_datasource.rs`](examples/custom_datasource.rs): Run queris against a custom datasource (TableProvider)
- [`custom_datasource.rs`](examples/custom_datasource.rs): Run queries against a custom datasource (TableProvider)
- [`dataframe.rs`](examples/dataframe.rs): Run a query using a DataFrame against a local parquet file
- [`dataframe_in_memory.rs`](examples/dataframe_in_memory.rs): Run a query using a DataFrame against data in memory
- [`deserialize_to_struct.rs`](examples/deserialize_to_struct.rs): Convert query results into rust structs using serde
Expand Down
42 changes: 42 additions & 0 deletions datafusion/core/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
use std::any::Any;
use std::sync::Arc;

use arrow::array::Int64Array;
use async_trait::async_trait;
use datafusion_common::DataFusionError;
use parquet::file::properties::WriterProperties;

use datafusion_common::{Column, DFSchema, ScalarValue};
Expand Down Expand Up @@ -361,6 +363,39 @@ impl DataFrame {
Ok(DataFrame::new(self.session_state, plan))
}

/// Run a count aggregate on the DataFrame and execute the DataFrame to collect this
/// count and return it as a usize, to find the total number of rows after executing
/// the DataFrame.
/// ```
/// # use datafusion::prelude::*;
/// # use datafusion::error::Result;
/// # #[tokio::main]
/// # async fn main() -> Result<()> {
/// let ctx = SessionContext::new();
/// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?;
/// let count = df.count().await?;
/// # Ok(())
/// # }
/// ```
pub async fn count(self) -> Result<usize> {
let rows = self
.aggregate(
vec![],
vec![datafusion_expr::count(Expr::Literal(ScalarValue::Null))],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was actually working by accident as a result of a quirk of NullArray whereby it doesn't have a null buffer despite all values being null. The count reported by this query should be 0 as only non-null values are counted.

Fix in https://github.com/apache/arrow-datafusion/pull/5612/files#diff-932cfd7271917561280a69edabb35cfd109d22ae736f77e85adcf63455918121R630

)?
.collect()
.await?;
let len = *rows
.first()
.and_then(|r| r.columns().first())
.and_then(|c| c.as_any().downcast_ref::<Int64Array>())
.and_then(|a| a.values().first())
.ok_or(DataFusionError::Internal(
"Unexpected output when collecting for count()".to_string(),
))? as usize;
Ok(len)
}

/// Convert the logical plan represented by this DataFrame into a physical plan and
/// execute it, collecting all resulting batches into memory
/// Executes this DataFrame and collects all results into a vector of RecordBatch.
Expand Down Expand Up @@ -1001,6 +1036,13 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn df_count() -> Result<()> {
let count = test_table().await?.count().await?;
assert_eq!(100, count);
Ok(())
}

#[tokio::test]
async fn explain() -> Result<()> {
// build query using Table API
Expand Down
1 change: 1 addition & 0 deletions docs/source/user-guide/dataframe.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ These methods execute the logical plan represented by the DataFrame and either c
| -------------------------- | --------------------------------------------------------------------------------------------------------------------------- |
| collect | Executes this DataFrame and collects all results into a vector of RecordBatch. |
| collect_partitioned | Executes this DataFrame and collects all results into a vector of vector of RecordBatch maintaining the input partitioning. |
| count | Executes this DataFrame to get the total number of rows. |
| execute_stream | Executes this DataFrame and returns a stream over a single partition. |
| execute_stream_partitioned | Executes this DataFrame and returns one stream per partition. |
| show | Execute this DataFrame and print the results to stdout. |
Expand Down