diff --git a/.changes/unreleased/Features-20240606-112334.yaml b/.changes/unreleased/Features-20240606-112334.yaml new file mode 100644 index 00000000000..4a325d6811f --- /dev/null +++ b/.changes/unreleased/Features-20240606-112334.yaml @@ -0,0 +1,6 @@ +kind: Features +body: add pre_model and post_model hook calls to data and unit tests to be able to provide extra config options +time: 2024-06-06T11:23:34.758675-05:00 +custom: + Author: McKnight-42 + Issue: "10198" diff --git a/core/dbt/task/test.py b/core/dbt/task/test.py index 2ae65dc3ebe..546bd43a943 100644 --- a/core/dbt/task/test.py +++ b/core/dbt/task/test.py @@ -126,6 +126,8 @@ def before_execute(self): def execute_data_test(self, data_test: TestNode, manifest: Manifest) -> TestResultData: context = generate_runtime_model_context(data_test, self.config, manifest) + hook_ctx = self.adapter.pre_model_hook(context) + materialization_macro = manifest.find_materialization_macro_by_name( self.config.project_name, data_test.get_materialization(), self.adapter.type() ) @@ -142,8 +144,12 @@ def execute_data_test(self, data_test: TestNode, manifest: Manifest) -> TestResu # generate materialization macro macro_func = MacroGenerator(materialization_macro, context) - # execute materialization macro - macro_func() + try: + # execute materialization macro + macro_func() + finally: + self.adapter.post_model_hook(context, hook_ctx) + # load results from context # could eventually be returned directly by materialization result = context["load_result"]("main") @@ -198,6 +204,8 @@ def execute_unit_test( # materialization, not compile the node.compiled_code context = generate_runtime_model_context(unit_test_node, self.config, unit_test_manifest) + hook_ctx = self.adapter.pre_model_hook(context) + materialization_macro = unit_test_manifest.find_materialization_macro_by_name( self.config.project_name, unit_test_node.get_materialization(), self.adapter.type() ) @@ -215,14 +223,16 @@ def execute_unit_test( # generate materialization macro macro_func = MacroGenerator(materialization_macro, context) - # execute materialization macro try: + # execute materialization macro macro_func() except DbtBaseException as e: raise DbtRuntimeError( f"An error occurred during execution of unit test '{unit_test_def.name}'. " f"There may be an error in the unit test definition: check the data types.\n {e}" ) + finally: + self.adapter.post_model_hook(context, hook_ctx) # load results from context # could eventually be returned directly by materialization diff --git a/tests/functional/data_tests/test_hooks.py b/tests/functional/data_tests/test_hooks.py new file mode 100644 index 00000000000..60eee2f543f --- /dev/null +++ b/tests/functional/data_tests/test_hooks.py @@ -0,0 +1,111 @@ +from unittest import mock + +import pytest + +from dbt.tests.util import run_dbt, run_dbt_and_capture +from dbt_common.exceptions import CompilationError + +orders_csv = """order_id,order_date,customer_id +1,2024-06-01,1001 +2,2024-06-02,1002 +3,2024-06-03,1003 +4,2024-06-04,1004 +""" + + +orders_model_sql = """ +with source as ( + select + order_id, + order_date, + customer_id + from {{ ref('seed_orders') }} +), +final as ( + select + order_id, + order_date, + customer_id + from source +) +select * from final +""" + + +orders_test_sql = """ +select * +from {{ ref('orders') }} +where order_id is null +""" + + +class BaseSingularTestHooks: + @pytest.fixture(scope="class") + def seeds(self): + return {"seed_orders.csv": orders_csv} + + @pytest.fixture(scope="class") + def models(self): + return {"orders.sql": orders_model_sql} + + @pytest.fixture(scope="class") + def tests(self): + return {"orders_test.sql": orders_test_sql} + + +class TestSingularTestPreHook(BaseSingularTestHooks): + def test_data_test_runs_adapter_pre_hook_pass(self, project): + results = run_dbt(["seed"]) + assert len(results) == 1 + + results = run_dbt(["run"]) + assert len(results) == 1 + + mock_pre_model_hook = mock.Mock() + with mock.patch.object(type(project.adapter), "pre_model_hook", mock_pre_model_hook): + results = run_dbt(["test"], expect_pass=True) + assert len(results) == 1 + mock_pre_model_hook.assert_called_once() + + def test_data_test_runs_adapter_pre_hook_fails(self, project): + results = run_dbt(["seed"]) + assert len(results) == 1 + + results = run_dbt(["run"]) + assert len(results) == 1 + + mock_pre_model_hook = mock.Mock() + mock_pre_model_hook.side_effect = CompilationError("exception from adapter.pre_model_hook") + with mock.patch.object(type(project.adapter), "pre_model_hook", mock_pre_model_hook): + (_, log_output) = run_dbt_and_capture(["test"], expect_pass=False) + assert "exception from adapter.pre_model_hook" in log_output + + +class TestSingularTestPostHook(BaseSingularTestHooks): + def test_data_test_runs_adapter_post_hook_pass(self, project): + results = run_dbt(["seed"]) + assert len(results) == 1 + + results = run_dbt(["run"]) + assert len(results) == 1 + + mock_post_model_hook = mock.Mock() + with mock.patch.object(type(project.adapter), "post_model_hook", mock_post_model_hook): + results = run_dbt(["test"], expect_pass=True) + assert len(results) == 1 + mock_post_model_hook.assert_called_once() + + def test_data_test_runs_adapter_post_hook_fails(self, project): + results = run_dbt(["seed"]) + assert len(results) == 1 + + results = run_dbt(["run"]) + assert len(results) == 1 + + mock_post_model_hook = mock.Mock() + mock_post_model_hook.side_effect = CompilationError( + "exception from adapter.post_model_hook" + ) + with mock.patch.object(type(project.adapter), "post_model_hook", mock_post_model_hook): + (_, log_output) = run_dbt_and_capture(["test"], expect_pass=False) + assert "exception from adapter.post_model_hook" in log_output diff --git a/tests/functional/unit_testing/fixtures.py b/tests/functional/unit_testing/fixtures.py index 3028e0bc1e6..e73351f89d8 100644 --- a/tests/functional/unit_testing/fixtures.py +++ b/tests/functional/unit_testing/fixtures.py @@ -116,6 +116,23 @@ tags: test_this """ +test_my_model_pass_yml = """ +unit_tests: + - name: test_my_model + model: my_model + given: + - input: ref('my_model_a') + rows: + - {id: 1, a: 1} + - input: ref('my_model_b') + rows: + - {id: 1, b: 2} + - {id: 2, b: 2} + expect: + rows: + - {c: 3} +""" + test_my_model_simple_fixture_yml = """ unit_tests: diff --git a/tests/functional/unit_testing/test_ut_adapter_hooks.py b/tests/functional/unit_testing/test_ut_adapter_hooks.py new file mode 100644 index 00000000000..a2f496752e2 --- /dev/null +++ b/tests/functional/unit_testing/test_ut_adapter_hooks.py @@ -0,0 +1,75 @@ +from unittest import mock + +import pytest + +from dbt.tests.util import run_dbt, run_dbt_and_capture +from dbt_common.exceptions import CompilationError +from tests.functional.unit_testing.fixtures import ( + my_model_a_sql, + my_model_b_sql, + my_model_sql, + test_my_model_pass_yml, +) + + +class BaseUnitTestAdapterHook: + @pytest.fixture(scope="class") + def models(self): + return { + "my_model.sql": my_model_sql, + "my_model_a.sql": my_model_a_sql, + "my_model_b.sql": my_model_b_sql, + "test_my_model.yml": test_my_model_pass_yml, + } + + +class TestUnitTestAdapterPreHook(BaseUnitTestAdapterHook): + def test_unit_test_runs_adapter_pre_hook_passes(self, project): + results = run_dbt(["run"]) + assert len(results) == 3 + + mock_pre_model_hook = mock.Mock() + with mock.patch.object(type(project.adapter), "pre_model_hook", mock_pre_model_hook): + results = run_dbt(["test", "--select", "test_name:test_my_model"], expect_pass=True) + + assert len(results) == 1 + mock_pre_model_hook.assert_called_once() + + def test_unit_test_runs_adapter_pre_hook_fails(self, project): + results = run_dbt(["run"]) + assert len(results) == 3 + + mock_pre_model_hook = mock.Mock() + mock_pre_model_hook.side_effect = CompilationError("exception from adapter.pre_model_hook") + with mock.patch.object(type(project.adapter), "pre_model_hook", mock_pre_model_hook): + (_, log_output) = run_dbt_and_capture( + ["test", "--select", "test_name:test_my_model"], expect_pass=False + ) + assert "exception from adapter.pre_model_hook" in log_output + + +class TestUnitTestAdapterPostHook(BaseUnitTestAdapterHook): + def test_unit_test_runs_adapter_post_hook_pass(self, project): + results = run_dbt(["run"]) + assert len(results) == 3 + + mock_post_model_hook = mock.Mock() + with mock.patch.object(type(project.adapter), "post_model_hook", mock_post_model_hook): + results = run_dbt(["test", "--select", "test_name:test_my_model"], expect_pass=True) + + assert len(results) == 1 + mock_post_model_hook.assert_called_once() + + def test_unit_test_runs_adapter_post_hook_fails(self, project): + results = run_dbt(["run"]) + assert len(results) == 3 + + mock_post_model_hook = mock.Mock() + mock_post_model_hook.side_effect = CompilationError( + "exception from adapter.post_model_hook" + ) + with mock.patch.object(type(project.adapter), "post_model_hook", mock_post_model_hook): + (_, log_output) = run_dbt_and_capture( + ["test", "--select", "test_name:test_my_model"], expect_pass=False + ) + assert "exception from adapter.post_model_hook" in log_output