Skip to content

Commit ca5dec0

Browse files
committed
Add support for Arrow Dictionary type in Substrait
This commit adds support for the Arrow Dictionary type in Substrait plans. Resolves #16273
1 parent aadb79b commit ca5dec0

File tree

3 files changed

+47
-19
lines changed

3 files changed

+47
-19
lines changed

datafusion/substrait/src/logical_plan/consumer/types.rs

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use crate::variation_const::{
2222
DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF,
2323
DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF,
2424
DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_INTERVAL_DAY_TYPE_VARIATION_REF,
25-
DEFAULT_TYPE_VARIATION_REF, DURATION_INTERVAL_DAY_TYPE_VARIATION_REF,
25+
DEFAULT_TYPE_VARIATION_REF,DICTIONARY_CONTAINER_TYPE_VARIATION_REF, DURATION_INTERVAL_DAY_TYPE_VARIATION_REF,
2626
INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_NAME,
2727
INTERVAL_MONTH_DAY_NANO_TYPE_REF, INTERVAL_YEAR_MONTH_TYPE_REF,
2828
LARGE_CONTAINER_TYPE_VARIATION_REF, TIMESTAMP_MICRO_TYPE_VARIATION_REF,
@@ -177,24 +177,32 @@ pub fn from_substrait_type(
177177
let value_type = map.value.as_ref().ok_or_else(|| {
178178
substrait_datafusion_err!("Map type must have value type")
179179
})?;
180-
let key_field = Arc::new(Field::new(
181-
"key",
182-
from_substrait_type(consumer, key_type, dfs_names, name_idx)?,
183-
false,
184-
));
185-
let value_field = Arc::new(Field::new(
186-
"value",
187-
from_substrait_type(consumer, value_type, dfs_names, name_idx)?,
188-
true,
189-
));
190-
Ok(DataType::Map(
191-
Arc::new(Field::new_struct(
192-
"entries",
193-
[key_field, value_field],
194-
false, // The inner map field is always non-nullable (Arrow #1697),
180+
let key_type =
181+
from_substrait_type(consumer, key_type, dfs_names, name_idx)?;
182+
let value_type =
183+
from_substrait_type(consumer, value_type, dfs_names, name_idx)?;
184+
185+
match map.type_variation_reference {
186+
DEFAULT_CONTAINER_TYPE_VARIATION_REF => {
187+
let key_field = Arc::new(Field::new("key", key_type, false));
188+
let value_field = Arc::new(Field::new("value", value_type, true));
189+
Ok(DataType::Map(
190+
Arc::new(Field::new_struct(
191+
"entries",
192+
[key_field, value_field],
193+
false, // The inner map field is always non-nullable (Arrow #1697),
194+
)),
195+
false, // whether keys are sorted
196+
))
197+
}
198+
DICTIONARY_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::Dictionary(
199+
Box::new(key_type),
200+
Box::new(value_type),
195201
)),
196-
false, // whether keys are sorted
197-
))
202+
v => not_impl_err!(
203+
"Unsupported Substrait type variation {v} of type {s_kind:?}"
204+
),
205+
}
198206
}
199207
r#type::Kind::Decimal(d) => match d.type_variation_reference {
200208
DECIMAL_128_TYPE_VARIATION_REF => {

datafusion/substrait/src/logical_plan/producer/types.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use crate::variation_const::{
2121
DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF,
2222
DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF,
2323
DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_INTERVAL_DAY_TYPE_VARIATION_REF,
24-
DEFAULT_TYPE_VARIATION_REF, DURATION_INTERVAL_DAY_TYPE_VARIATION_REF,
24+
DEFAULT_TYPE_VARIATION_REF,DICTIONARY_CONTAINER_TYPE_VARIATION_REF, DURATION_INTERVAL_DAY_TYPE_VARIATION_REF,
2525
LARGE_CONTAINER_TYPE_VARIATION_REF, TIME_32_TYPE_VARIATION_REF,
2626
TIME_64_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF,
2727
VIEW_CONTAINER_TYPE_VARIATION_REF,
@@ -283,6 +283,18 @@ pub(crate) fn to_substrait_type(
283283
}
284284
_ => plan_err!("Map fields must contain a Struct with exactly 2 fields"),
285285
},
286+
DataType::Dictionary(key_type, value_type) => {
287+
let key_type = to_substrait_type(key_type, nullable)?;
288+
let value_type = to_substrait_type(value_type, nullable)?;
289+
Ok(substrait::proto::Type {
290+
kind: Some(r#type::Kind::Map(Box::new(r#type::Map {
291+
key: Some(Box::new(key_type)),
292+
value: Some(Box::new(value_type)),
293+
type_variation_reference: DICTIONARY_CONTAINER_TYPE_VARIATION_REF,
294+
nullability,
295+
}))),
296+
})
297+
}
286298
DataType::Struct(fields) => {
287299
let field_types = fields
288300
.iter()
@@ -407,6 +419,10 @@ mod tests {
407419
.into(),
408420
false,
409421
))?;
422+
round_trip_type(DataType::Dictionary(
423+
Box::new(DataType::Utf8),
424+
Box::new(DataType::Int32),
425+
))?;
410426

411427
round_trip_type(DataType::Struct(
412428
vec![

datafusion/substrait/src/variation_const.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ pub const TIME_64_TYPE_VARIATION_REF: u32 = 1;
5555
pub const DEFAULT_CONTAINER_TYPE_VARIATION_REF: u32 = 0;
5656
pub const LARGE_CONTAINER_TYPE_VARIATION_REF: u32 = 1;
5757
pub const VIEW_CONTAINER_TYPE_VARIATION_REF: u32 = 2;
58+
/// Used for the arrow type [`DataType::Map`].
59+
///
60+
/// [`DataType::Map`]: datafusion::arrow::datatypes::DataType::Map
61+
pub const DICTIONARY_CONTAINER_TYPE_VARIATION_REF: u32 = 3;
5862
pub const DECIMAL_128_TYPE_VARIATION_REF: u32 = 0;
5963
pub const DECIMAL_256_TYPE_VARIATION_REF: u32 = 1;
6064
/// Used for the arrow type [`DataType::Interval`] with [`IntervalUnit::DayTime`].

0 commit comments

Comments
 (0)