Skip to content

Commit

Permalink
[Minor] Reduce code duplication creating ScalarValue::List (#3197)
Browse files Browse the repository at this point in the history
* Reduce code duplication creating ScalarValue::List

* clean more
  • Loading branch information
alamb authored Aug 20, 2022
1 parent 3df9f80 commit e0a9fa3
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 235 deletions.
82 changes: 30 additions & 52 deletions datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,11 @@ impl ScalarValue {
ScalarValue::IntervalMonthDayNano(Some(val))
}

/// Create a new nullable ScalarValue::List with the specified child_type
pub fn new_list(scalars: Option<Vec<Self>>, child_type: DataType) -> Self {
Self::List(scalars, Box::new(Field::new("item", child_type, true)))
}

/// Getter for the `DataType` of the value
pub fn get_datatype(&self) -> DataType {
match self {
Expand Down Expand Up @@ -1506,10 +1511,7 @@ impl ScalarValue {
Some(scalar_vec)
}
};
ScalarValue::List(
value,
Box::new(Field::new("item", nested_type.data_type().clone(), true)),
)
ScalarValue::new_list(value, nested_type.data_type().clone())
}
DataType::Date32 => {
typed_cast!(array, index, Date32Array, Date32)
Expand Down Expand Up @@ -1610,10 +1612,7 @@ impl ScalarValue {
Some(scalar_vec)
}
};
ScalarValue::List(
value,
Box::new(Field::new("item", nested_type.data_type().clone(), true)),
)
ScalarValue::new_list(value, nested_type.data_type().clone())
}
other => {
return Err(DataFusionError::NotImplemented(format!(
Expand Down Expand Up @@ -1951,10 +1950,9 @@ impl TryFrom<&DataType> for ScalarValue {
index_type.clone(),
Box::new(value_type.as_ref().try_into()?),
),
DataType::List(ref nested_type) => ScalarValue::List(
None,
Box::new(Field::new("item", nested_type.data_type().clone(), true)),
),
DataType::List(ref nested_type) => {
ScalarValue::new_list(None, nested_type.data_type().clone())
}
DataType::Struct(fields) => {
ScalarValue::Struct(None, Box::new(fields.clone()))
}
Expand Down Expand Up @@ -3124,20 +3122,12 @@ mod tests {
assert_eq!(array, &expected);

// Define list-of-structs scalars
let nl0 = ScalarValue::List(
Some(vec![s0.clone(), s1.clone()]),
Box::new(Field::new("item", s0.get_datatype(), true)),
);
let nl0 =
ScalarValue::new_list(Some(vec![s0.clone(), s1.clone()]), s0.get_datatype());

let nl1 = ScalarValue::List(
Some(vec![s2]),
Box::new(Field::new("item", s0.get_datatype(), true)),
);
let nl1 = ScalarValue::new_list(Some(vec![s2]), s0.get_datatype());

let nl2 = ScalarValue::List(
Some(vec![s1]),
Box::new(Field::new("item", s0.get_datatype(), true)),
);
let nl2 = ScalarValue::new_list(Some(vec![s1]), s0.get_datatype());
// iter_to_array for list-of-struct
let array = ScalarValue::iter_to_array(vec![nl0, nl1, nl2]).unwrap();
let array = array.as_any().downcast_ref::<ListArray>().unwrap();
Expand Down Expand Up @@ -3263,56 +3253,44 @@ mod tests {
#[test]
fn test_nested_lists() {
// Define inner list scalars
let l1 = ScalarValue::List(
let l1 = ScalarValue::new_list(
Some(vec![
ScalarValue::List(
ScalarValue::new_list(
Some(vec![
ScalarValue::from(1i32),
ScalarValue::from(2i32),
ScalarValue::from(3i32),
]),
Box::new(Field::new("item", DataType::Int32, true)),
DataType::Int32,
),
ScalarValue::List(
ScalarValue::new_list(
Some(vec![ScalarValue::from(4i32), ScalarValue::from(5i32)]),
Box::new(Field::new("item", DataType::Int32, true)),
DataType::Int32,
),
]),
Box::new(Field::new(
"item",
DataType::List(Box::new(Field::new("item", DataType::Int32, true))),
true,
)),
DataType::List(Box::new(Field::new("item", DataType::Int32, true))),
);

let l2 = ScalarValue::List(
let l2 = ScalarValue::new_list(
Some(vec![
ScalarValue::List(
ScalarValue::new_list(
Some(vec![ScalarValue::from(6i32)]),
Box::new(Field::new("item", DataType::Int32, true)),
DataType::Int32,
),
ScalarValue::List(
ScalarValue::new_list(
Some(vec![ScalarValue::from(7i32), ScalarValue::from(8i32)]),
Box::new(Field::new("item", DataType::Int32, true)),
DataType::Int32,
),
]),
Box::new(Field::new(
"item",
DataType::List(Box::new(Field::new("item", DataType::Int32, true))),
true,
)),
DataType::List(Box::new(Field::new("item", DataType::Int32, true))),
);

let l3 = ScalarValue::List(
Some(vec![ScalarValue::List(
let l3 = ScalarValue::new_list(
Some(vec![ScalarValue::new_list(
Some(vec![ScalarValue::from(9i32)]),
Box::new(Field::new("item", DataType::Int32, true)),
DataType::Int32,
)]),
Box::new(Field::new(
"item",
DataType::List(Box::new(Field::new("item", DataType::Int32, true))),
true,
)),
DataType::List(Box::new(Field::new("item", DataType::Int32, true))),
);

let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap();
Expand Down
60 changes: 22 additions & 38 deletions datafusion/physical-expr/src/aggregate/array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,9 @@ impl Accumulator for ArrayAggAccumulator {
}

fn evaluate(&self) -> Result<ScalarValue> {
Ok(ScalarValue::List(
Ok(ScalarValue::new_list(
Some(self.values.clone()),
Box::new(Field::new("item", self.datatype.clone(), true)),
self.datatype.clone(),
))
}
}
Expand All @@ -171,81 +171,65 @@ mod tests {
fn array_agg_i32() -> Result<()> {
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));

let list = ScalarValue::List(
let list = ScalarValue::new_list(
Some(vec![
ScalarValue::Int32(Some(1)),
ScalarValue::Int32(Some(2)),
ScalarValue::Int32(Some(3)),
ScalarValue::Int32(Some(4)),
ScalarValue::Int32(Some(5)),
]),
Box::new(Field::new("item", DataType::Int32, true)),
DataType::Int32,
);

generic_test_op!(a, DataType::Int32, ArrayAgg, list, DataType::Int32)
}

#[test]
fn array_agg_nested() -> Result<()> {
let l1 = ScalarValue::List(
let l1 = ScalarValue::new_list(
Some(vec![
ScalarValue::List(
ScalarValue::new_list(
Some(vec![
ScalarValue::from(1i32),
ScalarValue::from(2i32),
ScalarValue::from(3i32),
]),
Box::new(Field::new("item", DataType::Int32, true)),
DataType::Int32,
),
ScalarValue::List(
ScalarValue::new_list(
Some(vec![ScalarValue::from(4i32), ScalarValue::from(5i32)]),
Box::new(Field::new("item", DataType::Int32, true)),
DataType::Int32,
),
]),
Box::new(Field::new(
"item",
DataType::List(Box::new(Field::new("item", DataType::Int32, true))),
true,
)),
DataType::List(Box::new(Field::new("item", DataType::Int32, true))),
);

let l2 = ScalarValue::List(
let l2 = ScalarValue::new_list(
Some(vec![
ScalarValue::List(
ScalarValue::new_list(
Some(vec![ScalarValue::from(6i32)]),
Box::new(Field::new("item", DataType::Int32, true)),
DataType::Int32,
),
ScalarValue::List(
ScalarValue::new_list(
Some(vec![ScalarValue::from(7i32), ScalarValue::from(8i32)]),
Box::new(Field::new("item", DataType::Int32, true)),
DataType::Int32,
),
]),
Box::new(Field::new(
"item",
DataType::List(Box::new(Field::new("item", DataType::Int32, true))),
true,
)),
DataType::List(Box::new(Field::new("item", DataType::Int32, true))),
);

let l3 = ScalarValue::List(
Some(vec![ScalarValue::List(
let l3 = ScalarValue::new_list(
Some(vec![ScalarValue::new_list(
Some(vec![ScalarValue::from(9i32)]),
Box::new(Field::new("item", DataType::Int32, true)),
DataType::Int32,
)]),
Box::new(Field::new(
"item",
DataType::List(Box::new(Field::new("item", DataType::Int32, true))),
true,
)),
DataType::List(Box::new(Field::new("item", DataType::Int32, true))),
);

let list = ScalarValue::List(
let list = ScalarValue::new_list(
Some(vec![l1.clone(), l2.clone(), l3.clone()]),
Box::new(Field::new(
"item",
DataType::List(Box::new(Field::new("item", DataType::Int32, true))),
true,
)),
DataType::List(Box::new(Field::new("item", DataType::Int32, true))),
);

let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap();
Expand Down
64 changes: 24 additions & 40 deletions datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,9 @@ impl DistinctArrayAggAccumulator {

impl Accumulator for DistinctArrayAggAccumulator {
fn state(&self) -> Result<Vec<AggregateState>> {
Ok(vec![AggregateState::Scalar(ScalarValue::List(
Ok(vec![AggregateState::Scalar(ScalarValue::new_list(
Some(self.values.clone().into_iter().collect()),
Box::new(Field::new("item", self.datatype.clone(), true)),
self.datatype.clone(),
))])
}

Expand Down Expand Up @@ -151,9 +151,9 @@ impl Accumulator for DistinctArrayAggAccumulator {
}

fn evaluate(&self) -> Result<ScalarValue> {
Ok(ScalarValue::List(
Ok(ScalarValue::new_list(
Some(self.values.clone().into_iter().collect()),
Box::new(Field::new("item", self.datatype.clone(), true)),
self.datatype.clone(),
))
}
}
Expand Down Expand Up @@ -206,15 +206,15 @@ mod tests {
fn distinct_array_agg_i32() -> Result<()> {
let col: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 4, 5, 2]));

let out = ScalarValue::List(
let out = ScalarValue::new_list(
Some(vec![
ScalarValue::Int32(Some(1)),
ScalarValue::Int32(Some(2)),
ScalarValue::Int32(Some(7)),
ScalarValue::Int32(Some(4)),
ScalarValue::Int32(Some(5)),
]),
Box::new(Field::new("item", DataType::Int32, true)),
DataType::Int32,
);

check_distinct_array_agg(col, out, DataType::Int32)
Expand All @@ -223,67 +223,51 @@ mod tests {
#[test]
fn distinct_array_agg_nested() -> Result<()> {
// [[1, 2, 3], [4, 5]]
let l1 = ScalarValue::List(
let l1 = ScalarValue::new_list(
Some(vec![
ScalarValue::List(
ScalarValue::new_list(
Some(vec![
ScalarValue::from(1i32),
ScalarValue::from(2i32),
ScalarValue::from(3i32),
]),
Box::new(Field::new("item", DataType::Int32, true)),
DataType::Int32,
),
ScalarValue::List(
ScalarValue::new_list(
Some(vec![ScalarValue::from(4i32), ScalarValue::from(5i32)]),
Box::new(Field::new("item", DataType::Int32, true)),
DataType::Int32,
),
]),
Box::new(Field::new(
"item",
DataType::List(Box::new(Field::new("item", DataType::Int32, true))),
true,
)),
DataType::List(Box::new(Field::new("item", DataType::Int32, true))),
);

// [[6], [7, 8]]
let l2 = ScalarValue::List(
let l2 = ScalarValue::new_list(
Some(vec![
ScalarValue::List(
ScalarValue::new_list(
Some(vec![ScalarValue::from(6i32)]),
Box::new(Field::new("item", DataType::Int32, true)),
DataType::Int32,
),
ScalarValue::List(
ScalarValue::new_list(
Some(vec![ScalarValue::from(7i32), ScalarValue::from(8i32)]),
Box::new(Field::new("item", DataType::Int32, true)),
DataType::Int32,
),
]),
Box::new(Field::new(
"item",
DataType::List(Box::new(Field::new("item", DataType::Int32, true))),
true,
)),
DataType::List(Box::new(Field::new("item", DataType::Int32, true))),
);

// [[9]]
let l3 = ScalarValue::List(
Some(vec![ScalarValue::List(
let l3 = ScalarValue::new_list(
Some(vec![ScalarValue::new_list(
Some(vec![ScalarValue::from(9i32)]),
Box::new(Field::new("item", DataType::Int32, true)),
DataType::Int32,
)]),
Box::new(Field::new(
"item",
DataType::List(Box::new(Field::new("item", DataType::Int32, true))),
true,
)),
DataType::List(Box::new(Field::new("item", DataType::Int32, true))),
);

let list = ScalarValue::List(
let list = ScalarValue::new_list(
Some(vec![l1.clone(), l2.clone(), l3.clone()]),
Box::new(Field::new(
"item",
DataType::List(Box::new(Field::new("item", DataType::Int32, true))),
true,
)),
DataType::List(Box::new(Field::new("item", DataType::Int32, true))),
);

// Duplicate l1 in the input array and check that it is deduped in the output.
Expand Down
Loading

0 comments on commit e0a9fa3

Please sign in to comment.