Skip to content

Commit 55a8c0b

Browse files
kczimmphillipleblanc
authored andcommitted
Infer placeholder datatype for Expr::InSubquery (#80)
* infer placeholder datatype for InSubquery * update comment * only infer subquery if exactly 1 field
1 parent bddf7e6 commit 55a8c0b

File tree

1 file changed

+102
-1
lines changed

1 file changed

+102
-1
lines changed

datafusion/expr/src/expr.rs

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1645,6 +1645,26 @@ impl Expr {
16451645
| Expr::SimilarTo(Like { expr, pattern, .. }) => {
16461646
rewrite_placeholder(pattern.as_mut(), expr.as_ref(), schema)?;
16471647
}
1648+
Expr::InSubquery(InSubquery {
1649+
expr,
1650+
subquery,
1651+
negated: _,
1652+
}) => {
1653+
let subquery_schema = subquery.subquery.schema();
1654+
let fields = subquery_schema.fields();
1655+
1656+
// only supports subquery with exactly 1 field
1657+
if let [first_field] = &fields[..] {
1658+
rewrite_placeholder(
1659+
expr.as_mut(),
1660+
&Expr::Column(Column {
1661+
relation: None,
1662+
name: first_field.name().clone(),
1663+
}),
1664+
schema,
1665+
)?;
1666+
}
1667+
}
16481668
Expr::Placeholder(_) => {
16491669
has_placeholder = true;
16501670
}
@@ -2852,7 +2872,8 @@ mod test {
28522872
use crate::expr_fn::col;
28532873
use crate::{
28542874
case, lit, qualified_wildcard, wildcard, wildcard_with_options, ColumnarValue,
2855-
ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Volatility,
2875+
LogicalPlan, LogicalTableSource, Projection, ScalarFunctionArgs, ScalarUDF,
2876+
ScalarUDFImpl, TableScan, Volatility,
28562877
};
28572878
use arrow::datatypes::{Field, Schema};
28582879
use sqlparser::ast;
@@ -2914,6 +2935,86 @@ mod test {
29142935
}
29152936
}
29162937

2938+
#[test]
2939+
fn infer_placeholder_in_subquery() -> Result<()> {
2940+
// Schema for my_table: A (Int32), B (Int32)
2941+
let schema = Arc::new(Schema::new(vec![
2942+
Field::new("A", DataType::Int32, true),
2943+
Field::new("B", DataType::Int32, true),
2944+
]));
2945+
2946+
let source = Arc::new(LogicalTableSource::new(schema.clone()));
2947+
2948+
// Simulate: SELECT * FROM my_table WHERE $1 IN (SELECT A FROM my_table WHERE B > 3);
2949+
let placeholder = Expr::Placeholder(Placeholder {
2950+
id: "$1".to_string(),
2951+
data_type: None,
2952+
});
2953+
2954+
// Subquery: SELECT A FROM my_table WHERE B > 3
2955+
let subquery_filter = Expr::BinaryExpr(BinaryExpr {
2956+
left: Box::new(col("B")),
2957+
op: Operator::Gt,
2958+
right: Box::new(Expr::Literal(ScalarValue::Int32(Some(3)))),
2959+
});
2960+
2961+
let subquery_scan = LogicalPlan::TableScan(TableScan {
2962+
table_name: TableReference::from("my_table"),
2963+
source,
2964+
projected_schema: Arc::new(DFSchema::try_from(schema.clone())?),
2965+
projection: None,
2966+
filters: vec![subquery_filter.clone()],
2967+
fetch: None,
2968+
});
2969+
2970+
let projected_fields = vec![Field::new("A", DataType::Int32, true)];
2971+
let projected_schema = Arc::new(DFSchema::from_unqualified_fields(
2972+
projected_fields.into(),
2973+
Default::default(),
2974+
)?);
2975+
2976+
let subquery = Subquery {
2977+
subquery: Arc::new(LogicalPlan::Projection(Projection {
2978+
expr: vec![col("A")],
2979+
input: Arc::new(subquery_scan),
2980+
schema: projected_schema,
2981+
})),
2982+
outer_ref_columns: vec![],
2983+
};
2984+
2985+
let in_subquery = Expr::InSubquery(InSubquery {
2986+
expr: Box::new(placeholder),
2987+
subquery,
2988+
negated: false,
2989+
});
2990+
2991+
let df_schema = DFSchema::try_from(schema)?;
2992+
2993+
let (inferred_expr, contains_placeholder) =
2994+
in_subquery.infer_placeholder_types(&df_schema)?;
2995+
2996+
assert!(
2997+
contains_placeholder,
2998+
"Expression should contain a placeholder"
2999+
);
3000+
3001+
match inferred_expr {
3002+
Expr::InSubquery(in_subquery) => match *in_subquery.expr {
3003+
Expr::Placeholder(placeholder) => {
3004+
assert_eq!(
3005+
placeholder.data_type,
3006+
Some(DataType::Int32),
3007+
"Placeholder $1 should infer Int32"
3008+
);
3009+
}
3010+
_ => panic!("Expected Placeholder expression in InSubquery"),
3011+
},
3012+
_ => panic!("Expected InSubquery expression"),
3013+
}
3014+
3015+
Ok(())
3016+
}
3017+
29173018
#[test]
29183019
fn infer_placeholder_like_and_similar_to() {
29193020
// name LIKE $1

0 commit comments

Comments
 (0)