Skip to content

Commit

Permalink
[BUG] Add over clause in read_sql percentile reads (#3094)
Browse files Browse the repository at this point in the history
Addresses: #3075

SQL server requires an `OVER` clause to be specified in percentile
queries (because it's a window function). Read sql uses percentiles to
determine partition bounds.

Adds AzureSqlEdge as a test database. Might as well since a lot of ppl
use us to read sqlserver, and have had bugs with sql server. Kind of a
pain to get it set up since it requires odbc and drivers etc. but it
works. It's also not much of a hit on CI times, installing drivers takes
around ~15s and the extra tests take around 5s.

Additionally made some modifications to some tests and pushdowns, left
comments on the rationale.

---------

Co-authored-by: Colin Ho <colinho@Colins-MacBook-Pro.local>
Co-authored-by: Colin Ho <colinho@Colins-MBP.localdomain>
  • Loading branch information
3 people authored Oct 23, 2024
1 parent 26d639b commit 459ba82
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 17 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,12 @@ jobs:
run: |
uv pip install -r requirements-dev.txt dist/${{ env.package-name }}-*x86_64*.whl --force-reinstall
rm -rf daft
- name: Install ODBC Driver 18 for SQL Server
run: |
curl https://packages.microsoft.com/keys/microsoft.asc | sudo apt-key add -
sudo add-apt-repository https://packages.microsoft.com/ubuntu/$(lsb_release -rs)/prod
sudo apt-get update
sudo ACCEPT_EULA=Y apt-get install -y msodbcsql18
- name: Spin up services
run: |
pushd ./tests/integration/sql/docker-compose/
Expand Down
8 changes: 6 additions & 2 deletions daft/sql/sql_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,18 @@ def _get_num_rows(self) -> int:

def _attempt_partition_bounds_read(self, num_scan_tasks: int) -> tuple[Any, PartitionBoundStrategy]:
try:
# Try to get percentiles using percentile_cont
# Try to get percentiles using percentile_disc.
# Favor percentile_disc over percentile_cont because we want exact values to do <= and >= comparisons.
percentiles = [i / num_scan_tasks for i in range(num_scan_tasks + 1)]
# Use the OVER clause for SQL Server
over_clause = "OVER ()" if self.conn.dialect in ["mssql", "tsql"] else ""
percentile_sql = self.conn.construct_sql_query(
self.sql,
projection=[
f"percentile_disc({percentile}) WITHIN GROUP (ORDER BY {self._partition_col}) AS bound_{i}"
f"percentile_disc({percentile}) WITHIN GROUP (ORDER BY {self._partition_col}) {over_clause} AS bound_{i}"
for i, percentile in enumerate(percentiles)
],
limit=1,
)
pa_table = self.conn.execute_sql_query(percentile_sql)
return pa_table, PartitionBoundStrategy.PERCENTILE
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ trino[sqlalchemy]==0.328.0; python_version >= '3.8'
PyMySQL==1.1.0; python_version >= '3.8'
psycopg2-binary==2.9.9; python_version >= '3.8'
sqlglot==23.3.0; python_version >= '3.8'
pyodbc==5.1.0; python_version >= '3.8'

# AWS
s3fs==2023.12.0; python_version >= '3.8'
Expand Down
16 changes: 2 additions & 14 deletions src/daft-dsl/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -990,21 +990,9 @@ impl Expr {
to_sql_inner(inner, buffer)?;
write!(buffer, ") IS NOT NULL")
}
Expr::IfElse {
if_true,
if_false,
predicate,
} => {
write!(buffer, "CASE WHEN ")?;
to_sql_inner(predicate, buffer)?;
write!(buffer, " THEN ")?;
to_sql_inner(if_true, buffer)?;
write!(buffer, " ELSE ")?;
to_sql_inner(if_false, buffer)?;
write!(buffer, " END")
}
// TODO: Implement SQL translations for these expressions if possible
Expr::Agg(..)
Expr::IfElse { .. }
| Expr::Agg(..)
| Expr::Cast(..)
| Expr::IsIn(..)
| Expr::Between(..)
Expand Down
1 change: 1 addition & 0 deletions tests/integration/sql/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"trino://user@localhost:8080/memory/default",
"postgresql://username:password@localhost:5432/postgres",
"mysql+pymysql://username:password@localhost:3306/mysql",
"mssql+pyodbc://SA:StrongPassword!@127.0.0.1:1433/master?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes",
]
TEST_TABLE_NAME = "example"
EMPTY_TEST_TABLE_NAME = "empty_table"
Expand Down
12 changes: 12 additions & 0 deletions tests/integration/sql/docker-compose/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ services:
volumes:
- mysql_data:/var/lib/mysql

azuresqledge:
image: mcr.microsoft.com/azure-sql-edge
container_name: azuresqledge
environment:
ACCEPT_EULA: "Y"
MSSQL_SA_PASSWORD: "StrongPassword!"
ports:
- 1433:1433
volumes:
- azuresqledge_data:/var/opt/mssql

volumes:
postgres_data:
mysql_data:
azuresqledge_data:
8 changes: 7 additions & 1 deletion tests/integration/sql/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ def test_sql_read_with_partition_num_without_partition_col(test_db) -> None:
)
@pytest.mark.parametrize("num_partitions", [1, 2])
def test_sql_read_with_binary_filter_pushdowns(test_db, column, operator, value, num_partitions, pdf) -> None:
# Skip invalid comparisons for bool_col
if column == "bool_col" and operator not in ("=", "!="):
pytest.skip(f"Operator {operator} not valid for bool_col")

df = daft.read_sql(
f"SELECT * FROM {TEST_TABLE_NAME}",
test_db,
Expand Down Expand Up @@ -204,13 +208,15 @@ def test_sql_read_with_not_null_filter_pushdowns(test_db, num_partitions, pdf) -

@pytest.mark.integration()
@pytest.mark.parametrize("num_partitions", [1, 2])
def test_sql_read_with_if_else_filter_pushdown(test_db, num_partitions, pdf) -> None:
def test_sql_read_with_non_pushdowned_predicate(test_db, num_partitions, pdf) -> None:
df = daft.read_sql(
f"SELECT * FROM {TEST_TABLE_NAME}",
test_db,
partition_col="id",
num_partitions=num_partitions,
)

# If_else is not supported as a pushdown to read_sql, but it should still work
df = df.where((df["id"] > 100).if_else(df["float_col"] > 150, df["float_col"] < 50))

pdf = pdf[(pdf["id"] > 100) & (pdf["float_col"] > 150) | (pdf["float_col"] < 50)]
Expand Down

0 comments on commit 459ba82

Please sign in to comment.