Skip to content

Commit 83811c9

Browse files
kczimmpeasee
authored andcommitted
Infer placeholder datatype for Expr::InSubquery (#80)
UPSTREAM NOTE: Upstream PR has been created but not merged yet. Should be available in DF49 apache#15980
1 parent 24b7c1b commit 83811c9

File tree

1 file changed

+104
-1
lines changed

1 file changed

+104
-1
lines changed

datafusion/expr/src/expr.rs

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2071,6 +2071,27 @@ impl Expr {
20712071
| Expr::SimilarTo(Like { expr, pattern, .. }) => {
20722072
rewrite_placeholder(pattern.as_mut(), expr.as_ref(), schema)?;
20732073
}
2074+
Expr::InSubquery(InSubquery {
2075+
expr,
2076+
subquery,
2077+
negated: _,
2078+
}) => {
2079+
let subquery_schema = subquery.subquery.schema();
2080+
let fields = subquery_schema.fields();
2081+
2082+
// only supports subquery with exactly 1 field
2083+
if let [first_field] = &fields[..] {
2084+
rewrite_placeholder(
2085+
expr.as_mut(),
2086+
&Expr::Column(Column {
2087+
relation: None,
2088+
name: first_field.name().clone(),
2089+
spans: Spans::new(),
2090+
}),
2091+
schema,
2092+
)?;
2093+
}
2094+
}
20742095
Expr::Placeholder(_) => {
20752096
has_placeholder = true;
20762097
}
@@ -3555,7 +3576,8 @@ mod test {
35553576
use crate::expr_fn::col;
35563577
use crate::{
35573578
case, lit, qualified_wildcard, wildcard, wildcard_with_options, ColumnarValue,
3558-
ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Volatility,
3579+
LogicalPlan, LogicalTableSource, Projection, ScalarFunctionArgs, ScalarUDF,
3580+
ScalarUDFImpl, TableScan, Volatility,
35593581
};
35603582
use arrow::datatypes::{Field, Schema};
35613583
use sqlparser::ast;
@@ -3617,6 +3639,87 @@ mod test {
36173639
}
36183640
}
36193641

3642+
#[test]
3643+
fn infer_placeholder_in_subquery() -> Result<()> {
3644+
// Schema for my_table: A (Int32), B (Int32)
3645+
let schema = Arc::new(Schema::new(vec![
3646+
Field::new("A", DataType::Int32, true),
3647+
Field::new("B", DataType::Int32, true),
3648+
]));
3649+
3650+
let source = Arc::new(LogicalTableSource::new(schema.clone()));
3651+
3652+
// Simulate: SELECT * FROM my_table WHERE $1 IN (SELECT A FROM my_table WHERE B > 3);
3653+
let placeholder = Expr::Placeholder(Placeholder {
3654+
id: "$1".to_string(),
3655+
data_type: None,
3656+
});
3657+
3658+
// Subquery: SELECT A FROM my_table WHERE B > 3
3659+
let subquery_filter = Expr::BinaryExpr(BinaryExpr {
3660+
left: Box::new(col("B")),
3661+
op: Operator::Gt,
3662+
right: Box::new(Expr::Literal(ScalarValue::Int32(Some(3)))),
3663+
});
3664+
3665+
let subquery_scan = LogicalPlan::TableScan(TableScan {
3666+
table_name: TableReference::from("my_table"),
3667+
source,
3668+
projected_schema: Arc::new(DFSchema::try_from(schema.clone())?),
3669+
projection: None,
3670+
filters: vec![subquery_filter.clone()],
3671+
fetch: None,
3672+
});
3673+
3674+
let projected_fields = vec![Field::new("A", DataType::Int32, true)];
3675+
let projected_schema = Arc::new(DFSchema::from_unqualified_fields(
3676+
projected_fields.into(),
3677+
Default::default(),
3678+
)?);
3679+
3680+
let subquery = Subquery {
3681+
subquery: Arc::new(LogicalPlan::Projection(Projection {
3682+
expr: vec![col("A")],
3683+
input: Arc::new(subquery_scan),
3684+
schema: projected_schema,
3685+
})),
3686+
outer_ref_columns: vec![],
3687+
spans: Spans::new(),
3688+
};
3689+
3690+
let in_subquery = Expr::InSubquery(InSubquery {
3691+
expr: Box::new(placeholder),
3692+
subquery,
3693+
negated: false,
3694+
});
3695+
3696+
let df_schema = DFSchema::try_from(schema)?;
3697+
3698+
let (inferred_expr, contains_placeholder) =
3699+
in_subquery.infer_placeholder_types(&df_schema)?;
3700+
3701+
assert!(
3702+
contains_placeholder,
3703+
"Expression should contain a placeholder"
3704+
);
3705+
3706+
match inferred_expr {
3707+
Expr::InSubquery(in_subquery) => match *in_subquery.expr {
3708+
Expr::Placeholder(placeholder) => {
3709+
assert_eq!(
3710+
placeholder.data_type,
3711+
Some(DataType::Int32),
3712+
"Placeholder $1 should infer Int32"
3713+
);
3714+
}
3715+
_ => panic!("Expected Placeholder expression in InSubquery"),
3716+
},
3717+
_ => panic!("Expected InSubquery expression"),
3718+
}
3719+
3720+
Ok(())
3721+
}
3722+
36203723
#[test]
36213724
fn infer_placeholder_like_and_similar_to() {
36223725
// name LIKE $1

0 commit comments

Comments
 (0)