Skip to content

Commit 0d6f49f

Browse files
committed
feat: Support ArrayIndex (GetIndexedExpr) on dynamic key expressions
1 parent 3dcf38b commit 0d6f49f

File tree

10 files changed

+163
-86
lines changed

10 files changed

+163
-86
lines changed

datafusion/core/src/logical_plan/expr_rewriter.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ impl ExprRewritable for Expr {
228228
}
229229
Expr::GetIndexedField { expr, key } => Expr::GetIndexedField {
230230
expr: rewrite_boxed(expr, rewriter)?,
231-
key,
231+
key: rewrite_boxed(key, rewriter)?,
232232
},
233233
};
234234

datafusion/core/src/logical_plan/expr_visitor.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,11 @@ impl ExprVisitable for Expr {
101101
| Expr::Negative(expr)
102102
| Expr::Cast { expr, .. }
103103
| Expr::TryCast { expr, .. }
104-
| Expr::Sort { expr, .. }
105-
| Expr::GetIndexedField { expr, .. } => expr.accept(visitor),
104+
| Expr::Sort { expr, .. } => expr.accept(visitor),
105+
Expr::GetIndexedField { expr, key } => {
106+
let visitor = expr.accept(visitor)?;
107+
key.accept(visitor)
108+
}
106109
Expr::Column(_)
107110
| Expr::OuterColumn(_, _)
108111
| Expr::ScalarVariable(_, _)

datafusion/core/src/optimizer/utils.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -312,8 +312,10 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result<Vec<Expr>> {
312312
| Expr::Alias(expr, ..)
313313
| Expr::Not(expr)
314314
| Expr::Negative(expr)
315-
| Expr::Sort { expr, .. }
316-
| Expr::GetIndexedField { expr, .. } => Ok(vec![expr.as_ref().to_owned()]),
315+
| Expr::Sort { expr, .. } => Ok(vec![expr.as_ref().to_owned()]),
316+
Expr::GetIndexedField { expr, key } => {
317+
Ok(vec![expr.as_ref().to_owned(), key.as_ref().to_owned()])
318+
}
317319
Expr::ScalarFunction { args, .. }
318320
| Expr::ScalarUDF { args, .. }
319321
| Expr::TableUDF { args, .. }
@@ -547,9 +549,9 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result<Expr> {
547549
"QualifiedWildcard expressions are not valid in a logical query plan"
548550
.to_owned(),
549551
)),
550-
Expr::GetIndexedField { expr: _, key } => Ok(Expr::GetIndexedField {
552+
Expr::GetIndexedField { .. } => Ok(Expr::GetIndexedField {
551553
expr: Box::new(expressions[0].clone()),
552-
key: key.clone(),
554+
key: Box::new(expressions[1].clone()),
553555
}),
554556
}
555557
}

datafusion/core/src/physical_plan/planner.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
156156
}
157157
Expr::GetIndexedField { expr, key } => {
158158
let expr = create_physical_name(expr, false)?;
159+
let key = create_physical_name(key, false)?;
159160
Ok(format!("{}[{}]", expr, key))
160161
}
161162
Expr::ScalarFunction { fun, args, .. } => {
@@ -1093,7 +1094,7 @@ pub fn create_physical_expr(
10931094
)?),
10941095
Expr::GetIndexedField { expr, key } => Ok(Arc::new(GetIndexedFieldExpr::new(
10951096
create_physical_expr(expr, input_dfschema, input_schema, execution_props)?,
1096-
key.clone(),
1097+
create_physical_expr(key, input_dfschema, input_schema, execution_props)?,
10971098
))),
10981099

10991100
Expr::ScalarFunction { fun, args } => {

datafusion/core/src/sql/planner.rs

Lines changed: 9 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -133,25 +133,7 @@ impl SqlToRelContext {
133133
}
134134
}
135135

136-
fn plan_key(key: SQLExpr) -> Result<ScalarValue> {
137-
let scalar = match key {
138-
SQLExpr::Value(Value::Number(s, _)) => {
139-
ScalarValue::Int64(Some(s.parse().unwrap()))
140-
}
141-
SQLExpr::Value(Value::SingleQuotedString(s)) => ScalarValue::Utf8(Some(s)),
142-
SQLExpr::Identifier(ident) => ScalarValue::Utf8(Some(ident.value)),
143-
_ => {
144-
return Err(DataFusionError::SQL(ParserError(format!(
145-
"Unsuported index key expression: {:?}",
146-
key
147-
))))
148-
}
149-
};
150-
151-
Ok(scalar)
152-
}
153-
154-
fn plan_indexed(expr: Expr, mut keys: Vec<SQLExpr>) -> Result<Expr> {
136+
fn plan_indexed(expr: Expr, mut keys: Vec<Expr>) -> Result<Expr> {
155137
let key = keys.pop().ok_or_else(|| {
156138
DataFusionError::SQL(ParserError(
157139
"Internal error: Missing index key expression".to_string(),
@@ -166,7 +148,7 @@ fn plan_indexed(expr: Expr, mut keys: Vec<SQLExpr>) -> Result<Expr> {
166148

167149
Ok(Expr::GetIndexedField {
168150
expr: Box::new(expr),
169-
key: plan_key(key)?,
151+
key: Box::new(key),
170152
})
171153
}
172154

@@ -1704,26 +1686,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
17041686
}
17051687
}
17061688

1707-
SQLExpr::MapAccess { ref column, keys } => {
1708-
if let SQLExpr::Identifier(ref id) = column.as_ref() {
1709-
plan_indexed(col(&id.value), keys)
1710-
} else {
1711-
Err(DataFusionError::NotImplemented(format!(
1712-
"map access requires an identifier, found column {} instead",
1713-
column
1714-
)))
1715-
}
1716-
}
1717-
17181689
SQLExpr::ArrayIndex { obj, indexs } => {
1719-
if let SQLExpr::Identifier(ref id) = obj.as_ref() {
1720-
plan_indexed(col(&id.value), indexs)
1721-
} else {
1722-
Err(DataFusionError::NotImplemented(format!(
1723-
"array index access requires an identifier, found column {} instead",
1724-
obj
1725-
)))
1726-
}
1690+
let expr = self.sql_expr_to_logical_expr(*obj, schema)?;
1691+
1692+
plan_indexed(expr, indexs.into_iter()
1693+
.map(|e| self.sql_expr_to_logical_expr(e, schema))
1694+
.collect::<Result<Vec<_>>>()?)
17271695
}
17281696

17291697
SQLExpr::CompoundIdentifier(ids) => {
@@ -1754,7 +1722,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
17541722
// Access to a field of a column which is a structure, example: SELECT my_struct.key
17551723
Ok(Expr::GetIndexedField {
17561724
expr: Box::new(Expr::Column(field.qualified_column())),
1757-
key: ScalarValue::Utf8(Some(name)),
1725+
key: Box::new(Expr::Literal(ScalarValue::Utf8(Some(name)))),
17581726
})
17591727
} else {
17601728
// table.column identifier
@@ -2104,7 +2072,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
21042072
SQLExpr::DotExpr { expr, field } => {
21052073
Ok(Expr::GetIndexedField {
21062074
expr: Box::new(self.sql_expr_to_logical_expr(*expr, schema)?),
2107-
key: ScalarValue::Utf8(Some(field.value)),
2075+
key: Box::new(Expr::Literal(ScalarValue::Utf8(Some(field.value)))),
21082076
})
21092077
}
21102078

datafusion/core/src/sql/utils.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ where
391391
Expr::QualifiedWildcard { .. } => Ok(expr.clone()),
392392
Expr::GetIndexedField { expr, key } => Ok(Expr::GetIndexedField {
393393
expr: Box::new(clone_with_replacement(expr.as_ref(), replacement_fn)?),
394-
key: key.clone(),
394+
key: Box::new(clone_with_replacement(key.as_ref(), replacement_fn)?),
395395
}),
396396
},
397397
}

datafusion/core/tests/sql/select.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -596,12 +596,24 @@ async fn query_nested_get_indexed_field() -> Result<()> {
596596
"+----------+",
597597
];
598598
assert_batches_eq!(expected, &actual);
599+
600+
// nested with scalar values
599601
let sql = "SELECT some_list[0][0] as i0 FROM ints LIMIT 3";
600602
let actual = execute_to_batches(&ctx, sql).await;
601603
let expected = vec![
602604
"+----+", "| i0 |", "+----+", "| 0 |", "| 5 |", "| 11 |", "+----+",
603605
];
604606
assert_batches_eq!(expected, &actual);
607+
608+
// nested with dynamic expr in key
609+
assert_batches_eq!(expected, &actual);
610+
let sql = "SELECT some_list[1 - 1][1 - 1] as i0 FROM ints LIMIT 3";
611+
let actual = execute_to_batches(&ctx, sql).await;
612+
let expected = vec![
613+
"+----+", "| i0 |", "+----+", "| 0 |", "| 5 |", "| 11 |", "+----+",
614+
];
615+
assert_batches_eq!(expected, &actual);
616+
605617
Ok(())
606618
}
607619

@@ -634,7 +646,7 @@ async fn query_nested_get_indexed_field_on_struct() -> Result<()> {
634646
ctx.register_table("structs", table_a)?;
635647

636648
// Original column is micros, convert to millis and check timestamp
637-
let sql = "SELECT some_struct[\"bar\"] as l0 FROM structs LIMIT 3";
649+
let sql = "SELECT some_struct['bar'] as l0 FROM structs LIMIT 3";
638650
let actual = execute_to_batches(&ctx, sql).await;
639651
let expected = vec![
640652
"+----------------+",
@@ -661,7 +673,7 @@ async fn query_nested_get_indexed_field_on_struct() -> Result<()> {
661673
];
662674
assert_batches_eq!(expected, &actual);
663675

664-
let sql = "SELECT some_struct[\"bar\"][0] as i0 FROM structs LIMIT 3";
676+
let sql = "SELECT some_struct['bar'][0] as i0 FROM structs LIMIT 3";
665677
let actual = execute_to_batches(&ctx, sql).await;
666678
let expected = vec![
667679
"+----+", "| i0 |", "+----+", "| 0 |", "| 4 |", "| 8 |", "+----+",

datafusion/expr/src/expr.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ pub enum Expr {
114114
/// the expression to take the field from
115115
expr: Box<Expr>,
116116
/// The name of the field to take
117-
key: ScalarValue,
117+
key: Box<Expr>,
118118
},
119119
/// Whether an expression is between a given range.
120120
Between {

0 commit comments

Comments
 (0)