Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add emr cluster link #18691

Merged
merged 3 commits into from
Oct 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion airflow/providers/amazon/aws/operators/emr_create_job_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,36 @@
# specific language governing permissions and limitations
# under the License.
import ast
from datetime import datetime
from typing import Any, Dict, Optional, Union

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.models import BaseOperator, BaseOperatorLink, TaskInstance
from airflow.providers.amazon.aws.hooks.emr import EmrHook


class EmrClusterLink(BaseOperatorLink):
"""Operator link for EmrCreateJobFlowOperator. It allows users to access the EMR Cluster"""

name = 'EMR Cluster'

def get_link(self, operator: BaseOperator, dttm: datetime) -> str:
"""
Get link to EMR cluster.

:param operator: operator
:param dttm: datetime
:return: url link
"""
ti = TaskInstance(task=operator, execution_date=dttm)
flow_id = ti.xcom_pull(task_ids=operator.task_id)
return (
f'https://console.aws.amazon.com/elasticmapreduce/home#cluster-details:{flow_id}'
if flow_id
else ''
)


class EmrCreateJobFlowOperator(BaseOperator):
"""
Creates an EMR JobFlow, reading the config from the EMR connection.
Expand All @@ -44,6 +67,7 @@ class EmrCreateJobFlowOperator(BaseOperator):
template_ext = ('.json',)
template_fields_renderers = {"job_flow_overrides": "json"}
ui_color = '#f9c915'
operator_extra_links = (EmrClusterLink(),)

def __init__(
self,
Expand Down
3 changes: 3 additions & 0 deletions airflow/providers/amazon/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,9 @@ hook-class-names: # deprecated - to be removed after providers add dependency o
- airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook
- airflow.providers.amazon.aws.hooks.emr.EmrHook

extra-links:
- airflow.providers.amazon.aws.operators.emr_create_job_flow.EmrClusterLink

connection-types:
- hook-class-name: airflow.providers.amazon.aws.hooks.s3.S3Hook
connection-type: s3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ function discover_all_extra_links() {

local actual_number_of_extra_links
actual_number_of_extra_links=$(airflow providers links --output table | grep -c ^airflow.providers | xargs)
if (( actual_number_of_extra_links < 6 )); then
if (( actual_number_of_extra_links < 7 )); then
echo
echo "${COLOR_RED}ERROR: Number of links registered is wrong: ${actual_number_of_extra_links} ${COLOR_RESET}"
echo
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_providers_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def test_field_behaviours(self):
def test_extra_links(self):
provider_manager = ProvidersManager()
extra_link_class_names = list(provider_manager.extra_links_class_names)
assert len(extra_link_class_names) > 5
assert len(extra_link_class_names) > 6

def test_logging(self):
provider_manager = ProvidersManager()
Expand Down
54 changes: 51 additions & 3 deletions tests/providers/amazon/aws/operators/test_emr_create_job_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,23 @@
from datetime import timedelta
from unittest.mock import MagicMock, patch

import pytest
from jinja2 import StrictUndefined

from airflow.models import DAG, DagRun, TaskInstance
from airflow.providers.amazon.aws.operators.emr_create_job_flow import EmrCreateJobFlowOperator
from airflow.models.xcom import XCOM_RETURN_KEY
from airflow.providers.amazon.aws.operators.emr_create_job_flow import (
EmrClusterLink,
EmrCreateJobFlowOperator,
)
from airflow.serialization.serialized_objects import SerializedDAG
from airflow.utils import timezone
from tests.test_utils import AIRFLOW_MAIN_FOLDER

TASK_ID = 'test_task'

TEST_DAG_ID = 'test_dag_id'

DEFAULT_DATE = timezone.datetime(2017, 1, 1)

RUN_JOB_FLOW_SUCCESS_RETURN = {'ResponseMetadata': {'HTTPStatusCode': 200}, 'JobFlowId': 'j-8989898989'}
Expand Down Expand Up @@ -61,12 +71,12 @@ def setUp(self):
# Mock out the emr_client (moto has incorrect response)
self.emr_client_mock = MagicMock()
self.operator = EmrCreateJobFlowOperator(
task_id='test_task',
task_id=TASK_ID,
aws_conn_id='aws_default',
emr_conn_id='emr_default',
region_name='ap-southeast-2',
dag=DAG(
'test_dag_id',
TEST_DAG_ID,
default_args=args,
template_searchpath=TEMPLATE_SEARCHPATH,
template_undefined=StrictUndefined,
Expand Down Expand Up @@ -155,3 +165,41 @@ def test_execute_returns_job_id(self):

with patch('boto3.session.Session', boto3_session_mock):
assert self.operator.execute(None) == 'j-8989898989'


@pytest.mark.need_serialized_dag
def test_operator_extra_links(dag_maker, create_task_instance_of_operator):
ti = create_task_instance_of_operator(
EmrCreateJobFlowOperator, dag_id=TEST_DAG_ID, execution_date=DEFAULT_DATE, task_id=TASK_ID
)

serialized_dag = dag_maker.get_serialized_data()
deserialized_dag = SerializedDAG.from_dict(serialized_dag)
deserialized_task = deserialized_dag.task_dict[TASK_ID]

assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [
{"airflow.providers.amazon.aws.operators.emr_create_job_flow.EmrClusterLink": {}}
], "Operator links should exist for serialized DAG"

assert isinstance(
deserialized_task.operator_extra_links[0], EmrClusterLink
), "Operator link type should be preserved during deserialization"

assert (
ti.task.get_extra_links(DEFAULT_DATE, EmrClusterLink.name) == ""
), "Operator link should only be added if job id is available in XCom"

assert (
deserialized_task.get_extra_links(DEFAULT_DATE, EmrClusterLink.name) == ""
), "Operator link should be empty for deserialized task with no XCom push"

ti.xcom_push(key=XCOM_RETURN_KEY, value='j-SomeClusterId')

expected = "https://console.aws.amazon.com/elasticmapreduce/home#cluster-details:j-SomeClusterId"
assert (
deserialized_task.get_extra_links(DEFAULT_DATE, EmrClusterLink.name) == expected
), "Operator link should be preserved in deserialized tasks after execution"

assert (
ti.task.get_extra_links(DEFAULT_DATE, EmrClusterLink.name) == expected
), "Operator link should be preserved after execution"