diff --git a/datafusion/substrait/src/consumer.rs b/datafusion/substrait/src/consumer.rs index 80b400d523a5..7293ad9676e9 100644 --- a/datafusion/substrait/src/consumer.rs +++ b/datafusion/substrait/src/consumer.rs @@ -606,44 +606,70 @@ pub async fn from_substrait_rex( ))), } } - Some(RexType::Literal(lit)) => match &lit.literal_type { - Some(LiteralType::I8(n)) => { - Ok(Arc::new(Expr::Literal(ScalarValue::Int8(Some(*n as i8))))) - } - Some(LiteralType::I16(n)) => { - Ok(Arc::new(Expr::Literal(ScalarValue::Int16(Some(*n as i16))))) - } - Some(LiteralType::I32(n)) => { - Ok(Arc::new(Expr::Literal(ScalarValue::Int32(Some(*n))))) - } - Some(LiteralType::I64(n)) => { - Ok(Arc::new(Expr::Literal(ScalarValue::Int64(Some(*n))))) - } - Some(LiteralType::Boolean(b)) => { - Ok(Arc::new(Expr::Literal(ScalarValue::Boolean(Some(*b))))) - } - Some(LiteralType::Date(d)) => { - Ok(Arc::new(Expr::Literal(ScalarValue::Date32(Some(*d))))) - } - Some(LiteralType::Fp32(f)) => { - Ok(Arc::new(Expr::Literal(ScalarValue::Float32(Some(*f))))) - } - Some(LiteralType::Fp64(f)) => { - Ok(Arc::new(Expr::Literal(ScalarValue::Float64(Some(*f))))) - } - Some(LiteralType::String(s)) => { - Ok(Arc::new(Expr::Literal(ScalarValue::Utf8(Some(s.clone()))))) - } - Some(LiteralType::Binary(b)) => Ok(Arc::new(Expr::Literal( - ScalarValue::Binary(Some(b.clone())), - ))), - _ => { - return Err(DataFusionError::NotImplemented(format!( - "Unsupported literal_type: {:?}", - lit.literal_type - ))) + Some(RexType::Literal(lit)) => { + match &lit.literal_type { + Some(LiteralType::I8(n)) => { + Ok(Arc::new(Expr::Literal(ScalarValue::Int8(Some(*n as i8))))) + } + Some(LiteralType::I16(n)) => { + Ok(Arc::new(Expr::Literal(ScalarValue::Int16(Some(*n as i16))))) + } + Some(LiteralType::I32(n)) => { + Ok(Arc::new(Expr::Literal(ScalarValue::Int32(Some(*n))))) + } + Some(LiteralType::I64(n)) => { + Ok(Arc::new(Expr::Literal(ScalarValue::Int64(Some(*n))))) + } + Some(LiteralType::Boolean(b)) => { + Ok(Arc::new(Expr::Literal(ScalarValue::Boolean(Some(*b))))) + } + Some(LiteralType::Date(d)) => { + Ok(Arc::new(Expr::Literal(ScalarValue::Date32(Some(*d))))) + } + Some(LiteralType::Fp32(f)) => { + Ok(Arc::new(Expr::Literal(ScalarValue::Float32(Some(*f))))) + } + Some(LiteralType::Fp64(f)) => { + Ok(Arc::new(Expr::Literal(ScalarValue::Float64(Some(*f))))) + } + Some(LiteralType::Decimal(d)) => { + let value: [u8; 16] = d.value.clone().try_into().or(Err( + DataFusionError::Substrait( + "Failed to parse decimal value".to_string(), + ), + ))?; + let p = d.precision.try_into().map_err(|e| { + DataFusionError::Substrait(format!( + "Failed to parse decimal precision: {}", + e + )) + })?; + let s = d.scale.try_into().map_err(|e| { + DataFusionError::Substrait(format!( + "Failed to parse decimal scale: {}", + e + )) + })?; + Ok(Arc::new(Expr::Literal(ScalarValue::Decimal128( + Some(std::primitive::i128::from_le_bytes(value)), + p, + s, + )))) + } + Some(LiteralType::String(s)) => { + Ok(Arc::new(Expr::Literal(ScalarValue::Utf8(Some(s.clone()))))) + } + Some(LiteralType::Binary(b)) => Ok(Arc::new(Expr::Literal( + ScalarValue::Binary(Some(b.clone())), + ))), + _ => { + return Err(DataFusionError::NotImplemented(format!( + "Unsupported literal_type: {:?}", + lit.literal_type + ))) + } } - }, + } _ => Err(DataFusionError::NotImplemented( "unsupported rex_type".to_string(), )), diff --git a/datafusion/substrait/src/producer.rs b/datafusion/substrait/src/producer.rs index a1748c3ff53e..163abbaa9aa7 100644 --- a/datafusion/substrait/src/producer.rs +++ b/datafusion/substrait/src/producer.rs @@ -35,7 +35,7 @@ use substrait::proto::{ expression::{ field_reference::ReferenceType, if_then::IfClause, - literal::LiteralType, + literal::{Decimal, LiteralType}, mask_expression::{StructItem, StructSelect}, reference_segment, FieldReference, IfThen, Literal, MaskExpression, ReferenceSegment, RexType, ScalarFunction, @@ -579,6 +579,13 @@ pub fn to_substrait_rex( ScalarValue::Boolean(Some(b)) => Some(LiteralType::Boolean(*b)), ScalarValue::Float32(Some(f)) => Some(LiteralType::Fp32(*f)), ScalarValue::Float64(Some(f)) => Some(LiteralType::Fp64(*f)), + ScalarValue::Decimal128(v, p, s) if v.is_some() => { + Some(LiteralType::Decimal(Decimal { + value: v.unwrap().to_le_bytes().to_vec(), + precision: *p as i32, + scale: *s as i32, + })) + } ScalarValue::Utf8(Some(s)) => Some(LiteralType::String(s.clone())), ScalarValue::LargeUtf8(Some(s)) => Some(LiteralType::String(s.clone())), ScalarValue::Binary(Some(b)) => Some(LiteralType::Binary(b.clone())), diff --git a/datafusion/substrait/tests/roundtrip.rs b/datafusion/substrait/tests/roundtrip.rs index a819b2ba5799..141f4eb6b17e 100644 --- a/datafusion/substrait/tests/roundtrip.rs +++ b/datafusion/substrait/tests/roundtrip.rs @@ -22,6 +22,7 @@ use datafusion_substrait::producer; mod tests { use crate::{consumer::from_substrait_plan, producer::to_substrait_plan}; + use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::error::Result; use datafusion::prelude::*; use substrait::proto::extensions::simple_extension_declaration::MappingType; @@ -95,6 +96,11 @@ mod tests { roundtrip("SELECT a, c, avg(b) FROM data GROUP BY a, c").await } + #[tokio::test] + async fn decimal_literal() -> Result<()> { + roundtrip("SELECT * FROM data WHERE b > 2.5").await + } + #[tokio::test] async fn simple_distinct() -> Result<()> { test_alias( @@ -290,9 +296,17 @@ mod tests { async fn create_context() -> Result { let ctx = SessionContext::new(); - ctx.register_csv("data", "tests/testdata/data.csv", CsvReadOptions::new()) + let mut explicit_options = CsvReadOptions::new(); + let schema = Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Decimal128(5, 2), true), + Field::new("c", DataType::Date32, true), + Field::new("d", DataType::Boolean, true), + ]); + explicit_options.schema = Some(&schema); + ctx.register_csv("data", "tests/testdata/data.csv", explicit_options.clone()) .await?; - ctx.register_csv("data2", "tests/testdata/data.csv", CsvReadOptions::new()) + ctx.register_csv("data2", "tests/testdata/data.csv", explicit_options) .await?; Ok(ctx) } diff --git a/datafusion/substrait/tests/testdata/data.csv b/datafusion/substrait/tests/testdata/data.csv index 4394789bcda6..b0fc71024fd6 100644 --- a/datafusion/substrait/tests/testdata/data.csv +++ b/datafusion/substrait/tests/testdata/data.csv @@ -1,3 +1,3 @@ a,b,c,d -1,2,2020-01-01,false -3,4,2020-01-01,true \ No newline at end of file +1,2.0,2020-01-01,false +3,4.5,2020-01-01,true \ No newline at end of file