From 6232097722af8efd211035982bd40f58a1a05872 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Tue, 24 Oct 2023 20:39:36 -0400 Subject: [PATCH 1/5] first pass --- core/dbt/context/providers.py | 8 ++ core/dbt/contracts/graph/nodes.py | 1 + core/dbt/exceptions.py | 6 ++ core/dbt/parser/unit_tests.py | 80 ++++++++++++------- .../unit_testing/test_unit_testing.py | 74 +++++++++++++++++ 5 files changed, 139 insertions(+), 30 deletions(-) diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index 27d3b1fcd01..9cd3eaf9339 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -1541,6 +1541,14 @@ def env_var(self, var: str, default: Optional[str] = None) -> str: else: return super().env_var(var, default) + @contextproperty() + def this(self) -> Optional[str]: + if self.model.this: + # TODO: RuntimeRefResolver.set_cte also passes None as second argument. + self.model.set_cte(self.model.this.unique_id, None) # type: ignore + return self.adapter.Relation.add_ephemeral_prefix(self.model.this.name) + return None + # This is called by '_context_for', used in 'render_with_context' def generate_parser_model_context( diff --git a/core/dbt/contracts/graph/nodes.py b/core/dbt/contracts/graph/nodes.py index 7b864c41440..e93a3db35bb 100644 --- a/core/dbt/contracts/graph/nodes.py +++ b/core/dbt/contracts/graph/nodes.py @@ -1068,6 +1068,7 @@ class UnitTestNode(CompiledNode): attached_node: Optional[str] = None overrides: Optional[UnitTestOverrides] = None config: UnitTestNodeConfig = field(default_factory=UnitTestNodeConfig) + this: Optional[ModelNode] = None @dataclass diff --git a/core/dbt/exceptions.py b/core/dbt/exceptions.py index 1045058877a..3bae1bd6ead 100644 --- a/core/dbt/exceptions.py +++ b/core/dbt/exceptions.py @@ -1220,6 +1220,12 @@ def __init__( super().__init__(msg=msg) +class InvalidUnitTestGivenInput(ParsingError): + def __init__(self, input: str) -> None: + msg = f"Unit test given inputs must be either a 'ref', 'source' or 'this' call. Got: '{input}'." + super().__init__(msg=msg) + + class SameKeyNestedError(CompilationError): def __init__(self) -> None: msg = "Test cannot have the same key at the top-level and in config" diff --git a/core/dbt/parser/unit_tests.py b/core/dbt/parser/unit_tests.py index fa8aa6c48c3..005a6c39f4e 100644 --- a/core/dbt/parser/unit_tests.py +++ b/core/dbt/parser/unit_tests.py @@ -9,13 +9,12 @@ from dbt.contracts.graph.nodes import ( ModelNode, UnitTestNode, - RefArgs, UnitTestDefinition, DependsOn, UnitTestConfig, ) from dbt.contracts.graph.unparsed import UnparsedUnitTestSuite -from dbt.exceptions import ParsingError +from dbt.exceptions import ParsingError, InvalidUnitTestGivenInput from dbt.graph import UniqueId from dbt.node_types import NodeType from dbt.parser.schemas import ( @@ -28,7 +27,7 @@ ParseResult, ) from dbt.utils import get_pseudo_test_path -from dbt_extractor import py_extract_from_source # type: ignore +from dbt_extractor import py_extract_from_source, ExtractionError # type: ignore class UnitTestManifestLoader: @@ -53,6 +52,7 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition): actual_node = self.manifest.ref_lookup.perform_lookup( f"model.{package_name}.{test_case.model}", self.manifest ) + assert isinstance(actual_node, ModelNode) # Create UnitTestNode based on model being tested. Since selection has # already been done, we don't have to care about fields that are necessary @@ -106,7 +106,7 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition): # input models substituting for the same input ref'd model. for given in test_case.given: # extract the original_input_node from the ref in the "input" key of the given list - original_input_node = self._get_original_input_node(given.input) + original_input_node = self._get_original_input_node(given.input, actual_node) original_input_node_columns = None if ( @@ -117,11 +117,13 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition): column.name: column.data_type for column in original_input_node.columns } - # TODO: package_name? - input_name = f"{test_case.model}__{test_case.name}__{original_input_node.name}" + # TODO: include package_name? + input_name = f"{unit_test_node.name}__{original_input_node.name}" input_unique_id = f"model.{package_name}.{input_name}" input_node = ModelNode( - raw_code=self._build_raw_code(given.get_rows(), original_input_node_columns), + raw_code=self._build_fixture_raw_code( + given.get_rows(), original_input_node_columns + ), resource_type=NodeType.Model, package_name=package_name, path=original_input_node.path, @@ -136,37 +138,55 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition): checksum=FileHash.empty(), ) self.unit_test_manifest.nodes[input_node.unique_id] = input_node + + # Store input_node on unit_test_node for ease of access in UnitTestContext.this + if original_input_node == actual_node: + unit_test_node.this = input_node + # Add unique ids of input_nodes to depends_on unit_test_node.depends_on.nodes.append(input_node.unique_id) - def _build_raw_code(self, rows, column_name_to_data_types) -> str: + def _build_fixture_raw_code(self, rows, column_name_to_data_types) -> str: return ("{{{{ get_fixture_sql({rows}, {column_name_to_data_types}) }}}}").format( rows=rows, column_name_to_data_types=column_name_to_data_types ) - def _get_original_input_node(self, input: str): - """input: ref('my_model_a')""" - # Exract the ref or sources - statically_parsed = py_extract_from_source(f"{{{{ {input} }}}}") - if statically_parsed["refs"]: - # set refs and sources on the node object - refs: List[RefArgs] = [] - for ref in statically_parsed["refs"]: - name = ref.get("name") - package = ref.get("package") - version = ref.get("version") - refs.append(RefArgs(name, package, version)) - # TODO: disabled lookup, versioned lookup, public models - original_input_node = self.manifest.ref_lookup.find( - name, package, version, self.manifest - ) - elif statically_parsed["sources"]: - input_package_name, input_source_name = statically_parsed["sources"][0] - original_input_node = self.manifest.source_lookup.find( - input_source_name, input_package_name, self.manifest - ) + def _get_original_input_node(self, input: str, tested_node: ModelNode): + """ + Returns the original input node as defined in the project given an input reference + and the node being tested. + + input: str representing how input node is referenced in tested model sql + * examples: + - "ref('my_model_a')" + - "source('my_source_schema', 'my_source_name')" + - "this" + tested_node: ModelNode of representing node being tested + """ + if input.strip() == "this": + original_input_node = tested_node else: - raise ParsingError("given input must be ref or source") + try: + statically_parsed = py_extract_from_source(f"{{{{ {input} }}}}") + except ExtractionError: + raise InvalidUnitTestGivenInput(input=input) + + if statically_parsed["refs"]: + for ref in statically_parsed["refs"]: + name = ref.get("name") + package = ref.get("package") + version = ref.get("version") + # TODO: disabled lookup, versioned lookup, public models + original_input_node = self.manifest.ref_lookup.find( + name, package, version, self.manifest + ) + elif statically_parsed["sources"]: + input_package_name, input_source_name = statically_parsed["sources"][0] + original_input_node = self.manifest.source_lookup.find( + input_source_name, input_package_name, self.manifest + ) + else: + raise InvalidUnitTestGivenInput(input=input) return original_input_node diff --git a/tests/functional/unit_testing/test_unit_testing.py b/tests/functional/unit_testing/test_unit_testing.py index 4cb426a5343..d520ebc1543 100644 --- a/tests/functional/unit_testing/test_unit_testing.py +++ b/tests/functional/unit_testing/test_unit_testing.py @@ -331,3 +331,77 @@ def test_basic(self, project): ) with pytest.raises(ParsingError): results = run_dbt(["unit-test", "--select", "my_model"], expect_pass=False) + + +event_sql = """ +select DATE '2020-01-01' as event_time, 1 as event +union all +select DATE '2020-01-02' as event_time, 2 as event +union all +select DATE '2020-01-03' as event_time, 3 as event +""" + +my_incremental_model_sql = """ +{{ + config( + materialized='incremental' + ) +}} + +select * from {{ ref('events') }} +{% if is_incremental() %} +where event_time > (select max(event_time) from {{ this }}) +{% endif %} +""" + +test_my_model_incremental_yml = """ +unit: + - model: my_incremental_model + tests: + - name: incremental_false + overrides: + macros: + is_incremental: false + given: + - input: ref('events') + rows: + - {event_time: "2020-01-01", event: 1} + expect: + rows: + - {event_time: "2020-01-01", event: 1} + - name: incremental_true + overrides: + macros: + is_incremental: true + given: + - input: ref('events') + rows: + - {event_time: "2020-01-01", event: 1} + - {event_time: "2020-01-02", event: 2} + - {event_time: "2020-01-03", event: 3} + - input: this + rows: + - {event_time: "2020-01-01", event: 1} + expect: + rows: + - {event_time: "2020-01-02", event: 2} + - {event_time: "2020-01-03", event: 3} +""" + + +class TestUnitTestIncrementalModel: + @pytest.fixture(scope="class") + def models(self): + return { + "my_incremental_model.sql": my_incremental_model_sql, + "events.sql": event_sql, + "test_my_incremental_model.yml": test_my_model_incremental_yml, + } + + def test_basic(self, project): + results = run_dbt(["run"]) + assert len(results) == 2 + + # Select by model name + results = run_dbt(["unit-test", "--select", "my_incremental_model"], expect_pass=True) + assert len(results) == 2 From 96bbbfd5a4f4a33d3535de226912bfa7582a14cb Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Wed, 1 Nov 2023 10:18:52 -0400 Subject: [PATCH 2/5] changelog entry --- .changes/unreleased/Features-20231101-101845.yaml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 .changes/unreleased/Features-20231101-101845.yaml diff --git a/.changes/unreleased/Features-20231101-101845.yaml b/.changes/unreleased/Features-20231101-101845.yaml new file mode 100644 index 00000000000..603990ce2e7 --- /dev/null +++ b/.changes/unreleased/Features-20231101-101845.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Support unit testing incremental models +time: 2023-11-01T10:18:45.341781-04:00 +custom: + Author: michelleark + Issue: "8422" From adf9f2ae37194c23fae0b5f64ea1d6accf6e2b77 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Wed, 1 Nov 2023 16:56:04 -0400 Subject: [PATCH 3/5] remove this from UnitTestNode, rename attached_node to tested_node --- core/dbt/context/providers.py | 8 ++++---- core/dbt/contracts/graph/nodes.py | 7 +++++-- core/dbt/parser/unit_tests.py | 20 ++++++++------------ 3 files changed, 17 insertions(+), 18 deletions(-) diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index 9cd3eaf9339..e5e11c5c9a1 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -1543,10 +1543,10 @@ def env_var(self, var: str, default: Optional[str] = None) -> str: @contextproperty() def this(self) -> Optional[str]: - if self.model.this: - # TODO: RuntimeRefResolver.set_cte also passes None as second argument. - self.model.set_cte(self.model.this.unique_id, None) # type: ignore - return self.adapter.Relation.add_ephemeral_prefix(self.model.this.name) + if self.model.this_model_fixture_unique_id in self.manifest.nodes: + this_node = self.manifest.expect(self.model.this_model_fixture_unique_id) + self.model.set_cte(self.model.this_model_fixture_unique_id, None) # type: ignore + return self.adapter.Relation.add_ephemeral_prefix(this_node.name) return None diff --git a/core/dbt/contracts/graph/nodes.py b/core/dbt/contracts/graph/nodes.py index e93a3db35bb..e3a142462cc 100644 --- a/core/dbt/contracts/graph/nodes.py +++ b/core/dbt/contracts/graph/nodes.py @@ -1065,10 +1065,13 @@ def test_node_type(self): @dataclass class UnitTestNode(CompiledNode): resource_type: NodeType = field(metadata={"restrict": [NodeType.Unit]}) - attached_node: Optional[str] = None + tested_node: Optional[ModelNode] = None overrides: Optional[UnitTestOverrides] = None config: UnitTestNodeConfig = field(default_factory=UnitTestNodeConfig) - this: Optional[ModelNode] = None + + @property + def this_model_fixture_unique_id(self): + return f"model.{self.tested_node.package_name}.{self.name}__{self.tested_node.name}" @dataclass diff --git a/core/dbt/parser/unit_tests.py b/core/dbt/parser/unit_tests.py index 005a6c39f4e..55f11107261 100644 --- a/core/dbt/parser/unit_tests.py +++ b/core/dbt/parser/unit_tests.py @@ -48,11 +48,11 @@ def load(self) -> Manifest: def parse_unit_test_case(self, test_case: UnitTestDefinition): package_name = self.root_project.project_name - # Create unit test node based on the "actual" tested node - actual_node = self.manifest.ref_lookup.perform_lookup( + # Create unit test node based on the node being tested + tested_node = self.manifest.ref_lookup.perform_lookup( f"model.{package_name}.{test_case.model}", self.manifest ) - assert isinstance(actual_node, ModelNode) + assert isinstance(tested_node, ModelNode) # Create UnitTestNode based on model being tested. Since selection has # already been done, we don't have to care about fields that are necessary @@ -69,13 +69,13 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition): config=UnitTestNodeConfig( materialized="unit", expected_rows=test_case.expect.get_rows() ), - raw_code=actual_node.raw_code, - database=actual_node.database, - schema=actual_node.schema, + raw_code=tested_node.raw_code, + database=tested_node.database, + schema=tested_node.schema, alias=name, fqn=test_case.unique_id.split("."), checksum=FileHash.empty(), - attached_node=actual_node.unique_id, + tested_node=tested_node, overrides=test_case.overrides, ) @@ -106,7 +106,7 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition): # input models substituting for the same input ref'd model. for given in test_case.given: # extract the original_input_node from the ref in the "input" key of the given list - original_input_node = self._get_original_input_node(given.input, actual_node) + original_input_node = self._get_original_input_node(given.input, tested_node) original_input_node_columns = None if ( @@ -139,10 +139,6 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition): ) self.unit_test_manifest.nodes[input_node.unique_id] = input_node - # Store input_node on unit_test_node for ease of access in UnitTestContext.this - if original_input_node == actual_node: - unit_test_node.this = input_node - # Add unique ids of input_nodes to depends_on unit_test_node.depends_on.nodes.append(input_node.unique_id) From 470f977da03737986a972223f00438263b4775af Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Wed, 1 Nov 2023 17:08:04 -0400 Subject: [PATCH 4/5] store this_input_node_unique_id explicitly on UnitTestNode --- core/dbt/context/providers.py | 6 +++--- core/dbt/contracts/graph/nodes.py | 7 ++----- core/dbt/parser/unit_tests.py | 6 +++++- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index e5e11c5c9a1..4316a28fc53 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -1543,9 +1543,9 @@ def env_var(self, var: str, default: Optional[str] = None) -> str: @contextproperty() def this(self) -> Optional[str]: - if self.model.this_model_fixture_unique_id in self.manifest.nodes: - this_node = self.manifest.expect(self.model.this_model_fixture_unique_id) - self.model.set_cte(self.model.this_model_fixture_unique_id, None) # type: ignore + if self.model.this_input_node_unique_id: + this_node = self.manifest.expect(self.model.this_input_node_unique_id) + self.model.set_cte(self.model.this_input_node_unique_id, None) # type: ignore return self.adapter.Relation.add_ephemeral_prefix(this_node.name) return None diff --git a/core/dbt/contracts/graph/nodes.py b/core/dbt/contracts/graph/nodes.py index e3a142462cc..184dce50a3d 100644 --- a/core/dbt/contracts/graph/nodes.py +++ b/core/dbt/contracts/graph/nodes.py @@ -1065,14 +1065,11 @@ def test_node_type(self): @dataclass class UnitTestNode(CompiledNode): resource_type: NodeType = field(metadata={"restrict": [NodeType.Unit]}) - tested_node: Optional[ModelNode] = None + tested_node_unique_id: Optional[str] = None + this_input_node_unique_id: Optional[str] = None overrides: Optional[UnitTestOverrides] = None config: UnitTestNodeConfig = field(default_factory=UnitTestNodeConfig) - @property - def this_model_fixture_unique_id(self): - return f"model.{self.tested_node.package_name}.{self.name}__{self.tested_node.name}" - @dataclass class UnitTestDefinition(GraphNode): diff --git a/core/dbt/parser/unit_tests.py b/core/dbt/parser/unit_tests.py index 55f11107261..98cd08bb4b6 100644 --- a/core/dbt/parser/unit_tests.py +++ b/core/dbt/parser/unit_tests.py @@ -75,7 +75,7 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition): alias=name, fqn=test_case.unique_id.split("."), checksum=FileHash.empty(), - tested_node=tested_node, + tested_node_unique_id=tested_node.unique_id, overrides=test_case.overrides, ) @@ -139,6 +139,10 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition): ) self.unit_test_manifest.nodes[input_node.unique_id] = input_node + # Populate this_input_node_unique_id if input fixture represents node being tested + if original_input_node == tested_node: + unit_test_node.this_input_node_unique_id = input_node.unique_id + # Add unique ids of input_nodes to depends_on unit_test_node.depends_on.nodes.append(input_node.unique_id) From 8d028cbb3f5c08b4f5778f66b4dc2fdfe010e120 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Wed, 1 Nov 2023 17:09:32 -0400 Subject: [PATCH 5/5] improve readability in UnitTestContext.this --- core/dbt/context/providers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index 4316a28fc53..ef102068e8f 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -1545,7 +1545,7 @@ def env_var(self, var: str, default: Optional[str] = None) -> str: def this(self) -> Optional[str]: if self.model.this_input_node_unique_id: this_node = self.manifest.expect(self.model.this_input_node_unique_id) - self.model.set_cte(self.model.this_input_node_unique_id, None) # type: ignore + self.model.set_cte(this_node.unique_id, None) # type: ignore return self.adapter.Relation.add_ephemeral_prefix(this_node.name) return None