Skip to content

Commit

Permalink
DataflowStartFlexTemplateOperator. Check for Dataflow job type each c…
Browse files Browse the repository at this point in the history
…heck cycle. (#40584)



---------

Co-authored-by: Oleksandr Tkachov <oleksandr.tkachov@medecision.com>
  • Loading branch information
theopus and Oleksandr Tkachov authored Jul 4, 2024
1 parent ce12553 commit 93488d0
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 24 deletions.
50 changes: 26 additions & 24 deletions airflow/providers/google/cloud/hooks/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,32 +419,34 @@ def _check_dataflow_job_state(self, job) -> bool:
current_state = job["currentState"]
is_streaming = job.get("type") == DataflowJobType.JOB_TYPE_STREAMING

if self._expected_terminal_state is None:
current_expected_state = self._expected_terminal_state

if current_expected_state is None:
if is_streaming:
self._expected_terminal_state = DataflowJobStatus.JOB_STATE_RUNNING
current_expected_state = DataflowJobStatus.JOB_STATE_RUNNING
else:
self._expected_terminal_state = DataflowJobStatus.JOB_STATE_DONE
else:
terminal_states = DataflowJobStatus.TERMINAL_STATES | {DataflowJobStatus.JOB_STATE_RUNNING}
if self._expected_terminal_state not in terminal_states:
raise AirflowException(
f"Google Cloud Dataflow job's expected terminal state "
f"'{self._expected_terminal_state}' is invalid."
f" The value should be any of the following: {terminal_states}"
)
elif is_streaming and self._expected_terminal_state == DataflowJobStatus.JOB_STATE_DONE:
raise AirflowException(
"Google Cloud Dataflow job's expected terminal state cannot be "
"JOB_STATE_DONE while it is a streaming job"
)
elif not is_streaming and self._expected_terminal_state == DataflowJobStatus.JOB_STATE_DRAINED:
raise AirflowException(
"Google Cloud Dataflow job's expected terminal state cannot be "
"JOB_STATE_DRAINED while it is a batch job"
)
current_expected_state = DataflowJobStatus.JOB_STATE_DONE

terminal_states = DataflowJobStatus.TERMINAL_STATES | {DataflowJobStatus.JOB_STATE_RUNNING}
if current_expected_state not in terminal_states:
raise AirflowException(
f"Google Cloud Dataflow job's expected terminal state "
f"'{current_expected_state}' is invalid."
f" The value should be any of the following: {terminal_states}"
)
elif is_streaming and current_expected_state == DataflowJobStatus.JOB_STATE_DONE:
raise AirflowException(
"Google Cloud Dataflow job's expected terminal state cannot be "
"JOB_STATE_DONE while it is a streaming job"
)
elif not is_streaming and current_expected_state == DataflowJobStatus.JOB_STATE_DRAINED:
raise AirflowException(
"Google Cloud Dataflow job's expected terminal state cannot be "
"JOB_STATE_DRAINED while it is a batch job"
)

if current_state == self._expected_terminal_state:
if self._expected_terminal_state == DataflowJobStatus.JOB_STATE_RUNNING:
if current_state == current_expected_state:
if current_expected_state == DataflowJobStatus.JOB_STATE_RUNNING:
return not self._wait_until_finished
return True

Expand All @@ -454,7 +456,7 @@ def _check_dataflow_job_state(self, job) -> bool:
self.log.debug("Current job: %s", job)
raise AirflowException(
f"Google Cloud Dataflow job {job['name']} is in an unexpected terminal state: {current_state}, "
f"expected terminal state: {self._expected_terminal_state}"
f"expected terminal state: {current_expected_state}"
)

def wait_for_done(self) -> None:
Expand Down
72 changes: 72 additions & 0 deletions tests/providers/google/cloud/hooks/test_dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1498,6 +1498,78 @@ def test_check_dataflow_job_state_wait_until_finished(
result = dataflow_job._check_dataflow_job_state(job)
assert result == expected_result

@pytest.mark.parametrize(
"jobs, wait_until_finished, expected_result",
[
# STREAMING
(
[
(None, DataflowJobStatus.JOB_STATE_QUEUED),
(None, DataflowJobStatus.JOB_STATE_PENDING),
(DataflowJobType.JOB_TYPE_STREAMING, DataflowJobStatus.JOB_STATE_RUNNING),
],
None,
True,
),
(
[
(None, DataflowJobStatus.JOB_STATE_QUEUED),
(None, DataflowJobStatus.JOB_STATE_PENDING),
(DataflowJobType.JOB_TYPE_STREAMING, DataflowJobStatus.JOB_STATE_RUNNING),
],
True,
False,
),
# BATCH
(
[
(None, DataflowJobStatus.JOB_STATE_QUEUED),
(None, DataflowJobStatus.JOB_STATE_PENDING),
(DataflowJobType.JOB_TYPE_BATCH, DataflowJobStatus.JOB_STATE_RUNNING),
],
False,
True,
),
(
[
(None, DataflowJobStatus.JOB_STATE_QUEUED),
(None, DataflowJobStatus.JOB_STATE_PENDING),
(DataflowJobType.JOB_TYPE_BATCH, DataflowJobStatus.JOB_STATE_RUNNING),
],
None,
False,
),
(
[
(None, DataflowJobStatus.JOB_STATE_QUEUED),
(None, DataflowJobStatus.JOB_STATE_PENDING),
(DataflowJobType.JOB_TYPE_BATCH, DataflowJobStatus.JOB_STATE_DONE),
],
None,
True,
),
],
)
def test_check_dataflow_job_state_without_job_type_changed_on_terminal_state(
self, jobs, wait_until_finished, expected_result
):
dataflow_job = _DataflowJobsController(
dataflow=self.mock_dataflow,
project_number=TEST_PROJECT,
name="name-",
location=TEST_LOCATION,
poll_sleep=0,
job_id=None,
num_retries=20,
multiple_jobs=True,
wait_until_finished=wait_until_finished,
)
result = False
for current_job in jobs:
job = {"id": "id-2", "name": "name-2", "type": current_job[0], "currentState": current_job[1]}
result = dataflow_job._check_dataflow_job_state(job)
assert result == expected_result

@pytest.mark.parametrize(
"job_state, wait_until_finished, expected_result",
[
Expand Down

0 comments on commit 93488d0

Please sign in to comment.