Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EMR Serverless Fix for Jobs marked as success even on failure #26218

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions airflow/providers/amazon/aws/hooks/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,15 @@ class EmrServerlessHook(AwsBaseHook):
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""

JOB_INTERMEDIATE_STATES = {'PENDING', 'RUNNING', 'SCHEDULED', 'SUBMITTED'}
JOB_FAILURE_STATES = {'FAILED', 'CANCELLING', 'CANCELLED'}
JOB_SUCCESS_STATES = {'SUCCESS'}
JOB_TERMINAL_STATES = JOB_SUCCESS_STATES.union(JOB_FAILURE_STATES)

APPLICATION_INTERMEDIATE_STATES = {'CREATING', 'STARTING', 'STOPPING'}
APPLICATION_FAILURE_STATES = {'STOPPED', 'TERMINATED'}
APPLICATION_SUCCESS_STATES = {'CREATED', 'STARTED'}

def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs["client_type"] = "emr-serverless"
super().__init__(*args, **kwargs)
Expand Down
17 changes: 8 additions & 9 deletions airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
from airflow.providers.amazon.aws.links.emr import EmrClusterLink
from airflow.providers.amazon.aws.sensors.emr import EmrServerlessApplicationSensor, EmrServerlessJobSensor

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -552,7 +551,7 @@ def execute(self, context: Context):
get_state_args={'applicationId': application_id},
parse_response=['application', 'state'],
desired_state={'CREATED'},
failure_states=EmrServerlessApplicationSensor.FAILURE_STATES,
failure_states=EmrServerlessHook.APPLICATION_FAILURE_STATES,
object_type='application',
action='created',
)
Expand All @@ -567,7 +566,7 @@ def execute(self, context: Context):
get_state_args={'applicationId': application_id},
parse_response=['application', 'state'],
desired_state={'STARTED'},
failure_states=EmrServerlessApplicationSensor.FAILURE_STATES,
failure_states=EmrServerlessHook.APPLICATION_FAILURE_STATES,
object_type='application',
action='started',
)
Expand Down Expand Up @@ -633,15 +632,15 @@ def execute(self, context: Context) -> dict:
self.log.info('Starting job on Application: %s', self.application_id)

app_state = self.hook.conn.get_application(applicationId=self.application_id)['application']['state']
if app_state not in EmrServerlessApplicationSensor.SUCCESS_STATES:
if app_state not in EmrServerlessHook.APPLICATION_SUCCESS_STATES:
self.hook.conn.start_application(applicationId=self.application_id)

self.hook.waiter(
get_state_callable=self.hook.conn.get_application,
get_state_args={'applicationId': self.application_id},
parse_response=['application', 'state'],
desired_state={'STARTED'},
failure_states=EmrServerlessApplicationSensor.FAILURE_STATES,
failure_states=EmrServerlessHook.JOB_FAILURE_STATES,
object_type='application',
action='started',
)
Expand All @@ -668,8 +667,8 @@ def execute(self, context: Context) -> dict:
'jobRunId': response['jobRunId'],
},
parse_response=['jobRun', 'state'],
desired_state=EmrServerlessJobSensor.TERMINAL_STATES,
failure_states=EmrServerlessJobSensor.FAILURE_STATES,
desired_state=EmrServerlessHook.JOB_SUCCESS_STATES,
failure_states=EmrServerlessHook.JOB_FAILURE_STATES,
object_type='job',
action='run',
)
Expand Down Expand Up @@ -719,7 +718,7 @@ def execute(self, context: Context) -> None:
'applicationId': self.application_id,
},
parse_response=['application', 'state'],
desired_state=EmrServerlessApplicationSensor.FAILURE_STATES,
desired_state=EmrServerlessHook.APPLICATION_FAILURE_STATES,
failure_states=set(),
object_type='application',
action='stopped',
Expand All @@ -738,7 +737,7 @@ def execute(self, context: Context) -> None:
get_state_args={'applicationId': self.application_id},
parse_response=['application', 'state'],
desired_state={'TERMINATED'},
failure_states=EmrServerlessApplicationSensor.FAILURE_STATES,
failure_states=EmrServerlessHook.APPLICATION_FAILURE_STATES,
object_type='application',
action='deleted',
)
Expand Down
17 changes: 4 additions & 13 deletions airflow/providers/amazon/aws/sensors/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,6 @@ class EmrServerlessJobSensor(BaseSensorOperator):
:param aws_conn_id: aws connection to use, defaults to 'aws_default'
"""

INTERMEDIATE_STATES = {'PENDING', 'RUNNING', 'SCHEDULED', 'SUBMITTED'}
FAILURE_STATES = {'FAILED', 'CANCELLING', 'CANCELLED'}
SUCCESS_STATES = {'SUCCESS'}
TERMINAL_STATES = SUCCESS_STATES.union(FAILURE_STATES)

template_fields: Sequence[str] = (
'application_id',
'job_run_id',
Expand All @@ -144,7 +139,7 @@ def __init__(
*,
application_id: str,
job_run_id: str,
target_states: set | frozenset = frozenset(SUCCESS_STATES),
target_states: set | frozenset = frozenset(EmrServerlessHook.JOB_SUCCESS_STATES),
aws_conn_id: str = 'aws_default',
**kwargs: Any,
) -> None:
Expand All @@ -159,7 +154,7 @@ def poke(self, context: Context) -> bool:

state = response['jobRun']['state']

if state in self.FAILURE_STATES:
if state in EmrServerlessHook.JOB_FAILURE_STATES:
failure_message = f"EMR Serverless job failed: {self.failure_message_from_response(response)}"
raise AirflowException(failure_message)

Expand Down Expand Up @@ -198,15 +193,11 @@ class EmrServerlessApplicationSensor(BaseSensorOperator):

template_fields: Sequence[str] = ('application_id',)

INTERMEDIATE_STATES = {'CREATING', 'STARTING', 'STOPPING'}
FAILURE_STATES = {'STOPPED', 'TERMINATED'}
SUCCESS_STATES = {'CREATED', 'STARTED'}

def __init__(
self,
*,
application_id: str,
target_states: set | frozenset = frozenset(SUCCESS_STATES),
target_states: set | frozenset = frozenset(EmrServerlessHook.APPLICATION_SUCCESS_STATES),
aws_conn_id: str = 'aws_default',
**kwargs: Any,
) -> None:
Expand All @@ -220,7 +211,7 @@ def poke(self, context: Context) -> bool:

state = response['application']['state']

if state in self.FAILURE_STATES:
if state in EmrServerlessHook.APPLICATION_FAILURE_STATES:
failure_message = f"EMR Serverless job failed: {self.failure_message_from_response(response)}"
raise AirflowException(failure_message)

Expand Down
32 changes: 32 additions & 0 deletions tests/providers/amazon/aws/operators/test_emr_serverless.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,38 @@ def test_job_run_app_started(self, mock_conn, mock_waiter):
configurationOverrides=configuration_overrides,
)

@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
def test_job_run_job_failed(self, mock_conn):
mock_conn.get_application.return_value = {"application": {"state": "STARTED"}}
mock_conn.start_job_run.return_value = {
'jobRunId': job_run_id,
'ResponseMetadata': {'HTTPStatusCode': 200},
}

mock_conn.get_job_run.return_value = {'jobRun': {'state': 'FAILED'}}

operator = EmrServerlessStartJobOperator(
task_id=task_id,
client_request_token=client_request_token,
application_id=application_id,
execution_role_arn=execution_role_arn,
job_driver=job_driver,
configuration_overrides=configuration_overrides,
)
with pytest.raises(AirflowException) as ex_message:
id = operator.execute(None)
assert id == job_run_id
assert "Job reached failure state FAILED." in str(ex_message.value)
mock_conn.get_application.assert_called_once_with(applicationId=application_id)
mock_conn.get_job_run.assert_called_once_with(applicationId=application_id, jobRunId=job_run_id)
mock_conn.start_job_run.assert_called_once_with(
clientToken=client_request_token,
applicationId=application_id,
executionRoleArn=execution_role_arn,
jobDriver=job_driver,
configurationOverrides=configuration_overrides,
)

@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter")
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
def test_job_run_app_not_started(self, mock_conn, mock_waiter):
Expand Down