diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index 9ad5a49bb428..6d39ceaf059d 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -29,7 +29,7 @@ Run `git submodule update --init` to init test files. - [`avro_sql.rs`](examples/avro_sql.rs): Build and run a query plan from a SQL statement against a local AVRO file - [`csv_sql.rs`](examples/csv_sql.rs): Build and run a query plan from a SQL statement against a local CSV file -- [`custom_datasource.rs`](examples/custom_datasource.rs): Run queris against a custom datasource (TableProvider) +- [`custom_datasource.rs`](examples/custom_datasource.rs): Run queries against a custom datasource (TableProvider) - [`dataframe.rs`](examples/dataframe.rs): Run a query using a DataFrame against a local parquet file - [`dataframe_in_memory.rs`](examples/dataframe_in_memory.rs): Run a query using a DataFrame against data in memory - [`deserialize_to_struct.rs`](examples/deserialize_to_struct.rs): Convert query results into rust structs using serde diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs index e9773dbdf372..f8ee54ca0837 100644 --- a/datafusion/core/src/dataframe.rs +++ b/datafusion/core/src/dataframe.rs @@ -20,7 +20,9 @@ use std::any::Any; use std::sync::Arc; +use arrow::array::Int64Array; use async_trait::async_trait; +use datafusion_common::DataFusionError; use parquet::file::properties::WriterProperties; use datafusion_common::{Column, DFSchema, ScalarValue}; @@ -361,6 +363,39 @@ impl DataFrame { Ok(DataFrame::new(self.session_state, plan)) } + /// Run a count aggregate on the DataFrame and execute the DataFrame to collect this + /// count and return it as a usize, to find the total number of rows after executing + /// the DataFrame. + /// ``` + /// # use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let ctx = SessionContext::new(); + /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let count = df.count().await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn count(self) -> Result { + let rows = self + .aggregate( + vec![], + vec![datafusion_expr::count(Expr::Literal(ScalarValue::Null))], + )? + .collect() + .await?; + let len = *rows + .first() + .and_then(|r| r.columns().first()) + .and_then(|c| c.as_any().downcast_ref::()) + .and_then(|a| a.values().first()) + .ok_or(DataFusionError::Internal( + "Unexpected output when collecting for count()".to_string(), + ))? as usize; + Ok(len) + } + /// Convert the logical plan represented by this DataFrame into a physical plan and /// execute it, collecting all resulting batches into memory /// Executes this DataFrame and collects all results into a vector of RecordBatch. @@ -1001,6 +1036,13 @@ mod tests { Ok(()) } + #[tokio::test] + async fn df_count() -> Result<()> { + let count = test_table().await?.count().await?; + assert_eq!(100, count); + Ok(()) + } + #[tokio::test] async fn explain() -> Result<()> { // build query using Table API diff --git a/docs/source/user-guide/dataframe.md b/docs/source/user-guide/dataframe.md index 23766cd07bdb..5ba803fce7ef 100644 --- a/docs/source/user-guide/dataframe.md +++ b/docs/source/user-guide/dataframe.md @@ -86,6 +86,7 @@ These methods execute the logical plan represented by the DataFrame and either c | -------------------------- | --------------------------------------------------------------------------------------------------------------------------- | | collect | Executes this DataFrame and collects all results into a vector of RecordBatch. | | collect_partitioned | Executes this DataFrame and collects all results into a vector of vector of RecordBatch maintaining the input partitioning. | +| count | Executes this DataFrame to get the total number of rows. | | execute_stream | Executes this DataFrame and returns a stream over a single partition. | | execute_stream_partitioned | Executes this DataFrame and returns one stream per partition. | | show | Execute this DataFrame and print the results to stdout. |