forked from astronomer/astronomer-cosmos
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support for running dbt tasks in AWS EKS (astronomer#944)
## Description We are using MWAA in combination with EKS so that all our dags in airflow are running in our EKS. We would like to use the same setup with cosmos. ### What changes? - New AwsEksOperator classes (inheriting from KubernetesOperators) - Based on the original [EksOperator](https://github.com/apache/airflow/blob/main/airflow/providers/amazon/aws/operators/eks.py#L995) - Tests - Adjusted documentation ## Related Issue(s) - ## Breaking Change? No - only an additional feature ## Checklist - [x] I have made corresponding changes to the documentation (if required) - [x] I have added tests that prove my fix is effective or that my feature works --------- Co-authored-by: Pankaj Koti <pankajkoti699@gmail.com>
- Loading branch information
1 parent
007325a
commit cb2a27a
Showing
6 changed files
with
273 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any, Sequence | ||
|
||
from airflow.exceptions import AirflowException | ||
from airflow.providers.amazon.aws.hooks.eks import EksHook | ||
from airflow.utils.context import Context | ||
|
||
from cosmos.operators.kubernetes import ( | ||
DbtBuildKubernetesOperator, | ||
DbtKubernetesBaseOperator, | ||
DbtLSKubernetesOperator, | ||
DbtRunKubernetesOperator, | ||
DbtRunOperationKubernetesOperator, | ||
DbtSeedKubernetesOperator, | ||
DbtSnapshotKubernetesOperator, | ||
DbtTestKubernetesOperator, | ||
) | ||
|
||
DEFAULT_CONN_ID = "aws_default" | ||
DEFAULT_NAMESPACE = "default" | ||
|
||
|
||
class DbtAwsEksBaseOperator(DbtKubernetesBaseOperator): | ||
template_fields: Sequence[str] = tuple( | ||
{ | ||
"cluster_name", | ||
"in_cluster", | ||
"namespace", | ||
"pod_name", | ||
"aws_conn_id", | ||
"region", | ||
} | ||
| set(DbtKubernetesBaseOperator.template_fields) | ||
) | ||
|
||
def __init__( | ||
self, | ||
cluster_name: str, | ||
pod_name: str | None = None, | ||
namespace: str | None = DEFAULT_NAMESPACE, | ||
aws_conn_id: str = DEFAULT_CONN_ID, | ||
region: str | None = None, | ||
**kwargs: Any, | ||
) -> None: | ||
self.cluster_name = cluster_name | ||
self.pod_name = pod_name | ||
self.namespace = namespace | ||
self.aws_conn_id = aws_conn_id | ||
self.region = region | ||
super().__init__( | ||
name=self.pod_name, | ||
namespace=self.namespace, | ||
**kwargs, | ||
) | ||
# There is no need to manage the kube_config file, as it will be generated automatically. | ||
# All Kubernetes parameters (except config_file) are also valid for the EksPodOperator. | ||
if self.config_file: | ||
raise AirflowException("The config_file is not an allowed parameter for the EksPodOperator.") | ||
|
||
def execute(self, context: Context) -> Any | None: # type: ignore | ||
eks_hook = EksHook( | ||
aws_conn_id=self.aws_conn_id, | ||
region_name=self.region, | ||
) | ||
with eks_hook.generate_config_file( | ||
eks_cluster_name=self.cluster_name, pod_namespace=self.namespace | ||
) as self.config_file: | ||
return super().execute(context) | ||
|
||
|
||
class DbtBuildAwsEksOperator(DbtAwsEksBaseOperator, DbtBuildKubernetesOperator): | ||
""" | ||
Executes a dbt core build command. | ||
""" | ||
|
||
template_fields: Sequence[str] = ( | ||
DbtAwsEksBaseOperator.template_fields + DbtBuildKubernetesOperator.template_fields # type: ignore[operator] | ||
) | ||
|
||
|
||
class DbtLSAwsEksOperator(DbtAwsEksBaseOperator, DbtLSKubernetesOperator): | ||
""" | ||
Executes a dbt core ls command. | ||
""" | ||
|
||
|
||
class DbtSeedAwsEksOperator(DbtAwsEksBaseOperator, DbtSeedKubernetesOperator): | ||
""" | ||
Executes a dbt core seed command. | ||
""" | ||
|
||
template_fields: Sequence[str] = ( | ||
DbtAwsEksBaseOperator.template_fields + DbtSeedKubernetesOperator.template_fields # type: ignore[operator] | ||
) | ||
|
||
|
||
class DbtSnapshotAwsEksOperator(DbtAwsEksBaseOperator, DbtSnapshotKubernetesOperator): | ||
""" | ||
Executes a dbt core snapshot command. | ||
""" | ||
|
||
|
||
class DbtRunAwsEksOperator(DbtAwsEksBaseOperator, DbtRunKubernetesOperator): | ||
""" | ||
Executes a dbt core run command. | ||
""" | ||
|
||
template_fields: Sequence[str] = ( | ||
DbtAwsEksBaseOperator.template_fields + DbtRunKubernetesOperator.template_fields # type: ignore[operator] | ||
) | ||
|
||
|
||
class DbtTestAwsEksOperator(DbtAwsEksBaseOperator, DbtTestKubernetesOperator): | ||
""" | ||
Executes a dbt core test command. | ||
""" | ||
|
||
template_fields: Sequence[str] = ( | ||
DbtAwsEksBaseOperator.template_fields + DbtTestKubernetesOperator.template_fields # type: ignore[operator] | ||
) | ||
|
||
|
||
class DbtRunOperationAwsEksOperator(DbtAwsEksBaseOperator, DbtRunOperationKubernetesOperator): | ||
""" | ||
Executes a dbt core run-operation command. | ||
""" | ||
|
||
template_fields: Sequence[str] = ( | ||
DbtAwsEksBaseOperator.template_fields + DbtRunOperationKubernetesOperator.template_fields # type: ignore[operator] | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
from unittest.mock import MagicMock, patch | ||
|
||
import pytest | ||
from airflow.exceptions import AirflowException | ||
|
||
from cosmos.operators.aws_eks import ( | ||
DbtBuildAwsEksOperator, | ||
DbtLSAwsEksOperator, | ||
DbtRunAwsEksOperator, | ||
DbtSeedAwsEksOperator, | ||
DbtTestAwsEksOperator, | ||
) | ||
|
||
|
||
@pytest.fixture() | ||
def mock_kubernetes_execute(): | ||
with patch("cosmos.operators.kubernetes.KubernetesPodOperator.execute") as mock_execute: | ||
yield mock_execute | ||
|
||
|
||
base_kwargs = { | ||
"conn_id": "my_airflow_connection", | ||
"cluster_name": "my-cluster", | ||
"task_id": "my-task", | ||
"image": "my_image", | ||
"project_dir": "my/dir", | ||
"vars": { | ||
"start_time": "{{ data_interval_start.strftime('%Y%m%d%H%M%S') }}", | ||
"end_time": "{{ data_interval_end.strftime('%Y%m%d%H%M%S') }}", | ||
}, | ||
"no_version_check": True, | ||
} | ||
|
||
|
||
def test_dbt_kubernetes_build_command(): | ||
""" | ||
Since we know that the KubernetesOperator is tested, we can just test that the | ||
command is built correctly and added to the "arguments" parameter. | ||
""" | ||
|
||
result_map = { | ||
"ls": DbtLSAwsEksOperator(**base_kwargs), | ||
"run": DbtRunAwsEksOperator(**base_kwargs), | ||
"test": DbtTestAwsEksOperator(**base_kwargs), | ||
"build": DbtBuildAwsEksOperator(**base_kwargs), | ||
"seed": DbtSeedAwsEksOperator(**base_kwargs), | ||
} | ||
|
||
for command_name, command_operator in result_map.items(): | ||
command_operator.build_kube_args(context=MagicMock(), cmd_flags=MagicMock()) | ||
assert command_operator.arguments == [ | ||
"dbt", | ||
command_name, | ||
"--vars", | ||
"end_time: '{{ data_interval_end.strftime(''%Y%m%d%H%M%S'') }}'\n" | ||
"start_time: '{{ data_interval_start.strftime(''%Y%m%d%H%M%S'') }}'\n", | ||
"--no-version-check", | ||
"--project-dir", | ||
"my/dir", | ||
] | ||
|
||
|
||
@patch("cosmos.operators.kubernetes.DbtKubernetesBaseOperator.build_kube_args") | ||
@patch("cosmos.operators.aws_eks.EksHook.generate_config_file") | ||
def test_dbt_kubernetes_operator_execute(mock_generate_config_file, mock_build_kube_args, mock_kubernetes_execute): | ||
"""Tests that the execute method call results in both the build_kube_args method and the kubernetes execute method being called.""" | ||
operator = DbtLSAwsEksOperator( | ||
conn_id="my_airflow_connection", | ||
cluster_name="my-cluster", | ||
task_id="my-task", | ||
image="my_image", | ||
project_dir="my/dir", | ||
) | ||
operator.execute(context={}) | ||
# Assert that the build_kube_args method was called in the execution | ||
mock_build_kube_args.assert_called_once() | ||
|
||
# Assert that the generate_config_file method was called in the execution to create the kubeconfig for eks | ||
mock_generate_config_file.assert_called_once_with(eks_cluster_name="my-cluster", pod_namespace="default") | ||
|
||
# Assert that the kubernetes execute method was called in the execution | ||
mock_kubernetes_execute.assert_called_once() | ||
assert mock_kubernetes_execute.call_args.args[-1] == {} | ||
|
||
|
||
def test_provided_config_file_fails(): | ||
"""Tests that the constructor fails if it is called with a config_file.""" | ||
with pytest.raises(AirflowException) as err_context: | ||
DbtLSAwsEksOperator( | ||
conn_id="my_airflow_connection", | ||
cluster_name="my-cluster", | ||
task_id="my-task", | ||
image="my_image", | ||
project_dir="my/dir", | ||
config_file="my/config", | ||
) | ||
assert "The config_file is not an allowed parameter for the EksPodOperator." in str(err_context.value) |