From 26d639b7adbd6716079690d707734052750d4b9a Mon Sep 17 00:00:00 2001 From: Cory Grinstead Date: Tue, 22 Oct 2024 17:32:52 -0500 Subject: [PATCH] [FEAT]: add sql DISTINCT (#3087) --- src/daft-sql/src/planner.rs | 15 +++++++++++---- tests/sql/test_sql.py | 7 +++++++ 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index e651b6528f..55823e5843 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -11,7 +11,7 @@ use daft_functions::numeric::{ceil::ceil, floor::floor}; use daft_plan::{LogicalPlanBuilder, LogicalPlanRef}; use sqlparser::{ ast::{ - ArrayElemTypeDef, BinaryOperator, CastKind, ExactNumberInfo, ExcludeSelectItem, + ArrayElemTypeDef, BinaryOperator, CastKind, Distinct, ExactNumberInfo, ExcludeSelectItem, GroupByExpr, Ident, Query, SelectItem, Statement, StructField, Subscript, TableAlias, TableWithJoins, TimezoneInfo, UnaryOperator, Value, WildcardAdditionalOptions, }, @@ -202,6 +202,15 @@ impl SQLPlanner { } } + match &selection.distinct { + Some(Distinct::Distinct) => { + let rel = self.relation_mut(); + rel.inner = rel.inner.distinct()?; + } + Some(Distinct::On(_)) => unsupported_sql_err!("DISTINCT ON"), + None => {} + } + if let Some(order_by) = &query.order_by { if order_by.interpolate.is_some() { unsupported_sql_err!("ORDER BY [query] [INTERPOLATE]"); @@ -1186,9 +1195,7 @@ fn check_select_features(selection: &sqlparser::ast::Select) -> SQLPlannerResult if selection.top.is_some() { unsupported_sql_err!("TOP"); } - if selection.distinct.is_some() { - unsupported_sql_err!("DISTINCT"); - } + if selection.into.is_some() { unsupported_sql_err!("INTO"); } diff --git a/tests/sql/test_sql.py b/tests/sql/test_sql.py index 8b8cce43b5..6bcd716854 100644 --- a/tests/sql/test_sql.py +++ b/tests/sql/test_sql.py @@ -214,3 +214,10 @@ def test_sql_tbl_alias(): catalog = SQLCatalog({"df": daft.from_pydict({"n": [1, 2, 3]})}) df = daft.sql("SELECT df_alias.n FROM df AS df_alias where df_alias.n = 2", catalog) assert df.collect().to_pydict() == {"n": [2]} + + +def test_sql_distinct(): + df = daft.from_pydict({"n": [1, 1, 2, 2]}) + actual = daft.sql("SELECT DISTINCT n FROM df").collect().to_pydict() + expected = df.distinct().collect().to_pydict() + assert actual == expected