Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add extras links to some more EMR Operators and Sensors #31032

Merged
merged 5 commits into from
May 3, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions airflow/providers/amazon/aws/links/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,14 @@
# under the License.
from __future__ import annotations

from typing import Any

import boto3

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink
from airflow.utils.helpers import exactly_one


class EmrClusterLink(BaseAwsLink):
Expand All @@ -33,3 +40,23 @@ class EmrLogsLink(BaseAwsLink):
name = "EMR Cluster Logs"
key = "emr_logs"
format_str = BASE_AWS_CONSOLE_LINK + "/s3/buckets/{log_uri}?region={region_name}&prefix={job_flow_id}/"


def get_log_uri(
*, cluster: dict[str, Any] | None = None, emr_client: boto3.client = None, job_flow_id: str | None = None
) -> str:
"""
Retrieves the S3 URI to the EMR Job logs. Requires either the output of a
describe_cluster call or both an EMR Client and a job_flow_id to look it up.
"""
if not exactly_one(cluster, emr_client and job_flow_id):
ferruzzi marked this conversation as resolved.
Show resolved Hide resolved
raise AirflowException(
"Requires either the output of a describe_cluster call or both an EMR Client and a job_flow_id."
)

if cluster:
log_uri = S3Hook.parse_s3_url(cluster["Cluster"]["LogUri"])
else:
response = emr_client.describe_cluster(ClusterId=job_flow_id)
log_uri = S3Hook.parse_s3_url(response["Cluster"]["LogUri"])
return "/".join(log_uri)
52 changes: 47 additions & 5 deletions airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
from airflow.providers.amazon.aws.links.emr import EmrClusterLink
from airflow.providers.amazon.aws.links.emr import EmrClusterLink, EmrLogsLink, get_log_uri
from airflow.providers.amazon.aws.utils.waiter import waiter
from airflow.utils.helpers import exactly_one, prune_dict
from airflow.utils.types import NOTSET, ArgNotSet
Expand Down Expand Up @@ -68,7 +68,10 @@ class EmrAddStepsOperator(BaseOperator):
template_ext: Sequence[str] = (".json",)
template_fields_renderers = {"steps": "json"}
ui_color = "#f9c915"
operator_extra_links = (EmrClusterLink(),)
operator_extra_links = (
EmrClusterLink(),
EmrLogsLink(),
)

def __init__(
self,
Expand Down Expand Up @@ -119,6 +122,14 @@ def execute(self, context: Context) -> list[str]:
aws_partition=emr_hook.conn_partition,
job_flow_id=job_flow_id,
)
EmrLogsLink.persist(
context=context,
operator=self,
region_name=emr_hook.conn_region_name,
aws_partition=emr_hook.conn_partition,
job_flow_id=self.job_flow_id,
log_uri=get_log_uri(emr_client=emr_hook.conn, job_flow_id=self.job_flow_id),
)

self.log.info("Adding steps to %s", job_flow_id)

Expand Down Expand Up @@ -597,7 +608,7 @@ class EmrCreateJobFlowOperator(BaseOperator):
template_ext: Sequence[str] = (".json",)
template_fields_renderers = {"job_flow_overrides": "json"}
ui_color = "#f9c915"
operator_extra_links = (EmrClusterLink(),)
operator_extra_links = (EmrLogsLink(),)
ferruzzi marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
Expand Down Expand Up @@ -671,6 +682,15 @@ def execute(self, context: Context) -> str | None:
aws_partition=self._emr_hook.conn_partition,
job_flow_id=self._job_flow_id,
)
if self._job_flow_id:
EmrLogsLink.persist(
context=context,
operator=self,
region_name=self._emr_hook.conn_region_name,
aws_partition=self._emr_hook.conn_partition,
job_flow_id=self._job_flow_id,
log_uri=get_log_uri(emr_client=self._emr_hook.conn, job_flow_id=self._job_flow_id),
)

if self.wait_for_completion:
self._emr_hook.get_waiter("job_flow_waiting").wait(
Expand Down Expand Up @@ -712,7 +732,10 @@ class EmrModifyClusterOperator(BaseOperator):
template_fields: Sequence[str] = ("cluster_id", "step_concurrency_level")
template_ext: Sequence[str] = ()
ui_color = "#f9c915"
operator_extra_links = (EmrClusterLink(),)
operator_extra_links = (
EmrClusterLink(),
EmrLogsLink(),
)

def __init__(
self, *, cluster_id: str, step_concurrency_level: int, aws_conn_id: str = "aws_default", **kwargs
Expand All @@ -736,6 +759,14 @@ def execute(self, context: Context) -> int:
aws_partition=emr_hook.conn_partition,
job_flow_id=self.cluster_id,
)
EmrLogsLink.persist(
context=context,
operator=self,
region_name=emr_hook.conn_region_name,
aws_partition=emr_hook.conn_partition,
job_flow_id=self.cluster_id,
log_uri=get_log_uri(emr_client=emr_hook.conn, job_flow_id=self.cluster_id),
)

self.log.info("Modifying cluster %s", self.cluster_id)
response = emr.modify_cluster(
Expand Down Expand Up @@ -764,7 +795,10 @@ class EmrTerminateJobFlowOperator(BaseOperator):
template_fields: Sequence[str] = ("job_flow_id",)
template_ext: Sequence[str] = ()
ui_color = "#f9c915"
operator_extra_links = (EmrClusterLink(),)
operator_extra_links = (
EmrClusterLink(),
EmrLogsLink(),
)

def __init__(self, *, job_flow_id: str, aws_conn_id: str = "aws_default", **kwargs):
super().__init__(**kwargs)
Expand All @@ -782,6 +816,14 @@ def execute(self, context: Context) -> None:
aws_partition=emr_hook.conn_partition,
job_flow_id=self.job_flow_id,
)
EmrLogsLink.persist(
context=context,
operator=self,
region_name=emr_hook.conn_region_name,
aws_partition=emr_hook.conn_partition,
job_flow_id=self.job_flow_id,
log_uri=get_log_uri(emr_client=emr, job_flow_id=self.job_flow_id),
)

self.log.info("Terminating JobFlow %s", self.job_flow_id)
response = emr.terminate_job_flows(JobFlowIds=[self.job_flow_id])
Expand Down
43 changes: 37 additions & 6 deletions airflow/providers/amazon/aws/sensors/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.links.emr import EmrLogsLink
from airflow.providers.amazon.aws.links.emr import EmrClusterLink, EmrLogsLink, get_log_uri
from airflow.sensors.base import BaseSensorOperator

if TYPE_CHECKING:
Expand Down Expand Up @@ -385,7 +384,10 @@ class EmrJobFlowSensor(EmrBaseSensor):

template_fields: Sequence[str] = ("job_flow_id", "target_states", "failed_states")
template_ext: Sequence[str] = ()
operator_extra_links = (EmrLogsLink(),)
operator_extra_links = (
EmrClusterLink(),
EmrLogsLink(),
)

def __init__(
self,
Expand All @@ -412,14 +414,21 @@ def get_emr_response(self, context: Context) -> dict[str, Any]:
emr_client = self.hook.conn
self.log.info("Poking cluster %s", self.job_flow_id)
response = emr_client.describe_cluster(ClusterId=self.job_flow_id)
log_uri = S3Hook.parse_s3_url(response["Cluster"]["LogUri"])

EmrClusterLink.persist(
context=context,
operator=self,
region_name=self.hook.conn_region_name,
aws_partition=self.hook.conn_partition,
job_flow_id=self.job_flow_id,
)
EmrLogsLink.persist(
context=context,
operator=self,
region_name=self.hook.conn_region_name,
aws_partition=self.hook.conn_partition,
job_flow_id=self.job_flow_id,
log_uri="/".join(log_uri),
log_uri=get_log_uri(cluster=response),
)
return response

Expand Down Expand Up @@ -472,6 +481,10 @@ class EmrStepSensor(EmrBaseSensor):

template_fields: Sequence[str] = ("job_flow_id", "step_id", "target_states", "failed_states")
template_ext: Sequence[str] = ()
operator_extra_links = (
EmrClusterLink(),
EmrLogsLink(),
)

def __init__(
self,
Expand Down Expand Up @@ -500,7 +513,25 @@ def get_emr_response(self, context: Context) -> dict[str, Any]:
emr_client = self.hook.conn

self.log.info("Poking step %s on cluster %s", self.step_id, self.job_flow_id)
return emr_client.describe_step(ClusterId=self.job_flow_id, StepId=self.step_id)
response = emr_client.describe_step(ClusterId=self.job_flow_id, StepId=self.step_id)

EmrClusterLink.persist(
context=context,
operator=self,
region_name=self.hook.conn_region_name,
aws_partition=self.hook.conn_partition,
job_flow_id=self.job_flow_id,
)
EmrLogsLink.persist(
context=context,
operator=self,
region_name=self.hook.conn_region_name,
aws_partition=self.hook.conn_partition,
job_flow_id=self.job_flow_id,
log_uri=get_log_uri(emr_client=emr_client, job_flow_id=self.job_flow_id),
)

return response

@staticmethod
def state_from_response(response: dict[str, Any]) -> str:
Expand Down
12 changes: 9 additions & 3 deletions tests/providers/amazon/aws/operators/test_emr_add_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

from airflow.exceptions import AirflowException
from airflow.models import DAG, DagRun, TaskInstance
from airflow.providers.amazon.aws.hooks.emr import EmrHook
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.operators.emr import EmrAddStepsOperator
from airflow.utils import timezone
from tests.test_utils import AIRFLOW_MAIN_FOLDER
Expand Down Expand Up @@ -117,7 +119,8 @@ def test_render_template(self):

assert self.operator.steps == expected_args

def test_render_template_from_file(self):
@patch.object(S3Hook, "parse_s3_url", return_value="valid_uri")
def test_render_template_from_file(self, _):
dag = DAG(
dag_id="test_file",
default_args=self.args,
Expand Down Expand Up @@ -161,7 +164,8 @@ def test_render_template_from_file(self):
JobFlowId="j-8989898989", Steps=file_steps
)

def test_execute_returns_step_id(self):
@patch.object(S3Hook, "parse_s3_url", return_value="valid_uri")
def test_execute_returns_step_id(self, _):
self.emr_client_mock.add_job_flow_steps.return_value = ADD_STEPS_SUCCESS_RETURN

with patch("boto3.session.Session", self.boto3_session_mock), patch(
Expand Down Expand Up @@ -217,8 +221,10 @@ def test_init_with_nonexistent_cluster_name(self):
operator.execute(self.mock_context)
assert str(ctx.value) == f"No cluster found for name: {cluster_name}"

@patch.object(EmrHook, "conn")
@patch.object(S3Hook, "parse_s3_url", return_value="valid_uri")
@patch("airflow.providers.amazon.aws.hooks.emr.EmrHook.add_job_flow_steps")
def test_wait_for_completion(self, mock_add_job_flow_steps):
def test_wait_for_completion(self, mock_add_job_flow_steps, *_):
job_flow_id = "j-8989898989"
operator = EmrAddStepsOperator(
task_id="test_task",
Expand Down
10 changes: 7 additions & 3 deletions tests/providers/amazon/aws/operators/test_emr_create_job_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from jinja2 import StrictUndefined

from airflow.models import DAG, DagRun, TaskInstance
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.operators.emr import EmrCreateJobFlowOperator
from airflow.utils import timezone
from tests.providers.amazon.aws.utils.test_waiter import assert_expected_waiter_type
Expand Down Expand Up @@ -114,7 +115,8 @@ def test_render_template(self):

assert self.operator.job_flow_overrides == expected_args

def test_render_template_from_file(self):
@patch.object(S3Hook, "parse_s3_url", return_value="valid_uri")
def test_render_template_from_file(self, _):
self.operator.job_flow_overrides = "job.j2.json"
self.operator.params = {"releaseLabel": "5.11.0"}

Expand Down Expand Up @@ -156,7 +158,8 @@ def test_render_template_from_file(self):

assert self.operator.job_flow_overrides == expected_args

def test_execute_returns_job_id(self):
@patch.object(S3Hook, "parse_s3_url", return_value="valid_uri")
def test_execute_returns_job_id(self, _):
self.emr_client_mock.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN

# Mock out the emr_client creator
Expand All @@ -170,9 +173,10 @@ def test_execute_returns_job_id(self):
mock_isinstance.return_value = True
assert self.operator.execute(self.mock_context) == JOB_FLOW_ID

@patch.object(S3Hook, "parse_s3_url", return_value="valid_uri")
@mock.patch("botocore.waiter.get_service_module_name", return_value="emr")
@mock.patch.object(Waiter, "wait")
def test_execute_with_wait(self, mock_waiter, _):
def test_execute_with_wait(self, mock_waiter, *_):
self.emr_client_mock.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN

# Mock out the emr_client creator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from airflow.exceptions import AirflowException
from airflow.models.dag import DAG
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.operators.emr import EmrModifyClusterOperator
from airflow.utils import timezone

Expand Down Expand Up @@ -60,7 +61,8 @@ def test_init(self):
assert self.operator.step_concurrency_level == 1
assert self.operator.aws_conn_id == "aws_default"

def test_execute_returns_step_concurrency(self):
@patch.object(S3Hook, "parse_s3_url", return_value="valid_uri")
def test_execute_returns_step_concurrency(self, _):
self.emr_client_mock.modify_cluster.return_value = MODIFY_CLUSTER_SUCCESS_RETURN

with patch("boto3.session.Session", self.boto3_session_mock), patch(
Expand All @@ -69,7 +71,8 @@ def test_execute_returns_step_concurrency(self):
mock_isinstance.return_value = True
assert self.operator.execute(self.mock_context) == 1

def test_execute_returns_error(self):
@patch.object(S3Hook, "parse_s3_url", return_value="valid_uri")
def test_execute_returns_error(self, _):
self.emr_client_mock.modify_cluster.return_value = MODIFY_CLUSTER_ERROR_RETURN

with patch("boto3.session.Session", self.boto3_session_mock), patch(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from unittest.mock import MagicMock, patch

from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.operators.emr import EmrTerminateJobFlowOperator

TERMINATE_SUCCESS_RETURN = {"ResponseMetadata": {"HTTPStatusCode": 200}}
Expand All @@ -36,7 +37,8 @@ def setup_method(self):
# Mock out the emr_client creator
self.boto3_session_mock = MagicMock(return_value=mock_emr_session)

def test_execute_terminates_the_job_flow_and_does_not_error(self):
@patch.object(S3Hook, "parse_s3_url", return_value="valid_uri")
def test_execute_terminates_the_job_flow_and_does_not_error(self, _):
with patch("boto3.session.Session", self.boto3_session_mock), patch(
"airflow.providers.amazon.aws.hooks.base_aws.isinstance"
) as mock_isinstance:
Expand Down
6 changes: 3 additions & 3 deletions tests/providers/amazon/aws/sensors/test_emr_job_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from dateutil.tz import tzlocal

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.sensors.emr import EmrJobFlowSensor

DESCRIBE_CLUSTER_STARTING_RETURN = {
Expand Down Expand Up @@ -202,7 +203,8 @@ def setup_method(self):
# Mock context used in execute function
self.mock_ctx = MagicMock()

def test_execute_calls_with_the_job_flow_id_until_it_reaches_a_target_state(self):
@patch.object(S3Hook, "parse_s3_url", return_value="valid_uri")
def test_execute_calls_with_the_job_flow_id_until_it_reaches_a_target_state(self, _):
self.mock_emr_client.describe_cluster.side_effect = [
DESCRIBE_CLUSTER_STARTING_RETURN,
DESCRIBE_CLUSTER_RUNNING_RETURN,
Expand All @@ -218,7 +220,6 @@ def test_execute_calls_with_the_job_flow_id_until_it_reaches_a_target_state(self

operator.execute(self.mock_ctx)

# make sure we called twice
assert self.mock_emr_client.describe_cluster.call_count == 3

# make sure it was called with the job_flow_id
Expand Down Expand Up @@ -270,7 +271,6 @@ def test_different_target_states(self):

operator.execute(self.mock_ctx)

# make sure we called twice
assert self.mock_emr_client.describe_cluster.call_count == 3

# make sure it was called with the job_flow_id
Expand Down
Loading