diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 09b3fbeef25f..8a682fc678a4 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -220,12 +220,15 @@ impl CommonSubexprEliminate { .into_iter() .zip(window_schemas) .try_rfold(new_input, |plan, (new_window_expr, schema)| { - Window::try_new_with_schema( - new_window_expr, - Arc::new(plan), + match Window::try_new_with_schema( + new_window_expr.clone(), + Arc::new(plan.clone()), schema, - ) - .map(LogicalPlan::Window) + ) { + Ok(win) => Ok(LogicalPlan::Window(win)), + Err(_) => Window::try_new(new_window_expr, Arc::new(plan)) + .map(LogicalPlan::Window), + } }) } }) @@ -794,14 +797,16 @@ mod test { use std::any::Any; use std::iter; - use arrow::datatypes::{DataType, Field, Schema}; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_expr::logical_plan::{table_scan, JoinType}; + use datafusion_expr::window_frame::WindowFrame; use datafusion_expr::{ grouping_set, is_null, not, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, - SimpleAggregateUDF, Volatility, + SimpleAggregateUDF, TableSource, Volatility, }; use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; + use datafusion_functions_window::row_number::row_number_udwf; use super::*; use crate::optimizer::OptimizerContext; @@ -1669,6 +1674,56 @@ mod test { Ok(()) } + #[test] + fn test_window_cse_rebuild_preserves_schema() { + // Build a plan similar to SELECT ... QUALIFY ROW_NUMBER() + let scan = test_table_scan().unwrap(); + let col0 = col("a"); + let col1 = col("b"); + + let wnd = Expr::WindowFunction(datafusion_expr::expr::WindowFunction { + fun: datafusion_expr::expr::WindowFunctionDefinition::WindowUDF( + row_number_udwf(), + ), + params: datafusion_expr::expr::WindowFunctionParams { + partition_by: vec![col0.clone()], + order_by: vec![col1.clone().sort(true, false)], + window_frame: WindowFrame::new(None), + args: vec![], + null_treatment: None, + }, + }); + + let windowed = LogicalPlanBuilder::from(scan) + .window(vec![wnd.clone()]) + .unwrap() + .project(vec![col0.clone(), col1.clone(), wnd.clone()]) + .unwrap() + .build() + .unwrap(); + + // Simulate QUALIFY as a filter on the window output + let filtered = LogicalPlanBuilder::from(windowed) + .filter(Expr::BinaryExpr(BinaryExpr { + left: Box::new(wnd), + op: Operator::Eq, + right: Box::new(Expr::Literal(datafusion_common::ScalarValue::UInt64( + Some(1), + ))), + })) + .unwrap() + .project(vec![col("a"), col("b")]) + .unwrap() + .build() + .unwrap(); + + let rule = CommonSubexprEliminate::new(); + let cfg = OptimizerContext::new(); + let res = rule.rewrite(filtered, &cfg).unwrap(); + + assert_fields_eq(&res.data, vec!["a", "b"]); + } + /// returns a "random" function that is marked volatile (aka each invocation /// returns a different value) /// diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 33994b60b735..63ad54b50832 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -66,9 +66,6 @@ impl SqlToRel<'_, S> { if !select.lateral_views.is_empty() { return not_impl_err!("LATERAL VIEWS"); } - if select.qualify.is_some() { - return not_impl_err!("QUALIFY"); - } if select.top.is_some() { return not_impl_err!("TOP"); } @@ -148,9 +145,27 @@ impl SqlToRel<'_, S> { }) .transpose()?; + // Optionally the QUALIFY expression (filters after window functions) + let qualify_expr_opt_pre_aggr = select + .qualify + .map::, _>(|qualify_expr| { + let qualify_expr = self.sql_expr_to_logical_expr( + qualify_expr, + &combined_schema, + planner_context, + )?; + let qualify_expr = resolve_aliases_to_exprs(qualify_expr, &alias_map)?; + normalize_col(qualify_expr, &projected_plan) + }) + .transpose()?; + let has_qualify = qualify_expr_opt_pre_aggr.is_some(); + // The outer expressions we will search through for aggregates. - // Aggregates may be sourced from the SELECT list or from the HAVING expression. - let aggr_expr_haystack = select_exprs.iter().chain(having_expr_opt.iter()); + // Aggregates may be sourced from the SELECT list, HAVING expression, or QUALIFY expression. + let aggr_expr_haystack = select_exprs + .iter() + .chain(having_expr_opt.iter()) + .chain(qualify_expr_opt_pre_aggr.iter()); // All of the aggregate expressions (deduplicated). let aggr_exprs = find_aggregate_exprs(aggr_expr_haystack); @@ -198,22 +213,30 @@ impl SqlToRel<'_, S> { .collect() }; - // Process group by, aggregation or having - let (plan, mut select_exprs_post_aggr, having_expr_post_aggr) = if !group_by_exprs - .is_empty() - || !aggr_exprs.is_empty() - { + // Process group by, aggregation, having (and prepare qualify for post-aggregation) + let ( + plan, + mut select_exprs_post_aggr, + having_expr_post_aggr, + mut qualify_expr_post_aggr, + ) = if !group_by_exprs.is_empty() || !aggr_exprs.is_empty() { self.aggregate( &base_plan, &select_exprs, having_expr_opt.as_ref(), &group_by_exprs, &aggr_exprs, + qualify_expr_opt_pre_aggr.as_ref(), )? } else { match having_expr_opt { Some(having_expr) => return plan_err!("HAVING clause references: {having_expr} must appear in the GROUP BY clause or be used in an aggregate function"), - None => (base_plan.clone(), select_exprs.clone(), having_expr_opt) + None => ( + base_plan.clone(), + select_exprs.clone(), + having_expr_opt, + qualify_expr_opt_pre_aggr, + ) } }; @@ -226,7 +249,21 @@ impl SqlToRel<'_, S> { }; // Process window function - let window_func_exprs = find_window_exprs(&select_exprs_post_aggr); + let window_search_exprs: Vec = + if let Some(ref qualify_expr) = qualify_expr_post_aggr { + let mut v = select_exprs_post_aggr.clone(); + v.push(qualify_expr.clone()); + v + } else { + select_exprs_post_aggr.clone() + }; + let window_func_exprs = find_window_exprs(&window_search_exprs); + + if has_qualify && window_func_exprs.is_empty() { + return plan_err!( + "QUALIFY clause requires at least one window function in the SELECT list or QUALIFY predicate" + ); + } let plan = if window_func_exprs.is_empty() { plan @@ -239,6 +276,21 @@ impl SqlToRel<'_, S> { .map(|expr| rebase_expr(expr, &window_func_exprs, &plan)) .collect::>>()?; + // Re-write QUALIFY predicate to reference computed window columns + if let Some(q) = qualify_expr_post_aggr.take() { + qualify_expr_post_aggr = + Some(rebase_expr(&q, &window_func_exprs, &plan)?); + } + + plan + }; + + // Apply QUALIFY filter + let plan = if let Some(qualify_expr) = qualify_expr_post_aggr { + LogicalPlanBuilder::from(plan) + .filter(qualify_expr)? + .build()? + } else { plan }; @@ -782,7 +834,8 @@ impl SqlToRel<'_, S> { having_expr_opt: Option<&Expr>, group_by_exprs: &[Expr], aggr_exprs: &[Expr], - ) -> Result<(LogicalPlan, Vec, Option)> { + qualify_expr_opt: Option<&Expr>, + ) -> Result<(LogicalPlan, Vec, Option, Option)> { // create the aggregate plan let options = LogicalPlanBuilderOptions::new().with_add_implicit_group_by_exprs(true); @@ -866,7 +919,21 @@ impl SqlToRel<'_, S> { None }; - Ok((plan, select_exprs_post_aggr, having_expr_post_aggr)) + // Rewrite the QUALIFY expression (if any) to use columns produced by the aggregation + let qualify_expr_post_aggr = if let Some(qualify_expr) = qualify_expr_opt { + let qualify_expr_post_aggr = + rebase_expr(qualify_expr, &aggr_projection_exprs, input)?; + Some(qualify_expr_post_aggr) + } else { + None + }; + + Ok(( + plan, + select_exprs_post_aggr, + having_expr_post_aggr, + qualify_expr_post_aggr, + )) } // If the projection is done over a named window, that window diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 2804a1de0606..d3a6b70bd9d0 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -4207,10 +4207,6 @@ fn test_select_distinct_order_by() { "SELECT id, number FROM person LATERAL VIEW explode(numbers) exploded_table AS number", "This feature is not implemented: LATERAL VIEWS" )] -#[case::select_qualify_unsupported( - "SELECT i, p, o FROM person QUALIFY ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) = 1", - "This feature is not implemented: QUALIFY" -)] #[case::select_top_unsupported( "SELECT TOP (5) * FROM person", "This feature is not implemented: TOP" diff --git a/datafusion/sqllogictest/test_files/qualify.slt b/datafusion/sqllogictest/test_files/qualify.slt new file mode 100644 index 000000000000..e6180894e28f --- /dev/null +++ b/datafusion/sqllogictest/test_files/qualify.slt @@ -0,0 +1,114 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Basic table for QUALIFY tests, from Snowflake docs examples +statement ok +CREATE TABLE qt (i INT, p VARCHAR, o INT) AS VALUES + (1, 'A', 1), + (2, 'A', 2), + (3, 'B', 1), + (4, 'B', 2); + +# QUALIFY with window predicate directly +query ITI +SELECT i, p, o +FROM qt +QUALIFY ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) = 1 +ORDER BY p, o; +---- +1 A 1 +3 B 1 + +# QUALIFY referencing window alias from SELECT list +query ITII +SELECT i, p, o, ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) AS row_num +FROM qt +QUALIFY row_num = 1 +ORDER BY p, o; +---- +1 A 1 1 +3 B 1 1 + +# QUALIFY on a window over an aggregate alias from SELECT +query TI +SELECT p, SUM(o) AS s +FROM qt +GROUP BY p +QUALIFY RANK() OVER (ORDER BY s DESC) = 1 +ORDER BY p; +---- +A 3 +B 3 + +# QUALIFY requires at least one window function (error) +query error +SELECT i FROM qt QUALIFY i > 1; + +# WHERE with scalar aggregate subquery + QUALIFY +statement ok +CREATE TABLE bulk_import_entities ( + id INT, + _task_instance INT, + _uploaded_at TIMESTAMP +) AS VALUES + (1, 1, '2025-01-01 10:00:00'::timestamp), + (1, 2, '2025-01-02 09:00:00'::timestamp), + (1, 2, '2025-01-03 08:00:00'::timestamp), + (2, 1, '2025-01-01 11:00:00'::timestamp), + (2, 2, '2025-01-02 12:00:00'::timestamp), + (3, 1, '2025-01-01 13:00:00'::timestamp); + +query II +SELECT id, _task_instance +FROM bulk_import_entities +WHERE _task_instance = ( + SELECT MAX(_task_instance) FROM bulk_import_entities +) +QUALIFY ROW_NUMBER() OVER (PARTITION BY id ORDER BY _uploaded_at) = 1 +ORDER BY id; +---- +1 2 +2 2 + +# Constant filter + QUALIFY with multiple ORDER BY keys +statement ok +CREATE TABLE web_base_events_this_run ( + domain_sessionid VARCHAR, + app_id VARCHAR, + page_view_id VARCHAR, + derived_tstamp TIMESTAMP, + dvce_created_tstamp TIMESTAMP, + event_id VARCHAR +) AS SELECT * FROM VALUES + ('ds1', 'appA', NULL, '2025-01-01 10:00:00'::timestamp, '2025-01-01 10:05:00'::timestamp, 'e1'), + ('ds1', 'appA', NULL, '2025-01-01 11:00:00'::timestamp, '2025-01-01 11:00:00'::timestamp, 'e2'), + ('ds1', 'appA', 'pv', '2025-01-01 12:00:00'::timestamp, '2025-01-01 12:00:00'::timestamp, 'e3'), + ('ds2', 'appB', NULL, '2025-01-01 09:00:00'::timestamp, '2025-01-01 09:10:00'::timestamp, 'e4'), + ('ds2', 'appB', NULL, '2025-01-01 09:05:00'::timestamp, '2025-01-01 09:09:00'::timestamp, 'e5'); + +query TT +SELECT domain_sessionid, app_id +FROM web_base_events_this_run +WHERE page_view_id IS NULL +QUALIFY ROW_NUMBER() OVER ( + PARTITION BY domain_sessionid + ORDER BY derived_tstamp, dvce_created_tstamp, event_id +) = 1 +ORDER BY domain_sessionid; +---- +ds1 appA +ds2 appB