Skip to content

Commit

Permalink
add tests for enforcing contracts for incremental materializations (#685
Browse files Browse the repository at this point in the history
)

* add tests for enforcing contracts for incremental materializations

* remove changelog

* modify test

* add new test value

* add another error msg

* use the right models

* fix model definition

* reorganize tests

* persist constraints for incremental mats

* fix expected color fixture

* move consraints

* move do persists_constraints

* reset dev reqs

* stringify relation (#698)
  • Loading branch information
emmyoop authored Mar 28, 2023
1 parent 28e4493 commit 3b3b2a0
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 36 deletions.
2 changes: 1 addition & 1 deletion dbt/adapters/spark/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def _get_one_catalog(

columns: List[Dict[str, Any]] = []
for relation in self.list_relations(database, schema):
logger.debug("Getting table schema for relation {}", relation)
logger.debug("Getting table schema for relation {}", str(relation))
columns.extend(self._get_columns_for_catalog(relation))
return agate.Table.from_object(columns, column_types=DEFAULT_TYPE_TESTER)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
{%- call statement('main', language=language) -%}
{{ create_table_as(False, target_relation, compiled_code, language) }}
{%- endcall -%}
{% do persist_constraints(target_relation, model) %}
{%- elif existing_relation.is_view or should_full_refresh() -%}
{#-- Relation must be dropped & recreated --#}
{% set is_delta = (file_format == 'delta' and existing_relation.is_delta) %}
Expand All @@ -48,6 +49,7 @@
{%- call statement('main', language=language) -%}
{{ create_table_as(False, target_relation, compiled_code, language) }}
{%- endcall -%}
{% do persist_constraints(target_relation, model) %}
{%- else -%}
{#-- Relation must be merged --#}
{%- call statement('create_tmp_relation', language=language) -%}
Expand All @@ -63,7 +65,7 @@
See note in dbt-spark/dbt/include/spark/macros/adapters.sql
re: python models and temporary views.

Also, why doesn't either drop_relation or adapter.drop_relation work here?!
Also, why do neither drop_relation or adapter.drop_relation work here?!
--#}
{% call statement('drop_relation') -%}
drop table if exists {{ tmp_relation }}
Expand Down
165 changes: 131 additions & 34 deletions tests/functional/adapter/test_constraints.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
import pytest
from dbt.tests.util import relation_from_name
from dbt.tests.adapter.constraints.test_constraints import (
BaseTableConstraintsColumnsEqual,
BaseViewConstraintsColumnsEqual,
BaseConstraintsRuntimeEnforcement,
BaseIncrementalConstraintsColumnsEqual,
BaseConstraintsRuntimeDdlEnforcement,
BaseConstraintsRollback,
BaseIncrementalConstraintsRuntimeDdlEnforcement,
BaseIncrementalConstraintsRollback,
)
from dbt.tests.adapter.constraints.fixtures import (
my_model_sql,
my_model_wrong_order_sql,
my_model_wrong_name_sql,
model_schema_yml,
my_model_view_wrong_order_sql,
my_model_view_wrong_name_sql,
my_model_incremental_wrong_order_sql,
my_model_incremental_wrong_name_sql,
my_incremental_model_sql,
)

# constraints are enforced via 'alter' statements that run after table creation
_expected_sql_spark = """
create or replace table {0}
create or replace table <model_identifier>
using delta
as
select
Expand All @@ -35,14 +43,6 @@


class PyodbcSetup:
@pytest.fixture(scope="class")
def models(self):
return {
"my_model_wrong_order.sql": my_model_wrong_order_sql,
"my_model_wrong_name.sql": my_model_wrong_name_sql,
"constraints_schema.yml": constraints_yml,
}

@pytest.fixture(scope="class")
def project_config_update(self):
return {
Expand Down Expand Up @@ -81,14 +81,6 @@ def data_types(self, int_type, schema_int_type, string_type):


class DatabricksHTTPSetup:
@pytest.fixture(scope="class")
def models(self):
return {
"my_model_wrong_order.sql": my_model_wrong_order_sql,
"my_model_wrong_name.sql": my_model_wrong_name_sql,
"constraints_schema.yml": constraints_yml,
}

@pytest.fixture
def string_type(self):
return "STRING_TYPE"
Expand Down Expand Up @@ -120,12 +112,37 @@ def data_types(self, int_type, schema_int_type, string_type):

@pytest.mark.skip_profile("spark_session", "apache_spark", "databricks_http_cluster")
class TestSparkTableConstraintsColumnsEqualPyodbc(PyodbcSetup, BaseTableConstraintsColumnsEqual):
pass
@pytest.fixture(scope="class")
def models(self):
return {
"my_model_wrong_order.sql": my_model_wrong_order_sql,
"my_model_wrong_name.sql": my_model_wrong_name_sql,
"constraints_schema.yml": constraints_yml,
}


@pytest.mark.skip_profile("spark_session", "apache_spark", "databricks_http_cluster")
class TestSparkViewConstraintsColumnsEqualPyodbc(PyodbcSetup, BaseViewConstraintsColumnsEqual):
pass
@pytest.fixture(scope="class")
def models(self):
return {
"my_model_wrong_order.sql": my_model_view_wrong_order_sql,
"my_model_wrong_name.sql": my_model_view_wrong_name_sql,
"constraints_schema.yml": constraints_yml,
}


@pytest.mark.skip_profile("spark_session", "apache_spark", "databricks_http_cluster")
class TestSparkIncrementalConstraintsColumnsEqualPyodbc(
PyodbcSetup, BaseIncrementalConstraintsColumnsEqual
):
@pytest.fixture(scope="class")
def models(self):
return {
"my_model_wrong_order.sql": my_model_incremental_wrong_order_sql,
"my_model_wrong_name.sql": my_model_incremental_wrong_name_sql,
"constraints_schema.yml": constraints_yml,
}


@pytest.mark.skip_profile(
Expand All @@ -134,7 +151,13 @@ class TestSparkViewConstraintsColumnsEqualPyodbc(PyodbcSetup, BaseViewConstraint
class TestSparkTableConstraintsColumnsEqualDatabricksHTTP(
DatabricksHTTPSetup, BaseTableConstraintsColumnsEqual
):
pass
@pytest.fixture(scope="class")
def models(self):
return {
"my_model_wrong_order.sql": my_model_wrong_order_sql,
"my_model_wrong_name.sql": my_model_wrong_name_sql,
"constraints_schema.yml": constraints_yml,
}


@pytest.mark.skip_profile(
Expand All @@ -143,18 +166,31 @@ class TestSparkTableConstraintsColumnsEqualDatabricksHTTP(
class TestSparkViewConstraintsColumnsEqualDatabricksHTTP(
DatabricksHTTPSetup, BaseViewConstraintsColumnsEqual
):
pass
@pytest.fixture(scope="class")
def models(self):
return {
"my_model_wrong_order.sql": my_model_view_wrong_order_sql,
"my_model_wrong_name.sql": my_model_view_wrong_name_sql,
"constraints_schema.yml": constraints_yml,
}


@pytest.mark.skip_profile("spark_session", "apache_spark")
class TestSparkConstraintsRuntimeEnforcement(BaseConstraintsRuntimeEnforcement):
@pytest.mark.skip_profile(
"spark_session", "apache_spark", "databricks_sql_endpoint", "databricks_cluster"
)
class TestSparkIncrementalConstraintsColumnsEqualDatabricksHTTP(
DatabricksHTTPSetup, BaseIncrementalConstraintsColumnsEqual
):
@pytest.fixture(scope="class")
def models(self):
return {
"my_model.sql": my_model_sql,
"my_model_wrong_order.sql": my_model_incremental_wrong_order_sql,
"my_model_wrong_name.sql": my_model_incremental_wrong_name_sql,
"constraints_schema.yml": constraints_yml,
}


class BaseSparkConstraintsDdlEnforcementSetup:
@pytest.fixture(scope="class")
def project_config_update(self):
return {
Expand All @@ -164,27 +200,88 @@ def project_config_update(self):
}

@pytest.fixture(scope="class")
def expected_sql(self, project):
relation = relation_from_name(project.adapter, "my_model")
return _expected_sql_spark.format(relation)
def expected_sql(self):
return _expected_sql_spark

# On Spark/Databricks, constraints are applied *after* the table is replaced.
# We don't have any way to "rollback" the table to its previous happy state.
# So the 'color' column will be updated to 'red', instead of 'blue'.

@pytest.mark.skip_profile("spark_session", "apache_spark")
class TestSparkTableConstraintsDdlEnforcement(
BaseSparkConstraintsDdlEnforcementSetup, BaseConstraintsRuntimeDdlEnforcement
):
@pytest.fixture(scope="class")
def expected_color(self):
return "red"
def models(self):
return {
"my_model.sql": my_model_wrong_order_sql,
"constraints_schema.yml": constraints_yml,
}


@pytest.mark.skip_profile("spark_session", "apache_spark")
class TestSparkIncrementalConstraintsDdlEnforcement(
BaseSparkConstraintsDdlEnforcementSetup, BaseIncrementalConstraintsRuntimeDdlEnforcement
):
@pytest.fixture(scope="class")
def models(self):
return {
"my_model.sql": my_model_incremental_wrong_order_sql,
"constraints_schema.yml": constraints_yml,
}


class BaseSparkConstraintsRollbackSetup:
@pytest.fixture(scope="class")
def project_config_update(self):
return {
"models": {
"+file_format": "delta",
}
}

@pytest.fixture(scope="class")
def expected_error_messages(self):
return [
"violate the new CHECK constraint",
"DELTA_NEW_CHECK_CONSTRAINT_VIOLATION",
"violate the new NOT NULL constraint",
"(id > 0) violated by row with values:", # incremental mats
"DELTA_VIOLATE_CONSTRAINT_WITH_VALUES", # incremental mats
]

def assert_expected_error_messages(self, error_message, expected_error_messages):
# This needs to be ANY instead of ALL
# The CHECK constraint is added before the NOT NULL constraint
# and different connection types display/truncate the error message in different ways...
assert any(msg in error_message for msg in expected_error_messages)


@pytest.mark.skip_profile("spark_session", "apache_spark")
class TestSparkTableConstraintsRollback(
BaseSparkConstraintsRollbackSetup, BaseConstraintsRollback
):
@pytest.fixture(scope="class")
def models(self):
return {
"my_model.sql": my_model_sql,
"constraints_schema.yml": constraints_yml,
}

# On Spark/Databricks, constraints are applied *after* the table is replaced.
# We don't have any way to "rollback" the table to its previous happy state.
# So the 'color' column will be updated to 'red', instead of 'blue'.
@pytest.fixture(scope="class")
def expected_color(self):
return "red"


@pytest.mark.skip_profile("spark_session", "apache_spark")
class TestSparkIncrementalConstraintsRollback(
BaseSparkConstraintsRollbackSetup, BaseIncrementalConstraintsRollback
):
# color stays blue for incremental models since it's a new row that just
# doesn't get inserted
@pytest.fixture(scope="class")
def models(self):
return {
"my_model.sql": my_incremental_model_sql,
"constraints_schema.yml": constraints_yml,
}

0 comments on commit 3b3b2a0

Please sign in to comment.