Skip to content

Commit

Permalink
Merge branch 'main' into monkeypatch-bq-adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
pankajkoti authored Jan 27, 2025
2 parents 379d997 + bdc8746 commit 7a85d27
Show file tree
Hide file tree
Showing 13 changed files with 390 additions and 88 deletions.
Empty file.
65 changes: 65 additions & 0 deletions cosmos/operators/_asynchronous/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import importlib
import logging
from abc import ABCMeta
from typing import Any, Sequence

from airflow.utils.context import Context

from cosmos.airflow.graph import _snake_case_to_camelcase
from cosmos.config import ProfileConfig
from cosmos.constants import ExecutionMode
from cosmos.operators.local import DbtRunLocalOperator

log = logging.getLogger(__name__)


def _create_async_operator_class(profile_type: str, dbt_class: str) -> Any:
"""
Dynamically constructs and returns an asynchronous operator class for the given profile type and dbt class name.
The function constructs a class path string for an asynchronous operator, based on the provided `profile_type` and
`dbt_class`. It attempts to import the corresponding class dynamically and return it. If the class cannot be found,
it falls back to returning the `DbtRunLocalOperator` class.
:param profile_type: The dbt profile type
:param dbt_class: The dbt class name. Example DbtRun, DbtTest.
"""
execution_mode = ExecutionMode.AIRFLOW_ASYNC.value
class_path = f"cosmos.operators._asynchronous.{profile_type}.{dbt_class}{_snake_case_to_camelcase(execution_mode)}{profile_type.capitalize()}Operator"
try:
module_path, class_name = class_path.rsplit(".", 1)
module = importlib.import_module(module_path)
operator_class = getattr(module, class_name)
return operator_class
except (ModuleNotFoundError, AttributeError):
log.info("Error in loading class: %s. falling back to DbtRunLocalOperator", class_path)
return DbtRunLocalOperator


class DbtRunAirflowAsyncFactoryOperator(DbtRunLocalOperator, metaclass=ABCMeta): # type: ignore[misc]

template_fields: Sequence[str] = DbtRunLocalOperator.template_fields + ("project_dir",) # type: ignore[operator]

def __init__(self, project_dir: str, profile_config: ProfileConfig, **kwargs: Any):
self.project_dir = project_dir
self.profile_config = profile_config

async_operator_class = self.create_async_operator()

# Dynamically modify the base classes.
# This is necessary because the async operator class is only known at runtime.
# When using composition instead of inheritance to initialize the async class and run its execute method,
# Airflow throws a `DuplicateTaskIdFound` error.
DbtRunAirflowAsyncFactoryOperator.__bases__ = (async_operator_class,)
super().__init__(project_dir=project_dir, profile_config=profile_config, **kwargs)

def create_async_operator(self) -> Any:

profile_type = self.profile_config.get_profile_type()

async_class_operator = _create_async_operator_class(profile_type, "DbtRun")

return async_class_operator

def execute(self, context: Context) -> None:
super().execute(context)
108 changes: 108 additions & 0 deletions cosmos/operators/_asynchronous/bigquery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Any, Sequence

from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
from airflow.providers.google.cloud.operators.bigquery import BigQueryInsertJobOperator
from airflow.utils.context import Context

from cosmos import settings
from cosmos.config import ProfileConfig
from cosmos.exceptions import CosmosValueError
from cosmos.settings import remote_target_path, remote_target_path_conn_id


class DbtRunAirflowAsyncBigqueryOperator(BigQueryInsertJobOperator): # type: ignore[misc]

template_fields: Sequence[str] = (
"full_refresh",
"gcp_project",
"dataset",
"location",
)

def __init__(
self,
project_dir: str,
profile_config: ProfileConfig,
extra_context: dict[str, Any] | None = None,
**kwargs: Any,
):
self.project_dir = project_dir
self.profile_config = profile_config
self.gcp_conn_id = self.profile_config.profile_mapping.conn_id # type: ignore
profile = self.profile_config.profile_mapping.profile # type: ignore
self.gcp_project = profile["project"]
self.dataset = profile["dataset"]
self.extra_context = extra_context or {}
self.full_refresh = None
if "full_refresh" in kwargs:
self.full_refresh = kwargs.pop("full_refresh")
self.configuration: dict[str, Any] = {}
super().__init__(
gcp_conn_id=self.gcp_conn_id,
configuration=self.configuration,
deferrable=True,
**kwargs,
)

def get_remote_sql(self) -> str:
if not settings.AIRFLOW_IO_AVAILABLE:
raise CosmosValueError(f"Cosmos async support is only available starting in Airflow 2.8 or later.")
from airflow.io.path import ObjectStoragePath

file_path = self.extra_context["dbt_node_config"]["file_path"] # type: ignore
dbt_dag_task_group_identifier = self.extra_context["dbt_dag_task_group_identifier"]

remote_target_path_str = str(remote_target_path).rstrip("/")

if TYPE_CHECKING: # pragma: no cover
assert self.project_dir is not None

project_dir_parent = str(Path(self.project_dir).parent)
relative_file_path = str(file_path).replace(project_dir_parent, "").lstrip("/")
remote_model_path = f"{remote_target_path_str}/{dbt_dag_task_group_identifier}/compiled/{relative_file_path}"

object_storage_path = ObjectStoragePath(remote_model_path, conn_id=remote_target_path_conn_id)
with object_storage_path.open() as fp: # type: ignore
return fp.read() # type: ignore

def drop_table_sql(self) -> None:
model_name = self.extra_context["dbt_node_config"]["resource_name"] # type: ignore
sql = f"DROP TABLE IF EXISTS {self.gcp_project}.{self.dataset}.{model_name};"

hook = BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
self.configuration = {
"query": {
"query": sql,
"useLegacySql": False,
}
}
hook.insert_job(configuration=self.configuration, location=self.location, project_id=self.gcp_project)

def execute(self, context: Context) -> Any | None:

if not self.full_refresh:
raise CosmosValueError("The async execution only supported for full_refresh")
else:
# It may be surprising to some, but the dbt-core --full-refresh argument fully drops the table before populating it
# https://github.com/dbt-labs/dbt-core/blob/5e9f1b515f37dfe6cdae1ab1aa7d190b92490e24/core/dbt/context/base.py#L662-L666
# https://docs.getdbt.com/reference/resource-configs/full_refresh#recommendation
# We're emulating this behaviour here
# The compiled SQL has several limitations here, but these will be addressed in the PR: https://github.com/astronomer/astronomer-cosmos/pull/1474.
self.drop_table_sql()
sql = self.get_remote_sql()
model_name = self.extra_context["dbt_node_config"]["resource_name"] # type: ignore
# prefix explicit create command to create table
sql = f"CREATE TABLE {self.gcp_project}.{self.dataset}.{model_name} AS {sql}"
self.configuration = {
"query": {
"query": sql,
"useLegacySql": False,
}
}
return super().execute(context)
14 changes: 14 additions & 0 deletions cosmos/operators/_asynchronous/databricks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# TODO: Implement it

from typing import Any

from airflow.models.baseoperator import BaseOperator
from airflow.utils.context import Context


class DbtRunAirflowAsyncDatabricksOperator(BaseOperator):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)

def execute(self, context: Context) -> None:
raise NotImplementedError()
87 changes: 15 additions & 72 deletions cosmos/operators/airflow_async.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
from __future__ import annotations

import inspect
from typing import Any, Sequence

from airflow.providers.google.cloud.operators.bigquery import BigQueryInsertJobOperator
from airflow.utils.context import Context

from cosmos.config import ProfileConfig
from cosmos.constants import BIGQUERY_PROFILE_TYPE
from cosmos.exceptions import CosmosValueError
from cosmos.operators._asynchronous.base import DbtRunAirflowAsyncFactoryOperator
from cosmos.operators.base import AbstractDbtBaseOperator
from cosmos.operators.local import (
DbtBuildLocalOperator,
Expand All @@ -33,8 +28,8 @@

class DbtBaseAirflowAsyncOperator(BaseOperator, metaclass=ABCMeta):
def __init__(self, **kwargs) -> None: # type: ignore
self.location = kwargs.pop("location")
self.configuration = kwargs.pop("configuration", {})
if "location" in kwargs:
kwargs.pop("location")
super().__init__(**kwargs)


Expand All @@ -58,84 +53,32 @@ class DbtSourceAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtSourceLocalO
pass


class DbtRunAirflowAsyncOperator(BigQueryInsertJobOperator, DbtRunLocalOperator): # type: ignore

template_fields: Sequence[str] = DbtRunLocalOperator.template_fields + ( # type: ignore[operator]
"full_refresh",
"project_dir",
"gcp_project",
"dataset",
"location",
)
class DbtRunAirflowAsyncOperator(DbtRunAirflowAsyncFactoryOperator): # type: ignore

def __init__( # type: ignore
self,
project_dir: str,
profile_config: ProfileConfig,
location: str, # This is a mandatory parameter when using BigQueryInsertJobOperator with deferrable=True
full_refresh: bool = False,
extra_context: dict[str, object] | None = None,
configuration: dict[str, object] | None = None,
**kwargs,
) -> None:
# dbt task param
self.project_dir = project_dir
self.full_refresh = full_refresh
self.profile_config = profile_config
if not self.profile_config or not self.profile_config.profile_mapping:
raise CosmosValueError(f"Cosmos async support is only available when using ProfileMapping")

self.profile_type: str = profile_config.get_profile_type() # type: ignore
if self.profile_type not in _SUPPORTED_DATABASES:
raise CosmosValueError(f"Async run are only supported: {_SUPPORTED_DATABASES}")

# airflow task param
self.location = location
self.configuration = configuration or {}
self.gcp_conn_id = self.profile_config.profile_mapping.conn_id # type: ignore
profile = self.profile_config.profile_mapping.profile
self.gcp_project = profile["project"]
self.dataset = profile["dataset"]

# Cosmos attempts to pass many kwargs that BigQueryInsertJobOperator simply does not accept.
# Cosmos attempts to pass many kwargs that async operator simply does not accept.
# We need to pop them.
async_op_kwargs = {}
cosmos_op_kwargs = {}
clean_kwargs = {}
non_async_args = set(inspect.signature(AbstractDbtBaseOperator.__init__).parameters.keys())
non_async_args |= set(inspect.signature(DbtLocalBaseOperator.__init__).parameters.keys())

non_async_args -= {"task_id"}
for arg_key, arg_value in kwargs.items():
if arg_key == "task_id":
async_op_kwargs[arg_key] = arg_value
cosmos_op_kwargs[arg_key] = arg_value
elif arg_key not in non_async_args:
async_op_kwargs[arg_key] = arg_value
else:
cosmos_op_kwargs[arg_key] = arg_value

# The following are the minimum required parameters to run BigQueryInsertJobOperator using the deferrable mode
BigQueryInsertJobOperator.__init__(
self,
gcp_conn_id=self.gcp_conn_id,
configuration=self.configuration,
location=self.location,
deferrable=True,
**async_op_kwargs,
if arg_key not in non_async_args:
clean_kwargs[arg_key] = arg_value

super().__init__(
project_dir=project_dir,
profile_config=profile_config,
extra_context=extra_context,
**clean_kwargs,
)

DbtRunLocalOperator.__init__(
self,
project_dir=self.project_dir,
profile_config=self.profile_config,
**cosmos_op_kwargs,
)
self.async_context = extra_context or {}
self.async_context["profile_type"] = self.profile_type
self.async_context["async_operator"] = BigQueryInsertJobOperator

def execute(self, context: Context) -> Any | None:
return self.build_and_run_cmd(context, run_as_async=True, async_context=self.async_context)


class DbtTestAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtTestLocalOperator): # type: ignore
pass
Expand Down
6 changes: 5 additions & 1 deletion cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@
if TYPE_CHECKING:
from airflow.datasets import Dataset # noqa: F811
from dbt.cli.main import dbtRunner, dbtRunnerResult
from openlineage.client.run import RunEvent

try: # pragma: no cover
from openlineage.client.event_v2 import RunEvent # pragma: no cover
except ImportError: # pragma: no cover
from openlineage.client.run import RunEvent # pragma: no cover


from sqlalchemy.orm import Session
Expand Down
1 change: 0 additions & 1 deletion tests/airflow/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,6 @@ def test_build_airflow_graph_with_dbt_compile_task():
"project_dir": SAMPLE_PROJ_PATH,
"conn_id": "fake_conn",
"profile_config": bigquery_profile_config,
"location": "",
}
render_config = RenderConfig(
select=["tag:some"],
Expand Down
24 changes: 24 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import json
from unittest.mock import patch

import pytest
from airflow.models.connection import Connection


@pytest.fixture()
def mock_bigquery_conn(): # type: ignore
"""
Mocks and returns an Airflow BigQuery connection.
"""
extra = {
"project": "my_project",
"key_path": "my_key_path.json",
}
conn = Connection(
conn_id="my_bigquery_connection",
conn_type="google_cloud_platform",
extra=json.dumps(extra),
)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
yield conn
Empty file.
Loading

0 comments on commit 7a85d27

Please sign in to comment.