Skip to content

Commit

Permalink
Support customizing how dbt nodes are converted to Airflow (#503)
Browse files Browse the repository at this point in the history
This change aims to solve the following current limitations of Cosmos 1.0:
* If you want to pass arguments to just some tasks (e.g.,
on_warning_callback to test nodes), Cosmos has to add it to the main
interface and explicitly pass it down to just those tasks
* If a user wants to subclass one of the Cosmos operators and use that
instead, they can't
* If a user wants more granular customization over how each task is
rendered, they can't use Cosmos

It does this by introducing the parameter `node_converters` to
`RenderConfig`, which allows users to define a custom function to
convert a DbtNode into nothing or an Airflow resource (Operator or
TaskGroup instances).

## Example of the feature

```
import os
from datetime import datetime
from pathlib import Path

from airflow.operators.dummy import DummyOperator
from airflow.models.dag import DAG
from airflow.utils.task_group import TaskGroup

from cosmos import DbtDag, ProjectConfig, ProfileConfig, RenderConfig
from cosmos.constants import DbtResourceType
from cosmos.dbt.graph import DbtNode

DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt"
DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH))

os.environ["DBT_SQLITE_PATH"] = str(DEFAULT_DBT_ROOT_PATH / "simple")


profile_config = ProfileConfig(
    profile_name="simple",
    target_name="dev",
    profiles_yml_filepath=(DBT_ROOT_PATH / "simple/profiles.yml"),
)


# [START custom_dbt_nodes]
def convert_source(dag: DAG, task_group: TaskGroup, node: DbtNode, **kwargs):
    return DummyOperator(dag=dag, task_group=task_group, task_id=f"{node.name}_source")


def convert_exposure(dag: DAG, task_group: TaskGroup, node: DbtNode, **kwargs):
    return DummyOperator(dag=dag, task_group=task_group, task_id=f"{node.name}_exposure")


render_config = RenderConfig(
    node_converters={DbtResourceType.SOURCE: convert_source, DbtResourceType("exposure"): convert_exposure}
)


example_cosmos_sources = DbtDag(
    # dbt/cosmos-specific parameters
    project_config=ProjectConfig(
        DBT_ROOT_PATH / "simple",
    ),
    profile_config=profile_config,
    render_config=render_config,
    operator_args={"append_env": True},
    # normal dag parameters
    schedule_interval="@daily",
    start_date=datetime(2023, 1, 1),
    catchup=False,
    dag_id="example_cosmos_sources",
)
# [END custom_dbt_nodes]
```

It is now rendered in Airflow:
<img width="1173" alt="Screenshot 2023-10-10 at 12 04 31"
src="https://github.com/astronomer/astronomer-cosmos/assets/272048/96f96261-4a8e-4418-ae67-63404fc59f77">

Before this change, there was no way for users to describe how to render
source or other unsupported dbt resources.

## Related Issue(s)

Closes: #427
Closes: #477
  • Loading branch information
tatiana authored Oct 13, 2023
1 parent 9460312 commit e80abd7
Show file tree
Hide file tree
Showing 34 changed files with 7,611 additions and 12,853 deletions.
62 changes: 62 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,68 @@ jobs:
DATABRICKS_WAREHOUSE_ID: ${{ secrets.DATABRICKS_WAREHOUSE_ID }}
DATABRICKS_CLUSTER_ID: ${{ secrets.DATABRICKS_CLUSTER_ID }}

Run-Integration-Tests-Sqlite:
needs: Authorize
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]
airflow-version: ["2.7"]

steps:
- uses: actions/checkout@v3
with:
ref: ${{ github.event.pull_request.head.sha || github.ref }}
- uses: actions/cache@v3
with:
path: |
~/.cache/pip
.nox
key: integration-sqlite-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.airflow-version }}-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('cosmos/__init__.py') }}

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: Install packages and dependencies
run: |
python -m pip install hatch
hatch -e tests.py${{ matrix.python-version }}-${{ matrix.airflow-version }} run pip freeze
- name: Test Cosmos against Airflow ${{ matrix.airflow-version }} and Python ${{ matrix.python-version }}
run: |
hatch run tests.py${{ matrix.python-version }}-${{ matrix.airflow-version }}:test-integration-sqlite-setup
hatch run tests.py${{ matrix.python-version }}-${{ matrix.airflow-version }}:test-integration-sqlite
env:
AIRFLOW_HOME: /home/runner/work/astronomer-cosmos/astronomer-cosmos/
AIRFLOW_CONN_AIRFLOW_DB: postgres://postgres:postgres@0.0.0.0:5432/postgres
PYTHONPATH: /home/runner/work/astronomer-cosmos/astronomer-cosmos/:$PYTHONPATH
AIRFLOW_CONN_DATABRICKS_DEFAULT: ${{ secrets.AIRFLOW_CONN_DATABRICKS_DEFAULT }}
DATABRICKS_HOST: ${{ secrets.DATABRICKS_HOST }}
DATABRICKS_TOKEN: ${{ secrets.DATABRICKS_TOKEN }}
DATABRICKS_WAREHOUSE_ID: ${{ secrets.DATABRICKS_WAREHOUSE_ID }}
DATABRICKS_CLUSTER_ID: ${{ secrets.DATABRICKS_CLUSTER_ID }}
COSMOS_CONN_POSTGRES_PASSWORD: ${{ secrets.COSMOS_CONN_POSTGRES_PASSWORD }}
POSTGRES_HOST: localhost
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: postgres
POSTGRES_SCHEMA: public
POSTGRES_PORT: 5432

- name: Upload coverage to Github
uses: actions/upload-artifact@v2
with:
name: coverage-integration-sqlite-test-${{ matrix.python-version }}-${{ matrix.airflow-version }}
path: .coverage

env:
AIRFLOW_HOME: /home/runner/work/astronomer-cosmos/astronomer-cosmos/
AIRFLOW_CONN_AIRFLOW_DB: postgres://postgres:postgres@0.0.0.0:5432/postgres
PYTHONPATH: /home/runner/work/astronomer-cosmos/astronomer-cosmos/:$PYTHONPATH


Code-Coverage:
if: github.event.action != 'labeled'
needs:
Expand Down
105 changes: 73 additions & 32 deletions cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
from airflow.models.dag import DAG
from airflow.utils.task_group import TaskGroup

from cosmos.constants import DbtResourceType, ExecutionMode, TestBehavior
from cosmos.constants import DbtResourceType, TestBehavior, ExecutionMode, TESTABLE_DBT_RESOURCES, DEFAULT_DBT_RESOURCES
from cosmos.core.airflow import get_airflow_task as create_airflow_task
from cosmos.core.graph.entities import Task as TaskMetadata
from cosmos.dbt.graph import DbtNode
from cosmos.log import get_logger


logger = get_logger(__name__)


Expand Down Expand Up @@ -104,10 +105,9 @@ def create_task_metadata(
}
args = {**args, **{"models": node.name}}

if hasattr(node.resource_type, "value") and node.resource_type in dbt_resource_to_class:
if DbtResourceType(node.resource_type) in DEFAULT_DBT_RESOURCES and node.resource_type in dbt_resource_to_class:
if node.resource_type == DbtResourceType.MODEL:
task_id = f"{node.name}_run"

if use_task_group is True:
task_id = "run"
else:
Expand All @@ -122,10 +122,58 @@ def create_task_metadata(
)
return task_metadata
else:
logger.error(f"Unsupported resource type {node.resource_type} (node {node.unique_id}).")
msg = (
f"Unavailable conversion function for <{node.resource_type}> (node <{node.unique_id}>). "
"Define a converter function using render_config.node_converters."
)
logger.warning(msg)
return None


def generate_task_or_group(
dag: DAG,
task_group: TaskGroup | None,
node: DbtNode,
execution_mode: ExecutionMode,
task_args: dict[str, Any],
test_behavior: TestBehavior,
on_warning_callback: Callable[..., Any] | None,
**kwargs: Any,
) -> BaseOperator | TaskGroup | None:
task_or_group: BaseOperator | TaskGroup | None = None

use_task_group = (
node.resource_type in TESTABLE_DBT_RESOURCES
and test_behavior == TestBehavior.AFTER_EACH
and node.has_test is True
)

task_meta = create_task_metadata(
node=node, execution_mode=execution_mode, args=task_args, use_task_group=use_task_group
)

# In most cases, we'll map one DBT node to one Airflow task
# The exception are the test nodes, since it would be too slow to run test tasks individually.
# If test_behaviour=="after_each", each model task will be bundled with a test task, using TaskGroup
if task_meta and node.resource_type != DbtResourceType.TEST:
if use_task_group:
with TaskGroup(dag=dag, group_id=node.name, parent_group=task_group) as model_task_group:
task = create_airflow_task(task_meta, dag, task_group=model_task_group)
test_meta = create_test_task_metadata(
"test",
execution_mode,
task_args=task_args,
model_name=node.name,
on_warning_callback=on_warning_callback,
)
test_task = create_airflow_task(test_meta, dag, task_group=model_task_group)
task >> test_task
task_or_group = model_task_group
else:
task_or_group = create_airflow_task(task_meta, dag, task_group=task_group)
return task_or_group


def build_airflow_graph(
nodes: dict[str, DbtNode],
dag: DAG, # Airflow-specific - parent DAG where to associate tasks and (optional) task groups
Expand All @@ -135,6 +183,7 @@ def build_airflow_graph(
dbt_project_name: str, # DBT / Cosmos - used to name test task if mode is after_all,
task_group: TaskGroup | None = None,
on_warning_callback: Callable[..., Any] | None = None, # argument specific to the DBT test command
node_converters: dict[DbtResourceType, Callable[..., Any]] | None = None,
) -> None:
"""
Instantiate dbt `nodes` as Airflow tasks within the given `task_group` (optional) or `dag` (mandatory).
Expand All @@ -160,41 +209,33 @@ def build_airflow_graph(
:param on_warning_callback: A callback function called on warnings with additional Context variables “test_names”
and “test_results” of type List.
"""
node_converters = node_converters or {}
tasks_map = {}
task_or_group: TaskGroup | BaseOperator

# In most cases, we'll map one DBT node to one Airflow task
# The exception are the test nodes, since it would be too slow to run test tasks individually.
# If test_behaviour=="after_each", each model task will be bundled with a test task, using TaskGroup
for node_id, node in nodes.items():
use_task_group = (
node.resource_type == DbtResourceType.MODEL
and test_behavior == TestBehavior.AFTER_EACH
and node.has_test is True
)
task_meta = create_task_metadata(
node=node, execution_mode=execution_mode, args=task_args, use_task_group=use_task_group
conversion_function = node_converters.get(node.resource_type, generate_task_or_group)
if conversion_function != generate_task_or_group:
logger.warning(
"The `node_converters` attribute is an experimental feature. "
"Its syntax and behavior can be changed before a major release."
)
logger.debug(f"Converting <{node.unique_id}> using <{conversion_function.__name__}>")
task_or_group = conversion_function( # type: ignore
dag=dag,
task_group=task_group,
dbt_project_name=dbt_project_name,
execution_mode=execution_mode,
task_args=task_args,
test_behavior=test_behavior,
on_warning_callback=on_warning_callback,
node=node,
)

if task_meta and node.resource_type != DbtResourceType.TEST:
if use_task_group is True:
with TaskGroup(dag=dag, group_id=node.name, parent_group=task_group) as model_task_group:
task = create_airflow_task(task_meta, dag, task_group=model_task_group)
test_meta = create_test_task_metadata(
"test",
execution_mode,
task_args=task_args,
model_name=node.name,
on_warning_callback=on_warning_callback,
)
test_task = create_airflow_task(test_meta, dag, task_group=model_task_group)
task >> test_task
task_or_group = model_task_group
else:
task_or_group = create_airflow_task(task_meta, dag, task_group=task_group)
if task_or_group is not None:
logger.debug(f"Conversion of <{node.unique_id}> was successful!")
tasks_map[node_id] = task_or_group

# If test_behaviour=="after_all", there will be one test task, run "by the end" of the DAG
# If test_behaviour=="after_all", there will be one test task, run by the end of the DAG
# The end of a DAG is defined by the DAG leaf tasks (tasks which do not have downstream tasks)
if test_behavior == TestBehavior.AFTER_ALL:
test_meta = create_test_task_metadata(
Expand Down
5 changes: 3 additions & 2 deletions cosmos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import Iterator
from typing import Any, Iterator, Callable

from cosmos.constants import TestBehavior, ExecutionMode, LoadMode
from cosmos.constants import DbtResourceType, TestBehavior, ExecutionMode, LoadMode
from cosmos.dbt.executable import get_system_dbt
from cosmos.exceptions import CosmosValueError
from cosmos.log import get_logger
Expand Down Expand Up @@ -39,6 +39,7 @@ class RenderConfig:
select: list[str] = field(default_factory=list)
exclude: list[str] = field(default_factory=list)
dbt_deps: bool = True
node_converters: dict[DbtResourceType, Callable[..., Any]] | None = None


@dataclass
Expand Down
18 changes: 16 additions & 2 deletions cosmos/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from enum import Enum
from pathlib import Path

import aenum


DBT_PROFILE_PATH = Path(os.path.expanduser("~")).joinpath(".dbt/profiles.yml")
DEFAULT_DBT_PROFILE_NAME = "cosmos_profile"
Expand Down Expand Up @@ -49,8 +51,7 @@ class ExecutionMode(Enum):
VIRTUALENV = "virtualenv"


# Rename to DbtResourceType
class DbtResourceType(Enum):
class DbtResourceType(aenum.Enum): # type: ignore
"""
Type of dbt node.
"""
Expand All @@ -60,3 +61,16 @@ class DbtResourceType(Enum):
SEED = "seed"
TEST = "test"
SOURCE = "source"

@classmethod
def _missing_value_(cls, value): # type: ignore
aenum.extend_enum(cls, value.upper(), value.lower())
return getattr(DbtResourceType, value.upper())


DEFAULT_DBT_RESOURCES = DbtResourceType.__members__.values()


TESTABLE_DBT_RESOURCES = {
DbtResourceType.MODEL
} # TODO: extend with DbtResourceType.SOURCE, DbtResourceType.SNAPSHOT, DbtResourceType.SEED)
2 changes: 2 additions & 0 deletions cosmos/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __init__(
load_mode = render_config.load_method
manifest_path = project_config.parsed_manifest_path
dbt_executable_path = execution_config.dbt_executable_path
node_converters = render_config.node_converters

profile_args = {}
if profile_config.profile_mapping:
Expand Down Expand Up @@ -168,4 +169,5 @@ def __init__(
test_behavior=test_behavior,
dbt_project_name=dbt_project.name,
on_warning_callback=on_warning_callback,
node_converters=node_converters,
)
5 changes: 3 additions & 2 deletions cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,12 +342,13 @@ def load_from_dbt_manifest(self) -> None:
with open(self.project.manifest_path) as fp: # type: ignore[arg-type]
manifest = json.load(fp)

for unique_id, node_dict in manifest.get("nodes", {}).items():
resources = {**manifest.get("nodes", {}), **manifest.get("sources", {}), **manifest.get("exposures", {})}
for unique_id, node_dict in resources.items():
node = DbtNode(
name=node_dict.get("alias", node_dict["name"]),
unique_id=unique_id,
resource_type=DbtResourceType(node_dict["resource_type"]),
depends_on=node_dict["depends_on"].get("nodes", []),
depends_on=node_dict.get("depends_on", {}).get("nodes", []),
file_path=self.project.dir / node_dict["original_file_path"],
tags=node_dict["tags"],
config=node_dict["config"],
Expand Down
4 changes: 4 additions & 0 deletions dev/dags/dbt/jaffle_shop/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

target/
dbt_packages/
logs/
Loading

0 comments on commit e80abd7

Please sign in to comment.