diff --git a/tests/functional/adapter/test_constraints.py b/tests/functional/adapter/test_constraints.py index e2cfe1b96..45df88fb5 100644 --- a/tests/functional/adapter/test_constraints.py +++ b/tests/functional/adapter/test_constraints.py @@ -1,15 +1,18 @@ import pytest -from dbt.tests.util import relation_from_name from dbt.tests.adapter.constraints.test_constraints import ( BaseTableConstraintsColumnsEqual, BaseViewConstraintsColumnsEqual, - BaseConstraintsRuntimeEnforcement + BaseIncrementalConstraintsColumnsEqual, + BaseConstraintsRuntimeDdlEnforcement, + BaseConstraintsRollback, + BaseIncrementalConstraintsRuntimeDdlEnforcement, + BaseIncrementalConstraintsRollback, ) _expected_sql_snowflake = """ -create or replace transient table {0} ( +create or replace transient table ( id integer not null primary key, color text, date_day text @@ -51,6 +54,7 @@ def data_types(self, int_type, schema_int_type, string_type): ["""TO_VARIANT(PARSE_JSON('{"key3": "value3", "key4": "value4"}'))""", 'variant', 'VARIANT'], ] + class TestSnowflakeTableConstraintsColumnsEqual(SnowflakeColumnEqualSetup, BaseTableConstraintsColumnsEqual): pass @@ -59,13 +63,29 @@ class TestSnowflakeViewConstraintsColumnsEqual(SnowflakeColumnEqualSetup, BaseVi pass -class TestSnowflakeConstraintsRuntimeEnforcement(BaseConstraintsRuntimeEnforcement): +class TestSnowflakeIncrementalConstraintsColumnsEqual(SnowflakeColumnEqualSetup, BaseIncrementalConstraintsColumnsEqual): + pass + + +class TestSnowflakeTableConstraintsDdlEnforcement(BaseConstraintsRuntimeDdlEnforcement): + @pytest.fixture(scope="class") + def expected_sql(self): + return _expected_sql_snowflake + + +class TestSnowflakeIncrementalConstraintsDdlEnforcement(BaseIncrementalConstraintsRuntimeDdlEnforcement): + @pytest.fixture(scope="class") + def expected_sql(self): + return _expected_sql_snowflake + +class TestSnowflakeTableConstraintsRollback(BaseConstraintsRollback): @pytest.fixture(scope="class") - def expected_sql(self, project): - relation = relation_from_name(project.adapter, "my_model") - return _expected_sql_snowflake.format(relation) + def expected_error_messages(self): + return ["NULL result in a non-nullable column"] + +class TestSnowflakeIncrementalConstraintsRollback(BaseIncrementalConstraintsRollback): @pytest.fixture(scope="class") def expected_error_messages(self): return ["NULL result in a non-nullable column"]