Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 105 additions & 1 deletion datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1775,6 +1775,28 @@ impl Expr {
| Expr::SimilarTo(Like { expr, pattern, .. }) => {
rewrite_placeholder(pattern.as_mut(), expr.as_ref(), schema)?;
}
Expr::InSubquery(InSubquery {
expr,
subquery,
negated: _,
}) => {
let subquery_schema = subquery.subquery.schema();
let fields = subquery_schema.fields();

// only supports subquery with exactly 1 field
// https://github.com/apache/datafusion/blob/main/datafusion/sql/src/expr/subquery.rs#L120
if let [first_field] = &fields[..] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if there is more than one field? Will it not rewrite any placeholders?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rewrite_placeholder(
expr.as_mut(),
&Expr::Column(Column {
relation: None,
name: first_field.name().clone(),
spans: Spans::new(),
}),
schema,
)?;
}
}
Expr::Placeholder(_) => {
has_placeholder = true;
}
Expand Down Expand Up @@ -3198,7 +3220,8 @@ mod test {
use crate::expr_fn::col;
use crate::{
case, lit, qualified_wildcard, wildcard, wildcard_with_options, ColumnarValue,
ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Volatility,
LogicalPlan, LogicalTableSource, Projection, ScalarFunctionArgs, ScalarUDF,
ScalarUDFImpl, TableScan, Volatility,
};
use arrow::datatypes::{Field, Schema};
use sqlparser::ast;
Expand Down Expand Up @@ -3260,6 +3283,87 @@ mod test {
}
}

#[test]
fn infer_placeholder_in_subquery() -> Result<()> {
// Schema for my_table: A (Int32), B (Int32)
let schema = Arc::new(Schema::new(vec![
Field::new("A", DataType::Int32, true),
Field::new("B", DataType::Int32, true),
]));

let source = Arc::new(LogicalTableSource::new(Arc::clone(&schema)));

// Simulate: SELECT * FROM my_table WHERE $1 IN (SELECT A FROM my_table WHERE B > 3);
let placeholder = Expr::Placeholder(Placeholder {
id: "$1".to_string(),
data_type: None,
});

// Subquery: SELECT A FROM my_table WHERE B > 3
let subquery_filter = Expr::BinaryExpr(BinaryExpr {
left: Box::new(col("B")),
op: Operator::Gt,
right: Box::new(Expr::Literal(ScalarValue::Int32(Some(3)))),
});

let subquery_scan = LogicalPlan::TableScan(TableScan {
table_name: TableReference::from("my_table"),
source,
projected_schema: Arc::new(DFSchema::try_from(Arc::clone(&schema))?),
projection: None,
filters: vec![subquery_filter.clone()],
fetch: None,
});

let projected_fields = vec![Field::new("A", DataType::Int32, true)];
let projected_schema = Arc::new(DFSchema::from_unqualified_fields(
projected_fields.into(),
Default::default(),
)?);

let subquery = Subquery {
subquery: Arc::new(LogicalPlan::Projection(Projection {
expr: vec![col("A")],
input: Arc::new(subquery_scan),
schema: projected_schema,
})),
outer_ref_columns: vec![],
spans: Spans::new(),
};

let in_subquery = Expr::InSubquery(InSubquery {
expr: Box::new(placeholder),
subquery,
negated: false,
});

let df_schema = DFSchema::try_from(schema)?;

let (inferred_expr, contains_placeholder) =
in_subquery.infer_placeholder_types(&df_schema)?;

assert!(
contains_placeholder,
"Expression should contain a placeholder"
);

match inferred_expr {
Expr::InSubquery(in_subquery) => match *in_subquery.expr {
Expr::Placeholder(placeholder) => {
assert_eq!(
placeholder.data_type,
Some(DataType::Int32),
"Placeholder $1 should infer Int32"
);
}
_ => panic!("Expected Placeholder expression in InSubquery"),
},
_ => panic!("Expected InSubquery expression"),
}

Ok(())
}

#[test]
fn infer_placeholder_like_and_similar_to() {
// name LIKE $1
Expand Down
111 changes: 111 additions & 0 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1494,6 +1494,30 @@ impl LogicalPlan {
let mut param_types: HashMap<String, Option<DataType>> = HashMap::new();

self.apply_with_subqueries(|plan| {
if let LogicalPlan::Limit(Limit {
fetch: Some(f),
skip,
..
}) = plan
{
if let Expr::Placeholder(Placeholder { id, data_type }) = &**f {
// Valid assumption, https://github.com/apache/datafusion/blob/41e7aed3a943134c40d1b18cb9d424b358b5e5b1/datafusion/optimizer/src/analyzer/type_coercion.rs#L242
param_types.insert(
id.clone(),
Some(data_type.as_ref().cloned().unwrap_or(DataType::Int64)),
);
}

if let Some(s) = skip {
if let Expr::Placeholder(Placeholder { id, data_type }) = &**s {
// Valid assumption, https://github.com/apache/datafusion/blob/41e7aed3a943134c40d1b18cb9d424b358b5e5b1/datafusion/optimizer/src/analyzer/type_coercion.rs#L242
param_types.insert(
id.clone(),
Some(data_type.as_ref().cloned().unwrap_or(DataType::Int64)),
);
}
}
}
plan.apply_expressions(|expr| {
expr.apply(|expr| {
if let Expr::Placeholder(Placeholder { id, data_type }) = expr {
Expand All @@ -1507,6 +1531,10 @@ impl LogicalPlan {
(_, Some(dt)) => {
param_types.insert(id.clone(), Some(dt.clone()));
}
(Some(Some(_)), None) => {
// we have already inferred the datatype like
// the LIMIT case handled specially above.
}
_ => {
param_types.insert(id.clone(), None);
}
Expand Down Expand Up @@ -4029,6 +4057,89 @@ mod tests {
.build()
}

#[test]
fn test_resolved_placeholder_limit() -> Result<()> {
let schema = Arc::new(Schema::new(vec![Field::new("A", DataType::Int32, true)]));
let source = Arc::new(LogicalTableSource::new(Arc::clone(&schema)));

let placeholders = ["$1", "$2"];

// SELECT * FROM my_table LIMIT $1 OFFSET $2
let plan = LogicalPlan::Limit(Limit {
skip: Some(Box::new(Expr::Placeholder(Placeholder {
id: placeholders[1].to_string(),
data_type: None,
}))),
fetch: Some(Box::new(Expr::Placeholder(Placeholder {
id: placeholders[0].to_string(),
data_type: None,
}))),
input: Arc::new(LogicalPlan::TableScan(TableScan {
table_name: TableReference::from("my_table"),
source,
projected_schema: Arc::new(DFSchema::try_from(Arc::clone(&schema))?),
projection: None,
filters: vec![],
fetch: None,
})),
});

// try to infer the placeholder datatypes for the plan
let schema = DFSchema::try_from(Arc::clone(&schema))?;
let plan = plan
.map_expressions(|e| {
let (e, has_placeholder) = e.infer_placeholder_types(&schema)?;
Ok(if !has_placeholder {
Transformed::no(e)
} else {
Transformed::yes(e)
})
})
.expect("map expressions")
.data;

let LogicalPlan::Limit(Limit {
fetch: Some(f),
skip: Some(s),
..
}) = &plan
else {
panic!("plan is not Limit with fetch and skip");
};

if !matches!(
(&**f, &**s),
(
Expr::Placeholder(Placeholder {
data_type: None,
..
}),
Expr::Placeholder(Placeholder {
data_type: None,
..
})
)
) {
panic!(
"expected fetch and skip to be placeholders with datatypes uninferred"
);
}

let params = plan.get_parameter_types().expect("to infer type");
assert_eq!(params.len(), 2);

for placeholder in placeholders {
let parameter_type = params
.clone()
.get(placeholder)
.expect("to get fetch type")
.clone();
assert_eq!(parameter_type, Some(DataType::Int64));
}

Ok(())
}

#[test]
fn test_display_indent() -> Result<()> {
let plan = display_plan()?;
Expand Down