Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 62 additions & 7 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,12 +220,15 @@ impl CommonSubexprEliminate {
.into_iter()
.zip(window_schemas)
.try_rfold(new_input, |plan, (new_window_expr, schema)| {
Window::try_new_with_schema(
new_window_expr,
Arc::new(plan),
match Window::try_new_with_schema(
new_window_expr.clone(),
Arc::new(plan.clone()),
schema,
)
.map(LogicalPlan::Window)
) {
Ok(win) => Ok(LogicalPlan::Window(win)),
Err(_) => Window::try_new(new_window_expr, Arc::new(plan))
.map(LogicalPlan::Window),
}
})
}
})
Expand Down Expand Up @@ -794,14 +797,16 @@ mod test {
use std::any::Any;
use std::iter;

use arrow::datatypes::{DataType, Field, Schema};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion_expr::logical_plan::{table_scan, JoinType};
use datafusion_expr::window_frame::WindowFrame;
use datafusion_expr::{
grouping_set, is_null, not, AccumulatorFactoryFunction, AggregateUDF,
ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
SimpleAggregateUDF, Volatility,
SimpleAggregateUDF, TableSource, Volatility,
};
use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder};
use datafusion_functions_window::row_number::row_number_udwf;

use super::*;
use crate::optimizer::OptimizerContext;
Expand Down Expand Up @@ -1669,6 +1674,56 @@ mod test {
Ok(())
}

#[test]
fn test_window_cse_rebuild_preserves_schema() {
// Build a plan similar to SELECT ... QUALIFY ROW_NUMBER()
let scan = test_table_scan().unwrap();
let col0 = col("a");
let col1 = col("b");

let wnd = Expr::WindowFunction(datafusion_expr::expr::WindowFunction {
fun: datafusion_expr::expr::WindowFunctionDefinition::WindowUDF(
row_number_udwf(),
),
params: datafusion_expr::expr::WindowFunctionParams {
partition_by: vec![col0.clone()],
order_by: vec![col1.clone().sort(true, false)],
window_frame: WindowFrame::new(None),
args: vec![],
null_treatment: None,
},
});

let windowed = LogicalPlanBuilder::from(scan)
.window(vec![wnd.clone()])
.unwrap()
.project(vec![col0.clone(), col1.clone(), wnd.clone()])
.unwrap()
.build()
.unwrap();

// Simulate QUALIFY as a filter on the window output
let filtered = LogicalPlanBuilder::from(windowed)
.filter(Expr::BinaryExpr(BinaryExpr {
left: Box::new(wnd),
op: Operator::Eq,
right: Box::new(Expr::Literal(datafusion_common::ScalarValue::UInt64(
Some(1),
))),
}))
.unwrap()
.project(vec![col("a"), col("b")])
.unwrap()
.build()
.unwrap();

let rule = CommonSubexprEliminate::new();
let cfg = OptimizerContext::new();
let res = rule.rewrite(filtered, &cfg).unwrap();

assert_fields_eq(&res.data, vec!["a", "b"]);
}

/// returns a "random" function that is marked volatile (aka each invocation
/// returns a different value)
///
Expand Down
95 changes: 81 additions & 14 deletions datafusion/sql/src/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,6 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
if !select.lateral_views.is_empty() {
return not_impl_err!("LATERAL VIEWS");
}
if select.qualify.is_some() {
return not_impl_err!("QUALIFY");
}
if select.top.is_some() {
return not_impl_err!("TOP");
}
Expand Down Expand Up @@ -148,9 +145,27 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
})
.transpose()?;

// Optionally the QUALIFY expression (filters after window functions)
let qualify_expr_opt_pre_aggr = select
.qualify
.map::<Result<Expr>, _>(|qualify_expr| {
let qualify_expr = self.sql_expr_to_logical_expr(
qualify_expr,
&combined_schema,
planner_context,
)?;
let qualify_expr = resolve_aliases_to_exprs(qualify_expr, &alias_map)?;
normalize_col(qualify_expr, &projected_plan)
})
.transpose()?;
let has_qualify = qualify_expr_opt_pre_aggr.is_some();

// The outer expressions we will search through for aggregates.
// Aggregates may be sourced from the SELECT list or from the HAVING expression.
let aggr_expr_haystack = select_exprs.iter().chain(having_expr_opt.iter());
// Aggregates may be sourced from the SELECT list, HAVING expression, or QUALIFY expression.
let aggr_expr_haystack = select_exprs
.iter()
.chain(having_expr_opt.iter())
.chain(qualify_expr_opt_pre_aggr.iter());
// All of the aggregate expressions (deduplicated).
let aggr_exprs = find_aggregate_exprs(aggr_expr_haystack);

Expand Down Expand Up @@ -198,22 +213,30 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
.collect()
};

// Process group by, aggregation or having
let (plan, mut select_exprs_post_aggr, having_expr_post_aggr) = if !group_by_exprs
.is_empty()
|| !aggr_exprs.is_empty()
{
// Process group by, aggregation, having (and prepare qualify for post-aggregation)
let (
plan,
mut select_exprs_post_aggr,
having_expr_post_aggr,
mut qualify_expr_post_aggr,
) = if !group_by_exprs.is_empty() || !aggr_exprs.is_empty() {
self.aggregate(
&base_plan,
&select_exprs,
having_expr_opt.as_ref(),
&group_by_exprs,
&aggr_exprs,
qualify_expr_opt_pre_aggr.as_ref(),
)?
} else {
match having_expr_opt {
Some(having_expr) => return plan_err!("HAVING clause references: {having_expr} must appear in the GROUP BY clause or be used in an aggregate function"),
None => (base_plan.clone(), select_exprs.clone(), having_expr_opt)
None => (
base_plan.clone(),
select_exprs.clone(),
having_expr_opt,
qualify_expr_opt_pre_aggr,
)
}
};

Expand All @@ -226,7 +249,21 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
};

// Process window function
let window_func_exprs = find_window_exprs(&select_exprs_post_aggr);
let window_search_exprs: Vec<Expr> =
if let Some(ref qualify_expr) = qualify_expr_post_aggr {
let mut v = select_exprs_post_aggr.clone();
v.push(qualify_expr.clone());
v
} else {
select_exprs_post_aggr.clone()
};
let window_func_exprs = find_window_exprs(&window_search_exprs);

if has_qualify && window_func_exprs.is_empty() {
return plan_err!(
"QUALIFY clause requires at least one window function in the SELECT list or QUALIFY predicate"
);
}

let plan = if window_func_exprs.is_empty() {
plan
Expand All @@ -239,6 +276,21 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
.map(|expr| rebase_expr(expr, &window_func_exprs, &plan))
.collect::<Result<Vec<Expr>>>()?;

// Re-write QUALIFY predicate to reference computed window columns
if let Some(q) = qualify_expr_post_aggr.take() {
qualify_expr_post_aggr =
Some(rebase_expr(&q, &window_func_exprs, &plan)?);
}

plan
};

// Apply QUALIFY filter
let plan = if let Some(qualify_expr) = qualify_expr_post_aggr {
LogicalPlanBuilder::from(plan)
.filter(qualify_expr)?
.build()?
} else {
plan
};

Expand Down Expand Up @@ -782,7 +834,8 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
having_expr_opt: Option<&Expr>,
group_by_exprs: &[Expr],
aggr_exprs: &[Expr],
) -> Result<(LogicalPlan, Vec<Expr>, Option<Expr>)> {
qualify_expr_opt: Option<&Expr>,
) -> Result<(LogicalPlan, Vec<Expr>, Option<Expr>, Option<Expr>)> {
// create the aggregate plan
let options =
LogicalPlanBuilderOptions::new().with_add_implicit_group_by_exprs(true);
Expand Down Expand Up @@ -866,7 +919,21 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
None
};

Ok((plan, select_exprs_post_aggr, having_expr_post_aggr))
// Rewrite the QUALIFY expression (if any) to use columns produced by the aggregation
let qualify_expr_post_aggr = if let Some(qualify_expr) = qualify_expr_opt {
let qualify_expr_post_aggr =
rebase_expr(qualify_expr, &aggr_projection_exprs, input)?;
Some(qualify_expr_post_aggr)
} else {
None
};

Ok((
plan,
select_exprs_post_aggr,
having_expr_post_aggr,
qualify_expr_post_aggr,
))
}

// If the projection is done over a named window, that window
Expand Down
4 changes: 0 additions & 4 deletions datafusion/sql/tests/sql_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4207,10 +4207,6 @@ fn test_select_distinct_order_by() {
"SELECT id, number FROM person LATERAL VIEW explode(numbers) exploded_table AS number",
"This feature is not implemented: LATERAL VIEWS"
)]
#[case::select_qualify_unsupported(
"SELECT i, p, o FROM person QUALIFY ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) = 1",
"This feature is not implemented: QUALIFY"
)]
#[case::select_top_unsupported(
"SELECT TOP (5) * FROM person",
"This feature is not implemented: TOP"
Expand Down
114 changes: 114 additions & 0 deletions datafusion/sqllogictest/test_files/qualify.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# Basic table for QUALIFY tests, from Snowflake docs examples
statement ok
CREATE TABLE qt (i INT, p VARCHAR, o INT) AS VALUES
(1, 'A', 1),
(2, 'A', 2),
(3, 'B', 1),
(4, 'B', 2);

# QUALIFY with window predicate directly
query ITI
SELECT i, p, o
FROM qt
QUALIFY ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) = 1
ORDER BY p, o;
----
1 A 1
3 B 1

# QUALIFY referencing window alias from SELECT list
query ITII
SELECT i, p, o, ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) AS row_num
FROM qt
QUALIFY row_num = 1
ORDER BY p, o;
----
1 A 1 1
3 B 1 1

# QUALIFY on a window over an aggregate alias from SELECT
query TI
SELECT p, SUM(o) AS s
FROM qt
GROUP BY p
QUALIFY RANK() OVER (ORDER BY s DESC) = 1
ORDER BY p;
----
A 3
B 3

# QUALIFY requires at least one window function (error)
query error
SELECT i FROM qt QUALIFY i > 1;

# WHERE with scalar aggregate subquery + QUALIFY
statement ok
CREATE TABLE bulk_import_entities (
id INT,
_task_instance INT,
_uploaded_at TIMESTAMP
) AS VALUES
(1, 1, '2025-01-01 10:00:00'::timestamp),
(1, 2, '2025-01-02 09:00:00'::timestamp),
(1, 2, '2025-01-03 08:00:00'::timestamp),
(2, 1, '2025-01-01 11:00:00'::timestamp),
(2, 2, '2025-01-02 12:00:00'::timestamp),
(3, 1, '2025-01-01 13:00:00'::timestamp);

query II
SELECT id, _task_instance
FROM bulk_import_entities
WHERE _task_instance = (
SELECT MAX(_task_instance) FROM bulk_import_entities
)
QUALIFY ROW_NUMBER() OVER (PARTITION BY id ORDER BY _uploaded_at) = 1
ORDER BY id;
----
1 2
2 2

# Constant filter + QUALIFY with multiple ORDER BY keys
statement ok
CREATE TABLE web_base_events_this_run (
domain_sessionid VARCHAR,
app_id VARCHAR,
page_view_id VARCHAR,
derived_tstamp TIMESTAMP,
dvce_created_tstamp TIMESTAMP,
event_id VARCHAR
) AS SELECT * FROM VALUES
('ds1', 'appA', NULL, '2025-01-01 10:00:00'::timestamp, '2025-01-01 10:05:00'::timestamp, 'e1'),
('ds1', 'appA', NULL, '2025-01-01 11:00:00'::timestamp, '2025-01-01 11:00:00'::timestamp, 'e2'),
('ds1', 'appA', 'pv', '2025-01-01 12:00:00'::timestamp, '2025-01-01 12:00:00'::timestamp, 'e3'),
('ds2', 'appB', NULL, '2025-01-01 09:00:00'::timestamp, '2025-01-01 09:10:00'::timestamp, 'e4'),
('ds2', 'appB', NULL, '2025-01-01 09:05:00'::timestamp, '2025-01-01 09:09:00'::timestamp, 'e5');

query TT
SELECT domain_sessionid, app_id
FROM web_base_events_this_run
WHERE page_view_id IS NULL
QUALIFY ROW_NUMBER() OVER (
PARTITION BY domain_sessionid
ORDER BY derived_tstamp, dvce_created_tstamp, event_id
) = 1
ORDER BY domain_sessionid;
----
ds1 appA
ds2 appB
Loading