Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add on_kill to EMR Serverless Job Operator #31169

Merged
merged 4 commits into from
May 12, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 41 additions & 5 deletions airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,7 +889,7 @@ def hook(self) -> EmrServerlessHook:
"""Create and return an EmrServerlessHook."""
return EmrServerlessHook(aws_conn_id=self.aws_conn_id)

def execute(self, context: Context):
def execute(self, context: Context) -> str | None:
response = self.hook.conn.create_application(
clientToken=self.client_request_token,
releaseLabel=self.release_label,
Expand Down Expand Up @@ -994,6 +994,7 @@ def __init__(
self.name = name or self.config.pop("name", f"emr_serverless_job_airflow_{uuid4()}")
self.waiter_countdown = waiter_countdown
self.waiter_check_interval_seconds = waiter_check_interval_seconds
self.job_id: str | None = None
super().__init__(**kwargs)

self.client_request_token = client_request_token or str(uuid4())
Expand All @@ -1003,7 +1004,7 @@ def hook(self) -> EmrServerlessHook:
"""Create and return an EmrServerlessHook."""
return EmrServerlessHook(aws_conn_id=self.aws_conn_id)

def execute(self, context: Context) -> dict:
def execute(self, context: Context) -> str | None:
self.log.info("Starting job on Application: %s", self.application_id)

app_state = self.hook.conn.get_application(applicationId=self.application_id)["application"]["state"]
Expand Down Expand Up @@ -1035,14 +1036,15 @@ def execute(self, context: Context) -> dict:
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
raise AirflowException(f"EMR serverless job failed to start: {response}")

self.log.info("EMR serverless job started: %s", response["jobRunId"])
self.job_id = response["jobRunId"]
self.log.info("EMR serverless job started: %s", self.job_id)
if self.wait_for_completion:
# This should be replaced with a boto waiter when available.
dacort marked this conversation as resolved.
Show resolved Hide resolved
waiter(
get_state_callable=self.hook.conn.get_job_run,
get_state_args={
"applicationId": self.application_id,
"jobRunId": response["jobRunId"],
"jobRunId": self.job_id,
},
parse_response=["jobRun", "state"],
desired_state=EmrServerlessHook.JOB_SUCCESS_STATES,
Expand All @@ -1052,7 +1054,41 @@ def execute(self, context: Context) -> dict:
countdown=self.waiter_countdown,
check_interval_seconds=self.waiter_check_interval_seconds,
)
return response["jobRunId"]
return self.job_id

def on_kill(self) -> None:
"""Cancel the submitted job run"""
if self.job_id:
self.log.info("Stopping job run with jobId - %s", self.job_id)
response = self.hook.conn.cancel_job_run(applicationId=self.application_id, jobRunId=self.job_id)
http_status_code = None
try:
http_status_code = response["ResponseMetadata"]["HTTPStatusCode"]
except Exception as ex:
self.log.error("Exception while cancelling query: %s", ex)
finally:
if http_status_code is None or http_status_code != 200:
self.log.error("Unable to request query cancel on EMR Serverless. Exiting")
else:
self.log.info(
"Polling EMR Serverless for query with id %s to reach final state",
self.job_id,
)
# This should be replaced with a boto waiter when available.
waiter(
get_state_callable=self.hook.conn.get_job_run,
get_state_args={
"applicationId": self.application_id,
"jobRunId": self.job_id,
},
parse_response=["jobRun", "state"],
desired_state=EmrServerlessHook.JOB_TERMINAL_STATES,
failure_states=set(),
object_type="job",
action="cancelled",
countdown=self.waiter_countdown,
check_interval_seconds=self.waiter_check_interval_seconds,
)
dacort marked this conversation as resolved.
Show resolved Hide resolved


class EmrServerlessStopApplicationOperator(BaseOperator):
Expand Down
26 changes: 26 additions & 0 deletions tests/providers/amazon/aws/operators/test_emr_serverless.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,32 @@ def test_start_job_custom_name(self, mock_conn):
name=custom_name,
)

@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
def test_cancel_job_run(self, mock_conn):
mock_conn.get_application.return_value = {"application": {"state": "STARTED"}}
mock_conn.start_job_run.return_value = {
"jobRunId": job_run_id,
"ResponseMetadata": {"HTTPStatusCode": 200},
}
mock_conn.get_job_run.return_value = {"jobRun": {"state": "RUNNING"}}

operator = EmrServerlessStartJobOperator(
task_id=task_id,
client_request_token=client_request_token,
application_id=application_id,
execution_role_arn=execution_role_arn,
job_driver=job_driver,
configuration_overrides=configuration_overrides,
wait_for_completion=False,
)

id = operator.execute(None)
operator.on_kill()
mock_conn.cancel_job_run.assert_called_once_with(
applicationId=application_id,
jobRunId=id,
)


class TestEmrServerlessDeleteOperator:
@mock.patch("airflow.providers.amazon.aws.operators.emr.waiter")
Expand Down