diff --git a/datafusion/substrait/src/logical_plan/consumer/types.rs b/datafusion/substrait/src/logical_plan/consumer/types.rs index 7bc30e433d86..87bebe8ddd82 100644 --- a/datafusion/substrait/src/logical_plan/consumer/types.rs +++ b/datafusion/substrait/src/logical_plan/consumer/types.rs @@ -21,7 +21,8 @@ use super::SubstraitConsumer; use crate::variation_const::{ DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, - DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, + DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_MAP_TYPE_VARIATION_REF, + DEFAULT_TYPE_VARIATION_REF, DICTIONARY_MAP_TYPE_VARIATION_REF, INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_NAME, INTERVAL_MONTH_DAY_NANO_TYPE_REF, INTERVAL_YEAR_MONTH_TYPE_REF, LARGE_CONTAINER_TYPE_VARIATION_REF, TIMESTAMP_MICRO_TYPE_VARIATION_REF, @@ -180,24 +181,32 @@ pub fn from_substrait_type( let value_type = map.value.as_ref().ok_or_else(|| { substrait_datafusion_err!("Map type must have value type") })?; - let key_field = Arc::new(Field::new( - "key", - from_substrait_type(consumer, key_type, dfs_names, name_idx)?, - false, - )); - let value_field = Arc::new(Field::new( - "value", - from_substrait_type(consumer, value_type, dfs_names, name_idx)?, - true, - )); - Ok(DataType::Map( - Arc::new(Field::new_struct( - "entries", - [key_field, value_field], - false, // The inner map field is always non-nullable (Arrow #1697), + let key_type = + from_substrait_type(consumer, key_type, dfs_names, name_idx)?; + let value_type = + from_substrait_type(consumer, value_type, dfs_names, name_idx)?; + + match map.type_variation_reference { + DEFAULT_MAP_TYPE_VARIATION_REF => { + let key_field = Arc::new(Field::new("key", key_type, false)); + let value_field = Arc::new(Field::new("value", value_type, true)); + Ok(DataType::Map( + Arc::new(Field::new_struct( + "entries", + [key_field, value_field], + false, // The inner map field is always non-nullable (Arrow #1697), + )), + false, // whether keys are sorted + )) + } + DICTIONARY_MAP_TYPE_VARIATION_REF => Ok(DataType::Dictionary( + Box::new(key_type), + Box::new(value_type), )), - false, // whether keys are sorted - )) + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + } } r#type::Kind::Decimal(d) => match d.type_variation_reference { DECIMAL_128_TYPE_VARIATION_REF => { diff --git a/datafusion/substrait/src/logical_plan/producer/types.rs b/datafusion/substrait/src/logical_plan/producer/types.rs index 5762cc76b0c8..0c466dd2233a 100644 --- a/datafusion/substrait/src/logical_plan/producer/types.rs +++ b/datafusion/substrait/src/logical_plan/producer/types.rs @@ -21,7 +21,8 @@ use crate::variation_const::TIMESTAMP_NANO_TYPE_VARIATION_REF; use crate::variation_const::{ DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, - DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, + DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_MAP_TYPE_VARIATION_REF, + DEFAULT_TYPE_VARIATION_REF, DICTIONARY_MAP_TYPE_VARIATION_REF, LARGE_CONTAINER_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, VIEW_CONTAINER_TYPE_VARIATION_REF, }; @@ -235,13 +236,25 @@ pub(crate) fn to_substrait_type( kind: Some(r#type::Kind::Map(Box::new(r#type::Map { key: Some(Box::new(key_type)), value: Some(Box::new(value_type)), - type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, + type_variation_reference: DEFAULT_MAP_TYPE_VARIATION_REF, nullability, }))), }) } _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), }, + DataType::Dictionary(key_type, value_type) => { + let key_type = to_substrait_type(key_type, nullable)?; + let value_type = to_substrait_type(value_type, nullable)?; + Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Map(Box::new(r#type::Map { + key: Some(Box::new(key_type)), + value: Some(Box::new(value_type)), + type_variation_reference: DICTIONARY_MAP_TYPE_VARIATION_REF, + nullability, + }))), + }) + } DataType::Struct(fields) => { let field_types = fields .iter() @@ -271,8 +284,6 @@ pub(crate) fn to_substrait_type( precision: *p as i32, })), }), - // TODO: DataDog-specific workaround, don't commit upstream - DataType::Dictionary(_, dt) => to_substrait_type(dt, nullable), _ => not_impl_err!("Unsupported cast type: {dt:?}"), } } @@ -365,6 +376,10 @@ mod tests { .into(), false, ))?; + round_trip_type(DataType::Dictionary( + Box::new(DataType::Utf8), + Box::new(DataType::Int32), + ))?; round_trip_type(DataType::Struct( vec![ diff --git a/datafusion/substrait/src/variation_const.rs b/datafusion/substrait/src/variation_const.rs index e5bebf8e1181..49ea918980f7 100644 --- a/datafusion/substrait/src/variation_const.rs +++ b/datafusion/substrait/src/variation_const.rs @@ -53,6 +53,8 @@ pub const DATE_64_TYPE_VARIATION_REF: u32 = 1; pub const DEFAULT_CONTAINER_TYPE_VARIATION_REF: u32 = 0; pub const LARGE_CONTAINER_TYPE_VARIATION_REF: u32 = 1; pub const VIEW_CONTAINER_TYPE_VARIATION_REF: u32 = 2; +pub const DEFAULT_MAP_TYPE_VARIATION_REF: u32 = 0; +pub const DICTIONARY_MAP_TYPE_VARIATION_REF: u32 = 1; pub const DECIMAL_128_TYPE_VARIATION_REF: u32 = 0; pub const DECIMAL_256_TYPE_VARIATION_REF: u32 = 1;