Skip to content

Commit

Permalink
Add support for group by hash of a null column, tests for same
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Jul 29, 2021
1 parent 712bf71 commit b0d834a
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 5 deletions.
60 changes: 57 additions & 3 deletions datafusion/src/physical_plan/hash_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,10 @@ fn group_aggregate_batch(
// We can safely unwrap here as we checked we can create an accumulator before
let accumulator_set = create_accumulators(aggr_expr).unwrap();
batch_keys.push(key.clone());
let _ = create_group_by_values(&group_values, row, &mut group_by_values);
// Note it would be nice to make this a real error (rather than panic)
// but it is better than silently ignoring the issue and getting wrong results
create_group_by_values(&group_values, row, &mut group_by_values)
.expect("can not create group by value");
(
key.clone(),
(group_by_values.clone(), accumulator_set, vec![row as u32]),
Expand Down Expand Up @@ -508,7 +511,9 @@ fn dictionary_create_key_for_col<K: ArrowDictionaryKeyType>(
}

/// Appends a sequence of [u8] bytes for the value in `col[row]` to
/// `vec` to be used as a key into the hash map
/// `vec` to be used as a key into the hash map.
///
/// NOTE: This functon does not check col.is_valid(). Caller must do so
fn create_key_for_col(col: &ArrayRef, row: usize, vec: &mut Vec<u8>) -> Result<()> {
match col.data_type() {
DataType::Boolean => {
Expand Down Expand Up @@ -640,14 +645,63 @@ fn create_key_for_col(col: &ArrayRef, row: usize, vec: &mut Vec<u8>) -> Result<(
}

/// Create a key `Vec<u8>` that is used as key for the hashmap
///
/// This looks like
/// [null_byte][col_value_bytes][null_byte][col_value_bytes]
///
/// Note that relatively uncommon patterns (e.g. not 0x00) are chosen
/// for the null_byte to make debugging easier. The actual values are
/// arbitrary.
///
/// For a NULL value in a column, the key looks like
/// [0xFE]
///
/// For a Non-NULL value in a column, this looks like:
/// [0xFF][byte representation of column value]
///
/// Example of a key with no NULL values:
/// ```text
/// 0xFF byte at the start of each column
/// signifies the value is non-null
/// │
///
/// ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┴ ─ ─ ─ ─ ─ ─ ─ ┐
///
/// │ string len │ 0x1234
/// { ▼ (as usize le) "foo" ▼(as u16 le)
/// k1: "foo" ╔ ═┌──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──╦ ═┌──┬──┐
/// k2: 0x1234u16 FF║03│00│00│00│00│00│00│00│"f│"o│"o│FF║34│12│
/// } ╚ ═└──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──╩ ═└──┴──┘
/// 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14
/// ```
///
/// Example of a key with NULL values:
///
///```text
/// 0xFE byte at the start of k1 column
/// ┌ ─ signifies the value is NULL
///
/// └ ┐
/// 0x1234
/// { ▼ (as u16 le)
/// k1: NULL ╔ ═╔ ═┌──┬──┐
/// k2: 0x1234u16 FE║FF║12│34│
/// } ╚ ═╚ ═└──┴──┘
/// 0 1 2 3
///```
pub(crate) fn create_key(
group_by_keys: &[ArrayRef],
row: usize,
vec: &mut Vec<u8>,
) -> Result<()> {
vec.clear();
for col in group_by_keys {
create_key_for_col(col, row, vec)?
if !col.is_valid(row) {
vec.push(0xFE);
} else {
vec.push(0xFF);
create_key_for_col(col, row, vec)?
}
}
Ok(())
}
Expand Down
37 changes: 35 additions & 2 deletions datafusion/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use arrow::{
},
};
use ordered_float::OrderedFloat;
use std::convert::Infallible;
use std::convert::{Infallible, TryInto};
use std::str::FromStr;
use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc};

Expand Down Expand Up @@ -796,6 +796,11 @@ impl ScalarValue {

/// Converts a value in `array` at `index` into a ScalarValue
pub fn try_from_array(array: &ArrayRef, index: usize) -> Result<Self> {
// handle NULL value
if !array.is_valid(index) {
return array.data_type().try_into();
}

Ok(match array.data_type() {
DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean),
DataType::Float64 => typed_cast!(array, index, Float64Array, Float64),
Expand Down Expand Up @@ -897,6 +902,7 @@ impl ScalarValue {
let dict_array = array.as_any().downcast_ref::<DictionaryArray<K>>().unwrap();

// look up the index in the values dictionary
// (note validity was previously checked in `try_from_array`)
let keys_col = dict_array.keys();
let values_index = keys_col.value(index).to_usize().ok_or_else(|| {
DataFusionError::Internal(format!(
Expand Down Expand Up @@ -1132,6 +1138,7 @@ impl_try_from!(Boolean, bool);
impl TryFrom<&DataType> for ScalarValue {
type Error = DataFusionError;

/// Create a Null instance of ScalarValue for this datatype
fn try_from(datatype: &DataType) -> Result<Self> {
Ok(match datatype {
DataType::Boolean => ScalarValue::Boolean(None),
Expand Down Expand Up @@ -1161,12 +1168,15 @@ impl TryFrom<&DataType> for ScalarValue {
DataType::Timestamp(TimeUnit::Nanosecond, _) => {
ScalarValue::TimestampNanosecond(None)
}
DataType::Dictionary(_index_type, value_type) => {
value_type.as_ref().try_into()?
}
DataType::List(ref nested_type) => {
ScalarValue::List(None, Box::new(nested_type.data_type().clone()))
}
_ => {
return Err(DataFusionError::NotImplemented(format!(
"Can't create a scalar of type \"{:?}\"",
"Can't create a scalar from data_type \"{:?}\"",
datatype
)))
}
Expand Down Expand Up @@ -1535,6 +1545,29 @@ mod tests {
"{}", result);
}

#[test]
fn scalar_try_from_array_null() {
let array = vec![Some(33), None].into_iter().collect::<Int64Array>();
let array: ArrayRef = Arc::new(array);

assert_eq!(
ScalarValue::Int64(Some(33)),
ScalarValue::try_from_array(&array, 0).unwrap()
);
assert_eq!(
ScalarValue::Int64(None),
ScalarValue::try_from_array(&array, 1).unwrap()
);
}

#[test]
fn scalar_try_from_dict_datatype() {
let data_type =
DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8));
let data_type = &data_type;
assert_eq!(ScalarValue::Utf8(None), data_type.try_into().unwrap())
}

#[test]
fn size_of_scalar() {
// Since ScalarValues are used in a non trivial number of places,
Expand Down
110 changes: 110 additions & 0 deletions datafusion/tests/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3014,6 +3014,109 @@ async fn query_count_distinct() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn query_group_on_null() -> Result<()> {
let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, true)]));

let data = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from(vec![
Some(0),
Some(3),
None,
Some(1),
Some(3),
]))],
)?;

let table = MemTable::try_new(schema, vec![vec![data]])?;

let mut ctx = ExecutionContext::new();
ctx.register_table("test", Arc::new(table))?;
let sql = "SELECT COUNT(*), c1 FROM test GROUP BY c1";

let actual = execute_to_batches(&mut ctx, sql).await;

// Note that the results also
// include a row for NULL (c1=NULL, count = 1)
let expected = vec![
"+-----------------+----+",
"| COUNT(UInt8(1)) | c1 |",
"+-----------------+----+",
"| 1 | |",
"| 1 | 0 |",
"| 1 | 1 |",
"| 2 | 3 |",
"+-----------------+----+",
];
assert_batches_sorted_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn query_group_on_null_multi_col() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("c1", DataType::Int32, true),
Field::new("c2", DataType::Utf8, true),
]));

let data = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![
Some(0),
Some(0),
Some(3),
None,
None,
Some(3),
Some(0),
None,
Some(3),
])),
Arc::new(StringArray::from(vec![
None,
None,
Some("foo"),
None,
Some("bar"),
Some("foo"),
None,
Some("bar"),
Some("foo"),
])),
],
)?;

let table = MemTable::try_new(schema, vec![vec![data]])?;

let mut ctx = ExecutionContext::new();
ctx.register_table("test", Arc::new(table))?;
let sql = "SELECT COUNT(*), c1, c2 FROM test GROUP BY c1, c2";

let actual = execute_to_batches(&mut ctx, sql).await;

// Note that the results also include values for null
// include a row for NULL (c1=NULL, count = 1)
let expected = vec![
"+-----------------+----+-----+",
"| COUNT(UInt8(1)) | c1 | c2 |",
"+-----------------+----+-----+",
"| 1 | | |",
"| 2 | | bar |",
"| 3 | 0 | |",
"| 3 | 3 | foo |",
"+-----------------+----+-----+",
];
assert_batches_sorted_eq!(expected, &actual);

// Also run query with group columns reversed (results shoudl be the same)
let sql = "SELECT COUNT(*), c1, c2 FROM test GROUP BY c2, c1";
let actual = execute_to_batches(&mut ctx, sql).await;
assert_batches_sorted_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn query_on_string_dictionary() -> Result<()> {
// Test to ensure DataFusion can operate on dictionary types
Expand Down Expand Up @@ -3067,6 +3170,13 @@ async fn query_on_string_dictionary() -> Result<()> {
let expected = vec![vec!["2"]];
assert_eq!(expected, actual);

// grouping
let sql = "SELECT d1, COUNT(*) FROM test group by d1";
let mut actual = execute(&mut ctx, sql).await;
actual.sort();
let expected = vec![vec!["NULL", "1"], vec!["one", "1"], vec!["three", "1"]];
assert_eq!(expected, actual);

Ok(())
}

Expand Down

0 comments on commit b0d834a

Please sign in to comment.