-
Notifications
You must be signed in to change notification settings - Fork 186
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into monkeypatch-bq-adapter
- Loading branch information
Showing
13 changed files
with
390 additions
and
88 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import importlib | ||
import logging | ||
from abc import ABCMeta | ||
from typing import Any, Sequence | ||
|
||
from airflow.utils.context import Context | ||
|
||
from cosmos.airflow.graph import _snake_case_to_camelcase | ||
from cosmos.config import ProfileConfig | ||
from cosmos.constants import ExecutionMode | ||
from cosmos.operators.local import DbtRunLocalOperator | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
def _create_async_operator_class(profile_type: str, dbt_class: str) -> Any: | ||
""" | ||
Dynamically constructs and returns an asynchronous operator class for the given profile type and dbt class name. | ||
The function constructs a class path string for an asynchronous operator, based on the provided `profile_type` and | ||
`dbt_class`. It attempts to import the corresponding class dynamically and return it. If the class cannot be found, | ||
it falls back to returning the `DbtRunLocalOperator` class. | ||
:param profile_type: The dbt profile type | ||
:param dbt_class: The dbt class name. Example DbtRun, DbtTest. | ||
""" | ||
execution_mode = ExecutionMode.AIRFLOW_ASYNC.value | ||
class_path = f"cosmos.operators._asynchronous.{profile_type}.{dbt_class}{_snake_case_to_camelcase(execution_mode)}{profile_type.capitalize()}Operator" | ||
try: | ||
module_path, class_name = class_path.rsplit(".", 1) | ||
module = importlib.import_module(module_path) | ||
operator_class = getattr(module, class_name) | ||
return operator_class | ||
except (ModuleNotFoundError, AttributeError): | ||
log.info("Error in loading class: %s. falling back to DbtRunLocalOperator", class_path) | ||
return DbtRunLocalOperator | ||
|
||
|
||
class DbtRunAirflowAsyncFactoryOperator(DbtRunLocalOperator, metaclass=ABCMeta): # type: ignore[misc] | ||
|
||
template_fields: Sequence[str] = DbtRunLocalOperator.template_fields + ("project_dir",) # type: ignore[operator] | ||
|
||
def __init__(self, project_dir: str, profile_config: ProfileConfig, **kwargs: Any): | ||
self.project_dir = project_dir | ||
self.profile_config = profile_config | ||
|
||
async_operator_class = self.create_async_operator() | ||
|
||
# Dynamically modify the base classes. | ||
# This is necessary because the async operator class is only known at runtime. | ||
# When using composition instead of inheritance to initialize the async class and run its execute method, | ||
# Airflow throws a `DuplicateTaskIdFound` error. | ||
DbtRunAirflowAsyncFactoryOperator.__bases__ = (async_operator_class,) | ||
super().__init__(project_dir=project_dir, profile_config=profile_config, **kwargs) | ||
|
||
def create_async_operator(self) -> Any: | ||
|
||
profile_type = self.profile_config.get_profile_type() | ||
|
||
async_class_operator = _create_async_operator_class(profile_type, "DbtRun") | ||
|
||
return async_class_operator | ||
|
||
def execute(self, context: Context) -> None: | ||
super().execute(context) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
from __future__ import annotations | ||
|
||
from pathlib import Path | ||
from typing import TYPE_CHECKING, Any, Sequence | ||
|
||
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook | ||
from airflow.providers.google.cloud.operators.bigquery import BigQueryInsertJobOperator | ||
from airflow.utils.context import Context | ||
|
||
from cosmos import settings | ||
from cosmos.config import ProfileConfig | ||
from cosmos.exceptions import CosmosValueError | ||
from cosmos.settings import remote_target_path, remote_target_path_conn_id | ||
|
||
|
||
class DbtRunAirflowAsyncBigqueryOperator(BigQueryInsertJobOperator): # type: ignore[misc] | ||
|
||
template_fields: Sequence[str] = ( | ||
"full_refresh", | ||
"gcp_project", | ||
"dataset", | ||
"location", | ||
) | ||
|
||
def __init__( | ||
self, | ||
project_dir: str, | ||
profile_config: ProfileConfig, | ||
extra_context: dict[str, Any] | None = None, | ||
**kwargs: Any, | ||
): | ||
self.project_dir = project_dir | ||
self.profile_config = profile_config | ||
self.gcp_conn_id = self.profile_config.profile_mapping.conn_id # type: ignore | ||
profile = self.profile_config.profile_mapping.profile # type: ignore | ||
self.gcp_project = profile["project"] | ||
self.dataset = profile["dataset"] | ||
self.extra_context = extra_context or {} | ||
self.full_refresh = None | ||
if "full_refresh" in kwargs: | ||
self.full_refresh = kwargs.pop("full_refresh") | ||
self.configuration: dict[str, Any] = {} | ||
super().__init__( | ||
gcp_conn_id=self.gcp_conn_id, | ||
configuration=self.configuration, | ||
deferrable=True, | ||
**kwargs, | ||
) | ||
|
||
def get_remote_sql(self) -> str: | ||
if not settings.AIRFLOW_IO_AVAILABLE: | ||
raise CosmosValueError(f"Cosmos async support is only available starting in Airflow 2.8 or later.") | ||
from airflow.io.path import ObjectStoragePath | ||
|
||
file_path = self.extra_context["dbt_node_config"]["file_path"] # type: ignore | ||
dbt_dag_task_group_identifier = self.extra_context["dbt_dag_task_group_identifier"] | ||
|
||
remote_target_path_str = str(remote_target_path).rstrip("/") | ||
|
||
if TYPE_CHECKING: # pragma: no cover | ||
assert self.project_dir is not None | ||
|
||
project_dir_parent = str(Path(self.project_dir).parent) | ||
relative_file_path = str(file_path).replace(project_dir_parent, "").lstrip("/") | ||
remote_model_path = f"{remote_target_path_str}/{dbt_dag_task_group_identifier}/compiled/{relative_file_path}" | ||
|
||
object_storage_path = ObjectStoragePath(remote_model_path, conn_id=remote_target_path_conn_id) | ||
with object_storage_path.open() as fp: # type: ignore | ||
return fp.read() # type: ignore | ||
|
||
def drop_table_sql(self) -> None: | ||
model_name = self.extra_context["dbt_node_config"]["resource_name"] # type: ignore | ||
sql = f"DROP TABLE IF EXISTS {self.gcp_project}.{self.dataset}.{model_name};" | ||
|
||
hook = BigQueryHook( | ||
gcp_conn_id=self.gcp_conn_id, | ||
impersonation_chain=self.impersonation_chain, | ||
) | ||
self.configuration = { | ||
"query": { | ||
"query": sql, | ||
"useLegacySql": False, | ||
} | ||
} | ||
hook.insert_job(configuration=self.configuration, location=self.location, project_id=self.gcp_project) | ||
|
||
def execute(self, context: Context) -> Any | None: | ||
|
||
if not self.full_refresh: | ||
raise CosmosValueError("The async execution only supported for full_refresh") | ||
else: | ||
# It may be surprising to some, but the dbt-core --full-refresh argument fully drops the table before populating it | ||
# https://github.com/dbt-labs/dbt-core/blob/5e9f1b515f37dfe6cdae1ab1aa7d190b92490e24/core/dbt/context/base.py#L662-L666 | ||
# https://docs.getdbt.com/reference/resource-configs/full_refresh#recommendation | ||
# We're emulating this behaviour here | ||
# The compiled SQL has several limitations here, but these will be addressed in the PR: https://github.com/astronomer/astronomer-cosmos/pull/1474. | ||
self.drop_table_sql() | ||
sql = self.get_remote_sql() | ||
model_name = self.extra_context["dbt_node_config"]["resource_name"] # type: ignore | ||
# prefix explicit create command to create table | ||
sql = f"CREATE TABLE {self.gcp_project}.{self.dataset}.{model_name} AS {sql}" | ||
self.configuration = { | ||
"query": { | ||
"query": sql, | ||
"useLegacySql": False, | ||
} | ||
} | ||
return super().execute(context) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# TODO: Implement it | ||
|
||
from typing import Any | ||
|
||
from airflow.models.baseoperator import BaseOperator | ||
from airflow.utils.context import Context | ||
|
||
|
||
class DbtRunAirflowAsyncDatabricksOperator(BaseOperator): | ||
def __init__(self, *args: Any, **kwargs: Any): | ||
super().__init__(*args, **kwargs) | ||
|
||
def execute(self, context: Context) -> None: | ||
raise NotImplementedError() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import json | ||
from unittest.mock import patch | ||
|
||
import pytest | ||
from airflow.models.connection import Connection | ||
|
||
|
||
@pytest.fixture() | ||
def mock_bigquery_conn(): # type: ignore | ||
""" | ||
Mocks and returns an Airflow BigQuery connection. | ||
""" | ||
extra = { | ||
"project": "my_project", | ||
"key_path": "my_key_path.json", | ||
} | ||
conn = Connection( | ||
conn_id="my_bigquery_connection", | ||
conn_type="google_cloud_platform", | ||
extra=json.dumps(extra), | ||
) | ||
|
||
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): | ||
yield conn |
Empty file.
Oops, something went wrong.