Skip to content

Commit

Permalink
Fix regression by reverting Materialize dictionaries in group keys (#…
Browse files Browse the repository at this point in the history
…8740)

* revert eb8aff7 / Materialize dictionaries in group keys

* Update tests

* Update tests
  • Loading branch information
alamb authored Jan 8, 2024
1 parent dd4263f commit ff27d90
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 48 deletions.
15 changes: 12 additions & 3 deletions datafusion/core/tests/path_partition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,9 @@ async fn parquet_distinct_partition_col() -> Result<()> {
assert_eq!(min_limit, resulting_limit);

let s = ScalarValue::try_from_array(results[0].column(1), 0)?;
let month = match s {
ScalarValue::Utf8(Some(month)) => month,
s => panic!("Expected month as Utf8 found {s:?}"),
let month = match extract_as_utf(&s) {
Some(month) => month,
s => panic!("Expected month as Dict(_, Utf8) found {s:?}"),
};

let sql_on_partition_boundary = format!(
Expand All @@ -191,6 +191,15 @@ async fn parquet_distinct_partition_col() -> Result<()> {
Ok(())
}

fn extract_as_utf(v: &ScalarValue) -> Option<String> {
if let ScalarValue::Dictionary(_, v) = v {
if let ScalarValue::Utf8(v) = v.as_ref() {
return v.clone();
}
}
None
}

#[tokio::test]
async fn csv_filter_with_file_col() -> Result<()> {
let ctx = SessionContext::new();
Expand Down
27 changes: 23 additions & 4 deletions datafusion/physical-plan/src/aggregates/group_values/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,22 @@

use crate::aggregates::group_values::GroupValues;
use ahash::RandomState;
use arrow::compute::cast;
use arrow::record_batch::RecordBatch;
use arrow::row::{RowConverter, Rows, SortField};
use arrow_array::ArrayRef;
use arrow_schema::SchemaRef;
use arrow_array::{Array, ArrayRef};
use arrow_schema::{DataType, SchemaRef};
use datafusion_common::hash_utils::create_hashes;
use datafusion_common::Result;
use datafusion_common::{DataFusionError, Result};
use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt};
use datafusion_physical_expr::EmitTo;
use hashbrown::raw::RawTable;

/// A [`GroupValues`] making use of [`Rows`]
pub struct GroupValuesRows {
/// The output schema
schema: SchemaRef,

/// Converter for the group values
row_converter: RowConverter,

Expand Down Expand Up @@ -75,6 +79,7 @@ impl GroupValuesRows {
let map = RawTable::with_capacity(0);

Ok(Self {
schema,
row_converter,
map,
map_size: 0,
Expand Down Expand Up @@ -165,7 +170,7 @@ impl GroupValues for GroupValuesRows {
.take()
.expect("Can not emit from empty rows");

let output = match emit_to {
let mut output = match emit_to {
EmitTo::All => {
let output = self.row_converter.convert_rows(&group_values)?;
group_values.clear();
Expand Down Expand Up @@ -198,6 +203,20 @@ impl GroupValues for GroupValuesRows {
}
};

// TODO: Materialize dictionaries in group keys (#7647)
for (field, array) in self.schema.fields.iter().zip(&mut output) {
let expected = field.data_type();
if let DataType::Dictionary(_, v) = expected {
let actual = array.data_type();
if v.as_ref() != actual {
return Err(DataFusionError::Internal(format!(
"Converted group rows expected dictionary of {v} got {actual}"
)));
}
*array = cast(array.as_ref(), expected)?;
}
}

self.group_values = Some(group_values);
Ok(output)
}
Expand Down
35 changes: 3 additions & 32 deletions datafusion/physical-plan/src/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ use crate::{
use arrow::array::ArrayRef;
use arrow::datatypes::{Field, Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use arrow_schema::DataType;
use datafusion_common::stats::Precision;
use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result};
use datafusion_execution::TaskContext;
Expand Down Expand Up @@ -254,9 +253,6 @@ pub struct AggregateExec {
limit: Option<usize>,
/// Input plan, could be a partial aggregate or the input to the aggregate
pub input: Arc<dyn ExecutionPlan>,
/// Original aggregation schema, could be different from `schema` before dictionary group
/// keys get materialized
original_schema: SchemaRef,
/// Schema after the aggregate is applied
schema: SchemaRef,
/// Input schema before any aggregation is applied. For partial aggregate this will be the
Expand Down Expand Up @@ -287,19 +283,15 @@ impl AggregateExec {
input: Arc<dyn ExecutionPlan>,
input_schema: SchemaRef,
) -> Result<Self> {
let original_schema = create_schema(
let schema = create_schema(
&input.schema(),
&group_by.expr,
&aggr_expr,
group_by.contains_null(),
mode,
)?;

let schema = Arc::new(materialize_dict_group_keys(
&original_schema,
group_by.expr.len(),
));
let original_schema = Arc::new(original_schema);
let schema = Arc::new(schema);
AggregateExec::try_new_with_schema(
mode,
group_by,
Expand All @@ -308,7 +300,6 @@ impl AggregateExec {
input,
input_schema,
schema,
original_schema,
)
}

Expand All @@ -329,7 +320,6 @@ impl AggregateExec {
input: Arc<dyn ExecutionPlan>,
input_schema: SchemaRef,
schema: SchemaRef,
original_schema: SchemaRef,
) -> Result<Self> {
let input_eq_properties = input.equivalence_properties();
// Get GROUP BY expressions:
Expand Down Expand Up @@ -382,7 +372,6 @@ impl AggregateExec {
aggr_expr,
filter_expr,
input,
original_schema,
schema,
input_schema,
projection_mapping,
Expand Down Expand Up @@ -693,7 +682,7 @@ impl ExecutionPlan for AggregateExec {
children[0].clone(),
self.input_schema.clone(),
self.schema.clone(),
self.original_schema.clone(),
//self.original_schema.clone(),
)?;
me.limit = self.limit;
Ok(Arc::new(me))
Expand Down Expand Up @@ -800,24 +789,6 @@ fn create_schema(
Ok(Schema::new(fields))
}

/// returns schema with dictionary group keys materialized as their value types
/// The actual convertion happens in `RowConverter` and we don't do unnecessary
/// conversion back into dictionaries
fn materialize_dict_group_keys(schema: &Schema, group_count: usize) -> Schema {
let fields = schema
.fields
.iter()
.enumerate()
.map(|(i, field)| match field.data_type() {
DataType::Dictionary(_, value_data_type) if i < group_count => {
Field::new(field.name(), *value_data_type.clone(), field.is_nullable())
}
_ => Field::clone(field),
})
.collect::<Vec<_>>();
Schema::new(fields)
}

fn group_schema(schema: &Schema, group_count: usize) -> SchemaRef {
let group_fields = schema.fields()[0..group_count].to_vec();
Arc::new(Schema::new(group_fields))
Expand Down
4 changes: 1 addition & 3 deletions datafusion/physical-plan/src/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,9 +324,7 @@ impl GroupedHashAggregateStream {
.map(create_group_accumulator)
.collect::<Result<_>>()?;

// we need to use original schema so RowConverter in group_values below
// will do the proper coversion of dictionaries into value types
let group_schema = group_schema(&agg.original_schema, agg_group_by.expr.len());
let group_schema = group_schema(&agg_schema, agg_group_by.expr.len());
let spill_expr = group_schema
.fields
.into_iter()
Expand Down
10 changes: 5 additions & 5 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2469,11 +2469,11 @@ select max(x_dict) from value_dict group by x_dict % 2 order by max(x_dict);
query T
select arrow_typeof(x_dict) from value_dict group by x_dict;
----
Int32
Int32
Int32
Int32
Int32
Dictionary(Int64, Int32)
Dictionary(Int64, Int32)
Dictionary(Int64, Int32)
Dictionary(Int64, Int32)
Dictionary(Int64, Int32)

statement ok
drop table value
Expand Down
81 changes: 80 additions & 1 deletion datafusion/sqllogictest/test_files/dictionary.slt
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ order by date_bin('30 minutes', time) DESC

# Reproducer for https://github.com/apache/arrow-datafusion/issues/8738
# This query should work correctly
query error DataFusion error: External error: Arrow error: Invalid argument error: RowConverter column schema mismatch, expected Utf8 got Dictionary\(Int32, Utf8\)
query P?TT rowsort
SELECT
"data"."timestamp" as "time",
"data"."tag_id",
Expand Down Expand Up @@ -201,3 +201,82 @@ ORDER BY
"time",
"data"."tag_id"
;
----
2023-12-20T00:00:00 1000 f1 32.0
2023-12-20T00:00:00 1000 f2 foo
2023-12-20T00:10:00 1000 f1 32.0
2023-12-20T00:10:00 1000 f2 foo
2023-12-20T00:20:00 1000 f1 32.0
2023-12-20T00:20:00 1000 f2 foo
2023-12-20T00:30:00 1000 f1 32.0
2023-12-20T00:30:00 1000 f2 foo
2023-12-20T00:40:00 1000 f1 32.0
2023-12-20T00:40:00 1000 f2 foo
2023-12-20T00:50:00 1000 f1 32.0
2023-12-20T00:50:00 1000 f2 foo
2023-12-20T01:00:00 1000 f1 32.0
2023-12-20T01:00:00 1000 f2 foo
2023-12-20T01:10:00 1000 f1 32.0
2023-12-20T01:10:00 1000 f2 foo
2023-12-20T01:20:00 1000 f1 32.0
2023-12-20T01:20:00 1000 f2 foo
2023-12-20T01:30:00 1000 f1 32.0
2023-12-20T01:30:00 1000 f2 foo


# deterministic sort (so we can avoid rowsort)
query P?TT
SELECT
"data"."timestamp" as "time",
"data"."tag_id",
"data"."field",
"data"."value"
FROM (
(
SELECT "m2"."time" as "timestamp", "m2"."tag_id", 'active_power' as "field", "m2"."f5" as "value"
FROM "m2"
WHERE "m2"."time" >= '2023-12-05T14:46:35+01:00' AND "m2"."time" < '2024-01-03T14:46:35+01:00'
AND "m2"."f5" IS NOT NULL
AND "m2"."type" IN ('active')
AND "m2"."tag_id" IN ('1000')
) UNION (
SELECT "m1"."time" as "timestamp", "m1"."tag_id", 'f1' as "field", "m1"."f1" as "value"
FROM "m1"
WHERE "m1"."time" >= '2023-12-05T14:46:35+01:00' AND "m1"."time" < '2024-01-03T14:46:35+01:00'
AND "m1"."f1" IS NOT NULL
AND "m1"."tag_id" IN ('1000')
) UNION (
SELECT "m1"."time" as "timestamp", "m1"."tag_id", 'f2' as "field", "m1"."f2" as "value"
FROM "m1"
WHERE "m1"."time" >= '2023-12-05T14:46:35+01:00' AND "m1"."time" < '2024-01-03T14:46:35+01:00'
AND "m1"."f2" IS NOT NULL
AND "m1"."tag_id" IN ('1000')
)
) as "data"
ORDER BY
"time",
"data"."tag_id",
"data"."field",
"data"."value"
;
----
2023-12-20T00:00:00 1000 f1 32.0
2023-12-20T00:00:00 1000 f2 foo
2023-12-20T00:10:00 1000 f1 32.0
2023-12-20T00:10:00 1000 f2 foo
2023-12-20T00:20:00 1000 f1 32.0
2023-12-20T00:20:00 1000 f2 foo
2023-12-20T00:30:00 1000 f1 32.0
2023-12-20T00:30:00 1000 f2 foo
2023-12-20T00:40:00 1000 f1 32.0
2023-12-20T00:40:00 1000 f2 foo
2023-12-20T00:50:00 1000 f1 32.0
2023-12-20T00:50:00 1000 f2 foo
2023-12-20T01:00:00 1000 f1 32.0
2023-12-20T01:00:00 1000 f2 foo
2023-12-20T01:10:00 1000 f1 32.0
2023-12-20T01:10:00 1000 f2 foo
2023-12-20T01:20:00 1000 f1 32.0
2023-12-20T01:20:00 1000 f2 foo
2023-12-20T01:30:00 1000 f1 32.0
2023-12-20T01:30:00 1000 f2 foo

0 comments on commit ff27d90

Please sign in to comment.