Skip to content

Commit

Permalink
Use hash repartitioning for group bys on dictionaries
Browse files Browse the repository at this point in the history
  • Loading branch information
isidentical committed Sep 11, 2022
1 parent 8df5496 commit 8d05c9f
Showing 1 changed file with 34 additions and 11 deletions.
45 changes: 34 additions & 11 deletions datafusion/core/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ use crate::{
physical_plan::displayable,
};
use arrow::compute::SortOptions;
use arrow::datatypes::DataType;
use arrow::datatypes::{Schema, SchemaRef};
use async_trait::async_trait;
use datafusion_common::ScalarValue;
Expand Down Expand Up @@ -651,17 +650,9 @@ impl DefaultPhysicalPlanner {
// update group column indices based on partial aggregate plan evaluation
let final_group: Vec<Arc<dyn PhysicalExpr>> = initial_aggr.output_group_expr();

// TODO: dictionary type not yet supported in Hash Repartition
let contains_dict = groups
.expr()
.iter()
.flat_map(|x| x.0.data_type(physical_input_schema.as_ref()))
.any(|x| matches!(x, DataType::Dictionary(_, _)));

let can_repartition = !groups.is_empty()
&& session_state.config.target_partitions > 1
&& session_state.config.repartition_aggregations
&& !contains_dict;
&& session_state.config.repartition_aggregations;

let (initial_aggr, next_partition_mode): (
Arc<dyn ExecutionPlan>,
Expand Down Expand Up @@ -1664,6 +1655,7 @@ fn tuple_err<T, R>(value: (Result<T>, Result<R>)) -> Result<(T, R)> {
mod tests {
use super::*;
use crate::assert_contains;
use crate::datasource::MemTable;
use crate::execution::context::TaskContext;
use crate::execution::options::CsvReadOptions;
use crate::execution::runtime_env::RuntimeEnv;
Expand All @@ -1677,7 +1669,9 @@ mod tests {
use crate::{
logical_plan::LogicalPlanBuilder, physical_plan::SendableRecordBatchStream,
};
use arrow::datatypes::{DataType, Field, SchemaRef};
use arrow::array::{ArrayRef, DictionaryArray, Int32Array};
use arrow::datatypes::{DataType, Field, Int32Type, SchemaRef};
use arrow::record_batch::RecordBatch;
use datafusion_common::{DFField, DFSchema, DFSchemaRef};
use datafusion_expr::expr::GroupingSet;
use datafusion_expr::sum;
Expand Down Expand Up @@ -2087,6 +2081,35 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn hash_agg_group_by_partitioned_on_dicts() -> Result<()> {
let dict_array: DictionaryArray<Int32Type> =
vec!["A", "B", "A", "A", "C", "A"].into_iter().collect();
let val_array: Int32Array = vec![1, 2, 2, 4, 1, 1].into();

let batch = RecordBatch::try_from_iter(vec![
("d1", Arc::new(dict_array) as ArrayRef),
("d2", Arc::new(val_array) as ArrayRef),
])
.unwrap();

let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?;
let ctx = SessionContext::new();

let logical_plan =
LogicalPlanBuilder::from(ctx.read_table(Arc::new(table))?.to_logical_plan()?)
.aggregate(vec![col("d1")], vec![sum(col("d2"))])?
.build()?;

let execution_plan = plan(&logical_plan).await?;
let formatted = format!("{:?}", execution_plan);

// Make sure the plan contains a FinalPartitioned, which means it will not use the Final
// mode in Aggregate (which is slower)
assert!(formatted.contains("FinalPartitioned"));
Ok(())
}

#[tokio::test]
async fn hash_agg_grouping_set_by_partitioned() -> Result<()> {
let grouping_set_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![
Expand Down

0 comments on commit 8d05c9f

Please sign in to comment.