From bdd25b687e561651b918fe720cca2c99ea64ff4e Mon Sep 17 00:00:00 2001 From: Cory Grinstead Date: Fri, 8 Nov 2024 10:24:41 -0600 Subject: [PATCH] [BUG]: orderby with aggs (#3190) --- src/daft-sql/src/planner.rs | 160 ++++++++++++++++++++++++------------ tests/sql/test_orderbys.py | 103 +++++++++++++++++++++++ tests/sql/test_sql.py | 34 +++----- 3 files changed, 223 insertions(+), 74 deletions(-) create mode 100644 tests/sql/test_orderbys.py diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index a8025133fc..fdecb669bb 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -8,7 +8,7 @@ use daft_core::prelude::*; use daft_dsl::{ col, functions::utf8::{ilike, like, to_date, to_datetime}, - has_agg, lit, literals_to_series, null_lit, Expr, ExprRef, LiteralValue, Operator, + has_agg, lit, literals_to_series, null_lit, AggExpr, Expr, ExprRef, LiteralValue, Operator, }; use daft_functions::numeric::{ceil::ceil, floor::floor}; use daft_logical_plan::{LogicalPlanBuilder, LogicalPlanRef}; @@ -393,10 +393,16 @@ impl SQLPlanner { projection_schema: &Schema, ) -> Result<(), PlannerError> { let mut final_projection = Vec::with_capacity(projections.len()); - let mut orderby_projection = Vec::with_capacity(projections.len()); let mut aggs = Vec::with_capacity(projections.len()); - let mut orderby_exprs = None; - let mut orderby_desc = None; + + // these are orderbys that are part of the final projection + let mut orderbys_after_projection = Vec::new(); + let mut orderbys_after_projection_desc = Vec::new(); + + // these are orderbys that are not part of the final projection + let mut orderbys_before_projection = Vec::new(); + let mut orderbys_before_projection_desc = Vec::new(); + for p in projections { let fld = p.to_field(schema)?; @@ -435,14 +441,101 @@ impl SQLPlanner { let (exprs, desc) = self.plan_order_by_exprs(order_by.exprs.as_slice())?; - orderby_exprs = Some(exprs.clone()); - orderby_desc = Some(desc); + for (i, expr) in exprs.iter().enumerate() { + // the orderby is ordered by a column of the projection + // ex: SELECT a as b FROM t ORDER BY b + // so we don't need an additional projection + + if let Ok(fld) = expr.to_field(projection_schema) { + // check if it's an aggregate + // ex: SELECT sum(a) FROM t ORDER BY sum(a) + + // special handling for count(*) + // TODO: this is a hack, we should handle this better + // + // Since count(*) will always be `Ok` for `to_field(schema)` + // we need to manually check if it's in the final schema or not + if let Expr::Alias(e, alias) = expr.as_ref() { + if alias.as_ref() == "count" + && matches!(e.as_ref(), Expr::Agg(AggExpr::Count(_, CountMode::All))) + { + if let Some(alias) = aggs.iter().find_map(|agg| { + if let Expr::Alias(e, alias) = agg.as_ref() { + if e == expr { + Some(alias) + } else { + None + } + } else { + None + } + }) { + // its a count(*) that is already in the final projection + // ex: SELECT count(*) as c FROM t ORDER BY count(*) + orderbys_after_projection.push(col(alias.as_ref())); + orderbys_after_projection_desc.push(desc[i]); + } else { + // its a count(*) that is not in the final projection + // ex: SELECT sum(n) FROM t ORDER BY count(*); + aggs.push(expr.clone()); + orderbys_before_projection.push(col(fld.name.as_ref())); + orderbys_before_projection_desc.push(desc[i]); + } + } + } else if has_agg(expr) { + // aggregates part of the final projection are already resolved + // so we just need to push the column name + orderbys_after_projection.push(col(fld.name.as_ref())); + orderbys_after_projection_desc.push(desc[i]); + } else { + orderbys_after_projection.push(expr.clone()); + orderbys_after_projection_desc.push(desc[i]); + } - for expr in &exprs { - // if the orderby references a column that is not in the final projection - // then we need an additional projection - if let Err(DaftError::FieldNotFound(_)) = expr.to_field(projection_schema) { - orderby_projection.push(expr.clone()); + // the orderby is ordered by an expr from the original schema + // ex: SELECT sum(b) FROM t ORDER BY sum(a) + } else if let Ok(fld) = expr.to_field(schema) { + // check if it's an aggregate + if has_agg(expr) { + // check if it's an alias of something in the aggs + // if so, we can just use that column + // This way we avoid computing the aggregate twice + // + // ex: SELECT sum(a) as b FROM t ORDER BY sum(a); + if let Some(alias) = aggs.iter().find_map(|p| { + if let Expr::Alias(e, alias) = p.as_ref() { + if e == expr { + Some(alias) + } else { + None + } + } else { + None + } + }) { + orderbys_after_projection.push(col(alias.as_ref())); + orderbys_after_projection_desc.push(desc[i]); + } else { + // its an aggregate that is not part of the final projection + // ex: SELECT sum(a) FROM t ORDER BY sum(b) + // so we need need to add it to the aggs list + aggs.push(expr.clone()); + + // then add it to the orderbys that are not part of the final projection + orderbys_before_projection.push(col(fld.name.as_ref())); + orderbys_before_projection_desc.push(desc[i]); + } + } else { + // we know it's a column of the original schema + // and its nt part of the final projection + // so we need an additional projection + // ex: SELECT sum(a) FROM t ORDER BY b + + orderbys_before_projection.push(col(fld.name.as_ref())); + orderbys_before_projection_desc.push(desc[i]); + } + } else { + panic!("unexpected order by expr"); } } } @@ -450,52 +543,13 @@ impl SQLPlanner { let rel = self.relation_mut(); rel.inner = rel.inner.aggregate(aggs, groupby_exprs)?; - let needs_projection = !orderby_projection.is_empty(); - if needs_projection { - let orderby_projection = rel - .schema() - .names() - .iter() - .map(|n| col(n.as_str())) - .chain(orderby_projection) - .collect::>() // dedup - .into_iter() - .collect::>(); - - rel.inner = rel.inner.select(orderby_projection)?; - } - - // these are orderbys that are part of the final projection - let mut orderbys_after_projection = Vec::new(); - let mut orderbys_after_projection_desc = Vec::new(); - - // these are orderbys that are not part of the final projection - let mut orderbys_before_projection = Vec::new(); - let mut orderbys_before_projection_desc = Vec::new(); - - if let Some(orderby_exprs) = orderby_exprs { - // this needs to be done after the aggregation and any projections - // because the orderby may reference an alias, or an intermediate column that is not in the final projection - let schema = rel.schema(); - for (i, expr) in orderby_exprs.iter().enumerate() { - if let Err(DaftError::FieldNotFound(_)) = expr.to_field(&schema) { - orderbys_after_projection.push(expr.clone()); - let desc = orderby_desc.clone().map(|o| o[i]).unwrap(); - orderbys_after_projection_desc.push(desc); - } else { - let desc = orderby_desc.clone().map(|o| o[i]).unwrap(); - - orderbys_before_projection.push(expr.clone()); - orderbys_before_projection_desc.push(desc); - } - } - } - let has_orderby_before_projection = !orderbys_before_projection.is_empty(); let has_orderby_after_projection = !orderbys_after_projection.is_empty(); + // ---------------- // PERF(cory): if there are order bys from both parts, can we combine them into a single sort instead of two? // or can we optimize them into a single sort? + // ---------------- // order bys that are not in the final projection if has_orderby_before_projection { @@ -504,6 +558,7 @@ impl SQLPlanner { .sort(orderbys_before_projection, orderbys_before_projection_desc)?; } + // apply the final projection rel.inner = rel.inner.select(final_projection)?; // order bys that are in the final projection @@ -541,7 +596,6 @@ impl SQLPlanner { fn plan_from(&mut self, from: &[TableWithJoins]) -> SQLPlannerResult { if from.len() > 1 { - // todo!("cross join") let mut from_iter = from.iter(); let first = from_iter.next().unwrap(); diff --git a/tests/sql/test_orderbys.py b/tests/sql/test_orderbys.py new file mode 100644 index 0000000000..b2c99a20b4 --- /dev/null +++ b/tests/sql/test_orderbys.py @@ -0,0 +1,103 @@ +import pytest + +import daft + + +@pytest.fixture() +def df(): + return daft.from_pydict( + { + "text": ["g1", "g1", "g2", "g3", "g3", "g1"], + "n": [1, 2, 3, 3, 4, 100], + } + ) + + +def test_orderby_basic(df): + df = daft.sql(""" + SELECT * from df order by n + """) + + assert df.collect().to_pydict() == { + "text": ["g1", "g1", "g2", "g3", "g3", "g1"], + "n": [1, 2, 3, 3, 4, 100], + } + + +def test_orderby_compound(df): + df = daft.sql(""" + SELECT * from df order by n, text + """) + + assert df.collect().to_pydict() == { + "text": ["g1", "g1", "g2", "g3", "g3", "g1"], + "n": [1, 2, 3, 3, 4, 100], + } + + +def test_orderby_desc(df): + df = daft.sql(""" + SELECT n from df order by n desc + """) + + assert df.collect().to_pydict() == { + "n": [100, 4, 3, 3, 2, 1], + } + + +def test_orderby_groupby(df): + df = daft.sql(""" + SELECT + text, + count(*) as count_star + from df + group by text + order by count_star DESC + """) + + assert df.collect().to_pydict() == { + "text": ["g1", "g3", "g2"], + "count_star": [3, 2, 1], + } + + +def test_orderby_groupby_expr(df): + df = daft.sql(""" +SELECT + text, + count(*) as count_star +from df +group by text +order by count(*) DESC + """) + + assert df.collect().to_pydict() == {"text": ["g1", "g3", "g2"], "count_star": [3, 2, 1]} + + +def test_groupby_orderby_non_final_expr(df): + df = daft.sql(""" + SELECT + text, + count(*) as count_star + from df + group by text + order by sum(n) ASC + """) + + assert df.collect().to_pydict() == { + "text": ["g2", "g3", "g1"], + "count_star": [1, 2, 3], + } + + +def test_groupby_orderby_count_star(df): + df = daft.sql(""" + SELECT + text, + sum(n) as n + from df + group by text + order by count(*) ASC + """) + + assert df.collect().to_pydict() == {"text": ["g2", "g3", "g1"], "n": [3, 7, 103]} diff --git a/tests/sql/test_sql.py b/tests/sql/test_sql.py index 2973580d05..f80aec4de7 100644 --- a/tests/sql/test_sql.py +++ b/tests/sql/test_sql.py @@ -110,29 +110,21 @@ def test_sql_global_agg(): daft.sql("SELECT n,max(n) max_n FROM test", catalog=catalog) -def test_sql_groupby_agg(): +@pytest.mark.parametrize( + "query,expected", + [ + ("SELECT sum(v) as sum FROM test GROUP BY n ORDER BY n", {"sum": [3, 7]}), + ("SELECT n, sum(v) as sum FROM test GROUP BY n ORDER BY n", {"n": [1, 2], "sum": [3, 7]}), + ("SELECT max(v) as max, sum(v) as sum FROM test GROUP BY n ORDER BY n", {"max": [2, 4], "sum": [3, 7]}), + ("SELECT n as n_alias, sum(v) as sum FROM test GROUP BY n ORDER BY n", {"n_alias": [1, 2], "sum": [3, 7]}), + ("SELECT n, sum(v) as sum FROM test GROUP BY n ORDER BY sum", {"n": [1, 2], "sum": [3, 7]}), + ], +) +def test_sql_groupby_agg(query, expected): df = daft.from_pydict({"n": [1, 1, 2, 2], "v": [1, 2, 3, 4]}) catalog = SQLCatalog({"test": df}) - actual = daft.sql("SELECT sum(v) as sum FROM test GROUP BY n ORDER BY n", catalog=catalog) - assert actual.collect().to_pydict() == {"sum": [3, 7]} - - # test with grouping column - actual = daft.sql("SELECT n, sum(v) as sum FROM test GROUP BY n ORDER BY n", catalog=catalog) - assert actual.collect().to_pydict() == {"n": [1, 2], "sum": [3, 7]} - - # test with multiple columns - actual = daft.sql("SELECT max(v) as max, sum(v) as sum FROM test GROUP BY n ORDER BY n", catalog=catalog) - assert actual.collect().to_pydict() == {"max": [2, 4], "sum": [3, 7]} - - # test with aliased grouping key - actual = daft.sql("SELECT n as n_alias, sum(v) as sum FROM test GROUP BY n ORDER BY n", catalog=catalog) - assert actual.collect().to_pydict() == {"n_alias": [1, 2], "sum": [3, 7]} - - actual = daft.sql("SELECT n, sum(v) as sum FROM test GROUP BY n ORDER BY -n", catalog=catalog) - assert actual.collect().to_pydict() == {"n": [2, 1], "sum": [7, 3]} - - actual = daft.sql("SELECT n, sum(v) as sum FROM test GROUP BY n ORDER BY sum", catalog=catalog) - assert actual.collect().to_pydict() == {"n": [1, 2], "sum": [3, 7]} + actual = daft.sql(query, catalog=catalog) + assert actual.collect().to_pydict() == expected def test_sql_count_star():