diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index c13c622937ffc..fb2f5de47849f 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -1382,30 +1382,30 @@ def execute(self, context: Context, event: dict[str, Any] | None = None) -> str self.persist_links(context) - if self.deferrable: - self.defer( - trigger=EmrServerlessStartJobTrigger( - application_id=self.application_id, - job_id=self.job_id, - waiter_delay=self.waiter_delay, - waiter_max_attempts=self.waiter_max_attempts, - aws_conn_id=self.aws_conn_id, - ), - method_name="execute_complete", - timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay), - ) - if self.wait_for_completion: - waiter = self.hook.get_waiter("serverless_job_completed") - wait( - waiter=waiter, - waiter_max_attempts=self.waiter_max_attempts, - waiter_delay=self.waiter_delay, - args={"applicationId": self.application_id, "jobRunId": self.job_id}, - failure_message="Serverless Job failed", - status_message="Serverless Job status is", - status_args=["jobRun.state", "jobRun.stateDetails"], - ) + if self.deferrable: + self.defer( + trigger=EmrServerlessStartJobTrigger( + application_id=self.application_id, + job_id=self.job_id, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + aws_conn_id=self.aws_conn_id, + ), + method_name="execute_complete", + timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay), + ) + else: + waiter = self.hook.get_waiter("serverless_job_completed") + wait( + waiter=waiter, + waiter_max_attempts=self.waiter_max_attempts, + waiter_delay=self.waiter_delay, + args={"applicationId": self.application_id, "jobRunId": self.job_id}, + failure_message="Serverless Job failed", + status_message="Serverless Job status is", + status_args=["jobRun.state", "jobRun.stateDetails"], + ) return self.job_id diff --git a/tests/providers/amazon/aws/operators/test_emr_serverless.py b/tests/providers/amazon/aws/operators/test_emr_serverless.py index 4804f2286993c..12c5cc938018e 100644 --- a/tests/providers/amazon/aws/operators/test_emr_serverless.py +++ b/tests/providers/amazon/aws/operators/test_emr_serverless.py @@ -836,6 +836,27 @@ def test_start_job_deferrable(self, mock_conn): with pytest.raises(TaskDeferred): operator.execute(self.mock_context) + @mock.patch.object(EmrServerlessHook, "conn") + def test_start_job_deferrable_without_wait_for_completion(self, mock_conn): + mock_conn.get_application.return_value = {"application": {"state": "STARTED"}} + mock_conn.start_job_run.return_value = { + "jobRunId": job_run_id, + "ResponseMetadata": {"HTTPStatusCode": 200}, + } + operator = EmrServerlessStartJobOperator( + task_id=task_id, + application_id=application_id, + execution_role_arn=execution_role_arn, + job_driver=job_driver, + configuration_overrides=configuration_overrides, + deferrable=True, + wait_for_completion=False, + ) + + result = operator.execute(self.mock_context) + + assert result == job_run_id + @mock.patch.object(EmrServerlessHook, "get_waiter") @mock.patch.object(EmrServerlessHook, "conn") def test_start_job_deferrable_app_not_started(self, mock_conn, mock_get_waiter):