Skip to content

Commit

Permalink
Remove pin of sqlparse, minor refactoring, add tests (#7993)
Browse files Browse the repository at this point in the history
  • Loading branch information
gshank authored Jun 29, 2023
1 parent 8c201e8 commit 5d93780
Show file tree
Hide file tree
Showing 5 changed files with 302 additions and 84 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20230621-185452.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Remove limitation on use of sqlparse 0.4.4
time: 2023-06-21T18:54:52.246578-04:00
custom:
Author: gshank
Issue: "7515"
134 changes: 72 additions & 62 deletions core/dbt/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ def link_graph(self, manifest: Manifest):
self.add_node(source.unique_id)
for semantic_model in manifest.semantic_models.values():
self.add_node(semantic_model.unique_id)

for node in manifest.nodes.values():
self.link_node(node, manifest)
for exposure in manifest.exposures.values():
Expand Down Expand Up @@ -301,66 +300,6 @@ def add_ephemeral_prefix(self, name: str):
relation_cls = adapter.Relation
return relation_cls.add_ephemeral_prefix(name)

def _inject_ctes_into_sql(self, sql: str, ctes: List[InjectedCTE]) -> str:
"""
`ctes` is a list of InjectedCTEs like:
[
InjectedCTE(
id="cte_id_1",
sql="__dbt__cte__ephemeral as (select * from table)",
),
InjectedCTE(
id="cte_id_2",
sql="__dbt__cte__events as (select id, type from events)",
),
]
Given `sql` like:
"with internal_cte as (select * from sessions)
select * from internal_cte"
This will spit out:
"with __dbt__cte__ephemeral as (select * from table),
__dbt__cte__events as (select id, type from events),
with internal_cte as (select * from sessions)
select * from internal_cte"
(Whitespace enhanced for readability.)
"""
if len(ctes) == 0:
return sql

parsed_stmts = sqlparse.parse(sql)
parsed = parsed_stmts[0]

with_stmt = None
for token in parsed.tokens:
if token.is_keyword and token.normalized == "WITH":
with_stmt = token
elif token.is_keyword and token.normalized == "RECURSIVE" and with_stmt is not None:
with_stmt = token
break
elif not token.is_whitespace and with_stmt is not None:
break

if with_stmt is None:
# no with stmt, add one, and inject CTEs right at the beginning
first_token = parsed.token_first()
with_stmt = sqlparse.sql.Token(sqlparse.tokens.Keyword, "with")
parsed.insert_before(first_token, with_stmt)
else:
# stmt exists, add a comma (which will come after injected CTEs)
trailing_comma = sqlparse.sql.Token(sqlparse.tokens.Punctuation, ",")
parsed.insert_after(with_stmt, trailing_comma)

token = sqlparse.sql.Token(sqlparse.tokens.Keyword, ", ".join(c.sql for c in ctes))
parsed.insert_after(with_stmt, token)

return str(parsed)

def _recursively_prepend_ctes(
self,
model: ManifestSQLNode,
Expand Down Expand Up @@ -435,7 +374,7 @@ def _recursively_prepend_ctes(

_add_prepended_cte(prepended_ctes, InjectedCTE(id=cte.id, sql=sql))

injected_sql = self._inject_ctes_into_sql(
injected_sql = inject_ctes_into_sql(
model.compiled_code,
prepended_ctes,
)
Expand Down Expand Up @@ -586,3 +525,74 @@ def compile_node(
if write:
self._write_node(node)
return node


def inject_ctes_into_sql(sql: str, ctes: List[InjectedCTE]) -> str:
"""
`ctes` is a list of InjectedCTEs like:
[
InjectedCTE(
id="cte_id_1",
sql="__dbt__cte__ephemeral as (select * from table)",
),
InjectedCTE(
id="cte_id_2",
sql="__dbt__cte__events as (select id, type from events)",
),
]
Given `sql` like:
"with internal_cte as (select * from sessions)
select * from internal_cte"
This will spit out:
"with __dbt__cte__ephemeral as (select * from table),
__dbt__cte__events as (select id, type from events),
internal_cte as (select * from sessions)
select * from internal_cte"
(Whitespace enhanced for readability.)
"""
if len(ctes) == 0:
return sql

parsed_stmts = sqlparse.parse(sql)
parsed = parsed_stmts[0]

with_stmt = None
for token in parsed.tokens:
if token.is_keyword and token.normalized == "WITH":
with_stmt = token
elif token.is_keyword and token.normalized == "RECURSIVE" and with_stmt is not None:
with_stmt = token
break
elif not token.is_whitespace and with_stmt is not None:
break

if with_stmt is None:
# no with stmt, add one, and inject CTEs right at the beginning
# [original_sql]
first_token = parsed.token_first()
with_token = sqlparse.sql.Token(sqlparse.tokens.Keyword, "with")
parsed.insert_before(first_token, with_token)
# [with][original_sql]
injected_ctes = ", ".join(c.sql for c in ctes) + " "
injected_ctes_token = sqlparse.sql.Token(sqlparse.tokens.Keyword, injected_ctes)
parsed.insert_after(with_token, injected_ctes_token)
# [with][joined_ctes][original_sql]
else:
# with stmt exists so we don't need to add one, but we do need to add a comma
# between the injected ctes and the original sql
# [with][original_sql]
injected_ctes = ", ".join(c.sql for c in ctes)
injected_ctes_token = sqlparse.sql.Token(sqlparse.tokens.Keyword, injected_ctes)
parsed.insert_after(with_stmt, injected_ctes_token)
# [with][joined_ctes][original_sql]
comma_token = sqlparse.sql.Token(sqlparse.tokens.Punctuation, ", ")
parsed.insert_after(injected_ctes_token, comma_token)
# [with][joined_ctes][, ][original_sql]

return str(parsed)
5 changes: 2 additions & 3 deletions core/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,8 @@
"pathspec>=0.9,<0.12",
"isodate>=0.6,<0.7",
# ----
# There is a difficult-to-reproduce bug in sqlparse==0.4.4 for ephemeral model compilation
# For context: dbt-core#7396 + dbt-core#7515
"sqlparse>=0.2.3,<0.4.4",
# There was a pin to below 0.4.4 for a while due to a bug in Ubuntu/sqlparse 0.4.4
"sqlparse>=0.2.3",
# ----
# These are major-version-0 packages also maintained by dbt-labs. Accept patches.
"dbt-extractor~=0.4.1",
Expand Down
44 changes: 25 additions & 19 deletions tests/functional/compile/test_compile.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import json
import pathlib
import pytest
import re

from dbt.cli.main import dbtRunner
from dbt.exceptions import DbtRuntimeError, TargetNotFoundError
from dbt.tests.util import run_dbt, run_dbt_and_capture
from dbt.tests.util import run_dbt, run_dbt_and_capture, read_file
from tests.functional.compile.fixtures import (
first_model_sql,
second_model_sql,
Expand All @@ -17,9 +18,13 @@
)


def get_lines(model_name):
from dbt.tests.util import read_file
def norm_whitespace(string):
_RE_COMBINE_WHITESPACE = re.compile(r"\s+")
string = _RE_COMBINE_WHITESPACE.sub(" ", string).strip()
return string


def get_lines(model_name):
f = read_file("target", "compiled", "test", "models", model_name + ".sql")
return [line for line in f.splitlines() if line]

Expand Down Expand Up @@ -90,29 +95,30 @@ def test_last_selector(self, project):
def test_no_selector(self, project):
run_dbt(["compile"])

assert get_lines("first_ephemeral_model") == ["select 1 as fun"]
assert get_lines("second_ephemeral_model") == [
"with __dbt__cte__first_ephemeral_model as (",
"select 1 as fun",
")select * from __dbt__cte__first_ephemeral_model",
]
assert get_lines("third_ephemeral_model") == [
"with __dbt__cte__first_ephemeral_model as (",
"select 1 as fun",
"), __dbt__cte__second_ephemeral_model as (",
"select * from __dbt__cte__first_ephemeral_model",
")select * from __dbt__cte__second_ephemeral_model",
"union all",
"select 2 as fun",
]
sql = read_file("target", "compiled", "test", "models", "first_ephemeral_model.sql")
assert norm_whitespace(sql) == norm_whitespace("select 1 as fun")
sql = read_file("target", "compiled", "test", "models", "second_ephemeral_model.sql")
expected_sql = """with __dbt__cte__first_ephemeral_model as (
select 1 as fun
) select * from __dbt__cte__first_ephemeral_model"""
assert norm_whitespace(sql) == norm_whitespace(expected_sql)
sql = read_file("target", "compiled", "test", "models", "third_ephemeral_model.sql")
expected_sql = """with __dbt__cte__first_ephemeral_model as (
select 1 as fun
), __dbt__cte__second_ephemeral_model as (
select * from __dbt__cte__first_ephemeral_model
) select * from __dbt__cte__second_ephemeral_model
union all
select 2 as fun"""
assert norm_whitespace(sql) == norm_whitespace(expected_sql)

def test_with_recursive_cte(self, project):
run_dbt(["compile"])

assert get_lines("with_recursive_model") == [
"with recursive __dbt__cte__first_ephemeral_model as (",
"select 1 as fun",
"),t(n) as (",
"), t(n) as (",
" select * from __dbt__cte__first_ephemeral_model",
" union all",
" select n+1 from t where n < 100",
Expand Down
Loading

0 comments on commit 5d93780

Please sign in to comment.