Skip to content

Commit

Permalink
Add was_valid parameter to NullState callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
joroKr21 committed Jul 22, 2024
1 parent b1e9c36 commit 13601c8
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 22 deletions.
7 changes: 3 additions & 4 deletions datafusion/functions-aggregate/src/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -450,10 +450,9 @@ where
values,
opt_filter,
total_num_groups,
|group_index, new_value| {
|group_index, _, new_value| {
let sum = &mut self.sums[group_index];
*sum = sum.add_wrapping(new_value);

self.counts[group_index] += 1;
},
);
Expand Down Expand Up @@ -533,7 +532,7 @@ where
partial_counts,
opt_filter,
total_num_groups,
|group_index, partial_count| {
|group_index, _, partial_count| {
self.counts[group_index] += partial_count;
},
);
Expand All @@ -545,7 +544,7 @@ where
partial_sums,
opt_filter,
total_num_groups,
|group_index, new_value: <T as ArrowPrimitiveType>::Native| {
|group_index, _, new_value: <T as ArrowPrimitiveType>::Native| {
let sum = &mut self.sums[group_index];
*sum = sum.add_wrapping(new_value);
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ impl NullState {
mut value_fn: F,
) where
T: ArrowPrimitiveType + Send,
F: FnMut(usize, T::Native) + Send,
F: FnMut(usize, bool, T::Native) + Send,
{
let data: &[T::Native] = values.values();
assert_eq!(data.len(), group_indices.len());
Expand All @@ -147,8 +147,9 @@ impl NullState {
(false, None) => {
let iter = group_indices.iter().zip(data.iter());
for (&group_index, &new_value) in iter {
let was_valid = seen_values.get_bit(group_index);
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value);
value_fn(group_index, was_valid, new_value);
}
}
// nulls, no filter
Expand All @@ -174,8 +175,9 @@ impl NullState {
// valid bit was set, real value
let is_valid = (mask & index_mask) != 0;
if is_valid {
let was_valid = seen_values.get_bit(group_index);
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value);
value_fn(group_index, was_valid, new_value);
}
index_mask <<= 1;
},
Expand All @@ -191,8 +193,9 @@ impl NullState {
.for_each(|(i, (&group_index, &new_value))| {
let is_valid = remainder_bits & (1 << i) != 0;
if is_valid {
let was_valid = seen_values.get_bit(group_index);
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value);
value_fn(group_index, was_valid, new_value);
}
});
}
Expand All @@ -208,8 +211,9 @@ impl NullState {
.zip(filter.iter())
.for_each(|((&group_index, &new_value), filter_value)| {
if let Some(true) = filter_value {
let was_valid = seen_values.get_bit(group_index);
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value);
value_fn(group_index, was_valid, new_value);
}
})
}
Expand All @@ -226,8 +230,9 @@ impl NullState {
.for_each(|((filter_value, &group_index), new_value)| {
if let Some(true) = filter_value {
if let Some(new_value) = new_value {
let was_valid = seen_values.get_bit(group_index);
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value)
value_fn(group_index, was_valid, new_value)
}
}
})
Expand All @@ -253,7 +258,7 @@ impl NullState {
total_num_groups: usize,
mut value_fn: F,
) where
F: FnMut(usize, bool) + Send,
F: FnMut(usize, bool, bool) + Send,
{
let data = values.values();
assert_eq!(data.len(), group_indices.len());
Expand All @@ -271,8 +276,9 @@ impl NullState {
// buffer is big enough (start everything at valid)
group_indices.iter().zip(data.iter()).for_each(
|(&group_index, new_value)| {
let was_valid = seen_values.get_bit(group_index);
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value)
value_fn(group_index, was_valid, new_value)
},
)
}
Expand All @@ -285,8 +291,9 @@ impl NullState {
.zip(nulls.iter())
.for_each(|((&group_index, new_value), is_valid)| {
if is_valid {
let was_valid = seen_values.get_bit(group_index);
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value);
value_fn(group_index, was_valid, new_value);
}
})
}
Expand All @@ -300,8 +307,9 @@ impl NullState {
.zip(filter.iter())
.for_each(|((&group_index, new_value), filter_value)| {
if let Some(true) = filter_value {
let was_valid = seen_values.get_bit(group_index);
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value);
value_fn(group_index, was_valid, new_value);
}
})
}
Expand All @@ -315,8 +323,9 @@ impl NullState {
.for_each(|((filter_value, &group_index), new_value)| {
if let Some(true) = filter_value {
if let Some(new_value) = new_value {
let was_valid = seen_values.get_bit(group_index);
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value)
value_fn(group_index, was_valid, new_value)
}
}
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,14 @@ where
values,
opt_filter,
total_num_groups,
|group_index, new_value| {
let current_value = self.values.get_bit(group_index);
let value = (self.bool_fn)(current_value, new_value);
self.values.set_bit(group_index, value);
|group_index, was_valid, new_value| {
if was_valid {
let current_value = self.values.get_bit(group_index);
let value = (self.bool_fn)(current_value, new_value);
self.values.set_bit(group_index, value)
} else {
self.values.set_bit(group_index, new_value)
}
},
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,13 @@ where
values,
opt_filter,
total_num_groups,
|group_index, new_value| {
let value = &mut self.values[group_index];
(self.prim_fn)(value, new_value);
|group_index, was_valid, new_value| {
if was_valid {
let value = &mut self.values[group_index];
(self.prim_fn)(value, new_value)
} else {
self.values[group_index] = new_value
}
},
);

Expand Down

0 comments on commit 13601c8

Please sign in to comment.