Skip to content

Commit

Permalink
[BUG]: orderby with aggs (#3190)
Browse files Browse the repository at this point in the history
  • Loading branch information
universalmind303 authored Nov 8, 2024
1 parent 990149a commit bdd25b6
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 74 deletions.
160 changes: 107 additions & 53 deletions src/daft-sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use daft_core::prelude::*;
use daft_dsl::{
col,
functions::utf8::{ilike, like, to_date, to_datetime},
has_agg, lit, literals_to_series, null_lit, Expr, ExprRef, LiteralValue, Operator,
has_agg, lit, literals_to_series, null_lit, AggExpr, Expr, ExprRef, LiteralValue, Operator,
};
use daft_functions::numeric::{ceil::ceil, floor::floor};
use daft_logical_plan::{LogicalPlanBuilder, LogicalPlanRef};
Expand Down Expand Up @@ -393,10 +393,16 @@ impl SQLPlanner {
projection_schema: &Schema,
) -> Result<(), PlannerError> {
let mut final_projection = Vec::with_capacity(projections.len());
let mut orderby_projection = Vec::with_capacity(projections.len());
let mut aggs = Vec::with_capacity(projections.len());
let mut orderby_exprs = None;
let mut orderby_desc = None;

// these are orderbys that are part of the final projection
let mut orderbys_after_projection = Vec::new();
let mut orderbys_after_projection_desc = Vec::new();

// these are orderbys that are not part of the final projection
let mut orderbys_before_projection = Vec::new();
let mut orderbys_before_projection_desc = Vec::new();

for p in projections {
let fld = p.to_field(schema)?;

Expand Down Expand Up @@ -435,67 +441,115 @@ impl SQLPlanner {

let (exprs, desc) = self.plan_order_by_exprs(order_by.exprs.as_slice())?;

orderby_exprs = Some(exprs.clone());
orderby_desc = Some(desc);
for (i, expr) in exprs.iter().enumerate() {
// the orderby is ordered by a column of the projection
// ex: SELECT a as b FROM t ORDER BY b
// so we don't need an additional projection

if let Ok(fld) = expr.to_field(projection_schema) {
// check if it's an aggregate
// ex: SELECT sum(a) FROM t ORDER BY sum(a)

// special handling for count(*)
// TODO: this is a hack, we should handle this better
//
// Since count(*) will always be `Ok` for `to_field(schema)`
// we need to manually check if it's in the final schema or not
if let Expr::Alias(e, alias) = expr.as_ref() {
if alias.as_ref() == "count"
&& matches!(e.as_ref(), Expr::Agg(AggExpr::Count(_, CountMode::All)))
{
if let Some(alias) = aggs.iter().find_map(|agg| {
if let Expr::Alias(e, alias) = agg.as_ref() {
if e == expr {
Some(alias)
} else {
None
}
} else {
None
}
}) {
// its a count(*) that is already in the final projection
// ex: SELECT count(*) as c FROM t ORDER BY count(*)
orderbys_after_projection.push(col(alias.as_ref()));
orderbys_after_projection_desc.push(desc[i]);
} else {
// its a count(*) that is not in the final projection
// ex: SELECT sum(n) FROM t ORDER BY count(*);
aggs.push(expr.clone());
orderbys_before_projection.push(col(fld.name.as_ref()));
orderbys_before_projection_desc.push(desc[i]);
}
}
} else if has_agg(expr) {
// aggregates part of the final projection are already resolved
// so we just need to push the column name
orderbys_after_projection.push(col(fld.name.as_ref()));
orderbys_after_projection_desc.push(desc[i]);
} else {
orderbys_after_projection.push(expr.clone());
orderbys_after_projection_desc.push(desc[i]);
}

for expr in &exprs {
// if the orderby references a column that is not in the final projection
// then we need an additional projection
if let Err(DaftError::FieldNotFound(_)) = expr.to_field(projection_schema) {
orderby_projection.push(expr.clone());
// the orderby is ordered by an expr from the original schema
// ex: SELECT sum(b) FROM t ORDER BY sum(a)
} else if let Ok(fld) = expr.to_field(schema) {
// check if it's an aggregate
if has_agg(expr) {
// check if it's an alias of something in the aggs
// if so, we can just use that column
// This way we avoid computing the aggregate twice
//
// ex: SELECT sum(a) as b FROM t ORDER BY sum(a);
if let Some(alias) = aggs.iter().find_map(|p| {
if let Expr::Alias(e, alias) = p.as_ref() {
if e == expr {
Some(alias)
} else {
None
}
} else {
None
}
}) {
orderbys_after_projection.push(col(alias.as_ref()));
orderbys_after_projection_desc.push(desc[i]);
} else {
// its an aggregate that is not part of the final projection
// ex: SELECT sum(a) FROM t ORDER BY sum(b)
// so we need need to add it to the aggs list
aggs.push(expr.clone());

// then add it to the orderbys that are not part of the final projection
orderbys_before_projection.push(col(fld.name.as_ref()));
orderbys_before_projection_desc.push(desc[i]);
}
} else {
// we know it's a column of the original schema
// and its nt part of the final projection
// so we need an additional projection
// ex: SELECT sum(a) FROM t ORDER BY b

orderbys_before_projection.push(col(fld.name.as_ref()));
orderbys_before_projection_desc.push(desc[i]);
}
} else {
panic!("unexpected order by expr");
}
}
}

let rel = self.relation_mut();
rel.inner = rel.inner.aggregate(aggs, groupby_exprs)?;

let needs_projection = !orderby_projection.is_empty();
if needs_projection {
let orderby_projection = rel
.schema()
.names()
.iter()
.map(|n| col(n.as_str()))
.chain(orderby_projection)
.collect::<HashSet<_>>() // dedup
.into_iter()
.collect::<Vec<_>>();

rel.inner = rel.inner.select(orderby_projection)?;
}

// these are orderbys that are part of the final projection
let mut orderbys_after_projection = Vec::new();
let mut orderbys_after_projection_desc = Vec::new();

// these are orderbys that are not part of the final projection
let mut orderbys_before_projection = Vec::new();
let mut orderbys_before_projection_desc = Vec::new();

if let Some(orderby_exprs) = orderby_exprs {
// this needs to be done after the aggregation and any projections
// because the orderby may reference an alias, or an intermediate column that is not in the final projection
let schema = rel.schema();
for (i, expr) in orderby_exprs.iter().enumerate() {
if let Err(DaftError::FieldNotFound(_)) = expr.to_field(&schema) {
orderbys_after_projection.push(expr.clone());
let desc = orderby_desc.clone().map(|o| o[i]).unwrap();
orderbys_after_projection_desc.push(desc);
} else {
let desc = orderby_desc.clone().map(|o| o[i]).unwrap();

orderbys_before_projection.push(expr.clone());
orderbys_before_projection_desc.push(desc);
}
}
}

let has_orderby_before_projection = !orderbys_before_projection.is_empty();
let has_orderby_after_projection = !orderbys_after_projection.is_empty();

// ----------------
// PERF(cory): if there are order bys from both parts, can we combine them into a single sort instead of two?
// or can we optimize them into a single sort?
// ----------------

// order bys that are not in the final projection
if has_orderby_before_projection {
Expand All @@ -504,6 +558,7 @@ impl SQLPlanner {
.sort(orderbys_before_projection, orderbys_before_projection_desc)?;
}

// apply the final projection
rel.inner = rel.inner.select(final_projection)?;

// order bys that are in the final projection
Expand Down Expand Up @@ -541,7 +596,6 @@ impl SQLPlanner {

fn plan_from(&mut self, from: &[TableWithJoins]) -> SQLPlannerResult<Relation> {
if from.len() > 1 {
// todo!("cross join")
let mut from_iter = from.iter();

let first = from_iter.next().unwrap();
Expand Down
103 changes: 103 additions & 0 deletions tests/sql/test_orderbys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import pytest

import daft


@pytest.fixture()
def df():
return daft.from_pydict(
{
"text": ["g1", "g1", "g2", "g3", "g3", "g1"],
"n": [1, 2, 3, 3, 4, 100],
}
)


def test_orderby_basic(df):
df = daft.sql("""
SELECT * from df order by n
""")

assert df.collect().to_pydict() == {
"text": ["g1", "g1", "g2", "g3", "g3", "g1"],
"n": [1, 2, 3, 3, 4, 100],
}


def test_orderby_compound(df):
df = daft.sql("""
SELECT * from df order by n, text
""")

assert df.collect().to_pydict() == {
"text": ["g1", "g1", "g2", "g3", "g3", "g1"],
"n": [1, 2, 3, 3, 4, 100],
}


def test_orderby_desc(df):
df = daft.sql("""
SELECT n from df order by n desc
""")

assert df.collect().to_pydict() == {
"n": [100, 4, 3, 3, 2, 1],
}


def test_orderby_groupby(df):
df = daft.sql("""
SELECT
text,
count(*) as count_star
from df
group by text
order by count_star DESC
""")

assert df.collect().to_pydict() == {
"text": ["g1", "g3", "g2"],
"count_star": [3, 2, 1],
}


def test_orderby_groupby_expr(df):
df = daft.sql("""
SELECT
text,
count(*) as count_star
from df
group by text
order by count(*) DESC
""")

assert df.collect().to_pydict() == {"text": ["g1", "g3", "g2"], "count_star": [3, 2, 1]}


def test_groupby_orderby_non_final_expr(df):
df = daft.sql("""
SELECT
text,
count(*) as count_star
from df
group by text
order by sum(n) ASC
""")

assert df.collect().to_pydict() == {
"text": ["g2", "g3", "g1"],
"count_star": [1, 2, 3],
}


def test_groupby_orderby_count_star(df):
df = daft.sql("""
SELECT
text,
sum(n) as n
from df
group by text
order by count(*) ASC
""")

assert df.collect().to_pydict() == {"text": ["g2", "g3", "g1"], "n": [3, 7, 103]}
34 changes: 13 additions & 21 deletions tests/sql/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,29 +110,21 @@ def test_sql_global_agg():
daft.sql("SELECT n,max(n) max_n FROM test", catalog=catalog)


def test_sql_groupby_agg():
@pytest.mark.parametrize(
"query,expected",
[
("SELECT sum(v) as sum FROM test GROUP BY n ORDER BY n", {"sum": [3, 7]}),
("SELECT n, sum(v) as sum FROM test GROUP BY n ORDER BY n", {"n": [1, 2], "sum": [3, 7]}),
("SELECT max(v) as max, sum(v) as sum FROM test GROUP BY n ORDER BY n", {"max": [2, 4], "sum": [3, 7]}),
("SELECT n as n_alias, sum(v) as sum FROM test GROUP BY n ORDER BY n", {"n_alias": [1, 2], "sum": [3, 7]}),
("SELECT n, sum(v) as sum FROM test GROUP BY n ORDER BY sum", {"n": [1, 2], "sum": [3, 7]}),
],
)
def test_sql_groupby_agg(query, expected):
df = daft.from_pydict({"n": [1, 1, 2, 2], "v": [1, 2, 3, 4]})
catalog = SQLCatalog({"test": df})
actual = daft.sql("SELECT sum(v) as sum FROM test GROUP BY n ORDER BY n", catalog=catalog)
assert actual.collect().to_pydict() == {"sum": [3, 7]}

# test with grouping column
actual = daft.sql("SELECT n, sum(v) as sum FROM test GROUP BY n ORDER BY n", catalog=catalog)
assert actual.collect().to_pydict() == {"n": [1, 2], "sum": [3, 7]}

# test with multiple columns
actual = daft.sql("SELECT max(v) as max, sum(v) as sum FROM test GROUP BY n ORDER BY n", catalog=catalog)
assert actual.collect().to_pydict() == {"max": [2, 4], "sum": [3, 7]}

# test with aliased grouping key
actual = daft.sql("SELECT n as n_alias, sum(v) as sum FROM test GROUP BY n ORDER BY n", catalog=catalog)
assert actual.collect().to_pydict() == {"n_alias": [1, 2], "sum": [3, 7]}

actual = daft.sql("SELECT n, sum(v) as sum FROM test GROUP BY n ORDER BY -n", catalog=catalog)
assert actual.collect().to_pydict() == {"n": [2, 1], "sum": [7, 3]}

actual = daft.sql("SELECT n, sum(v) as sum FROM test GROUP BY n ORDER BY sum", catalog=catalog)
assert actual.collect().to_pydict() == {"n": [1, 2], "sum": [3, 7]}
actual = daft.sql(query, catalog=catalog)
assert actual.collect().to_pydict() == expected


def test_sql_count_star():
Expand Down

0 comments on commit bdd25b6

Please sign in to comment.