From 674d2f93efd709d9665321206049d2448c4c15eb Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Mon, 21 Oct 2024 12:46:41 -0700 Subject: [PATCH 1/6] add over for mssql --- daft/sql/sql_scan.py | 3 ++- tests/integration/sql/conftest.py | 1 + .../sql/docker-compose/docker-compose.yml | 12 ++++++++++++ 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/daft/sql/sql_scan.py b/daft/sql/sql_scan.py index 4f0f9a35c7..aa2da210c4 100644 --- a/daft/sql/sql_scan.py +++ b/daft/sql/sql_scan.py @@ -166,9 +166,10 @@ def _attempt_partition_bounds_read(self, num_scan_tasks: int) -> tuple[Any, Part percentile_sql = self.conn.construct_sql_query( self.sql, projection=[ - f"percentile_disc({percentile}) WITHIN GROUP (ORDER BY {self._partition_col}) AS bound_{i}" + f"percentile_disc({percentile}) WITHIN GROUP (ORDER BY {self._partition_col}) {"OVER ()" if self.conn.dialect in ["mssql", "tsql"] else ""} AS bound_{i}" for i, percentile in enumerate(percentiles) ], + limit=1, ) pa_table = self.conn.execute_sql_query(percentile_sql) return pa_table, PartitionBoundStrategy.PERCENTILE diff --git a/tests/integration/sql/conftest.py b/tests/integration/sql/conftest.py index f5c01dccc6..e202eed471 100644 --- a/tests/integration/sql/conftest.py +++ b/tests/integration/sql/conftest.py @@ -26,6 +26,7 @@ "trino://user@localhost:8080/memory/default", "postgresql://username:password@localhost:5432/postgres", "mysql+pymysql://username:password@localhost:3306/mysql", + "mssql+pyodbc://SA:StrongPassword!@127.0.0.1:1433/master?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes", ] TEST_TABLE_NAME = "example" EMPTY_TEST_TABLE_NAME = "empty_table" diff --git a/tests/integration/sql/docker-compose/docker-compose.yml b/tests/integration/sql/docker-compose/docker-compose.yml index 11c391b0d3..b8eb8c3eba 100644 --- a/tests/integration/sql/docker-compose/docker-compose.yml +++ b/tests/integration/sql/docker-compose/docker-compose.yml @@ -31,6 +31,18 @@ services: volumes: - mysql_data:/var/lib/mysql + azuresqledge: + image: mcr.microsoft.com/azure-sql-edge + container_name: azuresqledge + environment: + ACCEPT_EULA: "Y" + MSSQL_SA_PASSWORD: "StrongPassword!" + ports: + - 1433:1433 + volumes: + - azuresqledge_data:/var/opt/mssql + volumes: postgres_data: mysql_data: + azuresqledge_data: From 1cdb01cb83dfb209897906035fb947b329cef0db Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Mon, 21 Oct 2024 13:26:35 -0700 Subject: [PATCH 2/6] cleanup --- daft/sql/sql_scan.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/daft/sql/sql_scan.py b/daft/sql/sql_scan.py index aa2da210c4..27a3b626e0 100644 --- a/daft/sql/sql_scan.py +++ b/daft/sql/sql_scan.py @@ -163,10 +163,12 @@ def _attempt_partition_bounds_read(self, num_scan_tasks: int) -> tuple[Any, Part try: # Try to get percentiles using percentile_cont percentiles = [i / num_scan_tasks for i in range(num_scan_tasks + 1)] + # Use the OVER clause for SQL Server + over_clause = "OVER ()" if self.conn.dialect in ["mssql", "tsql"] else "" percentile_sql = self.conn.construct_sql_query( self.sql, projection=[ - f"percentile_disc({percentile}) WITHIN GROUP (ORDER BY {self._partition_col}) {"OVER ()" if self.conn.dialect in ["mssql", "tsql"] else ""} AS bound_{i}" + f"percentile_disc({percentile}) WITHIN GROUP (ORDER BY {self._partition_col}) {over_clause} AS bound_{i}" for i, percentile in enumerate(percentiles) ], limit=1, From 88d0ea919b8fc4c45f7e95fc3d304e9023012afd Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Tue, 22 Oct 2024 13:31:30 -0700 Subject: [PATCH 3/6] odbc --- .github/workflows/python-package.yml | 6 ++++++ requirements-dev.txt | 1 + 2 files changed, 7 insertions(+) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 3affeecc4c..0a7d2de10a 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -582,6 +582,12 @@ jobs: run: | uv pip install -r requirements-dev.txt dist/${{ env.package-name }}-*x86_64*.whl --force-reinstall rm -rf daft + - name: Install ODBC Driver 18 for SQL Server + run: | + curl https://packages.microsoft.com/keys/microsoft.asc | sudo apt-key add - + sudo add-apt-repository https://packages.microsoft.com/ubuntu/$(lsb_release -rs)/prod + sudo apt-get update + sudo ACCEPT_EULA=Y apt-get install -y msodbcsql18 - name: Spin up services run: | pushd ./tests/integration/sql/docker-compose/ diff --git a/requirements-dev.txt b/requirements-dev.txt index 9c7809ac80..3ab91623eb 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -66,6 +66,7 @@ trino[sqlalchemy]==0.328.0; python_version >= '3.8' PyMySQL==1.1.0; python_version >= '3.8' psycopg2-binary==2.9.9; python_version >= '3.8' sqlglot==23.3.0; python_version >= '3.8' +pyodbc==5.1.0; python_version >= '3.8' # AWS s3fs==2023.12.0; python_version >= '3.8' From 8d53443e4db1cf2b38b8d89b792bbb208805043a Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Tue, 22 Oct 2024 14:12:40 -0700 Subject: [PATCH 4/6] revise if else pushdowns and test --- src/daft-dsl/src/expr/mod.rs | 16 ++-------------- tests/integration/sql/test_sql.py | 3 +-- 2 files changed, 3 insertions(+), 16 deletions(-) diff --git a/src/daft-dsl/src/expr/mod.rs b/src/daft-dsl/src/expr/mod.rs index 873f9013bd..567a2d35d8 100644 --- a/src/daft-dsl/src/expr/mod.rs +++ b/src/daft-dsl/src/expr/mod.rs @@ -990,21 +990,9 @@ impl Expr { to_sql_inner(inner, buffer)?; write!(buffer, ") IS NOT NULL") } - Expr::IfElse { - if_true, - if_false, - predicate, - } => { - write!(buffer, "CASE WHEN ")?; - to_sql_inner(predicate, buffer)?; - write!(buffer, " THEN ")?; - to_sql_inner(if_true, buffer)?; - write!(buffer, " ELSE ")?; - to_sql_inner(if_false, buffer)?; - write!(buffer, " END") - } // TODO: Implement SQL translations for these expressions if possible - Expr::Agg(..) + Expr::IfElse { .. } + | Expr::Agg(..) | Expr::Cast(..) | Expr::IsIn(..) | Expr::Between(..) diff --git a/tests/integration/sql/test_sql.py b/tests/integration/sql/test_sql.py index ff02ebaac4..e668c36d1f 100644 --- a/tests/integration/sql/test_sql.py +++ b/tests/integration/sql/test_sql.py @@ -134,7 +134,6 @@ def test_sql_read_with_partition_num_without_partition_col(test_db) -> None: ("id", 100), ("float_col", 100.0123), ("string_col", "row_100"), - ("bool_col", True), ("date_col", datetime.date(2021, 1, 1)), ("date_time_col", datetime.datetime(2020, 1, 1, 10, 0, 0)), ], @@ -204,7 +203,7 @@ def test_sql_read_with_not_null_filter_pushdowns(test_db, num_partitions, pdf) - @pytest.mark.integration() @pytest.mark.parametrize("num_partitions", [1, 2]) -def test_sql_read_with_if_else_filter_pushdown(test_db, num_partitions, pdf) -> None: +def test_sql_read_with_if_else_filter(test_db, num_partitions, pdf) -> None: df = daft.read_sql( f"SELECT * FROM {TEST_TABLE_NAME}", test_db, From 579e3593d4a7c6d18101791ba85423d63934043a Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Tue, 22 Oct 2024 14:33:59 -0700 Subject: [PATCH 5/6] bool test --- tests/integration/sql/test_sql.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/integration/sql/test_sql.py b/tests/integration/sql/test_sql.py index e668c36d1f..7983be00c7 100644 --- a/tests/integration/sql/test_sql.py +++ b/tests/integration/sql/test_sql.py @@ -134,12 +134,17 @@ def test_sql_read_with_partition_num_without_partition_col(test_db) -> None: ("id", 100), ("float_col", 100.0123), ("string_col", "row_100"), + ("bool_col", True), ("date_col", datetime.date(2021, 1, 1)), ("date_time_col", datetime.datetime(2020, 1, 1, 10, 0, 0)), ], ) @pytest.mark.parametrize("num_partitions", [1, 2]) def test_sql_read_with_binary_filter_pushdowns(test_db, column, operator, value, num_partitions, pdf) -> None: + # Skip invalid comparisons for bool_col + if column == "bool_col" and operator not in ("=", "!="): + pytest.skip(f"Operator {operator} not valid for bool_col") + df = daft.read_sql( f"SELECT * FROM {TEST_TABLE_NAME}", test_db, @@ -203,13 +208,15 @@ def test_sql_read_with_not_null_filter_pushdowns(test_db, num_partitions, pdf) - @pytest.mark.integration() @pytest.mark.parametrize("num_partitions", [1, 2]) -def test_sql_read_with_if_else_filter(test_db, num_partitions, pdf) -> None: +def test_sql_read_with_non_pushdowned_predicate(test_db, num_partitions, pdf) -> None: df = daft.read_sql( f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id", num_partitions=num_partitions, ) + + # If_else is not supported as a pushdown to read_sql, but it should still work df = df.where((df["id"] > 100).if_else(df["float_col"] > 150, df["float_col"] < 50)) pdf = pdf[(pdf["id"] > 100) & (pdf["float_col"] > 150) | (pdf["float_col"] < 50)] From 86e1f62e627fdb41ab38e3750c80196e07d934da Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Tue, 22 Oct 2024 17:00:15 -0700 Subject: [PATCH 6/6] comment --- daft/sql/sql_scan.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/daft/sql/sql_scan.py b/daft/sql/sql_scan.py index 27a3b626e0..4d3156ae80 100644 --- a/daft/sql/sql_scan.py +++ b/daft/sql/sql_scan.py @@ -161,7 +161,8 @@ def _get_num_rows(self) -> int: def _attempt_partition_bounds_read(self, num_scan_tasks: int) -> tuple[Any, PartitionBoundStrategy]: try: - # Try to get percentiles using percentile_cont + # Try to get percentiles using percentile_disc. + # Favor percentile_disc over percentile_cont because we want exact values to do <= and >= comparisons. percentiles = [i / num_scan_tasks for i in range(num_scan_tasks + 1)] # Use the OVER clause for SQL Server over_clause = "OVER ()" if self.conn.dialect in ["mssql", "tsql"] else ""