Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Add over clause in read_sql percentile reads #3094

Merged
merged 6 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,12 @@ jobs:
run: |
uv pip install -r requirements-dev.txt dist/${{ env.package-name }}-*x86_64*.whl --force-reinstall
rm -rf daft
- name: Install ODBC Driver 18 for SQL Server
run: |
curl https://packages.microsoft.com/keys/microsoft.asc | sudo apt-key add -
sudo add-apt-repository https://packages.microsoft.com/ubuntu/$(lsb_release -rs)/prod
sudo apt-get update
sudo ACCEPT_EULA=Y apt-get install -y msodbcsql18
- name: Spin up services
run: |
pushd ./tests/integration/sql/docker-compose/
Expand Down
8 changes: 6 additions & 2 deletions daft/sql/sql_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,18 @@ def _get_num_rows(self) -> int:

def _attempt_partition_bounds_read(self, num_scan_tasks: int) -> tuple[Any, PartitionBoundStrategy]:
try:
# Try to get percentiles using percentile_cont
# Try to get percentiles using percentile_disc.
# Favor percentile_disc over percentile_cont because we want exact values to do <= and >= comparisons.
percentiles = [i / num_scan_tasks for i in range(num_scan_tasks + 1)]
# Use the OVER clause for SQL Server
over_clause = "OVER ()" if self.conn.dialect in ["mssql", "tsql"] else ""
percentile_sql = self.conn.construct_sql_query(
self.sql,
projection=[
f"percentile_disc({percentile}) WITHIN GROUP (ORDER BY {self._partition_col}) AS bound_{i}"
f"percentile_disc({percentile}) WITHIN GROUP (ORDER BY {self._partition_col}) {over_clause} AS bound_{i}"
colin-ho marked this conversation as resolved.
Show resolved Hide resolved
for i, percentile in enumerate(percentiles)
],
limit=1,
)
pa_table = self.conn.execute_sql_query(percentile_sql)
return pa_table, PartitionBoundStrategy.PERCENTILE
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ trino[sqlalchemy]==0.328.0; python_version >= '3.8'
PyMySQL==1.1.0; python_version >= '3.8'
psycopg2-binary==2.9.9; python_version >= '3.8'
sqlglot==23.3.0; python_version >= '3.8'
pyodbc==5.1.0; python_version >= '3.8'

# AWS
s3fs==2023.12.0; python_version >= '3.8'
Expand Down
16 changes: 2 additions & 14 deletions src/daft-dsl/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -990,21 +990,9 @@ impl Expr {
to_sql_inner(inner, buffer)?;
write!(buffer, ") IS NOT NULL")
}
Expr::IfElse {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SQL server only accepts values in case when then clauses, so it was failing on the if_else pushdowns test, because we test with conditional expressions inside the if_else branches.

I think it's better to just remove it, the benefit of having if_else pushdowns in read_sql is probably not much at all, and i'd instead rather just avoid having this be a bug in the future.

if_true,
if_false,
predicate,
} => {
write!(buffer, "CASE WHEN ")?;
to_sql_inner(predicate, buffer)?;
write!(buffer, " THEN ")?;
to_sql_inner(if_true, buffer)?;
write!(buffer, " ELSE ")?;
to_sql_inner(if_false, buffer)?;
write!(buffer, " END")
}
// TODO: Implement SQL translations for these expressions if possible
Expr::Agg(..)
Expr::IfElse { .. }
| Expr::Agg(..)
| Expr::Cast(..)
| Expr::IsIn(..)
| Expr::Between(..)
Expand Down
1 change: 1 addition & 0 deletions tests/integration/sql/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"trino://user@localhost:8080/memory/default",
"postgresql://username:password@localhost:5432/postgres",
"mysql+pymysql://username:password@localhost:3306/mysql",
"mssql+pyodbc://SA:StrongPassword!@127.0.0.1:1433/master?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes",
]
TEST_TABLE_NAME = "example"
EMPTY_TEST_TABLE_NAME = "empty_table"
Expand Down
12 changes: 12 additions & 0 deletions tests/integration/sql/docker-compose/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ services:
volumes:
- mysql_data:/var/lib/mysql

azuresqledge:
image: mcr.microsoft.com/azure-sql-edge
container_name: azuresqledge
environment:
ACCEPT_EULA: "Y"
MSSQL_SA_PASSWORD: "StrongPassword!"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol

ports:
- 1433:1433
volumes:
- azuresqledge_data:/var/opt/mssql

volumes:
postgres_data:
mysql_data:
azuresqledge_data:
8 changes: 7 additions & 1 deletion tests/integration/sql/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ def test_sql_read_with_partition_num_without_partition_col(test_db) -> None:
)
@pytest.mark.parametrize("num_partitions", [1, 2])
def test_sql_read_with_binary_filter_pushdowns(test_db, column, operator, value, num_partitions, pdf) -> None:
# Skip invalid comparisons for bool_col
if column == "bool_col" and operator not in ("=", "!="):
pytest.skip(f"Operator {operator} not valid for bool_col")
Comment on lines +144 to +146
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SQL server only accepts equality based comparisons on bools, e.g. bool_col = True or bool_col != True.

In fact, we shouldn't even need to test comparisons on bools other than = and != anyway

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit for the test file, not necessarily for this PR. But not-equal is typically <> in sql, so maybe we should test that too.


df = daft.read_sql(
f"SELECT * FROM {TEST_TABLE_NAME}",
test_db,
Expand Down Expand Up @@ -204,13 +208,15 @@ def test_sql_read_with_not_null_filter_pushdowns(test_db, num_partitions, pdf) -

@pytest.mark.integration()
@pytest.mark.parametrize("num_partitions", [1, 2])
def test_sql_read_with_if_else_filter_pushdown(test_db, num_partitions, pdf) -> None:
def test_sql_read_with_non_pushdowned_predicate(test_db, num_partitions, pdf) -> None:
df = daft.read_sql(
f"SELECT * FROM {TEST_TABLE_NAME}",
test_db,
partition_col="id",
num_partitions=num_partitions,
)

# If_else is not supported as a pushdown to read_sql, but it should still work
df = df.where((df["id"] > 100).if_else(df["float_col"] > 150, df["float_col"] < 50))

pdf = pdf[(pdf["id"] > 100) & (pdf["float_col"] > 150) | (pdf["float_col"] < 50)]
Expand Down
Loading