diff --git a/airflow/providers/amazon/aws/hooks/emr.py b/airflow/providers/amazon/aws/hooks/emr.py index ebb2a79ef84e7..fb18c32394715 100644 --- a/airflow/providers/amazon/aws/hooks/emr.py +++ b/airflow/providers/amazon/aws/hooks/emr.py @@ -105,6 +105,15 @@ class EmrServerlessHook(AwsBaseHook): :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` """ + JOB_INTERMEDIATE_STATES = {'PENDING', 'RUNNING', 'SCHEDULED', 'SUBMITTED'} + JOB_FAILURE_STATES = {'FAILED', 'CANCELLING', 'CANCELLED'} + JOB_SUCCESS_STATES = {'SUCCESS'} + JOB_TERMINAL_STATES = JOB_SUCCESS_STATES.union(JOB_FAILURE_STATES) + + APPLICATION_INTERMEDIATE_STATES = {'CREATING', 'STARTING', 'STOPPING'} + APPLICATION_FAILURE_STATES = {'STOPPED', 'TERMINATED'} + APPLICATION_SUCCESS_STATES = {'CREATED', 'STARTED'} + def __init__(self, *args: Any, **kwargs: Any) -> None: kwargs["client_type"] = "emr-serverless" super().__init__(*args, **kwargs) diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index 5028dfed86a05..5ccff487e95ee 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -26,7 +26,6 @@ from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook from airflow.providers.amazon.aws.links.emr import EmrClusterLink -from airflow.providers.amazon.aws.sensors.emr import EmrServerlessApplicationSensor, EmrServerlessJobSensor if TYPE_CHECKING: from airflow.utils.context import Context @@ -552,7 +551,7 @@ def execute(self, context: Context): get_state_args={'applicationId': application_id}, parse_response=['application', 'state'], desired_state={'CREATED'}, - failure_states=EmrServerlessApplicationSensor.FAILURE_STATES, + failure_states=EmrServerlessHook.APPLICATION_FAILURE_STATES, object_type='application', action='created', ) @@ -567,7 +566,7 @@ def execute(self, context: Context): get_state_args={'applicationId': application_id}, parse_response=['application', 'state'], desired_state={'STARTED'}, - failure_states=EmrServerlessApplicationSensor.FAILURE_STATES, + failure_states=EmrServerlessHook.APPLICATION_FAILURE_STATES, object_type='application', action='started', ) @@ -633,7 +632,7 @@ def execute(self, context: Context) -> dict: self.log.info('Starting job on Application: %s', self.application_id) app_state = self.hook.conn.get_application(applicationId=self.application_id)['application']['state'] - if app_state not in EmrServerlessApplicationSensor.SUCCESS_STATES: + if app_state not in EmrServerlessHook.APPLICATION_SUCCESS_STATES: self.hook.conn.start_application(applicationId=self.application_id) self.hook.waiter( @@ -641,7 +640,7 @@ def execute(self, context: Context) -> dict: get_state_args={'applicationId': self.application_id}, parse_response=['application', 'state'], desired_state={'STARTED'}, - failure_states=EmrServerlessApplicationSensor.FAILURE_STATES, + failure_states=EmrServerlessHook.JOB_FAILURE_STATES, object_type='application', action='started', ) @@ -668,8 +667,8 @@ def execute(self, context: Context) -> dict: 'jobRunId': response['jobRunId'], }, parse_response=['jobRun', 'state'], - desired_state=EmrServerlessJobSensor.TERMINAL_STATES, - failure_states=EmrServerlessJobSensor.FAILURE_STATES, + desired_state=EmrServerlessHook.JOB_SUCCESS_STATES, + failure_states=EmrServerlessHook.JOB_FAILURE_STATES, object_type='job', action='run', ) @@ -719,7 +718,7 @@ def execute(self, context: Context) -> None: 'applicationId': self.application_id, }, parse_response=['application', 'state'], - desired_state=EmrServerlessApplicationSensor.FAILURE_STATES, + desired_state=EmrServerlessHook.APPLICATION_FAILURE_STATES, failure_states=set(), object_type='application', action='stopped', @@ -738,7 +737,7 @@ def execute(self, context: Context) -> None: get_state_args={'applicationId': self.application_id}, parse_response=['application', 'state'], desired_state={'TERMINATED'}, - failure_states=EmrServerlessApplicationSensor.FAILURE_STATES, + failure_states=EmrServerlessHook.APPLICATION_FAILURE_STATES, object_type='application', action='deleted', ) diff --git a/airflow/providers/amazon/aws/sensors/emr.py b/airflow/providers/amazon/aws/sensors/emr.py index 0b1c5c686b6dd..4759b3d8388a6 100644 --- a/airflow/providers/amazon/aws/sensors/emr.py +++ b/airflow/providers/amazon/aws/sensors/emr.py @@ -129,11 +129,6 @@ class EmrServerlessJobSensor(BaseSensorOperator): :param aws_conn_id: aws connection to use, defaults to 'aws_default' """ - INTERMEDIATE_STATES = {'PENDING', 'RUNNING', 'SCHEDULED', 'SUBMITTED'} - FAILURE_STATES = {'FAILED', 'CANCELLING', 'CANCELLED'} - SUCCESS_STATES = {'SUCCESS'} - TERMINAL_STATES = SUCCESS_STATES.union(FAILURE_STATES) - template_fields: Sequence[str] = ( 'application_id', 'job_run_id', @@ -144,7 +139,7 @@ def __init__( *, application_id: str, job_run_id: str, - target_states: set | frozenset = frozenset(SUCCESS_STATES), + target_states: set | frozenset = frozenset(EmrServerlessHook.JOB_SUCCESS_STATES), aws_conn_id: str = 'aws_default', **kwargs: Any, ) -> None: @@ -159,7 +154,7 @@ def poke(self, context: Context) -> bool: state = response['jobRun']['state'] - if state in self.FAILURE_STATES: + if state in EmrServerlessHook.JOB_FAILURE_STATES: failure_message = f"EMR Serverless job failed: {self.failure_message_from_response(response)}" raise AirflowException(failure_message) @@ -198,15 +193,11 @@ class EmrServerlessApplicationSensor(BaseSensorOperator): template_fields: Sequence[str] = ('application_id',) - INTERMEDIATE_STATES = {'CREATING', 'STARTING', 'STOPPING'} - FAILURE_STATES = {'STOPPED', 'TERMINATED'} - SUCCESS_STATES = {'CREATED', 'STARTED'} - def __init__( self, *, application_id: str, - target_states: set | frozenset = frozenset(SUCCESS_STATES), + target_states: set | frozenset = frozenset(EmrServerlessHook.APPLICATION_SUCCESS_STATES), aws_conn_id: str = 'aws_default', **kwargs: Any, ) -> None: @@ -220,7 +211,7 @@ def poke(self, context: Context) -> bool: state = response['application']['state'] - if state in self.FAILURE_STATES: + if state in EmrServerlessHook.APPLICATION_FAILURE_STATES: failure_message = f"EMR Serverless job failed: {self.failure_message_from_response(response)}" raise AirflowException(failure_message) diff --git a/tests/providers/amazon/aws/operators/test_emr_serverless.py b/tests/providers/amazon/aws/operators/test_emr_serverless.py index 688cc2711b317..a175bc8d1081e 100644 --- a/tests/providers/amazon/aws/operators/test_emr_serverless.py +++ b/tests/providers/amazon/aws/operators/test_emr_serverless.py @@ -220,6 +220,38 @@ def test_job_run_app_started(self, mock_conn, mock_waiter): configurationOverrides=configuration_overrides, ) + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") + def test_job_run_job_failed(self, mock_conn): + mock_conn.get_application.return_value = {"application": {"state": "STARTED"}} + mock_conn.start_job_run.return_value = { + 'jobRunId': job_run_id, + 'ResponseMetadata': {'HTTPStatusCode': 200}, + } + + mock_conn.get_job_run.return_value = {'jobRun': {'state': 'FAILED'}} + + operator = EmrServerlessStartJobOperator( + task_id=task_id, + client_request_token=client_request_token, + application_id=application_id, + execution_role_arn=execution_role_arn, + job_driver=job_driver, + configuration_overrides=configuration_overrides, + ) + with pytest.raises(AirflowException) as ex_message: + id = operator.execute(None) + assert id == job_run_id + assert "Job reached failure state FAILED." in str(ex_message.value) + mock_conn.get_application.assert_called_once_with(applicationId=application_id) + mock_conn.get_job_run.assert_called_once_with(applicationId=application_id, jobRunId=job_run_id) + mock_conn.start_job_run.assert_called_once_with( + clientToken=client_request_token, + applicationId=application_id, + executionRoleArn=execution_role_arn, + jobDriver=job_driver, + configurationOverrides=configuration_overrides, + ) + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter") @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn") def test_job_run_app_not_started(self, mock_conn, mock_waiter):