Skip to content

Commit

Permalink
fix window aggregation with alias and add integration test case (#454)
Browse files Browse the repository at this point in the history
* fix window expression with alias

* add integration test
  • Loading branch information
jimexist authored Jun 2, 2021
1 parent 1601112 commit c3fc0c7
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 11 deletions.
2 changes: 1 addition & 1 deletion datafusion/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ impl LogicalPlanBuilder {
// FIXME: implement next
// window_frame: Option<WindowFrame>,
) -> Result<Self> {
let window_expr = window_expr.into_iter().collect::<Vec<Expr>>();
let window_expr = window_expr.into_iter().collect::<Vec<_>>();
// FIXME: implement next
// let partition_by_expr = partition_by_expr.into_iter().collect::<Vec<Expr>>();
// FIXME: implement next
Expand Down
25 changes: 16 additions & 9 deletions datafusion/src/sql/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,14 @@

//! SQL Query Planner (produces logical plan from SQL AST)
use std::str::FromStr;
use std::sync::Arc;
use std::{convert::TryInto, vec};

use crate::catalog::TableReference;
use crate::datasource::TableProvider;
use crate::logical_plan::Expr::Alias;
use crate::logical_plan::{
and, lit, DFSchema, Expr, LogicalPlan, LogicalPlanBuilder, Operator, PlanType,
StringifiedPlan, ToDFSchema,
};
use crate::prelude::JoinType;
use crate::scalar::ScalarValue;
use crate::{
error::{DataFusionError, Result},
Expand All @@ -38,11 +35,8 @@ use crate::{
physical_plan::{aggregates, functions, window_functions},
sql::parser::{CreateExternalTable, FileType, Statement as DFStatement},
};

use arrow::datatypes::*;
use hashbrown::HashMap;

use crate::prelude::JoinType;
use sqlparser::ast::{
BinaryOperator, DataType as SQLDataType, DateTimeField, Expr as SQLExpr, FunctionArg,
Ident, Join, JoinConstraint, JoinOperator, ObjectName, Query, Select, SelectItem,
Expand All @@ -52,6 +46,9 @@ use sqlparser::ast::{
use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption};
use sqlparser::ast::{OrderByExpr, Statement};
use sqlparser::parser::ParserError::ParserError;
use std::str::FromStr;
use std::sync::Arc;
use std::{convert::TryInto, vec};

use super::{
parser::DFParser,
Expand Down Expand Up @@ -678,11 +675,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
select_exprs: &[Expr],
) -> Result<(LogicalPlan, Vec<Expr>)> {
let plan = LogicalPlanBuilder::from(input)
.window(window_exprs)?
.window(window_exprs.clone())?
.build()?;
let select_exprs = select_exprs
.iter()
.map(|expr| expr_as_column_expr(&expr, &plan))
.map(|expr| rebase_expr(expr, &window_exprs, &plan))
.into_iter()
.collect::<Result<Vec<_>>>()?;
Ok((plan, select_exprs))
Expand Down Expand Up @@ -2710,6 +2707,16 @@ mod tests {
quick_test(sql, expected);
}

#[test]
fn empty_over_with_alias() {
let sql = "SELECT order_id oid, MAX(order_id) OVER () max_oid from orders";
let expected = "\
Projection: #order_id AS oid, #MAX(order_id) AS max_oid\
\n WindowAggr: windowExpr=[[MAX(#order_id)]] partitionBy=[], orderBy=[]\
\n TableScan: orders projection=None";
quick_test(sql, expected);
}

#[test]
fn empty_over_plus() {
let sql = "SELECT order_id, MAX(qty * 1.1) OVER () from orders";
Expand Down
25 changes: 25 additions & 0 deletions integration-tests/sqls/simple_window_full_aggregation.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
-- Licensed to the Apache Software Foundation (ASF) under one
-- or more contributor license agreements. See the NOTICE file
-- distributed with this work for additional information
-- regarding copyright ownership. The ASF licenses this file
-- to you under the Apache License, Version 2.0 (the
-- "License"); you may not use this file except in compliance
-- with the License. You may obtain a copy of the License at

-- http://www.apache.org/licenses/LICENSE-2.0

-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language gOVERning permissions and
-- limitations under the License.

SELECT
row_number() OVER () AS row_number,
count(c3) OVER () AS count_c3,
avg(c3) OVER () AS avg,
sum(c3) OVER () AS sum,
max(c3) OVER () AS max,
min(c3) OVER () AS min
FROM test
ORDER BY row_number;
2 changes: 1 addition & 1 deletion integration-tests/test_psql_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class PsqlParityTest(unittest.TestCase):
def test_parity(self):
root = Path(os.path.dirname(__file__)) / "sqls"
files = set(root.glob("*.sql"))
self.assertEqual(len(files), 4, msg="tests are missed")
self.assertEqual(len(files), 5, msg="tests are missed")
for fname in files:
with self.subTest(fname=fname):
datafusion_output = pd.read_csv(
Expand Down

0 comments on commit c3fc0c7

Please sign in to comment.