Skip to content

Commit

Permalink
ARROW-10722: [Rust][DataFusion] Reduce overhead of some data types in…
Browse files Browse the repository at this point in the history
… aggregations / joins, improve benchmarks

This PR reduces the size of `GroupByScalar` from 32 bytes to 16 bytes by using `Box<String>`. This will reduce the size of a `Vec<GroupByScalar>` and thus the key of hashmaps used for aggregates / joins.
Also, it changes the type of the key to `Box<[GroupByScalar]>` to reduce memory usage further by 8 bytes per key needed to hold the capacity of the vec.
Finally we can remove a  `Box` around the `Vec` holding the indices.

Difference in speed seems to be minimal, at least in current state.

I think in the future, it could be nice to see if the data could be packed efficiently in one `Box<[T]>` (where T is a primitive value) when having no dynamically sized types by using the schema instead of creating "dynamic" values. That should also make the hashing faster. Currently, when grouping on multiple i32 values, we need 32 bytes per value (next to 24 bytes for the Vec holding the values) instead of just 4! Also using const generics https://rust-lang.github.io/rfcs/2000-const-generics.html#:~:text=Rust%20currently%20has%20one%20type,implement%20traits%20for%20all%20arrays could provide a further improvement (by not having to store the length of the slice).

This PR also tries to improve reproducability in the benchmarks a bit by using the seed in the random number generator (still a quite noisy on my machine though).

Closes #8765 from Dandandan/reduce_key_size

Lead-authored-by: Heres, Daniel <danielheres@gmail.com>
Co-authored-by: Daniël Heres <danielheres@gmail.com>
Signed-off-by: Jorge C. Leitao <jorgecarleitao@gmail.com>
  • Loading branch information
Dandandan authored and jorgecarleitao committed Dec 7, 2020
1 parent 49f23a1 commit 8711ca9
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 22 deletions.
42 changes: 32 additions & 10 deletions rust/datafusion/benches/aggregate_query_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
extern crate criterion;
use criterion::Criterion;

use rand::seq::SliceRandom;
use rand::Rng;
use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng};
use std::sync::{Arc, Mutex};
use tokio::runtime::Runtime;

Expand All @@ -40,6 +39,10 @@ use datafusion::datasource::MemTable;
use datafusion::error::Result;
use datafusion::execution::context::ExecutionContext;

pub fn seedable_rng() -> StdRng {
StdRng::seed_from_u64(42)
}

fn query(ctx: Arc<Mutex<ExecutionContext>>, sql: &str) {
let mut rt = Runtime::new().unwrap();

Expand All @@ -50,7 +53,7 @@ fn query(ctx: Arc<Mutex<ExecutionContext>>, sql: &str) {

fn create_data(size: usize, null_density: f64) -> Vec<Option<f64>> {
// use random numbers to avoid spurious compiler optimizations wrt to branching
let mut rng = rand::thread_rng();
let mut rng = seedable_rng();

(0..size)
.map(|_| {
Expand All @@ -65,7 +68,7 @@ fn create_data(size: usize, null_density: f64) -> Vec<Option<f64>> {

fn create_integer_data(size: usize, value_density: f64) -> Vec<Option<u64>> {
// use random numbers to avoid spurious compiler optimizations wrt to branching
let mut rng = rand::thread_rng();
let mut rng = seedable_rng();

(0..size)
.map(|_| {
Expand Down Expand Up @@ -98,6 +101,8 @@ fn create_context(
Field::new("u64_narrow", DataType::UInt64, false),
]));

let mut rng = seedable_rng();

// define data.
let partitions = (0..partitions_len)
.map(|_| {
Expand All @@ -109,7 +114,7 @@ fn create_context(
let keys: Vec<String> = (0..batch_size)
.map(
// use random numbers to avoid spurious compiler optimizations wrt to branching
|_| format!("hi{:?}", vs.choose(&mut rand::thread_rng())),
|_| format!("hi{:?}", vs.choose(&mut rng)),
)
.collect();
let keys: Vec<&str> = keys.iter().map(|e| &**e).collect();
Expand All @@ -122,11 +127,7 @@ fn create_context(
// Integer values between [0, 9].
let integer_values_narrow_choices = (0..10).collect::<Vec<u64>>();
let integer_values_narrow = (0..batch_size)
.map(|_| {
*integer_values_narrow_choices
.choose(&mut rand::thread_rng())
.unwrap()
})
.map(|_| *integer_values_narrow_choices.choose(&mut rng).unwrap())
.collect::<Vec<u64>>();

RecordBatch::try_new(
Expand Down Expand Up @@ -216,6 +217,27 @@ fn criterion_benchmark(c: &mut Criterion) {
)
})
});

c.bench_function("aggregate_query_group_by_u64 15 12", |b| {
b.iter(|| {
query(
ctx.clone(),
"SELECT u64_narrow, MIN(f64), AVG(f64), COUNT(f64) \
FROM t GROUP BY u64_narrow",
)
})
});

c.bench_function("aggregate_query_group_by_with_filter_u64 15 12", |b| {
b.iter(|| {
query(
ctx.clone(),
"SELECT u64_narrow, MIN(f64), AVG(f64), COUNT(f64) \
FROM t \
WHERE f32 > 10 AND f32 < 20 GROUP BY u64_narrow",
)
})
});
}

criterion_group!(benches, criterion_benchmark);
Expand Down
11 changes: 8 additions & 3 deletions rust/datafusion/src/physical_plan/group_scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pub(crate) enum GroupByScalar {
Int16(i16),
Int32(i32),
Int64(i64),
Utf8(String),
Utf8(Box<String>),
}

impl TryFrom<&ScalarValue> for GroupByScalar {
Expand All @@ -50,7 +50,7 @@ impl TryFrom<&ScalarValue> for GroupByScalar {
ScalarValue::UInt16(Some(v)) => GroupByScalar::UInt16(*v),
ScalarValue::UInt32(Some(v)) => GroupByScalar::UInt32(*v),
ScalarValue::UInt64(Some(v)) => GroupByScalar::UInt64(*v),
ScalarValue::Utf8(Some(v)) => GroupByScalar::Utf8(v.clone()),
ScalarValue::Utf8(Some(v)) => GroupByScalar::Utf8(Box::new(v.clone())),
ScalarValue::Int8(None)
| ScalarValue::Int16(None)
| ScalarValue::Int32(None)
Expand Down Expand Up @@ -86,7 +86,7 @@ impl From<&GroupByScalar> for ScalarValue {
GroupByScalar::UInt16(v) => ScalarValue::UInt16(Some(*v)),
GroupByScalar::UInt32(v) => ScalarValue::UInt32(Some(*v)),
GroupByScalar::UInt64(v) => ScalarValue::UInt64(Some(*v)),
GroupByScalar::Utf8(v) => ScalarValue::Utf8(Some(v.clone())),
GroupByScalar::Utf8(v) => ScalarValue::Utf8(Some(v.to_string())),
}
}
}
Expand Down Expand Up @@ -131,4 +131,9 @@ mod tests {

Ok(())
}

#[test]
fn size_of_group_by_scalar() {
assert_eq!(std::mem::size_of::<GroupByScalar>(), 16);
}
}
16 changes: 10 additions & 6 deletions rust/datafusion/src/physical_plan/hash_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,8 @@ fn group_aggregate_batch(
key.push(GroupByScalar::UInt32(0));
}

let mut key = key.into_boxed_slice();

// 1.1 construct the key from the group values
// 1.2 construct the mapping key if it does not exist
// 1.3 add the row' index to `indices`
Expand All @@ -270,7 +272,7 @@ fn group_aggregate_batch(
.or_insert_with(|| {
// We can safely unwrap here as we checked we can create an accumulator before
let accumulator_set = create_accumulators(aggr_expr).unwrap();
(key.clone(), (accumulator_set, Box::new(vec![row as u32])))
(key.clone(), (accumulator_set, vec![row as u32]))
});
}

Expand All @@ -296,7 +298,7 @@ fn group_aggregate_batch(
// 2.3
compute::take(
array,
&UInt32Array::from(*indices.clone()),
&UInt32Array::from(indices.clone()),
None, // None: no index check
)
.unwrap()
Expand Down Expand Up @@ -389,7 +391,7 @@ impl GroupedHashAggregateStream {

type AccumulatorSet = Vec<Box<dyn Accumulator>>;
type Accumulators =
HashMap<Vec<GroupByScalar>, (AccumulatorSet, Box<Vec<u32>>), RandomState>;
HashMap<Box<[GroupByScalar]>, (AccumulatorSet, Vec<u32>), RandomState>;

impl Stream for GroupedHashAggregateStream {
type Item = ArrowResult<RecordBatch>;
Expand Down Expand Up @@ -658,7 +660,9 @@ fn create_batch_from_map(
GroupByScalar::UInt16(n) => Arc::new(UInt16Array::from(vec![*n])),
GroupByScalar::UInt32(n) => Arc::new(UInt32Array::from(vec![*n])),
GroupByScalar::UInt64(n) => Arc::new(UInt64Array::from(vec![*n])),
GroupByScalar::Utf8(str) => Arc::new(StringArray::from(vec![&**str])),
GroupByScalar::Utf8(str) => {
Arc::new(StringArray::from(vec![&***str]))
}
})
.collect::<Vec<ArrayRef>>();

Expand Down Expand Up @@ -726,7 +730,7 @@ fn finalize_aggregation(
pub(crate) fn create_key(
group_by_keys: &[ArrayRef],
row: usize,
vec: &mut Vec<GroupByScalar>,
vec: &mut Box<[GroupByScalar]>,
) -> Result<()> {
for i in 0..group_by_keys.len() {
let col = &group_by_keys[i];
Expand Down Expand Up @@ -765,7 +769,7 @@ pub(crate) fn create_key(
}
DataType::Utf8 => {
let array = col.as_any().downcast_ref::<StringArray>().unwrap();
vec[i] = GroupByScalar::Utf8(String::from(array.value(row)))
vec[i] = GroupByScalar::Utf8(Box::new(array.value(row).into()))
}
_ => {
// This is internal because we should have caught this before.
Expand Down
9 changes: 6 additions & 3 deletions rust/datafusion/src/physical_plan/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ type JoinIndex = Option<(usize, usize)>;
// Maps ["on" value] -> [list of indices with this key's value]
// E.g. [1, 2] -> [(0, 3), (1, 6), (0, 8)] indicates that (column1, column2) = [1, 2] is true
// for rows 3 and 8 from batch 0 and row 6 from batch 1.
type JoinHashMap = HashMap<Vec<GroupByScalar>, Vec<Index>, RandomState>;
type JoinHashMap = HashMap<Box<[GroupByScalar]>, Vec<Index>, RandomState>;
type JoinLeftData = (JoinHashMap, Vec<RecordBatch>);

/// join execution plan executes partitions in parallel and combines them into a set of
Expand Down Expand Up @@ -209,6 +209,8 @@ fn update_hash(
key.push(GroupByScalar::UInt32(0));
}

let mut key = key.into_boxed_slice();

// update the hash map
for row in 0..batch.num_rows() {
create_key(&keys_values, row, &mut key)?;
Expand Down Expand Up @@ -368,8 +370,9 @@ fn build_join_indexes(
JoinType::Inner => {
// inner => key intersection
// unfortunately rust does not support intersection of map keys :(
let left_set: HashSet<Vec<GroupByScalar>> = left.keys().cloned().collect();
let left_right: HashSet<Vec<GroupByScalar>> = right.keys().cloned().collect();
let left_set: HashSet<Box<[GroupByScalar]>> = left.keys().cloned().collect();
let left_right: HashSet<Box<[GroupByScalar]>> =
right.keys().cloned().collect();
let inner = left_set.intersection(&left_right);

let mut indexes = Vec::new(); // unknown a prior size
Expand Down

0 comments on commit 8711ca9

Please sign in to comment.