Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GroupsAccumulator for ArrayAgg #229

Open
wants to merge 21 commits into
base: v37
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 119 additions & 3 deletions datafusion/common/src/hash_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ use arrow::{downcast_dictionary_array, downcast_primitive_array};
use arrow_buffer::i256;

use crate::cast::{
as_boolean_array, as_generic_binary_array, as_primitive_array, as_string_array,
as_boolean_array, as_generic_binary_array, as_large_list_array, as_list_array,
as_primitive_array, as_string_array, as_struct_array,
};
use crate::error::{DataFusionError, Result, _internal_err};

Expand Down Expand Up @@ -207,6 +208,32 @@ fn hash_dictionary<K: ArrowDictionaryKeyType>(
Ok(())
}

fn hash_struct_array(
array: &StructArray,
random_state: &RandomState,
hashes_buffer: &mut [u64],
) -> Result<()> {
let nulls = array.nulls();
let row_len = array.len();

let valid_row_indices: Vec<usize> = if let Some(nulls) = nulls {
nulls.valid_indices().collect()
} else {
(0..row_len).collect()
};

// Create hashes for each row that combines the hashes over all the column at that row.
let mut values_hashes = vec![0u64; row_len];
create_hashes(array.columns(), random_state, &mut values_hashes)?;

for i in valid_row_indices {
let hash = &mut hashes_buffer[i];
*hash = combine_hashes(*hash, values_hashes[i]);
}

Ok(())
}

fn hash_list_array<OffsetSize>(
array: &GenericListArray<OffsetSize>,
random_state: &RandomState,
Expand Down Expand Up @@ -327,12 +354,16 @@ pub fn create_hashes<'a>(
array => hash_dictionary(array, random_state, hashes_buffer, rehash)?,
_ => unreachable!()
}
DataType::Struct(_) => {
let array = as_struct_array(array)?;
hash_struct_array(array, random_state, hashes_buffer)?;
}
DataType::List(_) => {
let array = as_list_array(array);
let array = as_list_array(array)?;
hash_list_array(array, random_state, hashes_buffer)?;
}
DataType::LargeList(_) => {
let array = as_large_list_array(array);
let array = as_large_list_array(array)?;
hash_list_array(array, random_state, hashes_buffer)?;
}
_ => {
Expand Down Expand Up @@ -515,6 +546,91 @@ mod tests {
assert_eq!(hashes[2], hashes[3]);
}

#[test]
// Tests actual values of hashes, which are different if forcing collisions
#[cfg(not(feature = "force_hash_collisions"))]
fn create_hashes_for_struct_arrays() {
use arrow_buffer::Buffer;

let boolarr = Arc::new(BooleanArray::from(vec![
false, false, true, true, true, true,
]));
let i32arr = Arc::new(Int32Array::from(vec![10, 10, 20, 20, 30, 31]));

let struct_array = StructArray::from((
vec![
(
Arc::new(Field::new("bool", DataType::Boolean, false)),
boolarr.clone() as ArrayRef,
),
(
Arc::new(Field::new("i32", DataType::Int32, false)),
i32arr.clone() as ArrayRef,
),
(
Arc::new(Field::new("i32", DataType::Int32, false)),
i32arr.clone() as ArrayRef,
),
(
Arc::new(Field::new("bool", DataType::Boolean, false)),
boolarr.clone() as ArrayRef,
),
],
Buffer::from(&[0b001011]),
));

assert!(struct_array.is_valid(0));
assert!(struct_array.is_valid(1));
assert!(struct_array.is_null(2));
assert!(struct_array.is_valid(3));
assert!(struct_array.is_null(4));
assert!(struct_array.is_null(5));

let array = Arc::new(struct_array) as ArrayRef;

let random_state = RandomState::with_seeds(0, 0, 0, 0);
let mut hashes = vec![0; array.len()];
create_hashes(&[array], &random_state, &mut hashes).unwrap();
assert_eq!(hashes[0], hashes[1]);
// same value but the third row ( hashes[2] ) is null
assert_ne!(hashes[2], hashes[3]);
// different values but both are null
assert_eq!(hashes[4], hashes[5]);
}

#[test]
// Tests actual values of hashes, which are different if forcing collisions
#[cfg(not(feature = "force_hash_collisions"))]
fn create_hashes_for_struct_arrays_more_column_than_row() {
let struct_array = StructArray::from(vec![
(
Arc::new(Field::new("bool", DataType::Boolean, false)),
Arc::new(BooleanArray::from(vec![false, false])) as ArrayRef,
),
(
Arc::new(Field::new("i32-1", DataType::Int32, false)),
Arc::new(Int32Array::from(vec![10, 10])) as ArrayRef,
),
(
Arc::new(Field::new("i32-2", DataType::Int32, false)),
Arc::new(Int32Array::from(vec![10, 10])) as ArrayRef,
),
(
Arc::new(Field::new("i32-3", DataType::Int32, false)),
Arc::new(Int32Array::from(vec![10, 10])) as ArrayRef,
),
]);

assert!(struct_array.is_valid(0));
assert!(struct_array.is_valid(1));

let array = Arc::new(struct_array) as ArrayRef;
let random_state = RandomState::with_seeds(0, 0, 0, 0);
let mut hashes = vec![0; array.len()];
create_hashes(&[array], &random_state, &mut hashes).unwrap();
assert_eq!(hashes[0], hashes[1]);
}

#[test]
// Tests actual values of hashes, which are different if forcing collisions
#[cfg(not(feature = "force_hash_collisions"))]
Expand Down
96 changes: 89 additions & 7 deletions datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ use arrow::{
use arrow_array::cast::as_list_array;
use arrow_array::types::ArrowTimestampType;
use arrow_array::{ArrowNativeTypeOp, Scalar};
use arrow_buffer::Buffer;
use arrow_schema::{UnionFields, UnionMode};

/// A dynamically typed, nullable single value, (the single-valued counter-part
/// to arrow's [`Array`])
Expand Down Expand Up @@ -187,6 +189,11 @@ pub enum ScalarValue {
DurationNanosecond(Option<i64>),
/// struct of nested ScalarValue
Struct(Option<Vec<ScalarValue>>, Fields),
/// A nested datatype that can represent slots of differing types. Components:
/// `.0`: a tuple of union `type_id` and the single value held by this Scalar
/// `.1`: the list of fields, zero-to-one of which will by set in `.0`
/// `.2`: the physical storage of the source/destination UnionArray from which this Scalar came
Union(Option<(i8, Box<ScalarValue>)>, UnionFields, UnionMode),
/// Dictionary type: index type and value
Dictionary(Box<DataType>, Box<ScalarValue>),
}
Expand Down Expand Up @@ -287,6 +294,10 @@ impl PartialEq for ScalarValue {
(IntervalMonthDayNano(_), _) => false,
(Struct(v1, t1), Struct(v2, t2)) => v1.eq(v2) && t1.eq(t2),
(Struct(_, _), _) => false,
(Union(val1, fields1, mode1), Union(val2, fields2, mode2)) => {
val1.eq(val2) && fields1.eq(fields2) && mode1.eq(mode2)
}
(Union(_, _, _), _) => false,
(Dictionary(k1, v1), Dictionary(k2, v2)) => k1.eq(k2) && v1.eq(v2),
(Dictionary(_, _), _) => false,
(Null, Null) => true,
Expand Down Expand Up @@ -448,6 +459,14 @@ impl PartialOrd for ScalarValue {
}
}
(Struct(_, _), _) => None,
(Union(v1, t1, m1), Union(v2, t2, m2)) => {
if t1.eq(t2) && m1.eq(m2) {
v1.partial_cmp(v2)
} else {
None
}
}
(Union(_, _, _), _) => None,
(Dictionary(k1, v1), Dictionary(k2, v2)) => {
// Don't compare if the key types don't match (it is effectively a different datatype)
if k1 == k2 {
Expand Down Expand Up @@ -546,6 +565,11 @@ impl std::hash::Hash for ScalarValue {
v.hash(state);
t.hash(state);
}
Union(v, t, m) => {
v.hash(state);
t.hash(state);
m.hash(state);
}
Dictionary(k, v) => {
k.hash(state);
v.hash(state);
Expand Down Expand Up @@ -968,6 +992,7 @@ impl ScalarValue {
DataType::Duration(TimeUnit::Nanosecond)
}
ScalarValue::Struct(_, fields) => DataType::Struct(fields.clone()),
ScalarValue::Union(_, fields, mode) => DataType::Union(fields.clone(), *mode),
ScalarValue::Dictionary(k, v) => {
DataType::Dictionary(k.clone(), Box::new(v.data_type()))
}
Expand Down Expand Up @@ -1167,6 +1192,7 @@ impl ScalarValue {
ScalarValue::DurationMicrosecond(v) => v.is_none(),
ScalarValue::DurationNanosecond(v) => v.is_none(),
ScalarValue::Struct(v, _) => v.is_none(),
ScalarValue::Union(v, _, _) => v.is_none(),
ScalarValue::Dictionary(_, v) => v.is_null(),
}
}
Expand Down Expand Up @@ -1992,6 +2018,39 @@ impl ScalarValue {
new_null_array(&dt, size)
}
},
ScalarValue::Union(value, fields, _mode) => match value {
Some((v_id, value)) => {
let mut field_type_ids = Vec::<i8>::with_capacity(fields.len());
let mut child_arrays =
Vec::<(Field, ArrayRef)>::with_capacity(fields.len());
for (f_id, field) in fields.iter() {
let ar = if f_id == *v_id {
value.to_array_of_size(size)?
} else {
let dt = field.data_type();
new_null_array(dt, size)
};
let field = (**field).clone();
child_arrays.push((field, ar));
field_type_ids.push(f_id);
}
let type_ids = repeat(*v_id).take(size).collect::<Vec<_>>();
let type_ids = Buffer::from_slice_ref(type_ids);
let value_offsets: Option<Buffer> = None;
let ar = UnionArray::try_new(
field_type_ids.as_slice(),
type_ids,
value_offsets,
child_arrays,
)
.map_err(|e| DataFusionError::ArrowError(e))?;
Arc::new(ar)
}
None => {
let dt = self.data_type();
new_null_array(&dt, size)
}
},
ScalarValue::Dictionary(key_type, v) => {
// values array is one element long (the value)
match key_type.as_ref() {
Expand Down Expand Up @@ -2492,6 +2551,9 @@ impl ScalarValue {
ScalarValue::Struct(_, _) => {
return _not_impl_err!("Struct is not supported yet")
}
ScalarValue::Union(_, _, _) => {
return _not_impl_err!("Union is not supported yet")
}
ScalarValue::Dictionary(key_type, v) => {
let (values_array, values_index) = match key_type.as_ref() {
DataType::Int8 => get_dict_value::<Int8Type>(array, index)?,
Expand Down Expand Up @@ -2565,17 +2627,26 @@ impl ScalarValue {
| ScalarValue::FixedSizeList(arr) => arr.get_array_memory_size(),
ScalarValue::Struct(vals, fields) => {
vals.as_ref()
.map(|vals| {
vals.iter()
.map(|sv| sv.size() - std::mem::size_of_val(sv))
.sum::<usize>()
+ (std::mem::size_of::<ScalarValue>() * vals.capacity())
})
.map(|vals| {
vals.iter()
.map(|sv| sv.size() - std::mem::size_of_val(sv))
.sum::<usize>()
+ (std::mem::size_of::<ScalarValue>() * vals.capacity())
})
.unwrap_or_default()
// `fields` is boxed, so it is NOT already included in `self`
+ std::mem::size_of_val(fields)
+ (std::mem::size_of::<Field>() * fields.len())
+ fields.iter().map(|field| field.size() - std::mem::size_of_val(field)).sum::<usize>()
}
ScalarValue::Union(vals, fields, _mode) => {
vals.as_ref()
.map(|(_id, sv)| sv.size() - std::mem::size_of_val(sv))
.unwrap_or_default()
// `fields` is boxed, so it is NOT already included in `self`
+ std::mem::size_of_val(fields)
+ (std::mem::size_of::<Field>() * fields.len())
+ fields.iter().map(|field| field.size() - std::mem::size_of_val(field)).sum::<usize>()
+ fields.iter().map(|(_idx, field)| field.size() - std::mem::size_of_val(field)).sum::<usize>()
}
ScalarValue::Dictionary(dt, sv) => {
// `dt` and `sv` are boxed, so they are NOT already included in `self`
Expand Down Expand Up @@ -2873,6 +2944,9 @@ impl TryFrom<&DataType> for ScalarValue {
1,
)),
DataType::Struct(fields) => ScalarValue::Struct(None, fields.clone()),
DataType::Union(fields, mode) => {
ScalarValue::Union(None, fields.clone(), *mode)
}
DataType::Null => ScalarValue::Null,
_ => {
return _not_impl_err!(
Expand Down Expand Up @@ -2971,6 +3045,10 @@ impl fmt::Display for ScalarValue {
)?,
None => write!(f, "NULL")?,
},
ScalarValue::Union(val, _fields, _mode) => match val {
Some((id, val)) => write!(f, "{}:{}", id, val)?,
None => write!(f, "NULL")?,
},
ScalarValue::Dictionary(_k, v) => write!(f, "{v}")?,
ScalarValue::Null => write!(f, "NULL")?,
};
Expand Down Expand Up @@ -3069,6 +3147,10 @@ impl fmt::Debug for ScalarValue {
None => write!(f, "Struct(NULL)"),
}
}
ScalarValue::Union(val, _fields, _mode) => match val {
Some((id, val)) => write!(f, "Union {}:{}", id, val),
None => write!(f, "Union(NULL)"),
},
ScalarValue::Dictionary(k, v) => write!(f, "Dictionary({k:?}, {v:?})"),
ScalarValue::Null => write!(f, "NULL"),
}
Expand Down
7 changes: 6 additions & 1 deletion datafusion/execution/src/memory_pool/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ pub trait MemoryPool: Send + Sync + std::fmt::Debug {

/// Return the total amount of memory reserved
fn reserved(&self) -> usize;

/// Return the configured pool size (if any)
fn pool_size(&self) -> Option<usize>;
}

/// A memory consumer that can be tracked by [`MemoryReservation`] in
Expand Down Expand Up @@ -286,7 +289,9 @@ mod tests {

#[test]
fn test_memory_pool_underflow() {
let pool = Arc::new(GreedyMemoryPool::new(50)) as _;
let pool: Arc<dyn MemoryPool> = Arc::new(GreedyMemoryPool::new(50)) as _;
assert_eq!(pool.pool_size(), Some(50));

let mut a1 = MemoryConsumer::new("a1").register(&pool);
assert_eq!(pool.reserved(), 0);

Expand Down
Loading
Loading