From 96ca0d6fb1033b8e22cf7a1e1864058ae1bfc7a8 Mon Sep 17 00:00:00 2001 From: "Heres, Daniel" Date: Tue, 25 May 2021 19:58:47 +0200 Subject: [PATCH 1/6] Rebase changes --- .../src/physical_plan/hash_aggregate.rs | 157 ++++++++---------- datafusion/src/scalar.rs | 60 ++++++- 2 files changed, 119 insertions(+), 98 deletions(-) diff --git a/datafusion/src/physical_plan/hash_aggregate.rs b/datafusion/src/physical_plan/hash_aggregate.rs index c9d268619cad..b039cb651d09 100644 --- a/datafusion/src/physical_plan/hash_aggregate.rs +++ b/datafusion/src/physical_plan/hash_aggregate.rs @@ -32,6 +32,7 @@ use crate::physical_plan::{ Accumulator, AggregateExpr, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalExpr, SQLMetric, }; +use crate::scalar::ScalarValue; use arrow::{ array::{Array, UInt32Builder}, @@ -623,10 +624,12 @@ fn create_key_for_col(col: &ArrayRef, row: usize, vec: &mut Vec) -> Result<( DataType::UInt64 => { dictionary_create_key_for_col::(col, row, vec)?; } - _ => return Err(DataFusionError::Internal(format!( + _ => { + return Err(DataFusionError::Internal(format!( "Unsupported GROUP BY type (dictionary index type not supported creating key) {}", col.data_type(), - ))), + ))) + } }, _ => { // This is internal because we should have caught this before. @@ -956,20 +959,6 @@ impl RecordBatchStream for HashAggregateStream { } } -/// Given Vec>, concatenates the inners `Vec` into `ArrayRef`, returning `Vec` -/// This assumes that `arrays` is not empty. -fn concatenate(arrays: Vec>) -> ArrowResult> { - (0..arrays[0].len()) - .map(|column| { - let array_list = arrays - .iter() - .map(|a| a[column].as_ref()) - .collect::>(); - compute::concat(&array_list) - }) - .collect::>>() -} - /// Create a RecordBatch with all group keys and accumulator' states or values. fn create_batch_from_map( mode: &AggregateMode, @@ -977,84 +966,74 @@ fn create_batch_from_map( num_group_expr: usize, output_schema: &Schema, ) -> ArrowResult { - // 1. for each key - // 2. create single-row ArrayRef with all group expressions - // 3. create single-row ArrayRef with all aggregate states or values - // 4. collect all in a vector per key of vec, vec[i][j] - // 5. concatenate the arrays over the second index [j] into a single vec. - let arrays = accumulators - .iter() - .map(|(_, (group_by_values, accumulator_set, _))| { - // 2. - let mut groups = (0..num_group_expr) - .map(|i| match &group_by_values[i] { - GroupByScalar::Float32(n) => { - Arc::new(Float32Array::from(vec![(*n).into()] as Vec)) - as ArrayRef - } - GroupByScalar::Float64(n) => { - Arc::new(Float64Array::from(vec![(*n).into()] as Vec)) - as ArrayRef - } - GroupByScalar::Int8(n) => { - Arc::new(Int8Array::from(vec![*n])) as ArrayRef - } - GroupByScalar::Int16(n) => Arc::new(Int16Array::from(vec![*n])), - GroupByScalar::Int32(n) => Arc::new(Int32Array::from(vec![*n])), - GroupByScalar::Int64(n) => Arc::new(Int64Array::from(vec![*n])), - GroupByScalar::UInt8(n) => Arc::new(UInt8Array::from(vec![*n])), - GroupByScalar::UInt16(n) => Arc::new(UInt16Array::from(vec![*n])), - GroupByScalar::UInt32(n) => Arc::new(UInt32Array::from(vec![*n])), - GroupByScalar::UInt64(n) => Arc::new(UInt64Array::from(vec![*n])), - GroupByScalar::Utf8(str) => { - Arc::new(StringArray::from(vec![&***str])) - } - GroupByScalar::LargeUtf8(str) => { - Arc::new(LargeStringArray::from(vec![&***str])) - } - GroupByScalar::Boolean(b) => Arc::new(BooleanArray::from(vec![*b])), - GroupByScalar::TimeMillisecond(n) => { - Arc::new(TimestampMillisecondArray::from(vec![*n])) - } - GroupByScalar::TimeMicrosecond(n) => { - Arc::new(TimestampMicrosecondArray::from(vec![*n])) - } - GroupByScalar::TimeNanosecond(n) => { - Arc::new(TimestampNanosecondArray::from_vec(vec![*n], None)) - } - GroupByScalar::Date32(n) => Arc::new(Date32Array::from(vec![*n])), - }) - .collect::>(); + if accumulators.is_empty() { + return Ok(RecordBatch::new_empty(Arc::new(output_schema.to_owned()))); + } + let (_, (_, accs, _)) = accumulators.iter().nth(0).unwrap(); + let mut acc_data_types: Vec = vec![]; - // 3. - groups.extend( - finalize_aggregation(accumulator_set, mode) - .map_err(DataFusionError::into_arrow_external_error)?, - ); + // Calculate number/shape of state arrays + match mode { + AggregateMode::Partial => { + for acc in accs.iter() { + let state = acc + .state() + .map_err(DataFusionError::into_arrow_external_error)?; + acc_data_types.push(state.len()); + } + } + AggregateMode::Final | AggregateMode::FinalPartitioned => { + for _ in accs { + acc_data_types.push(1); + } + } + } - Ok(groups) + let mut columns = (0..num_group_expr) + .map(|i| { + ScalarValue::iter_to_array(accumulators.into_iter().map( + |(_, (group_by_values, _, _))| ScalarValue::from(&group_by_values[i]), + )) }) - // 4. - .collect::>>>()?; + .collect::>>() + .map_err(|x| x.into_arrow_external_error())?; + + // add state / evaluated arrays + for (x, &state_len) in acc_data_types.iter().enumerate() { + for y in 0..state_len { + match mode { + AggregateMode::Partial => { + let res = ScalarValue::iter_to_array(accumulators.into_iter().map( + |(_, (_, accumulator, _))| { + let x = accumulator[x].state().unwrap(); + x[y].clone() + }, + )) + .map_err(DataFusionError::into_arrow_external_error)?; + + columns.push(res); + } + AggregateMode::Final | AggregateMode::FinalPartitioned => { + let res = ScalarValue::iter_to_array(accumulators.into_iter().map( + |(_, (_, accumulator, _))| accumulator[x].evaluate().unwrap(), + )) + .map_err(DataFusionError::into_arrow_external_error)?; + columns.push(res); + } + } + } + } - let batch = if !arrays.is_empty() { - // 5. - let columns = concatenate(arrays)?; + // cast output if needed (e.g. for types like Dictionary where + // the intermediate GroupByScalar type was not the same as the + // output + let columns = columns + .iter() + .zip(output_schema.fields().iter()) + .map(|(col, desired_field)| cast(col, desired_field.data_type())) + .collect::>>()?; - // cast output if needed (e.g. for types like Dictionary where - // the intermediate GroupByScalar type was not the same as the - // output - let columns = columns - .iter() - .zip(output_schema.fields().iter()) - .map(|(col, desired_field)| cast(col, desired_field.data_type())) - .collect::>>()?; - - RecordBatch::try_new(Arc::new(output_schema.to_owned()), columns)? - } else { - RecordBatch::new_empty(Arc::new(output_schema.to_owned())) - }; - Ok(batch) + RecordBatch::try_new(Arc::new(output_schema.to_owned()), columns) } fn create_accumulators( diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index f3fa5b2c5de5..1374b83766c9 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -311,7 +311,7 @@ impl ScalarValue { /// ]; /// /// // Build an Array from the list of ScalarValues - /// let array = ScalarValue::iter_to_array(scalars.iter()) + /// let array = ScalarValue::iter_to_array(scalars.into_iter()) /// .unwrap(); /// /// let expected: ArrayRef = std::sync::Arc::new( @@ -324,8 +324,8 @@ impl ScalarValue { /// /// assert_eq!(&array, &expected); /// ``` - pub fn iter_to_array<'a>( - scalars: impl IntoIterator, + pub fn iter_to_array( + scalars: impl IntoIterator, ) -> Result { let mut scalars = scalars.into_iter().peekable(); @@ -347,7 +347,7 @@ impl ScalarValue { let values = scalars .map(|sv| { if let ScalarValue::$SCALAR_TY(v) = sv { - Ok(*v) + Ok(v) } else { Err(DataFusionError::Internal(format!( "Inconsistent types in ScalarValue::iter_to_array. \ @@ -394,6 +394,24 @@ impl ScalarValue { }}; } + macro_rules! build_array_list { + ($ARRAY_TY:ident, $SCALAR_TY:ident, $NATIVE_TYPE:ident) => {{ + Arc::new(ListArray::from_iter_primitive::<$ARRAY_TY, _, _>( + scalars.into_iter().map(|x| match x { + ScalarValue::List(xs, _) => xs.map(|x| { + x.iter() + .map(|x| match x { + ScalarValue::$SCALAR_TY(i) => *i, + _ => panic!("xxx"), + }) + .collect::>>() + }), + _ => panic!("xxx"), + }), + )) + }}; + } + let array: ArrayRef = match &data_type { DataType::Boolean => build_array_primitive!(BooleanArray, Boolean), DataType::Float32 => build_array_primitive!(Float32Array, Float32), @@ -430,6 +448,30 @@ impl ScalarValue { DataType::Interval(IntervalUnit::YearMonth) => { build_array_primitive!(IntervalYearMonthArray, IntervalYearMonth) } + DataType::List(fields) if fields.data_type() == &DataType::Int8 => { + build_array_list!(Int8Type, Int8, i8) + } + DataType::List(fields) if fields.data_type() == &DataType::Int16 => { + build_array_list!(Int16Type, Int16, i16) + } + DataType::List(fields) if fields.data_type() == &DataType::Int32 => { + build_array_list!(Int32Type, Int32, i32) + } + DataType::List(fields) if fields.data_type() == &DataType::Int64 => { + build_array_list!(Int64Type, Int64, i64) + } + DataType::List(fields) if fields.data_type() == &DataType::UInt8 => { + build_array_list!(UInt8Type, UInt8, u8) + } + DataType::List(fields) if fields.data_type() == &DataType::UInt16 => { + build_array_list!(UInt16Type, UInt16, u16) + } + DataType::List(fields) if fields.data_type() == &DataType::UInt32 => { + build_array_list!(UInt32Type, UInt32, u32) + } + DataType::List(fields) if fields.data_type() == &DataType::UInt64 => { + build_array_list!(UInt64Type, UInt64, u64) + } _ => { return Err(DataFusionError::Internal(format!( "Unsupported creation of {:?} array from ScalarValue {:?}", @@ -1102,7 +1144,7 @@ mod tests { let scalars: Vec<_> = $INPUT.iter().map(|v| ScalarValue::$SCALAR_T(*v)).collect(); - let array = ScalarValue::iter_to_array(scalars.iter()).unwrap(); + let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap(); let expected: ArrayRef = Arc::new($ARRAYTYPE::from($INPUT)); @@ -1119,7 +1161,7 @@ mod tests { .map(|v| ScalarValue::$SCALAR_T(v.map(|v| v.to_string()))) .collect(); - let array = ScalarValue::iter_to_array(scalars.iter()).unwrap(); + let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap(); let expected: ArrayRef = Arc::new($ARRAYTYPE::from($INPUT)); @@ -1136,7 +1178,7 @@ mod tests { .map(|v| ScalarValue::$SCALAR_T(v.map(|v| v.to_vec()))) .collect(); - let array = ScalarValue::iter_to_array(scalars.iter()).unwrap(); + let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap(); let expected: $ARRAYTYPE = $INPUT.iter().map(|v| v.map(|v| v.to_vec())).collect(); @@ -1210,7 +1252,7 @@ mod tests { fn scalar_iter_to_array_empty() { let scalars = vec![] as Vec; - let result = ScalarValue::iter_to_array(scalars.iter()).unwrap_err(); + let result = ScalarValue::iter_to_array(scalars.into_iter()).unwrap_err(); assert!( result .to_string() @@ -1226,7 +1268,7 @@ mod tests { // If the scalar values are not all the correct type, error here let scalars: Vec = vec![Boolean(Some(true)), Int32(Some(5))]; - let result = ScalarValue::iter_to_array(scalars.iter()).unwrap_err(); + let result = ScalarValue::iter_to_array(scalars.into_iter()).unwrap_err(); assert!(result.to_string().contains("Inconsistent types in ScalarValue::iter_to_array. Expected Boolean, got Int32(5)"), "{}", result); } From acf81c1ba7d84c4668a5be18a67a0132705572ba Mon Sep 17 00:00:00 2001 From: "Heres, Daniel" Date: Tue, 25 May 2021 21:02:41 +0200 Subject: [PATCH 2/6] Fmt --- datafusion/src/scalar.rs | 86 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 82 insertions(+), 4 deletions(-) diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index 1374b83766c9..9fc03f2392f6 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -402,12 +402,14 @@ impl ScalarValue { x.iter() .map(|x| match x { ScalarValue::$SCALAR_TY(i) => *i, - _ => panic!("xxx"), + sv => panic!("Inconsistent types in ScalarValue::iter_to_array. \ + Expected {:?}, got {:?}", data_type, sv), }) .collect::>>() }), - _ => panic!("xxx"), - }), + sv => panic!("Inconsistent types in ScalarValue::iter_to_array. \ + Expected {:?}, got {:?}", data_type, sv), + }), )) }}; } @@ -472,12 +474,88 @@ impl ScalarValue { DataType::List(fields) if fields.data_type() == &DataType::UInt64 => { build_array_list!(UInt64Type, UInt64, u64) } + DataType::List(fields) if fields.data_type() == &DataType::Utf8 => { + let mut builder = ListBuilder::new(StringBuilder::new(0)); + + for scalar in scalars.into_iter() { + match scalar { + ScalarValue::List(Some(xs), _) => { + for s in xs { + match s { + ScalarValue::Utf8(Some(val)) => { + builder.values().append_value(val)?; + } + ScalarValue::Utf8(None) => { + builder.values().append_null()?; + } + sv => return Err(DataFusionError::Internal(format!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected Utf8, got {:?}", + sv + ))), + } + } + builder.append(true)?; + } + ScalarValue::List(None, _) => { + builder.append(false)?; + } + sv => { + return Err(DataFusionError::Internal(format!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected List, got {:?}", + sv + ))) + } + } + } + + Arc::new(builder.finish()) + } + DataType::List(fields) if fields.data_type() == &DataType::LargeUtf8 => { + let mut builder = ListBuilder::new(LargeStringBuilder::new(0)); + + for scalar in scalars.into_iter() { + match scalar { + ScalarValue::List(Some(xs), _) => { + for s in xs { + match s { + ScalarValue::Utf8(Some(val)) => { + builder.values().append_value(val)?; + } + ScalarValue::Utf8(None) => { + builder.values().append_null()?; + } + sv => return Err(DataFusionError::Internal(format!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected Utf8, got {:?}", + sv + ))), + } + } + builder.append(true)?; + } + ScalarValue::List(None, _) => { + builder.append(false)?; + } + sv => { + return Err(DataFusionError::Internal(format!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected List, got {:?}", + sv + ))) + } + } + } + + Arc::new(builder.finish()) + } _ => { return Err(DataFusionError::Internal(format!( "Unsupported creation of {:?} array from ScalarValue {:?}", data_type, scalars.peek() - ))) + ))); } }; From 1100c26a98848b2106219fd6547214ffcce58a09 Mon Sep 17 00:00:00 2001 From: "Heres, Daniel" Date: Tue, 25 May 2021 21:11:17 +0200 Subject: [PATCH 3/6] Use macro --- datafusion/src/scalar.rs | 136 ++++++++++++++++----------------------- 1 file changed, 54 insertions(+), 82 deletions(-) diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index 9fc03f2392f6..7fdcff01a626 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -394,7 +394,7 @@ impl ScalarValue { }}; } - macro_rules! build_array_list { + macro_rules! build_array_list_primitive { ($ARRAY_TY:ident, $SCALAR_TY:ident, $NATIVE_TYPE:ident) => {{ Arc::new(ListArray::from_iter_primitive::<$ARRAY_TY, _, _>( scalars.into_iter().map(|x| match x { @@ -414,6 +414,48 @@ impl ScalarValue { }}; } + macro_rules! build_array_list_string { + ($BUILDER:ident, $SCALAR_TY:ident) => {{ + let mut builder = ListBuilder::new($BUILDER::new(0)); + + for scalar in scalars.into_iter() { + match scalar { + ScalarValue::List(Some(xs), _) => { + for s in xs { + match s { + ScalarValue::$SCALAR_TY(Some(val)) => { + builder.values().append_value(val)?; + } + ScalarValue::$SCALAR_TY(None) => { + builder.values().append_null()?; + } + sv => return Err(DataFusionError::Internal(format!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected Utf8, got {:?}", + sv + ))), + } + } + builder.append(true)?; + } + ScalarValue::List(None, _) => { + builder.append(false)?; + } + sv => { + return Err(DataFusionError::Internal(format!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected List, got {:?}", + sv + ))) + } + } + } + + Arc::new(builder.finish()) + + }} + } + let array: ArrayRef = match &data_type { DataType::Boolean => build_array_primitive!(BooleanArray, Boolean), DataType::Float32 => build_array_primitive!(Float32Array, Float32), @@ -451,111 +493,41 @@ impl ScalarValue { build_array_primitive!(IntervalYearMonthArray, IntervalYearMonth) } DataType::List(fields) if fields.data_type() == &DataType::Int8 => { - build_array_list!(Int8Type, Int8, i8) + build_array_list_primitive!(Int8Type, Int8, i8) } DataType::List(fields) if fields.data_type() == &DataType::Int16 => { - build_array_list!(Int16Type, Int16, i16) + build_array_list_primitive!(Int16Type, Int16, i16) } DataType::List(fields) if fields.data_type() == &DataType::Int32 => { - build_array_list!(Int32Type, Int32, i32) + build_array_list_primitive!(Int32Type, Int32, i32) } DataType::List(fields) if fields.data_type() == &DataType::Int64 => { - build_array_list!(Int64Type, Int64, i64) + build_array_list_primitive!(Int64Type, Int64, i64) } DataType::List(fields) if fields.data_type() == &DataType::UInt8 => { - build_array_list!(UInt8Type, UInt8, u8) + build_array_list_primitive!(UInt8Type, UInt8, u8) } DataType::List(fields) if fields.data_type() == &DataType::UInt16 => { - build_array_list!(UInt16Type, UInt16, u16) + build_array_list_primitive!(UInt16Type, UInt16, u16) } DataType::List(fields) if fields.data_type() == &DataType::UInt32 => { - build_array_list!(UInt32Type, UInt32, u32) + build_array_list_primitive!(UInt32Type, UInt32, u32) } DataType::List(fields) if fields.data_type() == &DataType::UInt64 => { - build_array_list!(UInt64Type, UInt64, u64) + build_array_list_primitive!(UInt64Type, UInt64, u64) } DataType::List(fields) if fields.data_type() == &DataType::Utf8 => { - let mut builder = ListBuilder::new(StringBuilder::new(0)); - - for scalar in scalars.into_iter() { - match scalar { - ScalarValue::List(Some(xs), _) => { - for s in xs { - match s { - ScalarValue::Utf8(Some(val)) => { - builder.values().append_value(val)?; - } - ScalarValue::Utf8(None) => { - builder.values().append_null()?; - } - sv => return Err(DataFusionError::Internal(format!( - "Inconsistent types in ScalarValue::iter_to_array. \ - Expected Utf8, got {:?}", - sv - ))), - } - } - builder.append(true)?; - } - ScalarValue::List(None, _) => { - builder.append(false)?; - } - sv => { - return Err(DataFusionError::Internal(format!( - "Inconsistent types in ScalarValue::iter_to_array. \ - Expected List, got {:?}", - sv - ))) - } - } - } - - Arc::new(builder.finish()) + build_array_list_string!(StringBuilder, Utf8) } DataType::List(fields) if fields.data_type() == &DataType::LargeUtf8 => { - let mut builder = ListBuilder::new(LargeStringBuilder::new(0)); - - for scalar in scalars.into_iter() { - match scalar { - ScalarValue::List(Some(xs), _) => { - for s in xs { - match s { - ScalarValue::Utf8(Some(val)) => { - builder.values().append_value(val)?; - } - ScalarValue::Utf8(None) => { - builder.values().append_null()?; - } - sv => return Err(DataFusionError::Internal(format!( - "Inconsistent types in ScalarValue::iter_to_array. \ - Expected Utf8, got {:?}", - sv - ))), - } - } - builder.append(true)?; - } - ScalarValue::List(None, _) => { - builder.append(false)?; - } - sv => { - return Err(DataFusionError::Internal(format!( - "Inconsistent types in ScalarValue::iter_to_array. \ - Expected List, got {:?}", - sv - ))) - } - } - } - - Arc::new(builder.finish()) + build_array_list_string!(LargeStringBuilder, LargeUtf8) } _ => { return Err(DataFusionError::Internal(format!( "Unsupported creation of {:?} array from ScalarValue {:?}", data_type, scalars.peek() - ))); + ))) } }; From 214eb7e4eae5ed933fe228a46b008fbd836b8e24 Mon Sep 17 00:00:00 2001 From: "Heres, Daniel" Date: Tue, 25 May 2021 21:50:09 +0200 Subject: [PATCH 4/6] Support floats too --- datafusion/src/scalar.rs | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index 7fdcff01a626..d25793998cdb 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -21,10 +21,10 @@ use crate::error::{DataFusionError, Result}; use arrow::{ array::*, datatypes::{ - ArrowDictionaryKeyType, ArrowNativeType, DataType, Field, Float32Type, Int16Type, - Int32Type, Int64Type, Int8Type, IntervalUnit, TimeUnit, TimestampMicrosecondType, - TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, - UInt16Type, UInt32Type, UInt64Type, UInt8Type, + ArrowDictionaryKeyType, ArrowNativeType, DataType, Field, Float32Type, + Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalUnit, TimeUnit, + TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, + TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }, }; use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; @@ -516,6 +516,12 @@ impl ScalarValue { DataType::List(fields) if fields.data_type() == &DataType::UInt64 => { build_array_list_primitive!(UInt64Type, UInt64, u64) } + DataType::List(fields) if fields.data_type() == &DataType::Float32 => { + build_array_list_primitive!(Float32Type, Float32, f32) + } + DataType::List(fields) if fields.data_type() == &DataType::Float64 => { + build_array_list_primitive!(Float64Type, Float64, f64) + } DataType::List(fields) if fields.data_type() == &DataType::Utf8 => { build_array_list_string!(StringBuilder, Utf8) } From 9cc0ea031c175d61ce32f8b8fa048e4b6f6d1a2e Mon Sep 17 00:00:00 2001 From: "Heres, Daniel" Date: Tue, 25 May 2021 22:04:54 +0200 Subject: [PATCH 5/6] Avoid temporary vec for primitive / string --- datafusion/src/scalar.rs | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index d25793998cdb..ac7deeed22c7 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -344,7 +344,7 @@ impl ScalarValue { macro_rules! build_array_primitive { ($ARRAY_TY:ident, $SCALAR_TY:ident) => {{ { - let values = scalars + let array = scalars .map(|sv| { if let ScalarValue::$SCALAR_TY(v) = sv { Ok(v) @@ -356,9 +356,8 @@ impl ScalarValue { ))) } }) - .collect::>>()?; + .collect::>()?; - let array: $ARRAY_TY = values.iter().collect(); Arc::new(array) } }}; @@ -369,7 +368,7 @@ impl ScalarValue { macro_rules! build_array_string { ($ARRAY_TY:ident, $SCALAR_TY:ident) => {{ { - let values = scalars + let array = scalars .map(|sv| { if let ScalarValue::$SCALAR_TY(v) = sv { Ok(v) @@ -381,14 +380,7 @@ impl ScalarValue { ))) } }) - .collect::>>()?; - - // it is annoying that one can not create - // StringArray et al directly from iter of &String, - // requiring this map to &str - let values = values.iter().map(|s| s.as_ref()); - - let array: $ARRAY_TY = values.collect(); + .collect::>()?; Arc::new(array) } }}; From f8bfe3bd0da4ca45d3f737120c08387a98ed0ba3 Mon Sep 17 00:00:00 2001 From: "Heres, Daniel" Date: Tue, 25 May 2021 22:27:45 +0200 Subject: [PATCH 6/6] Clippy --- datafusion/src/physical_plan/hash_aggregate.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/datafusion/src/physical_plan/hash_aggregate.rs b/datafusion/src/physical_plan/hash_aggregate.rs index b039cb651d09..7d9c07ce40d2 100644 --- a/datafusion/src/physical_plan/hash_aggregate.rs +++ b/datafusion/src/physical_plan/hash_aggregate.rs @@ -20,6 +20,7 @@ use std::any::Any; use std::sync::Arc; use std::task::{Context, Poll}; +use std::vec; use ahash::RandomState; use futures::{ @@ -969,7 +970,7 @@ fn create_batch_from_map( if accumulators.is_empty() { return Ok(RecordBatch::new_empty(Arc::new(output_schema.to_owned()))); } - let (_, (_, accs, _)) = accumulators.iter().nth(0).unwrap(); + let (_, (_, accs, _)) = accumulators.iter().next().unwrap(); let mut acc_data_types: Vec = vec![]; // Calculate number/shape of state arrays @@ -983,9 +984,7 @@ fn create_batch_from_map( } } AggregateMode::Final | AggregateMode::FinalPartitioned => { - for _ in accs { - acc_data_types.push(1); - } + acc_data_types = vec![1; accs.len()]; } }