Skip to content

Commit

Permalink
Handle (partially) dictionary values in ScalarValue serde (#243)
Browse files Browse the repository at this point in the history
  • Loading branch information
thinkharderdev authored May 17, 2024
1 parent e35bb28 commit 98d2c6e
Show file tree
Hide file tree
Showing 7 changed files with 258 additions and 3 deletions.
6 changes: 6 additions & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -948,9 +948,15 @@ message Union{

// Used for List/FixedSizeList/LargeList/Struct
message ScalarNestedValue {
message Dictionary {
bytes ipc_message = 1;
bytes arrow_data = 2;
}

bytes ipc_message = 1;
bytes arrow_data = 2;
Schema schema = 3;
repeated Dictionary dictionaries = 4;
}

message ScalarTime32Value {
Expand Down
133 changes: 133 additions & 0 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

49 changes: 48 additions & 1 deletion datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use std::collections::HashMap;
use std::sync::Arc;

use crate::protobuf::{
Expand All @@ -29,6 +30,7 @@ use crate::protobuf::{
OptimizedPhysicalPlanType, PlaceholderNode, RollupNode,
};

use arrow::array::ArrayRef;
use arrow::{
array::AsArray,
buffer::Buffer,
Expand Down Expand Up @@ -587,6 +589,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
let protobuf::ScalarNestedValue {
ipc_message,
arrow_data,
dictionaries,
schema,
} = &v;

Expand All @@ -613,11 +616,55 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
)
})?;

let dict_by_id: HashMap<i64,ArrayRef> = dictionaries.iter().map(|protobuf::scalar_nested_value::Dictionary { ipc_message, arrow_data }| {
let message = root_as_message(ipc_message.as_slice()).map_err(|e| {
Error::General(format!(
"Error IPC message while deserializing ScalarValue::List dictionary message: {e}"
))
})?;
let buffer = Buffer::from(arrow_data);

let dict_batch = message.header_as_dictionary_batch().ok_or_else(|| {
Error::General(
"Unexpected message type deserializing ScalarValue::List dictionary message"
.to_string(),
)
})?;

let id = dict_batch.id();

let fields_using_this_dictionary = schema.fields_with_dict_id(id);
let first_field = fields_using_this_dictionary.first().ok_or_else(|| {
Error::General("dictionary id not found in schema while deserializing ScalarValue::List".to_string())
})?;

let values: ArrayRef = match first_field.data_type() {
DataType::Dictionary(_, ref value_type) => {
// Make a fake schema for the dictionary batch.
let value = value_type.as_ref().clone();
let schema = Schema::new(vec![Field::new("", value, true)]);
// Read a single column
let record_batch = read_record_batch(
&buffer,
dict_batch.data().unwrap(),
Arc::new(schema),
&Default::default(),
None,
&message.version(),
)?;
Ok(record_batch.column(0).clone())
}
_ => Err(Error::General("dictionary id not found in schema while deserializing ScalarValue::List".to_string())),
}?;

Ok((id,values))
}).collect::<Result<HashMap<_,_>>>()?;

let record_batch = read_record_batch(
&buffer,
ipc_batch,
Arc::new(schema),
&Default::default(),
&dict_by_id,
None,
&message.version(),
)
Expand Down
9 changes: 8 additions & 1 deletion datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1604,7 +1604,7 @@ fn encode_scalar_nested_value(

let gen = IpcDataGenerator {};
let mut dict_tracker = DictionaryTracker::new(false);
let (_, encoded_message) = gen
let (encoded_dictionaries, encoded_message) = gen
.encoded_batch(&batch, &mut dict_tracker, &Default::default())
.map_err(|e| {
Error::General(format!("Error encoding ScalarValue::List as IPC: {e}"))
Expand All @@ -1615,6 +1615,13 @@ fn encode_scalar_nested_value(
let scalar_list_value = protobuf::ScalarNestedValue {
ipc_message: encoded_message.ipc_message,
arrow_data: encoded_message.arrow_data,
dictionaries: encoded_dictionaries
.into_iter()
.map(|data| protobuf::scalar_nested_value::Dictionary {
ipc_message: data.ipc_message,
arrow_data: data.arrow_data,
})
.collect(),
schema: Some(schema),
};

Expand Down
49 changes: 49 additions & 0 deletions datafusion/proto/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1092,11 +1092,60 @@ fn round_trip_scalar_values() {
)
.build()
.unwrap(),
ScalarStructBuilder::new()
.with_scalar(
Field::new("a", DataType::Int32, true),
ScalarValue::from(23i32),
)
.with_scalar(
Field::new("b", DataType::Boolean, false),
ScalarValue::from(false),
)
.with_scalar(
Field::new(
"c",
DataType::Dictionary(
Box::new(DataType::UInt16),
Box::new(DataType::Utf8),
),
false,
),
ScalarValue::Dictionary(
Box::new(DataType::UInt16),
Box::new("value".into()),
),
)
.build()
.unwrap(),
ScalarValue::try_from(&DataType::Struct(Fields::from(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Boolean, false),
])))
.unwrap(),
ScalarValue::try_from(&DataType::Struct(Fields::from(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Boolean, false),
Field::new(
"c",
DataType::Dictionary(
Box::new(DataType::UInt16),
Box::new(DataType::Binary),
),
false,
),
Field::new(
"d",
DataType::new_list(
DataType::Dictionary(
Box::new(DataType::UInt16),
Box::new(DataType::Binary),
),
false,
),
false,
),
])))
.unwrap(),
ScalarValue::FixedSizeBinary(b"bar".to_vec().len() as i32, Some(b"bar".to_vec())),
ScalarValue::FixedSizeBinary(0, None),
ScalarValue::FixedSizeBinary(5, None),
Expand Down
Loading

0 comments on commit 98d2c6e

Please sign in to comment.