Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove pin of sqlparse, minor refactoring, add tests #7993

Merged
merged 1 commit into from
Jun 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20230621-185452.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Remove limitation on use of sqlparse 0.4.4
time: 2023-06-21T18:54:52.246578-04:00
custom:
Author: gshank
Issue: "7515"
134 changes: 72 additions & 62 deletions core/dbt/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ def link_graph(self, manifest: Manifest):
self.add_node(source.unique_id)
for semantic_model in manifest.semantic_models.values():
self.add_node(semantic_model.unique_id)

for node in manifest.nodes.values():
self.link_node(node, manifest)
for exposure in manifest.exposures.values():
Expand Down Expand Up @@ -301,66 +300,6 @@ def add_ephemeral_prefix(self, name: str):
relation_cls = adapter.Relation
return relation_cls.add_ephemeral_prefix(name)

def _inject_ctes_into_sql(self, sql: str, ctes: List[InjectedCTE]) -> str:
"""
`ctes` is a list of InjectedCTEs like:

[
InjectedCTE(
id="cte_id_1",
sql="__dbt__cte__ephemeral as (select * from table)",
),
InjectedCTE(
id="cte_id_2",
sql="__dbt__cte__events as (select id, type from events)",
),
]

Given `sql` like:

"with internal_cte as (select * from sessions)
select * from internal_cte"

This will spit out:

"with __dbt__cte__ephemeral as (select * from table),
__dbt__cte__events as (select id, type from events),
with internal_cte as (select * from sessions)
select * from internal_cte"

(Whitespace enhanced for readability.)
"""
if len(ctes) == 0:
return sql

parsed_stmts = sqlparse.parse(sql)
parsed = parsed_stmts[0]

with_stmt = None
for token in parsed.tokens:
if token.is_keyword and token.normalized == "WITH":
with_stmt = token
elif token.is_keyword and token.normalized == "RECURSIVE" and with_stmt is not None:
with_stmt = token
break
elif not token.is_whitespace and with_stmt is not None:
break

if with_stmt is None:
# no with stmt, add one, and inject CTEs right at the beginning
first_token = parsed.token_first()
with_stmt = sqlparse.sql.Token(sqlparse.tokens.Keyword, "with")
parsed.insert_before(first_token, with_stmt)
else:
# stmt exists, add a comma (which will come after injected CTEs)
trailing_comma = sqlparse.sql.Token(sqlparse.tokens.Punctuation, ",")
parsed.insert_after(with_stmt, trailing_comma)

token = sqlparse.sql.Token(sqlparse.tokens.Keyword, ", ".join(c.sql for c in ctes))
parsed.insert_after(with_stmt, token)

return str(parsed)

def _recursively_prepend_ctes(
self,
model: ManifestSQLNode,
Expand Down Expand Up @@ -435,7 +374,7 @@ def _recursively_prepend_ctes(

_add_prepended_cte(prepended_ctes, InjectedCTE(id=cte.id, sql=sql))

injected_sql = self._inject_ctes_into_sql(
injected_sql = inject_ctes_into_sql(
model.compiled_code,
prepended_ctes,
)
Expand Down Expand Up @@ -586,3 +525,74 @@ def compile_node(
if write:
self._write_node(node)
return node


def inject_ctes_into_sql(sql: str, ctes: List[InjectedCTE]) -> str:
"""
`ctes` is a list of InjectedCTEs like:

[
InjectedCTE(
id="cte_id_1",
sql="__dbt__cte__ephemeral as (select * from table)",
),
InjectedCTE(
id="cte_id_2",
sql="__dbt__cte__events as (select id, type from events)",
),
]

Given `sql` like:

"with internal_cte as (select * from sessions)
select * from internal_cte"

This will spit out:

"with __dbt__cte__ephemeral as (select * from table),
__dbt__cte__events as (select id, type from events),
internal_cte as (select * from sessions)
select * from internal_cte"

(Whitespace enhanced for readability.)
"""
if len(ctes) == 0:
return sql

parsed_stmts = sqlparse.parse(sql)
parsed = parsed_stmts[0]

with_stmt = None
for token in parsed.tokens:
if token.is_keyword and token.normalized == "WITH":
with_stmt = token
elif token.is_keyword and token.normalized == "RECURSIVE" and with_stmt is not None:
with_stmt = token
break
elif not token.is_whitespace and with_stmt is not None:
break

if with_stmt is None:
# no with stmt, add one, and inject CTEs right at the beginning
# [original_sql]
first_token = parsed.token_first()
with_token = sqlparse.sql.Token(sqlparse.tokens.Keyword, "with")
parsed.insert_before(first_token, with_token)
# [with][original_sql]
injected_ctes = ", ".join(c.sql for c in ctes) + " "
injected_ctes_token = sqlparse.sql.Token(sqlparse.tokens.Keyword, injected_ctes)
parsed.insert_after(with_token, injected_ctes_token)
# [with][joined_ctes][original_sql]
else:
# with stmt exists so we don't need to add one, but we do need to add a comma
# between the injected ctes and the original sql
# [with][original_sql]
injected_ctes = ", ".join(c.sql for c in ctes)
injected_ctes_token = sqlparse.sql.Token(sqlparse.tokens.Keyword, injected_ctes)
parsed.insert_after(with_stmt, injected_ctes_token)
# [with][joined_ctes][original_sql]
comma_token = sqlparse.sql.Token(sqlparse.tokens.Punctuation, ", ")
parsed.insert_after(injected_ctes_token, comma_token)
# [with][joined_ctes][, ][original_sql]

return str(parsed)
5 changes: 2 additions & 3 deletions core/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,8 @@
"pathspec>=0.9,<0.12",
"isodate>=0.6,<0.7",
# ----
# There is a difficult-to-reproduce bug in sqlparse==0.4.4 for ephemeral model compilation
# For context: dbt-core#7396 + dbt-core#7515
"sqlparse>=0.2.3,<0.4.4",
# There was a pin to below 0.4.4 for a while due to a bug in Ubuntu/sqlparse 0.4.4
"sqlparse>=0.2.3",
# ----
# These are major-version-0 packages also maintained by dbt-labs. Accept patches.
"dbt-extractor~=0.4.1",
Expand Down
44 changes: 25 additions & 19 deletions tests/functional/compile/test_compile.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import json
import pathlib
import pytest
import re

from dbt.cli.main import dbtRunner
from dbt.exceptions import DbtRuntimeError, TargetNotFoundError
from dbt.tests.util import run_dbt, run_dbt_and_capture
from dbt.tests.util import run_dbt, run_dbt_and_capture, read_file
from tests.functional.compile.fixtures import (
first_model_sql,
second_model_sql,
Expand All @@ -17,9 +18,13 @@
)


def get_lines(model_name):
from dbt.tests.util import read_file
def norm_whitespace(string):
_RE_COMBINE_WHITESPACE = re.compile(r"\s+")
string = _RE_COMBINE_WHITESPACE.sub(" ", string).strip()
return string


def get_lines(model_name):
f = read_file("target", "compiled", "test", "models", model_name + ".sql")
return [line for line in f.splitlines() if line]

Expand Down Expand Up @@ -90,29 +95,30 @@ def test_last_selector(self, project):
def test_no_selector(self, project):
run_dbt(["compile"])

assert get_lines("first_ephemeral_model") == ["select 1 as fun"]
assert get_lines("second_ephemeral_model") == [
"with __dbt__cte__first_ephemeral_model as (",
"select 1 as fun",
")select * from __dbt__cte__first_ephemeral_model",
]
assert get_lines("third_ephemeral_model") == [
"with __dbt__cte__first_ephemeral_model as (",
"select 1 as fun",
"), __dbt__cte__second_ephemeral_model as (",
"select * from __dbt__cte__first_ephemeral_model",
")select * from __dbt__cte__second_ephemeral_model",
"union all",
"select 2 as fun",
]
sql = read_file("target", "compiled", "test", "models", "first_ephemeral_model.sql")
assert norm_whitespace(sql) == norm_whitespace("select 1 as fun")
sql = read_file("target", "compiled", "test", "models", "second_ephemeral_model.sql")
expected_sql = """with __dbt__cte__first_ephemeral_model as (
select 1 as fun
) select * from __dbt__cte__first_ephemeral_model"""
assert norm_whitespace(sql) == norm_whitespace(expected_sql)
sql = read_file("target", "compiled", "test", "models", "third_ephemeral_model.sql")
expected_sql = """with __dbt__cte__first_ephemeral_model as (
select 1 as fun
), __dbt__cte__second_ephemeral_model as (
select * from __dbt__cte__first_ephemeral_model
) select * from __dbt__cte__second_ephemeral_model
union all
select 2 as fun"""
assert norm_whitespace(sql) == norm_whitespace(expected_sql)

def test_with_recursive_cte(self, project):
run_dbt(["compile"])

assert get_lines("with_recursive_model") == [
"with recursive __dbt__cte__first_ephemeral_model as (",
"select 1 as fun",
"),t(n) as (",
"), t(n) as (",
" select * from __dbt__cte__first_ephemeral_model",
" union all",
" select n+1 from t where n < 100",
Expand Down
Loading