Skip to content

Commit

Permalink
fix: tsql union with limit
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed Oct 4, 2024
1 parent d15efa2 commit 484df7d
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 8 deletions.
20 changes: 13 additions & 7 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,16 @@ def generate(self, expression: exp.Expression, copy: bool = True) -> str:

def preprocess(self, expression: exp.Expression) -> exp.Expression:
"""Apply generic preprocessing transformations to a given expression."""
expression = self._move_ctes_to_top_level(expression)

if self.ENSURE_BOOLS:
from sqlglot.transforms import ensure_bools

expression = ensure_bools(expression)

return expression

def _move_ctes_to_top_level(self, expression: E) -> E:
if (
not expression.parent
and type(expression) in self.EXPRESSIONS_WITHOUT_NESTED_CTES
Expand All @@ -725,12 +735,6 @@ def preprocess(self, expression: exp.Expression) -> exp.Expression:
from sqlglot.transforms import move_ctes_to_top_level

expression = move_ctes_to_top_level(expression)

if self.ENSURE_BOOLS:
from sqlglot.transforms import ensure_bools

expression = ensure_bools(expression)

return expression

def unsupported(self, message: str) -> None:
Expand Down Expand Up @@ -1377,7 +1381,9 @@ def set_operations(self, expression: exp.SetOperation) -> str:
order = expression.args.get("order")

if limit or order:
select = exp.subquery(expression, "_l_0", copy=False).select("*", copy=False)
select = self._move_ctes_to_top_level(
exp.subquery(expression, "_l_0", copy=False).select("*", copy=False)
)

if limit:
select = select.limit(limit.pop(), copy=False)
Expand Down
3 changes: 2 additions & 1 deletion sqlglot/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


if t.TYPE_CHECKING:
from sqlglot._typing import E
from sqlglot.generator import Generator


Expand Down Expand Up @@ -649,7 +650,7 @@ def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
return expression


def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
def move_ctes_to_top_level(expression: E) -> E:
"""
Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
defined at the top-level, so for example queries like:
Expand Down
5 changes: 5 additions & 0 deletions tests/dialects/test_tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ class TestTSQL(Validator):
dialect = "tsql"

def test_tsql(self):
self.validate_identity(
"with x as (select 1) select * from x union select * from x order by 1 limit 0",
"WITH x AS (SELECT 1 AS [1]) SELECT TOP 0 * FROM (SELECT * FROM x UNION SELECT * FROM x) AS _l_0 ORDER BY 1",
)

# https://learn.microsoft.com/en-us/previous-versions/sql/sql-server-2008-r2/ms187879(v=sql.105)?redirectedfrom=MSDN
# tsql allows .. which means use the default schema
self.validate_identity("SELECT * FROM a..b")
Expand Down

0 comments on commit 484df7d

Please sign in to comment.