Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Respect column quote config in model contracts #7537

Merged
merged 7 commits into from
Jun 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20230506-191813.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Respect column 'quote' config in model contracts
time: 2023-05-06T19:18:13.351819+02:00
custom:
Author: jtcohen6
Issue: "7370"
3 changes: 2 additions & 1 deletion core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1346,7 +1346,8 @@ def render_raw_columns_constraints(cls, raw_columns: Dict[str, Dict[str, Any]])
rendered_column_constraints = []

for v in raw_columns.values():
rendered_column_constraint = [f"{v['name']} {v['data_type']}"]
col_name = cls.quote(v["name"]) if v.get("quote") else v["name"]
rendered_column_constraint = [f"{col_name} {v['data_type']}"]
for con in v.get("constraints", None):
constraint = cls._parse_column_constraint(con)
c = cls.process_parsed_constraint(constraint, cls.render_column_constraint)
Expand Down
1 change: 1 addition & 0 deletions core/dbt/adapters/sql/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def list_relations_without_caching(
)
return relations

@classmethod
def quote(self, identifier):
return '"{}"'.format(identifier)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@
{%- if col['data_type'] is not defined -%}
{{ col_err.append(col['name']) }}
{%- endif -%}
cast(null as {{ col['data_type'] }}) as {{ col['name'] }}{{ ", " if not loop.last }}
{% set col_name = adapter.quote(col['name']) if col.get('quote') else col['name'] %}
cast(null as {{ col['data_type'] }}) as {{ col_name }}{{ ", " if not loop.last }}
{%- endfor -%}
{%- if (col_err | length) > 0 -%}
{{ exceptions.column_type_missing(column_names=col_err) }}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,24 @@
);
{%- endmacro %}


{% macro default__get_column_names() %}
{#- loop through user_provided_columns to get column names -#}
{%- set user_provided_columns = model['columns'] -%}
{%- for i in user_provided_columns %}
{%- set col = user_provided_columns[i] -%}
{%- set col_name = adapter.quote(col['name']) if col.get('quote') else col['name'] -%}
{{ col_name }}{{ ", " if not loop.last }}
{%- endfor -%}
{% endmacro %}


{% macro get_select_subquery(sql) %}
{{ return(adapter.dispatch('get_select_subquery', 'dbt')(sql)) }}
{% endmacro %}

{% macro default__get_select_subquery(sql) %}
select
{% for column in model['columns'] %}
{{ column }}{{ ", " if not loop.last }}
{% endfor %}
select {{ adapter.dispatch('get_column_names', 'dbt')() }}
from (
{{ sql }}
) as model_subq
Expand Down
4 changes: 3 additions & 1 deletion plugins/postgres/dbt/include/postgres/macros/adapters.sql
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
{% if contract_config.enforced %}
{{ get_assert_columns_equivalent(sql) }}
{{ get_table_columns_and_constraints() }} ;
insert into {{ relation }} {{ get_column_names() }}
insert into {{ relation }} (
{{ adapter.dispatch('get_column_names', 'dbt')() }}
)
{%- set sql = get_select_subquery(sql) %}
{% else %}
as
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +0,0 @@
{% macro get_column_names() %}
{# loop through user_provided_columns to get column names #}
{%- set user_provided_columns = model['columns'] -%}
(
{% for i in user_provided_columns %}
{% set col = user_provided_columns[i] %}
{{ col['name'] }} {{ "," if not loop.last }}
{% endfor %}
)
{% endmacro %}
44 changes: 41 additions & 3 deletions tests/adapter/dbt/tests/adapter/constraints/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,16 @@
'2019-01-01' as date_day
"""


# 'from' is a reserved word, so it must be quoted
my_model_with_quoted_column_name_sql = """
select
'blue' as {{ adapter.quote('from') }},
1 as id,
'2019-01-01' as date_day
"""


model_schema_yml = """
version: 2
models:
Expand All @@ -260,7 +270,6 @@
enforced: true
columns:
- name: id
quote: true
data_type: integer
description: hello
constraints:
Expand Down Expand Up @@ -344,7 +353,6 @@
enforced: true
columns:
- name: id
quote: true
data_type: integer
description: hello
constraints:
Expand Down Expand Up @@ -454,7 +462,6 @@
expression: {schema}.foreign_key_model (id)
columns:
- name: id
quote: true
data_type: integer
description: hello
constraints:
Expand Down Expand Up @@ -490,6 +497,37 @@
data_type: {data_type}
"""


model_quoted_column_schema_yml = """
version: 2
models:
- name: my_model
config:
contract:
enforced: true
materialized: table
constraints:
- type: check
# this one is the on the user
expression: ("from" = 'blue')
columns: [ '"from"' ]
columns:
- name: id
data_type: integer
description: hello
constraints:
- type: not_null
tests:
- unique
- name: from # reserved word
quote: true
data_type: text
constraints:
- type: not_null
- name: date_day
data_type: text
"""

model_contract_header_schema_yml = """
version: 2
models:
Expand Down
61 changes: 47 additions & 14 deletions tests/adapter/dbt/tests/adapter/constraints/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
my_model_incremental_wrong_name_sql,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you haven't already - let's open PRs in adapter repos to ensure existing tests pass + we inherit the new ones.

my_model_with_nulls_sql,
my_model_incremental_with_nulls_sql,
my_model_with_quoted_column_name_sql,
model_schema_yml,
model_fk_constraint_schema_yml,
constrained_model_schema_yml,
model_quoted_column_schema_yml,
foreign_key_model_sql,
my_model_wrong_order_depends_on_fk_sql,
my_model_incremental_wrong_order_depends_on_fk_sql,
Expand Down Expand Up @@ -165,11 +167,12 @@ def test__constraints_correct_column_data_types(self, project, data_types):


def _normalize_whitespace(input: str) -> str:
return re.sub(r"\s+", " ", input).lower().strip()
subbed = re.sub(r"\s+", " ", input)
return re.sub(r"\s?([\(\),])\s?", r"\1", subbed).lower().strip()


def _find_and_replace(sql, find, replace):
sql_tokens = sql.split(" ")
sql_tokens = sql.split()
for idx in [n for n, x in enumerate(sql_tokens) if find in x]:
sql_tokens[idx] = replace
return " ".join(sql_tokens)
Expand Down Expand Up @@ -235,17 +238,12 @@ def test__constraints_ddl(self, project, expected_sql):
# the name is not what we're testing here anyways and varies based on materialization
# TODO: consider refactoring this to introspect logs instead
generated_sql = read_file("target", "run", "test", "models", "my_model.sql")
generated_sql_modified = _normalize_whitespace(generated_sql)
generated_sql_generic = _find_and_replace(
generated_sql_modified, "my_model", "<model_identifier>"
)
generated_sql_generic = _find_and_replace(generated_sql, "my_model", "<model_identifier>")
generated_sql_generic = _find_and_replace(
generated_sql_generic, "foreign_key_model", "<foreign_key_model_identifier>"
)

expected_sql_check = _normalize_whitespace(expected_sql)

assert expected_sql_check == generated_sql_generic
assert _normalize_whitespace(expected_sql) == _normalize_whitespace(generated_sql_generic)


class BaseConstraintsRollback:
Expand Down Expand Up @@ -485,15 +483,50 @@ def test__model_constraints_ddl(self, project, expected_sql):
# assert at least my_model was run - additional upstreams may or may not be provided to the test setup via models fixture
assert len(results) >= 1
generated_sql = read_file("target", "run", "test", "models", "my_model.sql")
generated_sql_modified = _normalize_whitespace(generated_sql)
generated_sql_generic = _find_and_replace(
generated_sql_modified, "my_model", "<model_identifier>"
)

generated_sql_generic = _find_and_replace(generated_sql, "my_model", "<model_identifier>")
generated_sql_generic = _find_and_replace(
generated_sql_generic, "foreign_key_model", "<foreign_key_model_identifier>"
)
assert _normalize_whitespace(expected_sql) == generated_sql_generic

assert _normalize_whitespace(expected_sql) == _normalize_whitespace(generated_sql_generic)


class TestModelConstraintsRuntimeEnforcement(BaseModelConstraintsRuntimeEnforcement):
pass


class BaseConstraintQuotedColumn(BaseConstraintsRuntimeDdlEnforcement):
@pytest.fixture(scope="class")
def models(self):
return {
"my_model.sql": my_model_with_quoted_column_name_sql,
"constraints_schema.yml": model_quoted_column_schema_yml,
}

@pytest.fixture(scope="class")
def expected_sql(self):
return """
create table <model_identifier> (
id integer not null,
"from" text not null,
date_day text,
check (("from" = 'blue'))
) ;
insert into <model_identifier> (
id, "from", date_day
)
(
select id, "from", date_day
from (
select
'blue' as "from",
1 as id,
'2019-01-01' as date_day
) as model_subq
);
"""


class TestConstraintQuotedColumn(BaseConstraintQuotedColumn):
pass