Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support GetIndexedExpr on dynamic expressions #26

Merged
merged 1 commit into from
May 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion/core/src/logical_plan/expr_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ impl ExprRewritable for Expr {
}
Expr::GetIndexedField { expr, key } => Expr::GetIndexedField {
expr: rewrite_boxed(expr, rewriter)?,
key,
key: rewrite_boxed(key, rewriter)?,
},
};

Expand Down
7 changes: 5 additions & 2 deletions datafusion/core/src/logical_plan/expr_visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,11 @@ impl ExprVisitable for Expr {
| Expr::Negative(expr)
| Expr::Cast { expr, .. }
| Expr::TryCast { expr, .. }
| Expr::Sort { expr, .. }
| Expr::GetIndexedField { expr, .. } => expr.accept(visitor),
| Expr::Sort { expr, .. } => expr.accept(visitor),
Expr::GetIndexedField { expr, key } => {
let visitor = expr.accept(visitor)?;
key.accept(visitor)
}
Expr::Column(_)
| Expr::OuterColumn(_, _)
| Expr::ScalarVariable(_, _)
Expand Down
10 changes: 6 additions & 4 deletions datafusion/core/src/optimizer/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,10 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result<Vec<Expr>> {
| Expr::Alias(expr, ..)
| Expr::Not(expr)
| Expr::Negative(expr)
| Expr::Sort { expr, .. }
| Expr::GetIndexedField { expr, .. } => Ok(vec![expr.as_ref().to_owned()]),
| Expr::Sort { expr, .. } => Ok(vec![expr.as_ref().to_owned()]),
Expr::GetIndexedField { expr, key } => {
Ok(vec![expr.as_ref().to_owned(), key.as_ref().to_owned()])
}
Expr::ScalarFunction { args, .. }
| Expr::ScalarUDF { args, .. }
| Expr::TableUDF { args, .. }
Expand Down Expand Up @@ -547,9 +549,9 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result<Expr> {
"QualifiedWildcard expressions are not valid in a logical query plan"
.to_owned(),
)),
Expr::GetIndexedField { expr: _, key } => Ok(Expr::GetIndexedField {
Expr::GetIndexedField { .. } => Ok(Expr::GetIndexedField {
expr: Box::new(expressions[0].clone()),
key: key.clone(),
key: Box::new(expressions[1].clone()),
}),
}
}
Expand Down
3 changes: 2 additions & 1 deletion datafusion/core/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
}
Expr::GetIndexedField { expr, key } => {
let expr = create_physical_name(expr, false)?;
let key = create_physical_name(key, false)?;
Ok(format!("{}[{}]", expr, key))
}
Expr::ScalarFunction { fun, args, .. } => {
Expand Down Expand Up @@ -1093,7 +1094,7 @@ pub fn create_physical_expr(
)?),
Expr::GetIndexedField { expr, key } => Ok(Arc::new(GetIndexedFieldExpr::new(
create_physical_expr(expr, input_dfschema, input_schema, execution_props)?,
key.clone(),
create_physical_expr(key, input_dfschema, input_schema, execution_props)?,
))),

Expr::ScalarFunction { fun, args } => {
Expand Down
50 changes: 9 additions & 41 deletions datafusion/core/src/sql/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,25 +133,7 @@ impl SqlToRelContext {
}
}

fn plan_key(key: SQLExpr) -> Result<ScalarValue> {
let scalar = match key {
SQLExpr::Value(Value::Number(s, _)) => {
ScalarValue::Int64(Some(s.parse().unwrap()))
}
SQLExpr::Value(Value::SingleQuotedString(s)) => ScalarValue::Utf8(Some(s)),
SQLExpr::Identifier(ident) => ScalarValue::Utf8(Some(ident.value)),
_ => {
return Err(DataFusionError::SQL(ParserError(format!(
"Unsuported index key expression: {:?}",
key
))))
}
};

Ok(scalar)
}

fn plan_indexed(expr: Expr, mut keys: Vec<SQLExpr>) -> Result<Expr> {
fn plan_indexed(expr: Expr, mut keys: Vec<Expr>) -> Result<Expr> {
let key = keys.pop().ok_or_else(|| {
DataFusionError::SQL(ParserError(
"Internal error: Missing index key expression".to_string(),
Expand All @@ -166,7 +148,7 @@ fn plan_indexed(expr: Expr, mut keys: Vec<SQLExpr>) -> Result<Expr> {

Ok(Expr::GetIndexedField {
expr: Box::new(expr),
key: plan_key(key)?,
key: Box::new(key),
})
}

Expand Down Expand Up @@ -1704,26 +1686,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
}

SQLExpr::MapAccess { ref column, keys } => {
if let SQLExpr::Identifier(ref id) = column.as_ref() {
plan_indexed(col(&id.value), keys)
} else {
Err(DataFusionError::NotImplemented(format!(
"map access requires an identifier, found column {} instead",
column
)))
}
}

SQLExpr::ArrayIndex { obj, indexs } => {
if let SQLExpr::Identifier(ref id) = obj.as_ref() {
plan_indexed(col(&id.value), indexs)
} else {
Err(DataFusionError::NotImplemented(format!(
"array index access requires an identifier, found column {} instead",
obj
)))
}
let expr = self.sql_expr_to_logical_expr(*obj, schema)?;

plan_indexed(expr, indexs.into_iter()
.map(|e| self.sql_expr_to_logical_expr(e, schema))
.collect::<Result<Vec<_>>>()?)
}

SQLExpr::CompoundIdentifier(ids) => {
Expand Down Expand Up @@ -1754,7 +1722,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
// Access to a field of a column which is a structure, example: SELECT my_struct.key
Ok(Expr::GetIndexedField {
expr: Box::new(Expr::Column(field.qualified_column())),
key: ScalarValue::Utf8(Some(name)),
key: Box::new(Expr::Literal(ScalarValue::Utf8(Some(name)))),
})
} else {
// table.column identifier
Expand Down Expand Up @@ -2104,7 +2072,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
SQLExpr::DotExpr { expr, field } => {
Ok(Expr::GetIndexedField {
expr: Box::new(self.sql_expr_to_logical_expr(*expr, schema)?),
key: ScalarValue::Utf8(Some(field.value)),
key: Box::new(Expr::Literal(ScalarValue::Utf8(Some(field.value)))),
})
}

Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/sql/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ where
Expr::QualifiedWildcard { .. } => Ok(expr.clone()),
Expr::GetIndexedField { expr, key } => Ok(Expr::GetIndexedField {
expr: Box::new(clone_with_replacement(expr.as_ref(), replacement_fn)?),
key: key.clone(),
key: Box::new(clone_with_replacement(key.as_ref(), replacement_fn)?),
}),
},
}
Expand Down
84 changes: 82 additions & 2 deletions datafusion/core/tests/sql/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -596,12 +596,92 @@ async fn query_nested_get_indexed_field() -> Result<()> {
"+----------+",
];
assert_batches_eq!(expected, &actual);

// nested with scalar values
let sql = "SELECT some_list[0][0] as i0 FROM ints LIMIT 3";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+----+", "| i0 |", "+----+", "| 0 |", "| 5 |", "| 11 |", "+----+",
];
assert_batches_eq!(expected, &actual);

// nested with dynamic expr in key
assert_batches_eq!(expected, &actual);
let sql = "SELECT some_list[1 - 1][1 - 1] as i0 FROM ints LIMIT 3";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+----+", "| i0 |", "+----+", "| 0 |", "| 5 |", "| 11 |", "+----+",
];
assert_batches_eq!(expected, &actual);

Ok(())
}

#[tokio::test]
async fn query_get_indexed_array_dynamic_key() -> Result<()> {
let ctx = SessionContext::new();

let list_dt = Box::new(Field::new("item", DataType::Int64, true));
let schema = Arc::new(Schema::new(vec![
Field::new("arr", DataType::List(list_dt), false),
Field::new("key", DataType::Int64, false),
]));

let array_ints_builder = PrimitiveBuilder::<Int64Type>::new(3);
let mut arr_builder = ListBuilder::new(array_ints_builder);
let mut key_builder = PrimitiveBuilder::<Int64Type>::new(3);

for (int_vec, key) in vec![
(vec![0, 1, 2, 3], 1),
(vec![4, 5, 6, 7], 2),
(vec![8, 9, 10, 11], 3),
] {
for n in int_vec {
arr_builder.values().append_value(n)?;
}

key_builder.append_value(key)?;
arr_builder.append(true)?;
}

let data = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(arr_builder.finish()),
Arc::new(key_builder.finish()),
],
)?;
let table = MemTable::try_new(schema, vec![vec![data]])?;
let table_a = Arc::new(table);

ctx.register_table("array_and_keys", table_a)?;

let sql = "SELECT arr[key], key FROM array_and_keys";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+----------------------------------------+-----+",
"| array_and_keys.arr[array_and_keys.key] | key |",
"+----------------------------------------+-----+",
"| 0 | 1 |",
"| 4 | 2 |",
"| 8 | 3 |",
"+----------------------------------------+-----+",
];
assert_batches_eq!(expected, &actual);

// All dynamic
let sql = "SELECT r.value[r.key] FROM (SELECT array[1,2,3] as value, 1 as key UNION ALL SELECT array[4,5,6] as value, 2 as key) as r";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+----------------+",
"| r.value[r.key] |",
"+----------------+",
"| 1 |",
"| 5 |",
"+----------------+",
];
assert_batches_eq!(expected, &actual);

Ok(())
}

Expand Down Expand Up @@ -634,7 +714,7 @@ async fn query_nested_get_indexed_field_on_struct() -> Result<()> {
ctx.register_table("structs", table_a)?;

// Original column is micros, convert to millis and check timestamp
let sql = "SELECT some_struct[\"bar\"] as l0 FROM structs LIMIT 3";
let sql = "SELECT some_struct['bar'] as l0 FROM structs LIMIT 3";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+----------------+",
Expand All @@ -661,7 +741,7 @@ async fn query_nested_get_indexed_field_on_struct() -> Result<()> {
];
assert_batches_eq!(expected, &actual);

let sql = "SELECT some_struct[\"bar\"][0] as i0 FROM structs LIMIT 3";
let sql = "SELECT some_struct['bar'][0] as i0 FROM structs LIMIT 3";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+----+", "| i0 |", "+----+", "| 0 |", "| 4 |", "| 8 |", "+----+",
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ pub enum Expr {
/// the expression to take the field from
expr: Box<Expr>,
/// The name of the field to take
key: ScalarValue,
key: Box<Expr>,
},
/// Whether an expression is between a given range.
Between {
Expand Down
Loading