Skip to content

Commit

Permalink
Fix EmrServerlessStartJobOperator (#41103)
Browse files Browse the repository at this point in the history
  • Loading branch information
vincbeck authored Jul 29, 2024
1 parent 36b9234 commit 97c4fdc
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 23 deletions.
46 changes: 23 additions & 23 deletions airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1382,30 +1382,30 @@ def execute(self, context: Context, event: dict[str, Any] | None = None) -> str

self.persist_links(context)

if self.deferrable:
self.defer(
trigger=EmrServerlessStartJobTrigger(
application_id=self.application_id,
job_id=self.job_id,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
),
method_name="execute_complete",
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
)

if self.wait_for_completion:
waiter = self.hook.get_waiter("serverless_job_completed")
wait(
waiter=waiter,
waiter_max_attempts=self.waiter_max_attempts,
waiter_delay=self.waiter_delay,
args={"applicationId": self.application_id, "jobRunId": self.job_id},
failure_message="Serverless Job failed",
status_message="Serverless Job status is",
status_args=["jobRun.state", "jobRun.stateDetails"],
)
if self.deferrable:
self.defer(
trigger=EmrServerlessStartJobTrigger(
application_id=self.application_id,
job_id=self.job_id,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
),
method_name="execute_complete",
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
)
else:
waiter = self.hook.get_waiter("serverless_job_completed")
wait(
waiter=waiter,
waiter_max_attempts=self.waiter_max_attempts,
waiter_delay=self.waiter_delay,
args={"applicationId": self.application_id, "jobRunId": self.job_id},
failure_message="Serverless Job failed",
status_message="Serverless Job status is",
status_args=["jobRun.state", "jobRun.stateDetails"],
)

return self.job_id

Expand Down
21 changes: 21 additions & 0 deletions tests/providers/amazon/aws/operators/test_emr_serverless.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,27 @@ def test_start_job_deferrable(self, mock_conn):
with pytest.raises(TaskDeferred):
operator.execute(self.mock_context)

@mock.patch.object(EmrServerlessHook, "conn")
def test_start_job_deferrable_without_wait_for_completion(self, mock_conn):
mock_conn.get_application.return_value = {"application": {"state": "STARTED"}}
mock_conn.start_job_run.return_value = {
"jobRunId": job_run_id,
"ResponseMetadata": {"HTTPStatusCode": 200},
}
operator = EmrServerlessStartJobOperator(
task_id=task_id,
application_id=application_id,
execution_role_arn=execution_role_arn,
job_driver=job_driver,
configuration_overrides=configuration_overrides,
deferrable=True,
wait_for_completion=False,
)

result = operator.execute(self.mock_context)

assert result == job_run_id

@mock.patch.object(EmrServerlessHook, "get_waiter")
@mock.patch.object(EmrServerlessHook, "conn")
def test_start_job_deferrable_app_not_started(self, mock_conn, mock_get_waiter):
Expand Down

0 comments on commit 97c4fdc

Please sign in to comment.