Skip to content

Commit 03745da

Browse files
authored
Support more functions and scalar types (#7)
1 parent 61e4cc3 commit 03745da

File tree

4 files changed

+104
-34
lines changed

4 files changed

+104
-34
lines changed

src/consumer.rs

Lines changed: 55 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,37 @@ use substrait::protobuf::{
1616
Expression, Rel,
1717
};
1818

19+
pub fn reference_to_op(reference: u32) -> Result<Operator> {
20+
match reference {
21+
1 => Ok(Operator::Eq),
22+
2 => Ok(Operator::NotEq),
23+
3 => Ok(Operator::Lt),
24+
4 => Ok(Operator::LtEq),
25+
5 => Ok(Operator::Gt),
26+
6 => Ok(Operator::GtEq),
27+
7 => Ok(Operator::Plus),
28+
8 => Ok(Operator::Minus),
29+
9 => Ok(Operator::Multiply),
30+
10 => Ok(Operator::Divide),
31+
11 => Ok(Operator::Modulo),
32+
12 => Ok(Operator::And),
33+
13 => Ok(Operator::Or),
34+
14 => Ok(Operator::Like),
35+
15 => Ok(Operator::NotLike),
36+
16 => Ok(Operator::IsDistinctFrom),
37+
17 => Ok(Operator::IsNotDistinctFrom),
38+
18 => Ok(Operator::RegexMatch),
39+
19 => Ok(Operator::RegexIMatch),
40+
20 => Ok(Operator::RegexNotMatch),
41+
21 => Ok(Operator::RegexNotIMatch),
42+
22 => Ok(Operator::BitwiseAnd),
43+
_ => Err(DataFusionError::NotImplemented(format!(
44+
"Unsupported function_reference: {:?}",
45+
reference
46+
))),
47+
}
48+
}
49+
1950
/// Convert Substrait Rel to DataFusion DataFrame
2051
#[async_recursion]
2152
pub async fn from_substrait_rel(
@@ -95,19 +126,7 @@ pub async fn from_substrait_rex(e: &Expression, input: &dyn DataFrame) -> Result
95126
},
96127
Some(RexType::ScalarFunction(f)) => {
97128
assert!(f.args.len() == 2);
98-
let op = match f.function_reference {
99-
1 => Operator::Eq,
100-
2 => Operator::Lt,
101-
3 => Operator::LtEq,
102-
4 => Operator::Gt,
103-
5 => Operator::GtEq,
104-
_ => {
105-
return Err(DataFusionError::NotImplemented(format!(
106-
"Unsupported function_reference: {:?}",
107-
f.function_reference
108-
)))
109-
}
110-
};
129+
let op = reference_to_op(f.function_reference)?;
111130
Ok(Arc::new(Expr::BinaryExpr {
112131
left: Box::new(
113132
from_substrait_rex(&f.args[0], input)
@@ -124,19 +143,37 @@ pub async fn from_substrait_rex(e: &Expression, input: &dyn DataFrame) -> Result
124143
),
125144
}))
126145
}
127-
Some(RexType::Literal(lit)) => match lit.literal_type {
146+
Some(RexType::Literal(lit)) => match &lit.literal_type {
128147
Some(LiteralType::I8(n)) => {
129-
Ok(Arc::new(Expr::Literal(ScalarValue::Int8(Some(n as i8)))))
148+
Ok(Arc::new(Expr::Literal(ScalarValue::Int8(Some(*n as i8)))))
130149
}
131150
Some(LiteralType::I16(n)) => {
132-
Ok(Arc::new(Expr::Literal(ScalarValue::Int16(Some(n as i16)))))
151+
Ok(Arc::new(Expr::Literal(ScalarValue::Int16(Some(*n as i16)))))
133152
}
134153
Some(LiteralType::I32(n)) => {
135-
Ok(Arc::new(Expr::Literal(ScalarValue::Int32(Some(n as i32)))))
154+
Ok(Arc::new(Expr::Literal(ScalarValue::Int32(Some(*n as i32)))))
136155
}
137156
Some(LiteralType::I64(n)) => {
138-
Ok(Arc::new(Expr::Literal(ScalarValue::Int64(Some(n as i64)))))
157+
Ok(Arc::new(Expr::Literal(ScalarValue::Int64(Some(*n as i64)))))
158+
}
159+
Some(LiteralType::Boolean(b)) => {
160+
Ok(Arc::new(Expr::Literal(ScalarValue::Boolean(Some(*b)))))
161+
}
162+
Some(LiteralType::Date(d)) => {
163+
Ok(Arc::new(Expr::Literal(ScalarValue::Date32(Some(*d)))))
164+
}
165+
Some(LiteralType::Fp32(f)) => {
166+
Ok(Arc::new(Expr::Literal(ScalarValue::Float32(Some(*f)))))
167+
}
168+
Some(LiteralType::Fp64(f)) => {
169+
Ok(Arc::new(Expr::Literal(ScalarValue::Float64(Some(*f)))))
139170
}
171+
Some(LiteralType::String(s)) => Ok(Arc::new(Expr::Literal(ScalarValue::LargeUtf8(
172+
Some(s.clone()),
173+
)))),
174+
Some(LiteralType::Binary(b)) => Ok(Arc::new(Expr::Literal(ScalarValue::Binary(Some(
175+
b.clone(),
176+
))))),
140177
_ => {
141178
return Err(DataFusionError::NotImplemented(format!(
142179
"Unsupported literal_type: {:?}",

src/lib.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,16 @@ mod tests {
2323
roundtrip("SELECT * FROM data WHERE a > 1").await
2424
}
2525

26+
#[tokio::test]
27+
async fn select_with_filter_date() -> Result<()> {
28+
roundtrip("SELECT * FROM data WHERE c > CAST('2020-01-01' AS DATE)").await
29+
}
30+
31+
#[tokio::test]
32+
async fn select_with_filter_bool_expr() -> Result<()> {
33+
roundtrip("SELECT * FROM data WHERE d AND a > 1").await
34+
}
35+
2636
async fn roundtrip(sql: &str) -> Result<()> {
2737
let mut ctx = ExecutionContext::new();
2838
ctx.register_csv("data", "testdata/data.csv", CsvReadOptions::new())

src/producer.rs

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,33 @@ pub fn to_substrait_rel(plan: &LogicalPlan) -> Result<Box<Rel>> {
9494
}
9595
}
9696

97+
pub fn operator_to_reference(op: Operator) -> u32 {
98+
match op {
99+
Operator::Eq => 1,
100+
Operator::NotEq => 2,
101+
Operator::Lt => 3,
102+
Operator::LtEq => 4,
103+
Operator::Gt => 5,
104+
Operator::GtEq => 6,
105+
Operator::Plus => 7,
106+
Operator::Minus => 8,
107+
Operator::Multiply => 9,
108+
Operator::Divide => 10,
109+
Operator::Modulo => 11,
110+
Operator::And => 12,
111+
Operator::Or => 13,
112+
Operator::Like => 14,
113+
Operator::NotLike => 15,
114+
Operator::IsDistinctFrom => 16,
115+
Operator::IsNotDistinctFrom => 17,
116+
Operator::RegexMatch => 18,
117+
Operator::RegexIMatch => 19,
118+
Operator::RegexNotMatch => 20,
119+
Operator::RegexNotIMatch => 21,
120+
Operator::BitwiseAnd => 22,
121+
}
122+
}
123+
97124
/// Convert DataFusion Expr to Substrait Rex
98125
pub fn to_substrait_rex(expr: &Expr, schema: &DFSchemaRef) -> Result<Expression> {
99126
match expr {
@@ -114,19 +141,7 @@ pub fn to_substrait_rex(expr: &Expr, schema: &DFSchemaRef) -> Result<Expression>
114141
Expr::BinaryExpr { left, op, right } => {
115142
let l = to_substrait_rex(left, schema)?;
116143
let r = to_substrait_rex(right, schema)?;
117-
let function_reference: u32 = match op {
118-
Operator::Eq => 1,
119-
Operator::Lt => 2,
120-
Operator::LtEq => 3,
121-
Operator::Gt => 4,
122-
Operator::GtEq => 5,
123-
_ => {
124-
return Err(DataFusionError::NotImplemented(format!(
125-
"Unsupported operator: {:?}",
126-
op
127-
)))
128-
}
129-
};
144+
let function_reference: u32 = operator_to_reference(*op);
130145
Ok(Expression {
131146
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
132147
function_reference,
@@ -141,6 +156,14 @@ pub fn to_substrait_rex(expr: &Expr, schema: &DFSchemaRef) -> Result<Expression>
141156
ScalarValue::Int16(Some(n)) => Some(LiteralType::I16(*n as i32)),
142157
ScalarValue::Int32(Some(n)) => Some(LiteralType::I32(*n)),
143158
ScalarValue::Int64(Some(n)) => Some(LiteralType::I64(*n)),
159+
ScalarValue::Boolean(Some(b)) => Some(LiteralType::Boolean(*b)),
160+
ScalarValue::Float32(Some(f)) => Some(LiteralType::Fp32(*f)),
161+
ScalarValue::Float64(Some(f)) => Some(LiteralType::Fp64(*f)),
162+
ScalarValue::Utf8(Some(s)) => Some(LiteralType::String(s.clone())),
163+
ScalarValue::LargeUtf8(Some(s)) => Some(LiteralType::String(s.clone())),
164+
ScalarValue::Binary(Some(b)) => Some(LiteralType::Binary(b.clone())),
165+
ScalarValue::LargeBinary(Some(b)) => Some(LiteralType::Binary(b.clone())),
166+
ScalarValue::Date32(Some(d)) => Some(LiteralType::Date(*d)),
144167
_ => {
145168
return Err(DataFusionError::NotImplemented(format!(
146169
"Unsupported literal: {:?}",

testdata/data.csv

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
a,b
2-
1,2
3-
3,4
1+
a,b,c,d
2+
1,2,2020-01-01,false
3+
3,4,2020-01-01,true

0 commit comments

Comments
 (0)