Skip to content

Commit

Permalink
[DataFrame] - Add cache function for DataFrame (#3512)
Browse files Browse the repository at this point in the history
* feat: add cache function for dataframe

* fix: function doc typo

* fix: test issue
  • Loading branch information
francis-du authored Sep 18, 2022
1 parent 12f047e commit 86a8236
Showing 1 changed file with 62 additions and 3 deletions.
65 changes: 62 additions & 3 deletions datafusion/core/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use crate::arrow::datatypes::Schema;
use crate::arrow::datatypes::SchemaRef;
use crate::arrow::record_batch::RecordBatch;
use crate::arrow::util::pretty;
use crate::datasource::TableProvider;
use crate::datasource::{MemTable, TableProvider};
use crate::error::Result;
use crate::execution::{
context::{SessionState, TaskContext},
Expand All @@ -35,6 +35,7 @@ use crate::physical_plan::file_format::{plan_to_csv, plan_to_json, plan_to_parqu
use crate::physical_plan::SendableRecordBatchStream;
use crate::physical_plan::{collect, collect_partitioned};
use crate::physical_plan::{execute_stream, execute_stream_partitioned, ExecutionPlan};
use crate::prelude::SessionContext;
use crate::scalar::ScalarValue;
use async_trait::async_trait;
use parking_lot::RwLock;
Expand Down Expand Up @@ -733,6 +734,29 @@ impl DataFrame {
)))
}
}

/// Cache DataFrame as a memory table.
///
/// ```
/// # use datafusion::prelude::*;
/// # use datafusion::error::Result;
/// # #[tokio::main]
/// # async fn main() -> Result<()> {
/// let ctx = SessionContext::new();
/// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?;
/// let df = df.cache().await?;
/// # Ok(())
/// # }
/// ```
pub async fn cache(&self) -> Result<Arc<DataFrame>> {
let mem_table = MemTable::try_new(
SchemaRef::from(self.schema().clone()),
self.collect_partitioned().await?,
)?;

SessionContext::with_state(self.session_state.read().clone())
.read_table(Arc::new(mem_table))
}
}

// TODO: This will introduce a ref cycle (#2659)
Expand Down Expand Up @@ -1082,6 +1106,7 @@ mod tests {
);
Ok(())
}

/// Compare the formatted string representation of two plans for equality
fn assert_same_plan(plan1: &LogicalPlan, plan2: &LogicalPlan) {
assert_eq!(format!("{:?}", plan1), format!("{:?}", plan2));
Expand Down Expand Up @@ -1265,7 +1290,7 @@ mod tests {
\n Inner Join: #t1.c1 = #t2.c1\
\n TableScan: t1\
\n TableScan: t2",
format!("{:?}", df_renamed.to_unoptimized_plan())
format!("{:?}", df_renamed.to_unoptimized_plan())
);

assert_eq!("\
Expand All @@ -1275,7 +1300,7 @@ mod tests {
\n Inner Join: #t1.c1 = #t2.c1\
\n TableScan: t1 projection=[c1, c2, c3]\
\n TableScan: t2 projection=[c1, c2, c3]",
format!("{:?}", df_renamed.to_logical_plan()?)
format!("{:?}", df_renamed.to_logical_plan()?)
);

let df_results = df_renamed.collect().await?;
Expand Down Expand Up @@ -1303,6 +1328,7 @@ mod tests {
.with_column("sum", cast(col("c2") + col("c3"), DataType::Int64))?;

let df_results = df.collect().await?;
df.show().await?;
assert_batches_sorted_eq!(
vec![
"+----+----+-----+",
Expand Down Expand Up @@ -1353,4 +1379,37 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn cache_test() -> Result<()> {
let df = test_table()
.await?
.select_columns(&["c2", "c3"])?
.limit(0, Some(1))?
.with_column("sum", cast(col("c2") + col("c3"), DataType::Int64))?;

let cached_df = df.cache().await?;

assert_eq!(
"TableScan: ?table? projection=[c2, c3, sum]",
format!("{:?}", cached_df.to_logical_plan()?)
);

let df_results = df.collect().await?;
let cached_df_results = cached_df.collect().await?;
assert_batches_sorted_eq!(
vec![
"+----+----+-----+",
"| c2 | c3 | sum |",
"+----+----+-----+",
"| 2 | 1 | 3 |",
"+----+----+-----+",
],
&cached_df_results
);

assert_eq!(&df_results, &cached_df_results);

Ok(())
}
}

0 comments on commit 86a8236

Please sign in to comment.