Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix rendering dbt tests with multiple parents #1433

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
73 changes: 72 additions & 1 deletion cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from collections import defaultdict
from typing import Any, Callable, Union

from airflow.models import BaseOperator
Expand Down Expand Up @@ -74,6 +75,26 @@ def calculate_leaves(tasks_ids: list[str], nodes: dict[str, DbtNode]) -> list[st
return leaves


def exclude_detached_tests_if_needed(
node: DbtNode,
task_args: dict[str, str],
detached_from_parent: dict[str, DbtNode] | None = None,
) -> None:
"""
Add exclude statements if there are tests associated to the model that should be run detached from the model/tests.

Change task_args in-place.
"""
if detached_from_parent is None:
detached_from_parent = {}
exclude: list[str] = task_args.get("exclude", []) # type: ignore
tests_detached_from_this_node: list[DbtNode] = detached_from_parent.get(node.unique_id, []) # type: ignore
for test_node in tests_detached_from_this_node:
exclude.append(test_node.resource_name.split(".")[0])
if exclude:
task_args["exclude"] = exclude # type: ignore


def create_test_task_metadata(
test_task_name: str,
execution_mode: ExecutionMode,
Expand All @@ -82,6 +103,7 @@ def create_test_task_metadata(
on_warning_callback: Callable[..., Any] | None = None,
node: DbtNode | None = None,
render_config: RenderConfig | None = None,
detached_from_parent: dict[str, DbtNode] | None = None,
) -> TaskMetadata:
"""
Create the metadata that will be used to instantiate the Airflow Task that will be used to run the Dbt test node.
Expand All @@ -92,11 +114,13 @@ def create_test_task_metadata(
:param on_warning_callback: A callback function called on warnings with additional Context variables “test_names”
and “test_results” of type List.
:param node: If the test relates to a specific node, the node reference
:param detached_from_parent: Dictionary that maps node ids and their children tests that should be run detached
:returns: The metadata necessary to instantiate the source dbt node as an Airflow task.
"""
task_args = dict(task_args)
task_args["on_warning_callback"] = on_warning_callback
extra_context = {}
detached_from_parent = detached_from_parent or {}

task_owner = ""
airflow_task_config = {}
Expand All @@ -119,6 +143,9 @@ def create_test_task_metadata(
task_args["selector"] = render_config.selector
task_args["exclude"] = render_config.exclude

if node:
exclude_detached_tests_if_needed(node, task_args, detached_from_parent)

return TaskMetadata(
id=test_task_name,
owner=task_owner,
Expand Down Expand Up @@ -192,6 +219,7 @@ def create_task_metadata(
normalize_task_id: Callable[..., Any] | None = None,
test_behavior: TestBehavior = TestBehavior.AFTER_ALL,
on_warning_callback: Callable[..., Any] | None = None,
detached_from_parent: dict[str, DbtNode] | None = None,
) -> TaskMetadata | None:
"""
Create the metadata that will be used to instantiate the Airflow Task used to run the Dbt node.
Expand All @@ -205,6 +233,7 @@ def create_task_metadata(
If it is False, then use the name as a prefix for the task id, otherwise do not.
:param on_warning_callback: A callback function called on warnings with additional Context variables “test_names”
and “test_results” of type List. This is param available for dbt test and dbt source freshness command.
:param detached_from_parent: Dictionary that maps node ids and their children tests that should be run detached
:returns: The metadata necessary to instantiate the source dbt node as an Airflow task.
"""
dbt_resource_to_class = create_dbt_resource_to_class(test_behavior)
Expand All @@ -218,6 +247,7 @@ def create_task_metadata(
}

if test_behavior == TestBehavior.BUILD and node.resource_type in SUPPORTED_BUILD_RESOURCES:
exclude_detached_tests_if_needed(node, args, detached_from_parent)
task_id, args = _get_task_id_and_args(
node, args, use_task_group, normalize_task_id, "build", include_resource_type=True
)
Expand Down Expand Up @@ -268,6 +298,17 @@ def create_task_metadata(
return None


def is_detached_test(node: DbtNode) -> bool:
"""
Identify if node should be rendered detached from the parent. Conditions that should be met:
* is a test
* has multiple parents
"""
if node.resource_type == DbtResourceType.TEST and len(node.depends_on) > 1:
return True
return False


def generate_task_or_group(
dag: DAG,
task_group: TaskGroup | None,
Expand All @@ -279,9 +320,11 @@ def generate_task_or_group(
test_indirect_selection: TestIndirectSelection,
on_warning_callback: Callable[..., Any] | None,
normalize_task_id: Callable[..., Any] | None = None,
detached_from_parent: dict[str, DbtNode] | None = None,
**kwargs: Any,
) -> BaseOperator | TaskGroup | None:
task_or_group: BaseOperator | TaskGroup | None = None
detached_from_parent = detached_from_parent or {}

use_task_group = (
node.resource_type in TESTABLE_DBT_RESOURCES
Expand All @@ -299,12 +342,13 @@ def generate_task_or_group(
normalize_task_id=normalize_task_id,
test_behavior=test_behavior,
on_warning_callback=on_warning_callback,
detached_from_parent=detached_from_parent,
)

# In most cases, we'll map one DBT node to one Airflow task
# The exception are the test nodes, since it would be too slow to run test tasks individually.
# If test_behaviour=="after_each", each model task will be bundled with a test task, using TaskGroup
if task_meta and node.resource_type != DbtResourceType.TEST:
if task_meta and not node.resource_type == DbtResourceType.TEST:
if use_task_group:
with TaskGroup(dag=dag, group_id=node.name, parent_group=task_group) as model_task_group:
task = create_airflow_task(task_meta, dag, task_group=model_task_group)
Expand All @@ -315,12 +359,14 @@ def generate_task_or_group(
task_args=task_args,
node=node,
on_warning_callback=on_warning_callback,
detached_from_parent=detached_from_parent,
)
test_task = create_airflow_task(test_meta, dag, task_group=model_task_group)
task >> test_task
task_or_group = model_task_group
else:
task_or_group = create_airflow_task(task_meta, dag, task_group=task_group)

return task_or_group


Expand Down Expand Up @@ -405,6 +451,16 @@ def build_airflow_graph(
tasks_map: dict[str, Union[TaskGroup, BaseOperator]] = {}
task_or_group: TaskGroup | BaseOperator

# Identify test nodes that should be run detached from the associated dbt resource nodes because they
# have multiple parents
detached_from_parent = defaultdict(list)
detached_nodes = {}
for node_id, node in nodes.items():
if is_detached_test(node):
detached_nodes[node_id] = node
for parent_id in node.depends_on:
detached_from_parent[parent_id].append(node)

for node_id, node in nodes.items():
conversion_function = node_converters.get(node.resource_type, generate_task_or_group)
if conversion_function != generate_task_or_group:
Expand All @@ -425,11 +481,26 @@ def build_airflow_graph(
on_warning_callback=on_warning_callback,
normalize_task_id=normalize_task_id,
node=node,
detached_from_parent=detached_from_parent,
)
if task_or_group is not None:
logger.debug(f"Conversion of <{node.unique_id}> was successful!")
tasks_map[node_id] = task_or_group

# Handle detached test nodes
for node_id, node in detached_nodes.items():
test_meta = create_test_task_metadata(
f"{node.resource_name.split('.')[0]}_test",
execution_mode,
test_indirect_selection,
task_args=task_args,
on_warning_callback=on_warning_callback,
render_config=render_config,
node=node,
)
test_task = create_airflow_task(test_meta, dag, task_group=task_group)
tasks_map[node_id] = test_task

# If test_behaviour=="after_all", there will be one test task, run by the end of the DAG
# The end of a DAG is defined by the DAG leaf tasks (tasks which do not have downstream tasks)
if test_behavior == TestBehavior.AFTER_ALL:
Expand Down
15 changes: 8 additions & 7 deletions dev/dags/dbt/jaffle_shop/models/schema.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,14 @@ models:
- not_null
description: This is a unique identifier for an order

- name: customer_id
description: Foreign key to the customers table
tests:
- not_null
- relationships:
to: ref('customers')
field: customer_id
# Comment so we don't have a standalone test relationships_orders_customer_id__customer_id__ref_customers__test
#- name: customer_id
# description: Foreign key to the customers table
# tests:
# - not_null
# - relationships:
# to: ref('customers')
# field: customer_id

- name: order_date
description: Date (UTC) that the order was placed
Expand Down
12 changes: 12 additions & 0 deletions dev/dags/dbt/multiple_parents_test/dbt_project.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
name: 'my_dbt_project'
version: '1.0.0'
config-version: 2

profile: 'default'

model-paths: ["models"]
test-paths: ["tests"]

models:
my_dbt_project:
materialized: view
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{% test custom_test_combined_model(model) %}
WITH source_data AS (
SELECT id FROM {{ ref('model_a') }}
),
combined_data AS (
SELECT id FROM {{ model }}
)
SELECT
s.id
FROM
source_data s
LEFT JOIN
combined_data c
ON s.id = c.id
WHERE
c.id IS NULL
{% endtest %}
16 changes: 16 additions & 0 deletions dev/dags/dbt/multiple_parents_test/models/combined_model.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
-- Combine data from model_a and model_b
WITH model_a AS (
SELECT * FROM {{ ref('model_a') }}
),
model_b AS (
SELECT * FROM {{ ref('model_b') }}
)
SELECT
a.id,
a.name,
b.created_at
FROM
model_a AS a
JOIN
model_b AS b
ON a.id = b.id
4 changes: 4 additions & 0 deletions dev/dags/dbt/multiple_parents_test/models/model_a.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-- Create a simple table
SELECT 1 AS id, 'Alice' AS name
UNION ALL
SELECT 2 AS id, 'Bob' AS name
4 changes: 4 additions & 0 deletions dev/dags/dbt/multiple_parents_test/models/model_b.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-- Create another simple table
SELECT 1 AS id, '2024-12-25'::date AS created_at
UNION ALL
SELECT 2 AS id, '2024-12-26'::date AS created_at
32 changes: 32 additions & 0 deletions dev/dags/dbt/multiple_parents_test/models/schema.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
version: 2

models:
- name: model_a
description: "A simple model with user data"
tests:
- unique:
column_name: id

- name: model_b
description: "A simple model with date data"
tests:
- unique:
column_name: id

- name: combined_model
description: "Combines data from model_a and model_b"
columns:
- name: id
tests:
- not_null

- name: name
tests:
- not_null

- name: created_at
tests:
- not_null

tests:
- custom_test_combined_model: {}
12 changes: 12 additions & 0 deletions dev/dags/dbt/multiple_parents_test/profiles.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
default:
target: dev
outputs:
dev:
type: postgres
host: "{{ env_var('POSTGRES_HOST') }}"
user: "{{ env_var('POSTGRES_USER') }}"
password: "{{ env_var('POSTGRES_PASSWORD') }}"
port: "{{ env_var('POSTGRES_PORT') | int }}"
dbname: "{{ env_var('POSTGRES_DB') }}"
schema: "{{ env_var('POSTGRES_SCHEMA') }}"
threads: 4
34 changes: 34 additions & 0 deletions dev/dags/example_tests_multiple_parents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""
An example DAG that uses Cosmos to render a dbt project into an Airflow DAG.
"""

import os
from datetime import datetime
from pathlib import Path

from cosmos import DbtDag, ProfileConfig, ProjectConfig
from cosmos.profiles import PostgresUserPasswordProfileMapping

DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt"
DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH))

profile_config = ProfileConfig(
profile_name="default",
target_name="dev",
profile_mapping=PostgresUserPasswordProfileMapping(
conn_id="example_conn",
profile_args={"schema": "public"},
disable_event_tracking=True,
),
)

example_multiple_parents_test = DbtDag(
# dbt/cosmos-specific parameters
project_config=ProjectConfig(
DBT_ROOT_PATH / "multiple_parents_test",
),
profile_config=profile_config,
# normal dag parameters
start_date=datetime(2023, 1, 1),
dag_id="example_multiple_parents_test",
)
2 changes: 1 addition & 1 deletion tests/dbt/parser/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_LegacyDbtProject__handle_config_file():

dbt_project._handle_config_file(SAMPLE_YML_PATH)

assert len(dbt_project.tests) == 12
assert len(dbt_project.tests) == 10
assert "not_null_customer_id_customers" in dbt_project.tests
sample_test = dbt_project.tests["not_null_customer_id_customers"]
assert sample_test.type == DbtModelType.DBT_TEST
Expand Down
4 changes: 2 additions & 2 deletions tests/dbt/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1609,9 +1609,9 @@ def test_save_dbt_ls_cache(mock_variable_set, mock_datetime, tmp_dbt_project_dir
hash_dir, hash_args = version.split(",")
assert hash_args == "d41d8cd98f00b204e9800998ecf8427e"
if sys.platform == "darwin":
assert hash_dir == "2b0b0c3d243f9bfdda0f60b56ab65836"
assert hash_dir == "fa5edac64de49909d4b8cbc4dc8abd4f"
else:
assert hash_dir == "cd0535d9a4acb972d74e49eaab85fb6f"
assert hash_dir == "9c9f712b6f6f1ace880dfc7f5f4ff051"


@pytest.mark.integration
Expand Down
Loading
Loading