Skip to content

Commit

Permalink
Add decimal support to substrait serde (#5054)
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Jan 26, 2023
1 parent 14e153e commit 552eea7
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 42 deletions.
100 changes: 63 additions & 37 deletions datafusion/substrait/src/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -606,44 +606,70 @@ pub async fn from_substrait_rex(
))),
}
}
Some(RexType::Literal(lit)) => match &lit.literal_type {
Some(LiteralType::I8(n)) => {
Ok(Arc::new(Expr::Literal(ScalarValue::Int8(Some(*n as i8)))))
}
Some(LiteralType::I16(n)) => {
Ok(Arc::new(Expr::Literal(ScalarValue::Int16(Some(*n as i16)))))
}
Some(LiteralType::I32(n)) => {
Ok(Arc::new(Expr::Literal(ScalarValue::Int32(Some(*n)))))
}
Some(LiteralType::I64(n)) => {
Ok(Arc::new(Expr::Literal(ScalarValue::Int64(Some(*n)))))
}
Some(LiteralType::Boolean(b)) => {
Ok(Arc::new(Expr::Literal(ScalarValue::Boolean(Some(*b)))))
}
Some(LiteralType::Date(d)) => {
Ok(Arc::new(Expr::Literal(ScalarValue::Date32(Some(*d)))))
}
Some(LiteralType::Fp32(f)) => {
Ok(Arc::new(Expr::Literal(ScalarValue::Float32(Some(*f)))))
}
Some(LiteralType::Fp64(f)) => {
Ok(Arc::new(Expr::Literal(ScalarValue::Float64(Some(*f)))))
}
Some(LiteralType::String(s)) => {
Ok(Arc::new(Expr::Literal(ScalarValue::Utf8(Some(s.clone())))))
}
Some(LiteralType::Binary(b)) => Ok(Arc::new(Expr::Literal(
ScalarValue::Binary(Some(b.clone())),
))),
_ => {
return Err(DataFusionError::NotImplemented(format!(
"Unsupported literal_type: {:?}",
lit.literal_type
)))
Some(RexType::Literal(lit)) => {
match &lit.literal_type {
Some(LiteralType::I8(n)) => {
Ok(Arc::new(Expr::Literal(ScalarValue::Int8(Some(*n as i8)))))
}
Some(LiteralType::I16(n)) => {
Ok(Arc::new(Expr::Literal(ScalarValue::Int16(Some(*n as i16)))))
}
Some(LiteralType::I32(n)) => {
Ok(Arc::new(Expr::Literal(ScalarValue::Int32(Some(*n)))))
}
Some(LiteralType::I64(n)) => {
Ok(Arc::new(Expr::Literal(ScalarValue::Int64(Some(*n)))))
}
Some(LiteralType::Boolean(b)) => {
Ok(Arc::new(Expr::Literal(ScalarValue::Boolean(Some(*b)))))
}
Some(LiteralType::Date(d)) => {
Ok(Arc::new(Expr::Literal(ScalarValue::Date32(Some(*d)))))
}
Some(LiteralType::Fp32(f)) => {
Ok(Arc::new(Expr::Literal(ScalarValue::Float32(Some(*f)))))
}
Some(LiteralType::Fp64(f)) => {
Ok(Arc::new(Expr::Literal(ScalarValue::Float64(Some(*f)))))
}
Some(LiteralType::Decimal(d)) => {
let value: [u8; 16] = d.value.clone().try_into().or(Err(
DataFusionError::Substrait(
"Failed to parse decimal value".to_string(),
),
))?;
let p = d.precision.try_into().map_err(|e| {
DataFusionError::Substrait(format!(
"Failed to parse decimal precision: {}",
e
))
})?;
let s = d.scale.try_into().map_err(|e| {
DataFusionError::Substrait(format!(
"Failed to parse decimal scale: {}",
e
))
})?;
Ok(Arc::new(Expr::Literal(ScalarValue::Decimal128(
Some(std::primitive::i128::from_le_bytes(value)),
p,
s,
))))
}
Some(LiteralType::String(s)) => {
Ok(Arc::new(Expr::Literal(ScalarValue::Utf8(Some(s.clone())))))
}
Some(LiteralType::Binary(b)) => Ok(Arc::new(Expr::Literal(
ScalarValue::Binary(Some(b.clone())),
))),
_ => {
return Err(DataFusionError::NotImplemented(format!(
"Unsupported literal_type: {:?}",
lit.literal_type
)))
}
}
},
}
_ => Err(DataFusionError::NotImplemented(
"unsupported rex_type".to_string(),
)),
Expand Down
9 changes: 8 additions & 1 deletion datafusion/substrait/src/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use substrait::proto::{
expression::{
field_reference::ReferenceType,
if_then::IfClause,
literal::LiteralType,
literal::{Decimal, LiteralType},
mask_expression::{StructItem, StructSelect},
reference_segment, FieldReference, IfThen, Literal, MaskExpression,
ReferenceSegment, RexType, ScalarFunction,
Expand Down Expand Up @@ -579,6 +579,13 @@ pub fn to_substrait_rex(
ScalarValue::Boolean(Some(b)) => Some(LiteralType::Boolean(*b)),
ScalarValue::Float32(Some(f)) => Some(LiteralType::Fp32(*f)),
ScalarValue::Float64(Some(f)) => Some(LiteralType::Fp64(*f)),
ScalarValue::Decimal128(v, p, s) if v.is_some() => {
Some(LiteralType::Decimal(Decimal {
value: v.unwrap().to_le_bytes().to_vec(),
precision: *p as i32,
scale: *s as i32,
}))
}
ScalarValue::Utf8(Some(s)) => Some(LiteralType::String(s.clone())),
ScalarValue::LargeUtf8(Some(s)) => Some(LiteralType::String(s.clone())),
ScalarValue::Binary(Some(b)) => Some(LiteralType::Binary(b.clone())),
Expand Down
18 changes: 16 additions & 2 deletions datafusion/substrait/tests/roundtrip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use datafusion_substrait::producer;
mod tests {

use crate::{consumer::from_substrait_plan, producer::to_substrait_plan};
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::error::Result;
use datafusion::prelude::*;
use substrait::proto::extensions::simple_extension_declaration::MappingType;
Expand Down Expand Up @@ -95,6 +96,11 @@ mod tests {
roundtrip("SELECT a, c, avg(b) FROM data GROUP BY a, c").await
}

#[tokio::test]
async fn decimal_literal() -> Result<()> {
roundtrip("SELECT * FROM data WHERE b > 2.5").await
}

#[tokio::test]
async fn simple_distinct() -> Result<()> {
test_alias(
Expand Down Expand Up @@ -290,9 +296,17 @@ mod tests {

async fn create_context() -> Result<SessionContext> {
let ctx = SessionContext::new();
ctx.register_csv("data", "tests/testdata/data.csv", CsvReadOptions::new())
let mut explicit_options = CsvReadOptions::new();
let schema = Schema::new(vec![
Field::new("a", DataType::Int64, true),
Field::new("b", DataType::Decimal128(5, 2), true),
Field::new("c", DataType::Date32, true),
Field::new("d", DataType::Boolean, true),
]);
explicit_options.schema = Some(&schema);
ctx.register_csv("data", "tests/testdata/data.csv", explicit_options.clone())
.await?;
ctx.register_csv("data2", "tests/testdata/data.csv", CsvReadOptions::new())
ctx.register_csv("data2", "tests/testdata/data.csv", explicit_options)
.await?;
Ok(ctx)
}
Expand Down
4 changes: 2 additions & 2 deletions datafusion/substrait/tests/testdata/data.csv
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
a,b,c,d
1,2,2020-01-01,false
3,4,2020-01-01,true
1,2.0,2020-01-01,false
3,4.5,2020-01-01,true

0 comments on commit 552eea7

Please sign in to comment.