diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index cad054392308..b246830b6861 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -537,7 +537,7 @@ impl TryInto for &protobuf::scalar_value::Value } protobuf::scalar_value::Value::ListValue(v) => v.try_into()?, protobuf::scalar_value::Value::NullListValue(v) => { - ScalarValue::List(None, v.try_into()?) + ScalarValue::List(None, Box::new(v.try_into()?)) } protobuf::scalar_value::Value::NullValue(null_enum) => { PrimitiveScalarType::from_i32(*null_enum) @@ -581,8 +581,8 @@ impl TryInto for &protobuf::ScalarListValue { }) .collect::, _>>()?; datafusion::scalar::ScalarValue::List( - Some(typechecked_values), - leaf_scalar_type.into(), + Some(Box::new(typechecked_values)), + Box::new(leaf_scalar_type.into()), ) } Datatype::List(list_type) => { @@ -626,9 +626,9 @@ impl TryInto for &protobuf::ScalarListValue { datafusion::scalar::ScalarValue::List( match typechecked_values.len() { 0 => None, - _ => Some(typechecked_values), + _ => Some(Box::new(typechecked_values)), }, - list_type.try_into()?, + Box::new(list_type.try_into()?), ) } }; @@ -766,14 +766,16 @@ impl TryInto for &protobuf::ScalarValue { .map(|val| val.try_into()) .collect::, _>>()?; let scalar_type: DataType = pb_scalar_type.try_into()?; - ScalarValue::List(Some(typechecked_values), scalar_type) + let scalar_type = Box::new(scalar_type); + ScalarValue::List(Some(Box::new(typechecked_values)), scalar_type) } protobuf::scalar_value::Value::NullListValue(v) => { let pb_datatype = v .datatype .as_ref() .ok_or_else(|| proto_error("Protobuf deserialization error: NullListValue message missing required field 'datatyp'"))?; - ScalarValue::List(None, pb_datatype.try_into()?) + let pb_datatype = Box::new(pb_datatype.try_into()?); + ScalarValue::List(None, pb_datatype) } protobuf::scalar_value::Value::NullValue(v) => { let null_type_enum = protobuf::PrimitiveScalarType::from_i32(*v) diff --git a/ballista/rust/core/src/serde/logical_plan/mod.rs b/ballista/rust/core/src/serde/logical_plan/mod.rs index 0d27c58ac292..f4ff1dee9bdf 100644 --- a/ballista/rust/core/src/serde/logical_plan/mod.rs +++ b/ballista/rust/core/src/serde/logical_plan/mod.rs @@ -126,49 +126,57 @@ mod roundtrip_tests { let should_fail_on_seralize: Vec = vec![ //Should fail due to inconsistent types ScalarValue::List( - Some(vec![ + Some(Box::new(vec![ ScalarValue::Int16(None), ScalarValue::Float32(Some(32.0)), - ]), - DataType::List(new_box_field("item", DataType::Int16, true)), + ])), + Box::new(DataType::List(new_box_field("item", DataType::Int16, true))), ), ScalarValue::List( - Some(vec![ + Some(Box::new(vec![ ScalarValue::Float32(None), ScalarValue::Float32(Some(32.0)), - ]), - DataType::List(new_box_field("item", DataType::Int16, true)), + ])), + Box::new(DataType::List(new_box_field("item", DataType::Int16, true))), ), ScalarValue::List( - Some(vec![ + Some(Box::new(vec![ ScalarValue::List( None, - DataType::List(new_box_field("level2", DataType::Float32, true)), + Box::new(DataType::List(new_box_field( + "level2", + DataType::Float32, + true, + ))), ), ScalarValue::List( - Some(vec![ + Some(Box::new(vec![ ScalarValue::Float32(Some(-213.1)), ScalarValue::Float32(None), ScalarValue::Float32(Some(5.5)), ScalarValue::Float32(Some(2.0)), ScalarValue::Float32(Some(1.0)), - ]), - DataType::List(new_box_field("level2", DataType::Float32, true)), + ])), + Box::new(DataType::List(new_box_field( + "level2", + DataType::Float32, + true, + ))), ), ScalarValue::List( None, - DataType::List(new_box_field( + Box::new(DataType::List(new_box_field( "lists are typed inconsistently", DataType::Int16, true, - )), + ))), ), - ]), - DataType::List(new_box_field( + ])), + Box::new(DataType::List(new_box_field( "level1", DataType::List(new_box_field("level2", DataType::Float32, true)), true, - )), + ))), ), ]; @@ -200,7 +208,7 @@ mod roundtrip_tests { ScalarValue::UInt64(None), ScalarValue::Utf8(None), ScalarValue::LargeUtf8(None), - ScalarValue::List(None, DataType::Boolean), + ScalarValue::List(None, Box::new(DataType::Boolean)), ScalarValue::Date32(None), ScalarValue::TimestampMicrosecond(None), ScalarValue::TimestampNanosecond(None), @@ -248,37 +256,49 @@ mod roundtrip_tests { ScalarValue::TimestampMicrosecond(Some(i64::MAX)), ScalarValue::TimestampMicrosecond(None), ScalarValue::List( - Some(vec![ + Some(Box::new(vec![ ScalarValue::Float32(Some(-213.1)), ScalarValue::Float32(None), ScalarValue::Float32(Some(5.5)), ScalarValue::Float32(Some(2.0)), ScalarValue::Float32(Some(1.0)), - ]), - DataType::List(new_box_field("level1", DataType::Float32, true)), + ])), + Box::new(DataType::List(new_box_field( + "level1", + DataType::Float32, + true, + ))), ), ScalarValue::List( - Some(vec![ + Some(Box::new(vec![ ScalarValue::List( None, - DataType::List(new_box_field("level2", DataType::Float32, true)), + Box::new(DataType::List(new_box_field( + "level2", + DataType::Float32, + true, + ))), ), ScalarValue::List( - Some(vec![ + Some(Box::new(vec![ ScalarValue::Float32(Some(-213.1)), ScalarValue::Float32(None), ScalarValue::Float32(Some(5.5)), ScalarValue::Float32(Some(2.0)), ScalarValue::Float32(Some(1.0)), - ]), - DataType::List(new_box_field("level2", DataType::Float32, true)), + ])), + Box::new(DataType::List(new_box_field( + "level2", + DataType::Float32, + true, + ))), ), - ]), - DataType::List(new_box_field( + ])), + Box::new(DataType::List(new_box_field( "level1", DataType::List(new_box_field("level2", DataType::Float32, true)), true, - )), + ))), ), ]; diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index 07d7a59c114c..87f26a118e78 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -565,13 +565,13 @@ impl TryFrom<&datafusion::scalar::ScalarValue> for protobuf::ScalarValue { protobuf::ScalarValue { value: Some(protobuf::scalar_value::Value::ListValue( protobuf::ScalarListValue { - datatype: Some(datatype.try_into()?), + datatype: Some(datatype.as_ref().try_into()?), values: Vec::new(), }, )), } } else { - let scalar_type = match datatype { + let scalar_type = match datatype.as_ref() { DataType::List(field) => field.as_ref().data_type(), _ => todo!("Proper error handling"), }; @@ -579,16 +579,23 @@ impl TryFrom<&datafusion::scalar::ScalarValue> for protobuf::ScalarValue { let type_checked_values: Vec = values .iter() .map(|scalar| match (scalar, scalar_type) { - (scalar::ScalarValue::List(_, DataType::List(list_field)), DataType::List(field)) => { - let scalar_datatype = field.data_type(); - let list_datatype = list_field.data_type(); - if std::mem::discriminant(list_datatype) != std::mem::discriminant(scalar_datatype) { - return Err(proto_error(format!( - "Protobuf serialization error: Lists with inconsistent typing {:?} and {:?} found within list", - list_datatype, scalar_datatype - ))); + (scalar::ScalarValue::List(_, list_type), DataType::List(field)) => { + if let DataType::List(list_field) = list_type.as_ref() { + let scalar_datatype = field.data_type(); + let list_datatype = list_field.data_type(); + if std::mem::discriminant(list_datatype) != std::mem::discriminant(scalar_datatype) { + return Err(proto_error(format!( + "Protobuf serialization error: Lists with inconsistent typing {:?} and {:?} found within list", + list_datatype, scalar_datatype + ))); + } + scalar.try_into() + } else { + Err(proto_error(format!( + "Protobuf serialization error, {:?} was inconsistent with designated type {:?}", + scalar, datatype + ))) } - scalar.try_into() } (scalar::ScalarValue::Boolean(_), DataType::Boolean) => scalar.try_into(), (scalar::ScalarValue::Float32(_), DataType::Float32) => scalar.try_into(), @@ -612,7 +619,7 @@ impl TryFrom<&datafusion::scalar::ScalarValue> for protobuf::ScalarValue { protobuf::ScalarValue { value: Some(protobuf::scalar_value::Value::ListValue( protobuf::ScalarListValue { - datatype: Some(datatype.try_into()?), + datatype: Some(datatype.as_ref().try_into()?), values: type_checked_values, }, )), @@ -621,7 +628,7 @@ impl TryFrom<&datafusion::scalar::ScalarValue> for protobuf::ScalarValue { } None => protobuf::ScalarValue { value: Some(protobuf::scalar_value::Value::NullListValue( - datatype.try_into()?, + datatype.as_ref().try_into()?, )), }, } diff --git a/datafusion/src/physical_plan/distinct_expressions.rs b/datafusion/src/physical_plan/distinct_expressions.rs index f3513c2950e4..90c0836f7077 100644 --- a/datafusion/src/physical_plan/distinct_expressions.rs +++ b/datafusion/src/physical_plan/distinct_expressions.rs @@ -178,7 +178,9 @@ impl Accumulator for DistinctCountAccumulator { .state_data_types .iter() .map(|state_data_type| { - ScalarValue::List(Some(Vec::new()), state_data_type.clone()) + let values = Box::new(Vec::new()); + let data_type = Box::new(state_data_type.clone()); + ScalarValue::List(Some(values), data_type) }) .collect::>(); @@ -254,8 +256,8 @@ mod tests { macro_rules! state_to_vec { ($LIST:expr, $DATA_TYPE:ident, $PRIM_TY:ty) => {{ match $LIST { - ScalarValue::List(_, data_type) => match data_type { - DataType::$DATA_TYPE => (), + ScalarValue::List(_, data_type) => match data_type.as_ref() { + &DataType::$DATA_TYPE => (), _ => panic!("Unexpected DataType for list"), }, _ => panic!("Expected a ScalarValue::List"), diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index ab0836424242..129b4166a4e8 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -65,8 +65,9 @@ pub enum ScalarValue { Binary(Option>), /// large binary LargeBinary(Option>), - /// list of nested ScalarValue - List(Option>, DataType), + /// list of nested ScalarValue (boxed to reduce size_of(ScalarValue)) + #[allow(clippy::box_vec)] + List(Option>>, Box), /// Date stored as a signed 32bit int Date32(Option), /// Date stored as a signed 64bit int @@ -110,7 +111,7 @@ macro_rules! build_list { ) } Some(values) => { - build_values_list!($VALUE_BUILDER_TY, $SCALAR_TY, values, $SIZE) + build_values_list!($VALUE_BUILDER_TY, $SCALAR_TY, values.as_ref(), $SIZE) } } }}; @@ -130,32 +131,35 @@ macro_rules! build_timestamp_list { $SIZE, ) } - Some(values) => match $TIME_UNIT { - TimeUnit::Second => build_values_list!( - TimestampSecondBuilder, - TimestampSecond, - values, - $SIZE - ), - TimeUnit::Microsecond => build_values_list!( - TimestampMillisecondBuilder, - TimestampMillisecond, - values, - $SIZE - ), - TimeUnit::Millisecond => build_values_list!( - TimestampMicrosecondBuilder, - TimestampMicrosecond, - values, - $SIZE - ), - TimeUnit::Nanosecond => build_values_list!( - TimestampNanosecondBuilder, - TimestampNanosecond, - values, - $SIZE - ), - }, + Some(values) => { + let values = values.as_ref(); + match $TIME_UNIT { + TimeUnit::Second => build_values_list!( + TimestampSecondBuilder, + TimestampSecond, + values, + $SIZE + ), + TimeUnit::Microsecond => build_values_list!( + TimestampMillisecondBuilder, + TimestampMillisecond, + values, + $SIZE + ), + TimeUnit::Millisecond => build_values_list!( + TimestampMicrosecondBuilder, + TimestampMicrosecond, + values, + $SIZE + ), + TimeUnit::Nanosecond => build_values_list!( + TimestampNanosecondBuilder, + TimestampNanosecond, + values, + $SIZE + ), + } + } } }}; } @@ -235,9 +239,11 @@ impl ScalarValue { ScalarValue::LargeUtf8(_) => DataType::LargeUtf8, ScalarValue::Binary(_) => DataType::Binary, ScalarValue::LargeBinary(_) => DataType::LargeBinary, - ScalarValue::List(_, data_type) => { - DataType::List(Box::new(Field::new("item", data_type.clone(), true))) - } + ScalarValue::List(_, data_type) => DataType::List(Box::new(Field::new( + "item", + data_type.as_ref().clone(), + true, + ))), ScalarValue::Date32(_) => DataType::Date32, ScalarValue::Date64(_) => DataType::Date64, ScalarValue::IntervalYearMonth(_) => { @@ -415,6 +421,7 @@ impl ScalarValue { for scalar in scalars.into_iter() { match scalar { ScalarValue::List(Some(xs), _) => { + let xs = *xs; for s in xs { match s { ScalarValue::$SCALAR_TY(Some(val)) => { @@ -627,7 +634,7 @@ impl ScalarValue { .collect::(), ), }, - ScalarValue::List(values, data_type) => Arc::new(match data_type { + ScalarValue::List(values, data_type) => Arc::new(match data_type.as_ref() { DataType::Boolean => build_list!(BooleanBuilder, Boolean, values, size), DataType::Int8 => build_list!(Int8Builder, Int8, values, size), DataType::Int16 => build_list!(Int16Builder, Int16, values, size), @@ -643,7 +650,7 @@ impl ScalarValue { DataType::Timestamp(unit, tz) => { build_timestamp_list!(unit.clone(), tz.clone(), values, size) } - DataType::LargeUtf8 => { + &DataType::LargeUtf8 => { build_list!(LargeStringBuilder, LargeUtf8, values, size) } dt => panic!("Unexpected DataType for list {:?}", dt), @@ -705,7 +712,9 @@ impl ScalarValue { Some(scalar_vec) } }; - ScalarValue::List(value, nested_type.data_type().clone()) + let value = value.map(Box::new); + let data_type = Box::new(nested_type.data_type().clone()); + ScalarValue::List(value, data_type) } DataType::Date32 => { typed_cast!(array, index, Date32Array, Date32) @@ -965,7 +974,7 @@ impl TryFrom<&DataType> for ScalarValue { ScalarValue::TimestampNanosecond(None) } DataType::List(ref nested_type) => { - ScalarValue::List(None, nested_type.data_type().clone()) + ScalarValue::List(None, Box::new(nested_type.data_type().clone())) } _ => { return Err(DataFusionError::NotImplemented(format!( @@ -1167,7 +1176,8 @@ mod tests { #[test] fn scalar_list_null_to_array() { - let list_array_ref = ScalarValue::List(None, DataType::UInt64).to_array(); + let list_array_ref = + ScalarValue::List(None, Box::new(DataType::UInt64)).to_array(); let list_array = list_array_ref.as_any().downcast_ref::().unwrap(); assert!(list_array.is_null(0)); @@ -1178,12 +1188,12 @@ mod tests { #[test] fn scalar_list_to_array() { let list_array_ref = ScalarValue::List( - Some(vec![ + Some(Box::new(vec![ ScalarValue::UInt64(Some(100)), ScalarValue::UInt64(None), ScalarValue::UInt64(Some(101)), - ]), - DataType::UInt64, + ])), + Box::new(DataType::UInt64), ) .to_array(); @@ -1336,4 +1346,12 @@ mod tests { assert!(result.to_string().contains("Inconsistent types in ScalarValue::iter_to_array. Expected Boolean, got Int32(5)"), "{}", result); } + + #[test] + fn size_of_scalar() { + // Since ScalarValues are used in a non trivial number of places, + // making it larger means significant more memory consumption + // per distinct value. + assert_eq!(std::mem::size_of::(), 32); + } }