Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Union types in ScalarValue #9683

Merged
merged 6 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions datafusion/common/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pub enum DataFusionError {
IoError(io::Error),
/// Error when SQL is syntactically incorrect.
///
/// 2nd argument is for optional backtrace
/// 2nd argument is for optional backtrace
SQL(ParserError, Option<String>),
/// Error when a feature is not yet implemented.
///
Expand Down Expand Up @@ -101,7 +101,7 @@ pub enum DataFusionError {
/// This error can be returned in cases such as when schema inference is not
/// possible and when column names are not unique.
///
/// 2nd argument is for optional backtrace
/// 2nd argument is for optional backtrace
/// Boxing the optional backtrace to prevent <https://rust-lang.github.io/rust-clippy/master/index.html#/result_large_err>
SchemaError(SchemaError, Box<Option<String>>),
/// Error during execution of the query.
Expand Down
82 changes: 82 additions & 0 deletions datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ use arrow::{
},
};
use arrow_array::{ArrowNativeTypeOp, Scalar};
use arrow_buffer::Buffer;
use arrow_schema::{UnionFields, UnionMode};

pub use struct_builder::ScalarStructBuilder;

Expand Down Expand Up @@ -275,6 +277,11 @@ pub enum ScalarValue {
DurationMicrosecond(Option<i64>),
/// Duration in nanoseconds
DurationNanosecond(Option<i64>),
/// A nested datatype that can represent slots of differing types. Components:
avantgardnerio marked this conversation as resolved.
Show resolved Hide resolved
/// `.0`: a tuple of union `type_id` and the single value held by this Scalar
/// `.1`: the list of fields, zero-to-one of which will by set in `.0`
/// `.2`: the physical storage of the source/destination UnionArray from which this Scalar came
Union(Option<(i8, Box<ScalarValue>)>, UnionFields, UnionMode),
/// Dictionary type: index type and value
Dictionary(Box<DataType>, Box<ScalarValue>),
}
Expand Down Expand Up @@ -375,6 +382,10 @@ impl PartialEq for ScalarValue {
(IntervalDayTime(_), _) => false,
(IntervalMonthDayNano(v1), IntervalMonthDayNano(v2)) => v1.eq(v2),
(IntervalMonthDayNano(_), _) => false,
(Union(val1, fields1, mode1), Union(val2, fields2, mode2)) => {
val1.eq(val2) && fields1.eq(fields2) && mode1.eq(mode2)
}
(Union(_, _, _), _) => false,
avantgardnerio marked this conversation as resolved.
Show resolved Hide resolved
(Dictionary(k1, v1), Dictionary(k2, v2)) => k1.eq(k2) && v1.eq(v2),
(Dictionary(_, _), _) => false,
(Null, Null) => true,
Expand Down Expand Up @@ -500,6 +511,14 @@ impl PartialOrd for ScalarValue {
(DurationMicrosecond(_), _) => None,
(DurationNanosecond(v1), DurationNanosecond(v2)) => v1.partial_cmp(v2),
(DurationNanosecond(_), _) => None,
(Union(v1, t1, m1), Union(v2, t2, m2)) => {
if t1.eq(t2) && m1.eq(m2) {
v1.partial_cmp(v2)
} else {
None
}
}
(Union(_, _, _), _) => None,
(Dictionary(k1, v1), Dictionary(k2, v2)) => {
// Don't compare if the key types don't match (it is effectively a different datatype)
if k1 == k2 {
Expand Down Expand Up @@ -663,6 +682,11 @@ impl std::hash::Hash for ScalarValue {
IntervalYearMonth(v) => v.hash(state),
IntervalDayTime(v) => v.hash(state),
IntervalMonthDayNano(v) => v.hash(state),
Union(v, t, m) => {
v.hash(state);
t.hash(state);
m.hash(state);
}
Dictionary(k, v) => {
k.hash(state);
v.hash(state);
Expand Down Expand Up @@ -1093,6 +1117,7 @@ impl ScalarValue {
ScalarValue::DurationNanosecond(_) => {
DataType::Duration(TimeUnit::Nanosecond)
}
ScalarValue::Union(_, fields, mode) => DataType::Union(fields.clone(), *mode),
ScalarValue::Dictionary(k, v) => {
DataType::Dictionary(k.clone(), Box::new(v.data_type()))
}
Expand Down Expand Up @@ -1292,6 +1317,7 @@ impl ScalarValue {
ScalarValue::DurationMillisecond(v) => v.is_none(),
ScalarValue::DurationMicrosecond(v) => v.is_none(),
ScalarValue::DurationNanosecond(v) => v.is_none(),
ScalarValue::Union(v, _, _) => v.is_none(),
ScalarValue::Dictionary(_, v) => v.is_null(),
}
}
Expand Down Expand Up @@ -2083,6 +2109,39 @@ impl ScalarValue {
e,
size
),
ScalarValue::Union(value, fields, _mode) => match value {
Some((v_id, value)) => {
let mut field_type_ids = Vec::<i8>::with_capacity(fields.len());
let mut child_arrays =
Vec::<(Field, ArrayRef)>::with_capacity(fields.len());
for (f_id, field) in fields.iter() {
let ar = if f_id == *v_id {
value.to_array_of_size(size)?
} else {
let dt = field.data_type();
new_null_array(dt, size)
};
let field = (**field).clone();
child_arrays.push((field, ar));
field_type_ids.push(f_id);
}
let type_ids = repeat(*v_id).take(size).collect::<Vec<_>>();
let type_ids = Buffer::from_slice_ref(type_ids);
let value_offsets: Option<Buffer> = None;
let ar = UnionArray::try_new(
field_type_ids.as_slice(),
type_ids,
value_offsets,
child_arrays,
)
.map_err(|e| DataFusionError::ArrowError(e, None))?;
Arc::new(ar)
}
None => {
let dt = self.data_type();
new_null_array(&dt, size)
}
},
ScalarValue::Dictionary(key_type, v) => {
// values array is one element long (the value)
match key_type.as_ref() {
Expand Down Expand Up @@ -2622,6 +2681,9 @@ impl ScalarValue {
ScalarValue::DurationNanosecond(val) => {
eq_array_primitive!(array, index, DurationNanosecondArray, val)?
}
ScalarValue::Union(_, _, _) => {
return _not_impl_err!("Union is not supported yet")
}
ScalarValue::Dictionary(key_type, v) => {
let (values_array, values_index) = match key_type.as_ref() {
DataType::Int8 => get_dict_value::<Int8Type>(array, index)?,
Expand Down Expand Up @@ -2699,6 +2761,15 @@ impl ScalarValue {
ScalarValue::LargeList(arr) => arr.get_array_memory_size(),
ScalarValue::FixedSizeList(arr) => arr.get_array_memory_size(),
ScalarValue::Struct(arr) => arr.get_array_memory_size(),
ScalarValue::Union(vals, fields, _mode) => {
vals.as_ref()
.map(|(_id, sv)| sv.size() - std::mem::size_of_val(sv))
.unwrap_or_default()
// `fields` is boxed, so it is NOT already included in `self`
+ std::mem::size_of_val(fields)
+ (std::mem::size_of::<Field>() * fields.len())
+ fields.iter().map(|(_idx, field)| field.size() - std::mem::size_of_val(field)).sum::<usize>()
}
ScalarValue::Dictionary(dt, sv) => {
// `dt` and `sv` are boxed, so they are NOT already included in `self`
dt.size() + sv.size()
Expand Down Expand Up @@ -3044,6 +3115,9 @@ impl TryFrom<&DataType> for ScalarValue {
.to_owned()
.into(),
),
DataType::Union(fields, mode) => {
ScalarValue::Union(None, fields.clone(), *mode)
}
DataType::Null => ScalarValue::Null,
_ => {
return _not_impl_err!(
Expand Down Expand Up @@ -3160,6 +3234,10 @@ impl fmt::Display for ScalarValue {
.join(",")
)?
}
ScalarValue::Union(val, _fields, _mode) => match val {
Some((id, val)) => write!(f, "{}:{}", id, val)?,
None => write!(f, "NULL")?,
},
ScalarValue::Dictionary(_k, v) => write!(f, "{v}")?,
ScalarValue::Null => write!(f, "NULL")?,
};
Expand Down Expand Up @@ -3275,6 +3353,10 @@ impl fmt::Debug for ScalarValue {
ScalarValue::DurationNanosecond(_) => {
write!(f, "DurationNanosecond(\"{self}\")")
}
ScalarValue::Union(val, _fields, _mode) => match val {
Some((id, val)) => write!(f, "Union {}:{}", id, val),
None => write!(f, "Union(NULL)"),
},
ScalarValue::Dictionary(k, v) => write!(f, "Dictionary({k:?}, {v:?})"),
ScalarValue::Null => write!(f, "NULL"),
}
Expand Down
35 changes: 35 additions & 0 deletions datafusion/physical-plan/src/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,9 @@ mod tests {
use crate::test::exec::StatisticsExec;
use crate::ExecutionPlan;

use crate::empty::EmptyExec;
use arrow::datatypes::{DataType, Field, Schema};
use arrow_schema::{UnionFields, UnionMode};
use datafusion_common::{ColumnStatistics, ScalarValue};
use datafusion_expr::Operator;

Expand Down Expand Up @@ -1090,4 +1092,37 @@ mod tests {
assert_eq!(statistics.total_byte_size, Precision::Inexact(1600));
Ok(())
}

#[test]
fn test_equivalence_properties_union_type() -> Result<()> {
let union_type = DataType::Union(
UnionFields::new(
vec![0, 1],
vec![
Field::new("f1", DataType::Int32, true),
Field::new("f2", DataType::Utf8, true),
],
),
UnionMode::Sparse,
);

let schema = Arc::new(Schema::new(vec![
Field::new("c1", DataType::Int32, true),
Field::new("c2", union_type, true),
]));

let exec = FilterExec::try_new(
binary(
binary(col("c1", &schema)?, Operator::GtEq, lit(1i32), &schema)?,
Operator::And,
binary(col("c1", &schema)?, Operator::LtEq, lit(4i32), &schema)?,
&schema,
)?,
Arc::new(EmptyExec::new(schema.clone())),
)?;

exec.statistics().unwrap();

Ok(())
}
}
15 changes: 15 additions & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,20 @@ message IntervalMonthDayNanoValue {
int64 nanos = 3;
}

message UnionField {
int32 field_id = 1;
Field field = 2;
}

message UnionValue {
// Note that a null union value must have one or more fields, so we
// encode a null UnionValue as one with value_id == 128
int32 value_id = 1;
ScalarValue value = 2;
repeated UnionField fields = 3;
UnionMode mode = 4;
}

message ScalarFixedSizeBinary{
bytes values = 1;
int32 length = 2;
Expand Down Expand Up @@ -1042,6 +1056,7 @@ message ScalarValue{
ScalarTime64Value time64_value = 30;
IntervalMonthDayNanoValue interval_month_day_nano = 31;
ScalarFixedSizeBinary fixed_size_binary_value = 34;
UnionValue union_value = 42;
}
}

Expand Down
Loading
Loading