Skip to content

Commit

Permalink
Respect column quote config in model contracts (#7537)
Browse files Browse the repository at this point in the history
  • Loading branch information
jtcohen6 authored Jun 13, 2023
1 parent d46e885 commit 83d163a
Show file tree
Hide file tree
Showing 9 changed files with 115 additions and 34 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20230506-191813.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Respect column 'quote' config in model contracts
time: 2023-05-06T19:18:13.351819+02:00
custom:
Author: jtcohen6
Issue: "7370"
3 changes: 2 additions & 1 deletion core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1346,7 +1346,8 @@ def render_raw_columns_constraints(cls, raw_columns: Dict[str, Dict[str, Any]])
rendered_column_constraints = []

for v in raw_columns.values():
rendered_column_constraint = [f"{v['name']} {v['data_type']}"]
col_name = cls.quote(v["name"]) if v.get("quote") else v["name"]
rendered_column_constraint = [f"{col_name} {v['data_type']}"]
for con in v.get("constraints", None):
constraint = cls._parse_column_constraint(con)
c = cls.process_parsed_constraint(constraint, cls.render_column_constraint)
Expand Down
1 change: 1 addition & 0 deletions core/dbt/adapters/sql/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def list_relations_without_caching(
)
return relations

@classmethod
def quote(self, identifier):
return '"{}"'.format(identifier)

Expand Down
3 changes: 2 additions & 1 deletion core/dbt/include/global_project/macros/adapters/columns.sql
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@
{%- if col['data_type'] is not defined -%}
{{ col_err.append(col['name']) }}
{%- endif -%}
cast(null as {{ col['data_type'] }}) as {{ col['name'] }}{{ ", " if not loop.last }}
{% set col_name = adapter.quote(col['name']) if col.get('quote') else col['name'] %}
cast(null as {{ col['data_type'] }}) as {{ col_name }}{{ ", " if not loop.last }}
{%- endfor -%}
{%- if (col_err | length) > 0 -%}
{{ exceptions.column_type_missing(column_names=col_err) }}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,24 @@
);
{%- endmacro %}


{% macro default__get_column_names() %}
{#- loop through user_provided_columns to get column names -#}
{%- set user_provided_columns = model['columns'] -%}
{%- for i in user_provided_columns %}
{%- set col = user_provided_columns[i] -%}
{%- set col_name = adapter.quote(col['name']) if col.get('quote') else col['name'] -%}
{{ col_name }}{{ ", " if not loop.last }}
{%- endfor -%}
{% endmacro %}


{% macro get_select_subquery(sql) %}
{{ return(adapter.dispatch('get_select_subquery', 'dbt')(sql)) }}
{% endmacro %}

{% macro default__get_select_subquery(sql) %}
select
{% for column in model['columns'] %}
{{ column }}{{ ", " if not loop.last }}
{% endfor %}
select {{ adapter.dispatch('get_column_names', 'dbt')() }}
from (
{{ sql }}
) as model_subq
Expand Down
4 changes: 3 additions & 1 deletion plugins/postgres/dbt/include/postgres/macros/adapters.sql
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
{% if contract_config.enforced %}
{{ get_assert_columns_equivalent(sql) }}
{{ get_table_columns_and_constraints() }} ;
insert into {{ relation }} {{ get_column_names() }}
insert into {{ relation }} (
{{ adapter.dispatch('get_column_names', 'dbt')() }}
)
{%- set sql = get_select_subquery(sql) %}
{% else %}
as
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +0,0 @@
{% macro get_column_names() %}
{# loop through user_provided_columns to get column names #}
{%- set user_provided_columns = model['columns'] -%}
(
{% for i in user_provided_columns %}
{% set col = user_provided_columns[i] %}
{{ col['name'] }} {{ "," if not loop.last }}
{% endfor %}
)
{% endmacro %}
44 changes: 41 additions & 3 deletions tests/adapter/dbt/tests/adapter/constraints/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,16 @@
'2019-01-01' as date_day
"""


# 'from' is a reserved word, so it must be quoted
my_model_with_quoted_column_name_sql = """
select
'blue' as {{ adapter.quote('from') }},
1 as id,
'2019-01-01' as date_day
"""


model_schema_yml = """
version: 2
models:
Expand All @@ -260,7 +270,6 @@
enforced: true
columns:
- name: id
quote: true
data_type: integer
description: hello
constraints:
Expand Down Expand Up @@ -344,7 +353,6 @@
enforced: true
columns:
- name: id
quote: true
data_type: integer
description: hello
constraints:
Expand Down Expand Up @@ -454,7 +462,6 @@
expression: {schema}.foreign_key_model (id)
columns:
- name: id
quote: true
data_type: integer
description: hello
constraints:
Expand Down Expand Up @@ -490,6 +497,37 @@
data_type: {data_type}
"""


model_quoted_column_schema_yml = """
version: 2
models:
- name: my_model
config:
contract:
enforced: true
materialized: table
constraints:
- type: check
# this one is the on the user
expression: ("from" = 'blue')
columns: [ '"from"' ]
columns:
- name: id
data_type: integer
description: hello
constraints:
- type: not_null
tests:
- unique
- name: from # reserved word
quote: true
data_type: text
constraints:
- type: not_null
- name: date_day
data_type: text
"""

model_contract_header_schema_yml = """
version: 2
models:
Expand Down
61 changes: 47 additions & 14 deletions tests/adapter/dbt/tests/adapter/constraints/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
my_model_incremental_wrong_name_sql,
my_model_with_nulls_sql,
my_model_incremental_with_nulls_sql,
my_model_with_quoted_column_name_sql,
model_schema_yml,
model_fk_constraint_schema_yml,
constrained_model_schema_yml,
model_quoted_column_schema_yml,
foreign_key_model_sql,
my_model_wrong_order_depends_on_fk_sql,
my_model_incremental_wrong_order_depends_on_fk_sql,
Expand Down Expand Up @@ -165,11 +167,12 @@ def test__constraints_correct_column_data_types(self, project, data_types):


def _normalize_whitespace(input: str) -> str:
return re.sub(r"\s+", " ", input).lower().strip()
subbed = re.sub(r"\s+", " ", input)
return re.sub(r"\s?([\(\),])\s?", r"\1", subbed).lower().strip()


def _find_and_replace(sql, find, replace):
sql_tokens = sql.split(" ")
sql_tokens = sql.split()
for idx in [n for n, x in enumerate(sql_tokens) if find in x]:
sql_tokens[idx] = replace
return " ".join(sql_tokens)
Expand Down Expand Up @@ -235,17 +238,12 @@ def test__constraints_ddl(self, project, expected_sql):
# the name is not what we're testing here anyways and varies based on materialization
# TODO: consider refactoring this to introspect logs instead
generated_sql = read_file("target", "run", "test", "models", "my_model.sql")
generated_sql_modified = _normalize_whitespace(generated_sql)
generated_sql_generic = _find_and_replace(
generated_sql_modified, "my_model", "<model_identifier>"
)
generated_sql_generic = _find_and_replace(generated_sql, "my_model", "<model_identifier>")
generated_sql_generic = _find_and_replace(
generated_sql_generic, "foreign_key_model", "<foreign_key_model_identifier>"
)

expected_sql_check = _normalize_whitespace(expected_sql)

assert expected_sql_check == generated_sql_generic
assert _normalize_whitespace(expected_sql) == _normalize_whitespace(generated_sql_generic)


class BaseConstraintsRollback:
Expand Down Expand Up @@ -485,15 +483,50 @@ def test__model_constraints_ddl(self, project, expected_sql):
# assert at least my_model was run - additional upstreams may or may not be provided to the test setup via models fixture
assert len(results) >= 1
generated_sql = read_file("target", "run", "test", "models", "my_model.sql")
generated_sql_modified = _normalize_whitespace(generated_sql)
generated_sql_generic = _find_and_replace(
generated_sql_modified, "my_model", "<model_identifier>"
)

generated_sql_generic = _find_and_replace(generated_sql, "my_model", "<model_identifier>")
generated_sql_generic = _find_and_replace(
generated_sql_generic, "foreign_key_model", "<foreign_key_model_identifier>"
)
assert _normalize_whitespace(expected_sql) == generated_sql_generic

assert _normalize_whitespace(expected_sql) == _normalize_whitespace(generated_sql_generic)


class TestModelConstraintsRuntimeEnforcement(BaseModelConstraintsRuntimeEnforcement):
pass


class BaseConstraintQuotedColumn(BaseConstraintsRuntimeDdlEnforcement):
@pytest.fixture(scope="class")
def models(self):
return {
"my_model.sql": my_model_with_quoted_column_name_sql,
"constraints_schema.yml": model_quoted_column_schema_yml,
}

@pytest.fixture(scope="class")
def expected_sql(self):
return """
create table <model_identifier> (
id integer not null,
"from" text not null,
date_day text,
check (("from" = 'blue'))
) ;
insert into <model_identifier> (
id, "from", date_day
)
(
select id, "from", date_day
from (
select
'blue' as "from",
1 as id,
'2019-01-01' as date_day
) as model_subq
);
"""


class TestConstraintQuotedColumn(BaseConstraintQuotedColumn):
pass

0 comments on commit 83d163a

Please sign in to comment.