From 2fb3c74a6653d1abb4310ccb37792d54b8bfc890 Mon Sep 17 00:00:00 2001 From: Jefffrey <22608443+Jefffrey@users.noreply.github.com> Date: Thu, 26 Jan 2023 11:44:14 +1100 Subject: [PATCH 1/4] DataFrame len method --- datafusion-examples/README.md | 2 +- datafusion/core/src/dataframe.rs | 39 ++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index 9ad5a49bb428..6d39ceaf059d 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -29,7 +29,7 @@ Run `git submodule update --init` to init test files. - [`avro_sql.rs`](examples/avro_sql.rs): Build and run a query plan from a SQL statement against a local AVRO file - [`csv_sql.rs`](examples/csv_sql.rs): Build and run a query plan from a SQL statement against a local CSV file -- [`custom_datasource.rs`](examples/custom_datasource.rs): Run queris against a custom datasource (TableProvider) +- [`custom_datasource.rs`](examples/custom_datasource.rs): Run queries against a custom datasource (TableProvider) - [`dataframe.rs`](examples/dataframe.rs): Run a query using a DataFrame against a local parquet file - [`dataframe_in_memory.rs`](examples/dataframe_in_memory.rs): Run a query using a DataFrame against data in memory - [`deserialize_to_struct.rs`](examples/deserialize_to_struct.rs): Convert query results into rust structs using serde diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs index e9773dbdf372..f8d9df68bbaf 100644 --- a/datafusion/core/src/dataframe.rs +++ b/datafusion/core/src/dataframe.rs @@ -20,7 +20,10 @@ use std::any::Any; use std::sync::Arc; +use arrow::array::Int64Array; use async_trait::async_trait; +use datafusion_common::DataFusionError; +use datafusion_expr::count; use parquet::file::properties::WriterProperties; use datafusion_common::{Column, DFSchema, ScalarValue}; @@ -361,6 +364,35 @@ impl DataFrame { Ok(DataFrame::new(self.session_state, plan)) } + /// Run a count aggregate on the DataFrame and execute the DataFrame to collect this + /// count and return it as a usize. + /// ``` + /// # use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let ctx = SessionContext::new(); + /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let length = df.len().await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn len(self) -> Result { + let rows = self + .aggregate(vec![], vec![count(Expr::Literal(ScalarValue::Null))])? + .collect() + .await?; + let len = *rows + .first() + .and_then(|r| r.columns().first()) + .and_then(|c| c.as_any().downcast_ref::()) + .and_then(|a| a.values().first()) + .ok_or(DataFusionError::Internal( + "Unexpected output when collecting for len".to_string(), + ))? as usize; + Ok(len) + } + /// Convert the logical plan represented by this DataFrame into a physical plan and /// execute it, collecting all resulting batches into memory /// Executes this DataFrame and collects all results into a vector of RecordBatch. @@ -1001,6 +1033,13 @@ mod tests { Ok(()) } + #[tokio::test] + async fn len() -> Result<()> { + let len = test_table().await?.len().await?; + assert_eq!(100, len); + Ok(()) + } + #[tokio::test] async fn explain() -> Result<()> { // build query using Table API From 21cb7936896c52cacec64d6b616381bfecb2c566 Mon Sep 17 00:00:00 2001 From: Jefffrey <22608443+Jefffrey@users.noreply.github.com> Date: Thu, 26 Jan 2023 21:31:30 +1100 Subject: [PATCH 2/4] Rename len to count --- datafusion/core/src/dataframe.rs | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs index f8d9df68bbaf..f8fed7016b49 100644 --- a/datafusion/core/src/dataframe.rs +++ b/datafusion/core/src/dataframe.rs @@ -365,7 +365,8 @@ impl DataFrame { } /// Run a count aggregate on the DataFrame and execute the DataFrame to collect this - /// count and return it as a usize. + /// count and return it as a usize, to find the total number of rows after executing + /// the DataFrame. /// ``` /// # use datafusion::prelude::*; /// # use datafusion::error::Result; @@ -373,11 +374,11 @@ impl DataFrame { /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; - /// let length = df.len().await?; + /// let count = df.count().await?; /// # Ok(()) /// # } /// ``` - pub async fn len(self) -> Result { + pub async fn count(self) -> Result { let rows = self .aggregate(vec![], vec![count(Expr::Literal(ScalarValue::Null))])? .collect() @@ -388,7 +389,7 @@ impl DataFrame { .and_then(|c| c.as_any().downcast_ref::()) .and_then(|a| a.values().first()) .ok_or(DataFusionError::Internal( - "Unexpected output when collecting for len".to_string(), + "Unexpected output when collecting for count()".to_string(), ))? as usize; Ok(len) } @@ -1034,9 +1035,9 @@ mod tests { } #[tokio::test] - async fn len() -> Result<()> { - let len = test_table().await?.len().await?; - assert_eq!(100, len); + async fn count() -> Result<()> { + let count = test_table().await?.count().await?; + assert_eq!(100, count); Ok(()) } From 48b2e9d2052dfc93646413b56a50f31b3e95d079 Mon Sep 17 00:00:00 2001 From: Jefffrey <22608443+Jefffrey@users.noreply.github.com> Date: Thu, 26 Jan 2023 21:54:32 +1100 Subject: [PATCH 3/4] Fixes --- datafusion/core/src/dataframe.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs index f8fed7016b49..f8ee54ca0837 100644 --- a/datafusion/core/src/dataframe.rs +++ b/datafusion/core/src/dataframe.rs @@ -23,7 +23,6 @@ use std::sync::Arc; use arrow::array::Int64Array; use async_trait::async_trait; use datafusion_common::DataFusionError; -use datafusion_expr::count; use parquet::file::properties::WriterProperties; use datafusion_common::{Column, DFSchema, ScalarValue}; @@ -380,7 +379,10 @@ impl DataFrame { /// ``` pub async fn count(self) -> Result { let rows = self - .aggregate(vec![], vec![count(Expr::Literal(ScalarValue::Null))])? + .aggregate( + vec![], + vec![datafusion_expr::count(Expr::Literal(ScalarValue::Null))], + )? .collect() .await?; let len = *rows @@ -1035,7 +1037,7 @@ mod tests { } #[tokio::test] - async fn count() -> Result<()> { + async fn df_count() -> Result<()> { let count = test_table().await?.count().await?; assert_eq!(100, count); Ok(()) From daa63b533018438d728fc5bbb2e70bb45915ca9d Mon Sep 17 00:00:00 2001 From: Jefffrey <22608443+Jefffrey@users.noreply.github.com> Date: Fri, 27 Jan 2023 06:35:15 +1100 Subject: [PATCH 4/4] Update user guide --- docs/source/user-guide/dataframe.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/user-guide/dataframe.md b/docs/source/user-guide/dataframe.md index 23766cd07bdb..5ba803fce7ef 100644 --- a/docs/source/user-guide/dataframe.md +++ b/docs/source/user-guide/dataframe.md @@ -86,6 +86,7 @@ These methods execute the logical plan represented by the DataFrame and either c | -------------------------- | --------------------------------------------------------------------------------------------------------------------------- | | collect | Executes this DataFrame and collects all results into a vector of RecordBatch. | | collect_partitioned | Executes this DataFrame and collects all results into a vector of vector of RecordBatch maintaining the input partitioning. | +| count | Executes this DataFrame to get the total number of rows. | | execute_stream | Executes this DataFrame and returns a stream over a single partition. | | execute_stream_partitioned | Executes this DataFrame and returns one stream per partition. | | show | Execute this DataFrame and print the results to stdout. |