Skip to content

Commit

Permalink
feat: express unsigned literal in substrait (#5448)
Browse files Browse the repository at this point in the history
Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
  • Loading branch information
waynexia authored Mar 4, 2023
1 parent 53a638e commit c37ddf7
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 11 deletions.
50 changes: 46 additions & 4 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -615,16 +615,58 @@ pub async fn from_substrait_rex(
Some(RexType::Literal(lit)) => {
match &lit.literal_type {
Some(LiteralType::I8(n)) => {
Ok(Arc::new(Expr::Literal(ScalarValue::Int8(Some(*n as i8)))))
if lit.type_variation_reference == 0 {
Ok(Arc::new(Expr::Literal(ScalarValue::Int8(Some(*n as i8)))))
} else if lit.type_variation_reference == 1 {
Ok(Arc::new(Expr::Literal(ScalarValue::UInt8(Some(*n as u8)))))
} else {
Err(DataFusionError::Substrait(format!(
"Unknown type variation reference {}",
lit.type_variation_reference
)))
}
}
Some(LiteralType::I16(n)) => {
Ok(Arc::new(Expr::Literal(ScalarValue::Int16(Some(*n as i16)))))
if lit.type_variation_reference == 0 {
Ok(Arc::new(Expr::Literal(ScalarValue::Int16(Some(*n as i16)))))
} else if lit.type_variation_reference == 1 {
Ok(Arc::new(Expr::Literal(ScalarValue::UInt16(Some(
*n as u16,
)))))
} else {
Err(DataFusionError::Substrait(format!(
"Unknown type variation reference {}",
lit.type_variation_reference
)))
}
}
Some(LiteralType::I32(n)) => {
Ok(Arc::new(Expr::Literal(ScalarValue::Int32(Some(*n)))))
if lit.type_variation_reference == 0 {
Ok(Arc::new(Expr::Literal(ScalarValue::Int32(Some(*n)))))
} else if lit.type_variation_reference == 1 {
Ok(Arc::new(Expr::Literal(ScalarValue::UInt32(Some(unsafe {
std::mem::transmute_copy::<i32, u32>(n)
})))))
} else {
Err(DataFusionError::Substrait(format!(
"Unknown type variation reference {}",
lit.type_variation_reference
)))
}
}
Some(LiteralType::I64(n)) => {
Ok(Arc::new(Expr::Literal(ScalarValue::Int64(Some(*n)))))
if lit.type_variation_reference == 0 {
Ok(Arc::new(Expr::Literal(ScalarValue::Int64(Some(*n)))))
} else if lit.type_variation_reference == 1 {
Ok(Arc::new(Expr::Literal(ScalarValue::UInt64(Some(unsafe {
std::mem::transmute_copy::<i64, u64>(n)
})))))
} else {
Err(DataFusionError::Substrait(format!(
"Unknown type variation reference {}",
lit.type_variation_reference
)))
}
}
Some(LiteralType::Boolean(b)) => {
Ok(Arc::new(Expr::Literal(ScalarValue::Boolean(Some(*b)))))
Expand Down
16 changes: 13 additions & 3 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use std::{collections::HashMap, sync::Arc};
use std::{collections::HashMap, mem, sync::Arc};

use datafusion::{
error::{DataFusionError, Result},
Expand Down Expand Up @@ -580,10 +580,17 @@ pub fn to_substrait_rex(
Expr::Literal(value) => {
let literal_type = match value {
ScalarValue::Int8(Some(n)) => Some(LiteralType::I8(*n as i32)),
ScalarValue::UInt8(Some(n)) => Some(LiteralType::I8(*n as i32)),
ScalarValue::Int16(Some(n)) => Some(LiteralType::I16(*n as i32)),
ScalarValue::UInt16(Some(n)) => Some(LiteralType::I16(*n as i32)),
ScalarValue::Int32(Some(n)) => Some(LiteralType::I32(*n)),
ScalarValue::UInt32(Some(n)) => Some(LiteralType::I32(unsafe {
mem::transmute_copy::<u32, i32>(n)
})),
ScalarValue::Int64(Some(n)) => Some(LiteralType::I64(*n)),
ScalarValue::UInt8(Some(n)) => Some(LiteralType::I16(*n as i32)), // Substrait currently does not support unsigned integer
ScalarValue::UInt64(Some(n)) => Some(LiteralType::I64(unsafe {
mem::transmute_copy::<u64, i64>(n)
})),
ScalarValue::Boolean(Some(b)) => Some(LiteralType::Boolean(*b)),
ScalarValue::Float32(Some(f)) => Some(LiteralType::Fp32(*f)),
ScalarValue::Float64(Some(f)) => Some(LiteralType::Fp64(*f)),
Expand All @@ -601,10 +608,13 @@ pub fn to_substrait_rex(
ScalarValue::Date32(Some(d)) => Some(LiteralType::Date(*d)),
_ => Some(try_to_substrait_null(value)?),
};

let type_variation_reference = if value.is_unsigned() { 1 } else { 0 };

Ok(Expression {
rex_type: Some(RexType::Literal(Literal {
nullable: true,
type_variation_reference: 0,
type_variation_reference,
literal_type,
})),
})
Expand Down
8 changes: 7 additions & 1 deletion datafusion/substrait/tests/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ mod tests {
roundtrip("SELECT * FROM data WHERE b = NULL").await
}

#[tokio::test]
async fn u32_literal() -> Result<()> {
roundtrip("SELECT * FROM data WHERE e > 4294967295").await
}

#[tokio::test]
async fn simple_distinct() -> Result<()> {
test_alias(
Expand Down Expand Up @@ -226,7 +231,7 @@ mod tests {
async fn simple_intersect() -> Result<()> {
assert_expected_plan(
"SELECT COUNT(*) FROM (SELECT data.a FROM data INTERSECT SELECT data2.a FROM data2);",
"Aggregate: groupBy=[[]], aggr=[[COUNT(Int16(1))]]\
"Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\
\n LeftSemi Join: data.a = data2.a\
\n Aggregate: groupBy=[[data.a]], aggr=[[]]\
\n TableScan: data projection=[a]\
Expand Down Expand Up @@ -335,6 +340,7 @@ mod tests {
Field::new("b", DataType::Decimal128(5, 2), true),
Field::new("c", DataType::Date32, true),
Field::new("d", DataType::Boolean, true),
Field::new("e", DataType::UInt32, true),
]);
explicit_options.schema = Some(&schema);
ctx.register_csv("data", "tests/testdata/data.csv", explicit_options)
Expand Down
6 changes: 3 additions & 3 deletions datafusion/substrait/tests/testdata/data.csv
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
a,b,c,d
1,2.0,2020-01-01,false
3,4.5,2020-01-01,true
a,b,c,d,e
1,2.0,2020-01-01,false,4294967296
3,4.5,2020-01-01,true,2147483648

0 comments on commit c37ddf7

Please sign in to comment.