diff --git a/airflow/providers/amazon/aws/operators/emr_create_job_flow.py b/airflow/providers/amazon/aws/operators/emr_create_job_flow.py index 0b8aace12e442..68c8dbb5b5f54 100644 --- a/airflow/providers/amazon/aws/operators/emr_create_job_flow.py +++ b/airflow/providers/amazon/aws/operators/emr_create_job_flow.py @@ -16,13 +16,36 @@ # specific language governing permissions and limitations # under the License. import ast +from datetime import datetime from typing import Any, Dict, Optional, Union from airflow.exceptions import AirflowException -from airflow.models import BaseOperator +from airflow.models import BaseOperator, BaseOperatorLink, TaskInstance from airflow.providers.amazon.aws.hooks.emr import EmrHook +class EmrClusterLink(BaseOperatorLink): + """Operator link for EmrCreateJobFlowOperator. It allows users to access the EMR Cluster""" + + name = 'EMR Cluster' + + def get_link(self, operator: BaseOperator, dttm: datetime) -> str: + """ + Get link to EMR cluster. + + :param operator: operator + :param dttm: datetime + :return: url link + """ + ti = TaskInstance(task=operator, execution_date=dttm) + flow_id = ti.xcom_pull(task_ids=operator.task_id) + return ( + f'https://console.aws.amazon.com/elasticmapreduce/home#cluster-details:{flow_id}' + if flow_id + else '' + ) + + class EmrCreateJobFlowOperator(BaseOperator): """ Creates an EMR JobFlow, reading the config from the EMR connection. @@ -44,6 +67,7 @@ class EmrCreateJobFlowOperator(BaseOperator): template_ext = ('.json',) template_fields_renderers = {"job_flow_overrides": "json"} ui_color = '#f9c915' + operator_extra_links = (EmrClusterLink(),) def __init__( self, diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index 839765f6199a8..c3eb4fb5d5a5c 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -443,6 +443,9 @@ hook-class-names: # deprecated - to be removed after providers add dependency o - airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook - airflow.providers.amazon.aws.hooks.emr.EmrHook +extra-links: + - airflow.providers.amazon.aws.operators.emr_create_job_flow.EmrClusterLink + connection-types: - hook-class-name: airflow.providers.amazon.aws.hooks.s3.S3Hook connection-type: s3 diff --git a/scripts/in_container/run_install_and_test_provider_packages.sh b/scripts/in_container/run_install_and_test_provider_packages.sh index bd5b060f6e010..1e80dd3ded6a7 100755 --- a/scripts/in_container/run_install_and_test_provider_packages.sh +++ b/scripts/in_container/run_install_and_test_provider_packages.sh @@ -130,7 +130,7 @@ function discover_all_extra_links() { local actual_number_of_extra_links actual_number_of_extra_links=$(airflow providers links --output table | grep -c ^airflow.providers | xargs) - if (( actual_number_of_extra_links < 6 )); then + if (( actual_number_of_extra_links < 7 )); then echo echo "${COLOR_RED}ERROR: Number of links registered is wrong: ${actual_number_of_extra_links} ${COLOR_RESET}" echo diff --git a/tests/core/test_providers_manager.py b/tests/core/test_providers_manager.py index 07cb2af5d6444..17356f1456e9a 100644 --- a/tests/core/test_providers_manager.py +++ b/tests/core/test_providers_manager.py @@ -164,7 +164,7 @@ def test_field_behaviours(self): def test_extra_links(self): provider_manager = ProvidersManager() extra_link_class_names = list(provider_manager.extra_links_class_names) - assert len(extra_link_class_names) > 5 + assert len(extra_link_class_names) > 6 def test_logging(self): provider_manager = ProvidersManager() diff --git a/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py b/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py index 73b6f22a8bbac..11a963c93e748 100644 --- a/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py +++ b/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py @@ -22,13 +22,23 @@ from datetime import timedelta from unittest.mock import MagicMock, patch +import pytest from jinja2 import StrictUndefined from airflow.models import DAG, DagRun, TaskInstance -from airflow.providers.amazon.aws.operators.emr_create_job_flow import EmrCreateJobFlowOperator +from airflow.models.xcom import XCOM_RETURN_KEY +from airflow.providers.amazon.aws.operators.emr_create_job_flow import ( + EmrClusterLink, + EmrCreateJobFlowOperator, +) +from airflow.serialization.serialized_objects import SerializedDAG from airflow.utils import timezone from tests.test_utils import AIRFLOW_MAIN_FOLDER +TASK_ID = 'test_task' + +TEST_DAG_ID = 'test_dag_id' + DEFAULT_DATE = timezone.datetime(2017, 1, 1) RUN_JOB_FLOW_SUCCESS_RETURN = {'ResponseMetadata': {'HTTPStatusCode': 200}, 'JobFlowId': 'j-8989898989'} @@ -61,12 +71,12 @@ def setUp(self): # Mock out the emr_client (moto has incorrect response) self.emr_client_mock = MagicMock() self.operator = EmrCreateJobFlowOperator( - task_id='test_task', + task_id=TASK_ID, aws_conn_id='aws_default', emr_conn_id='emr_default', region_name='ap-southeast-2', dag=DAG( - 'test_dag_id', + TEST_DAG_ID, default_args=args, template_searchpath=TEMPLATE_SEARCHPATH, template_undefined=StrictUndefined, @@ -155,3 +165,41 @@ def test_execute_returns_job_id(self): with patch('boto3.session.Session', boto3_session_mock): assert self.operator.execute(None) == 'j-8989898989' + + +@pytest.mark.need_serialized_dag +def test_operator_extra_links(dag_maker, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + EmrCreateJobFlowOperator, dag_id=TEST_DAG_ID, execution_date=DEFAULT_DATE, task_id=TASK_ID + ) + + serialized_dag = dag_maker.get_serialized_data() + deserialized_dag = SerializedDAG.from_dict(serialized_dag) + deserialized_task = deserialized_dag.task_dict[TASK_ID] + + assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [ + {"airflow.providers.amazon.aws.operators.emr_create_job_flow.EmrClusterLink": {}} + ], "Operator links should exist for serialized DAG" + + assert isinstance( + deserialized_task.operator_extra_links[0], EmrClusterLink + ), "Operator link type should be preserved during deserialization" + + assert ( + ti.task.get_extra_links(DEFAULT_DATE, EmrClusterLink.name) == "" + ), "Operator link should only be added if job id is available in XCom" + + assert ( + deserialized_task.get_extra_links(DEFAULT_DATE, EmrClusterLink.name) == "" + ), "Operator link should be empty for deserialized task with no XCom push" + + ti.xcom_push(key=XCOM_RETURN_KEY, value='j-SomeClusterId') + + expected = "https://console.aws.amazon.com/elasticmapreduce/home#cluster-details:j-SomeClusterId" + assert ( + deserialized_task.get_extra_links(DEFAULT_DATE, EmrClusterLink.name) == expected + ), "Operator link should be preserved in deserialized tasks after execution" + + assert ( + ti.task.get_extra_links(DEFAULT_DATE, EmrClusterLink.name) == expected + ), "Operator link should be preserved after execution"