diff --git a/providers/apache/beam/src/airflow/providers/apache/beam/hooks/beam.py b/providers/apache/beam/src/airflow/providers/apache/beam/hooks/beam.py index 1e2364e9197b0..2ac4ed172d252 100644 --- a/providers/apache/beam/src/airflow/providers/apache/beam/hooks/beam.py +++ b/providers/apache/beam/src/airflow/providers/apache/beam/hooks/beam.py @@ -38,7 +38,6 @@ from airflow.exceptions import AirflowConfigException, AirflowException from airflow.hooks.base import BaseHook from airflow.providers.common.compat.standard.utils import prepare_virtualenv -from airflow.providers.google.go_module_utils import init_module, install_dependencies if TYPE_CHECKING: import logging @@ -377,6 +376,16 @@ def start_go_pipeline( "'https://airflow.apache.org/docs/docker-stack/recipes.html'." ) + try: + from airflow.providers.google.go_module_utils import init_module, install_dependencies + except ImportError: + from airflow.exceptions import AirflowOptionalProviderFeatureException + + raise AirflowOptionalProviderFeatureException( + "Failed to import apache-airflow-google-provider. To start a go pipeline, please install the" + " google provider." + ) + if "labels" in variables: variables["labels"] = json.dumps(variables["labels"], separators=(",", ":")) diff --git a/providers/apache/beam/src/airflow/providers/apache/beam/operators/beam.py b/providers/apache/beam/src/airflow/providers/apache/beam/operators/beam.py index adf76a0ab6610..fccb4aec3c994 100644 --- a/providers/apache/beam/src/airflow/providers/apache/beam/operators/beam.py +++ b/providers/apache/beam/src/airflow/providers/apache/beam/operators/beam.py @@ -30,23 +30,14 @@ from functools import partial from typing import TYPE_CHECKING, Any, Callable +from packaging.version import parse as parse_version + from airflow.configuration import conf -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowOptionalProviderFeatureException from airflow.models import BaseOperator from airflow.providers.apache.beam.hooks.beam import BeamHook, BeamRunnerType from airflow.providers.apache.beam.triggers.beam import BeamJavaPipelineTrigger, BeamPythonPipelineTrigger -from airflow.providers.google.cloud.hooks.dataflow import ( - DEFAULT_DATAFLOW_LOCATION, - DataflowHook, - DataflowJobStatus, - process_line_and_extract_dataflow_job_id_callback, -) -from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url -from airflow.providers.google.cloud.links.dataflow import DataflowJobLink -from airflow.providers.google.cloud.operators.dataflow import CheckJobRunning, DataflowConfiguration -from airflow.providers.google.cloud.triggers.dataflow import ( - DataflowJobStatusTrigger, -) +from airflow.providers_manager import ProvidersManager from airflow.utils.helpers import convert_camel_to_snake, exactly_one from airflow.version import version @@ -54,6 +45,26 @@ from airflow.utils.context import Context +try: + from airflow.providers.google.cloud.hooks.dataflow import ( + DEFAULT_DATAFLOW_LOCATION, + DataflowHook, + process_line_and_extract_dataflow_job_id_callback, + ) + from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url + from airflow.providers.google.cloud.links.dataflow import DataflowJobLink + from airflow.providers.google.cloud.operators.dataflow import CheckJobRunning, DataflowConfiguration + from airflow.providers.google.cloud.triggers.dataflow import ( + DataflowJobStateCompleteTrigger, + DataflowJobStatus, + DataflowJobStatusTrigger, + ) + + GOOGLE_PROVIDER_VERSION = ProvidersManager().providers["apache-airflow-providers-google"].version +except ImportError: + GOOGLE_PROVIDER_VERSION = "" + + class BeamDataflowMixin(metaclass=ABCMeta): """ Helper class to store common, Dataflow specific logic for both. @@ -68,6 +79,13 @@ class BeamDataflowMixin(metaclass=ABCMeta): gcp_conn_id: str dataflow_support_impersonation: bool = True + def __init__(self): + if not GOOGLE_PROVIDER_VERSION: + raise AirflowOptionalProviderFeatureException( + "Failed to import apache-airflow-google-provider. To use the dataflow service please install " + "the appropriate version of the google provider." + ) + def _set_dataflow( self, pipeline_options: dict, @@ -319,7 +337,7 @@ class BeamRunPythonPipelineOperator(BeamBasePipelineOperator): "dataflow_config", ) template_fields_renderers = {"dataflow_config": "json", "pipeline_options": "json"} - operator_extra_links = (DataflowJobLink(),) + operator_extra_links = (DataflowJobLink(),) if GOOGLE_PROVIDER_VERSION else () def __init__( self, @@ -423,22 +441,37 @@ def execute_on_dataflow(self, context: Context): process_line_callback=self.process_line_callback, is_dataflow_job_id_exist_callback=self.is_dataflow_job_id_exist_callback, ) + + location = self.dataflow_config.location or DEFAULT_DATAFLOW_LOCATION DataflowJobLink.persist( self, context, self.dataflow_config.project_id, - self.dataflow_config.location, + location, self.dataflow_job_id, ) + if self.deferrable: - self.defer( - trigger=DataflowJobStatusTrigger( - job_id=self.dataflow_job_id, + trigger_args = { + "job_id": self.dataflow_job_id, + "project_id": self.dataflow_config.project_id, + "location": location, + "gcp_conn_id": self.gcp_conn_id, + } + trigger: DataflowJobStatusTrigger | DataflowJobStateCompleteTrigger + if parse_version(GOOGLE_PROVIDER_VERSION) < parse_version("16.0.0"): + trigger = DataflowJobStatusTrigger( expected_statuses={DataflowJobStatus.JOB_STATE_DONE}, - project_id=self.dataflow_config.project_id, - location=self.dataflow_config.location or DEFAULT_DATAFLOW_LOCATION, - gcp_conn_id=self.gcp_conn_id, - ), + **trigger_args, + ) + else: + trigger = DataflowJobStateCompleteTrigger( + wait_until_finished=self.dataflow_config.wait_until_finished, + **trigger_args, + ) + + self.defer( + trigger=trigger, method_name="execute_complete", ) self.dataflow_hook.wait_for_done( @@ -498,7 +531,7 @@ class BeamRunJavaPipelineOperator(BeamBasePipelineOperator): template_fields_renderers = {"dataflow_config": "json", "pipeline_options": "json"} ui_color = "#0273d4" - operator_extra_links = (DataflowJobLink(),) + operator_extra_links = (DataflowJobLink(),) if GOOGLE_PROVIDER_VERSION else () def __init__( self, @@ -601,16 +634,29 @@ def execute_on_dataflow(self, context: Context): self.dataflow_job_id, ) if self.deferrable: - self.defer( - trigger=DataflowJobStatusTrigger( - job_id=self.dataflow_job_id, + trigger_args = { + "job_id": self.dataflow_job_id, + "project_id": self.dataflow_config.project_id, + "location": self.dataflow_config.location, + "gcp_conn_id": self.gcp_conn_id, + } + trigger: DataflowJobStatusTrigger | DataflowJobStateCompleteTrigger + if parse_version(GOOGLE_PROVIDER_VERSION) < parse_version("16.0.0"): + trigger = DataflowJobStatusTrigger( expected_statuses={DataflowJobStatus.JOB_STATE_DONE}, - project_id=self.dataflow_config.project_id, - location=self.dataflow_config.location, - gcp_conn_id=self.gcp_conn_id, - ), + **trigger_args, + ) + else: + trigger = DataflowJobStateCompleteTrigger( + wait_until_finished=self.dataflow_config.wait_until_finished, + **trigger_args, + ) + + self.defer( + trigger=trigger, method_name="execute_complete", ) + multiple_jobs = self.dataflow_config.multiple_jobs or False self.dataflow_hook.wait_for_done( job_name=self.dataflow_job_name, @@ -676,7 +722,7 @@ class BeamRunGoPipelineOperator(BeamBasePipelineOperator): "dataflow_config", ] template_fields_renderers = {"dataflow_config": "json", "pipeline_options": "json"} - operator_extra_links = (DataflowJobLink(),) + operator_extra_links = (DataflowJobLink(),) if GOOGLE_PROVIDER_VERSION else () def __init__( self, diff --git a/providers/apache/beam/src/airflow/providers/apache/beam/triggers/beam.py b/providers/apache/beam/src/airflow/providers/apache/beam/triggers/beam.py index d4a6d7c9ac304..18ed42578c42a 100644 --- a/providers/apache/beam/src/airflow/providers/apache/beam/triggers/beam.py +++ b/providers/apache/beam/src/airflow/providers/apache/beam/triggers/beam.py @@ -22,7 +22,6 @@ from typing import IO, Any from airflow.providers.apache.beam.hooks.beam import BeamAsyncHook -from airflow.providers.google.cloud.hooks.gcs import GCSHook from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -33,6 +32,37 @@ class BeamPipelineBaseTrigger(BaseTrigger): def _get_async_hook(*args, **kwargs) -> BeamAsyncHook: return BeamAsyncHook(*args, **kwargs) + @staticmethod + def file_has_gcs_path(file_path: str): + return file_path.lower().startswith("gs://") + + @staticmethod + async def provide_gcs_tempfile(gcs_file, gcp_conn_id): + try: + from airflow.providers.google.cloud.hooks.gcs import GCSHook + except ImportError: + from airflow.exceptions import AirflowOptionalProviderFeatureException + + raise AirflowOptionalProviderFeatureException( + "Failed to import GCSHook. To use the GCSHook functionality, please install the " + "apache-airflow-google-provider." + ) + + gcs_hook = GCSHook(gcp_conn_id=gcp_conn_id) + loop = asyncio.get_running_loop() + + # Running synchronous `enter_context()` method in a separate + # thread using the default executor `None`. The `run_in_executor()` function returns the + # file object, which is created using gcs function `provide_file()`, asynchronously. + # This means we can perform asynchronous operations with this file. + create_tmp_file_call = gcs_hook.provide_file(object_url=gcs_file) + tmp_gcs_file: IO[str] = await loop.run_in_executor( + None, + contextlib.ExitStack().enter_context, # type: ignore[arg-type] + create_tmp_file_call, + ) + return tmp_gcs_file + class BeamPythonPipelineTrigger(BeamPipelineBaseTrigger): """ @@ -101,20 +131,8 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] hook = self._get_async_hook(runner=self.runner) try: - # Get the current running event loop to manage I/O operations asynchronously - loop = asyncio.get_running_loop() - if self.py_file.lower().startswith("gs://"): - gcs_hook = GCSHook(gcp_conn_id=self.gcp_conn_id) - # Running synchronous `enter_context()` method in a separate - # thread using the default executor `None`. The `run_in_executor()` function returns the - # file object, which is created using gcs function `provide_file()`, asynchronously. - # This means we can perform asynchronous operations with this file. - create_tmp_file_call = gcs_hook.provide_file(object_url=self.py_file) - tmp_gcs_file: IO[str] = await loop.run_in_executor( - None, - contextlib.ExitStack().enter_context, # type: ignore[arg-type] - create_tmp_file_call, - ) + if self.file_has_gcs_path(self.py_file): + tmp_gcs_file = await self.provide_gcs_tempfile(self.py_file, self.gcp_conn_id) self.py_file = tmp_gcs_file.name return_code = await hook.start_python_pipeline_async( @@ -188,20 +206,8 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] hook = self._get_async_hook(runner=self.runner) return_code = 0 try: - # Get the current running event loop to manage I/O operations asynchronously - loop = asyncio.get_running_loop() - if self.jar.lower().startswith("gs://"): - gcs_hook = GCSHook(self.gcp_conn_id) - # Running synchronous `enter_context()` method in a separate - # thread using the default executor `None`. The `run_in_executor()` function returns the - # file object, which is created using gcs function `provide_file()`, asynchronously. - # This means we can perform asynchronous operations with this file. - create_tmp_file_call = gcs_hook.provide_file(object_url=self.jar) - tmp_gcs_file: IO[str] = await loop.run_in_executor( - None, - contextlib.ExitStack().enter_context, # type: ignore[arg-type] - create_tmp_file_call, - ) + if self.file_has_gcs_path(self.jar): + tmp_gcs_file = await self.provide_gcs_tempfile(self.jar, self.gcp_conn_id) self.jar = tmp_gcs_file.name return_code = await hook.start_java_pipeline_async( diff --git a/providers/apache/beam/tests/unit/apache/beam/triggers/test_beam.py b/providers/apache/beam/tests/unit/apache/beam/triggers/test_beam.py index 5803f01dbd14e..aed634c223be7 100644 --- a/providers/apache/beam/tests/unit/apache/beam/triggers/test_beam.py +++ b/providers/apache/beam/tests/unit/apache/beam/triggers/test_beam.py @@ -134,16 +134,17 @@ async def test_beam_trigger_exception_should_execute_successfully( assert TriggerEvent({"status": "error", "message": "Test exception"}) == actual @pytest.mark.asyncio - @mock.patch("airflow.providers.apache.beam.triggers.beam.GCSHook") - async def test_beam_trigger_gcs_provide_file_should_execute_successfully(self, gcs_hook, python_trigger): + async def test_beam_trigger_gcs_provide_file_should_execute_successfully(self, python_trigger): """ Test that BeamPythonPipelineTrigger downloads GCS provide file correct. """ - gcs_provide_file = gcs_hook.return_value.provide_file python_trigger.py_file = TEST_GCS_PY_FILE - generator = python_trigger.run() - await generator.asend(None) - gcs_provide_file.assert_called_once_with(object_url=TEST_GCS_PY_FILE) + with mock.patch("airflow.providers.google.cloud.hooks.gcs.GCSHook") as mock_gcs_hook: + mock_gcs_hook.return_value.provide_file.return_value = "mocked_temp_file" + generator = python_trigger.run() + await generator.asend(None) + mock_gcs_hook.assert_called_once_with(gcp_conn_id=python_trigger.gcp_conn_id) + mock_gcs_hook.return_value.provide_file.assert_called_once_with(object_url=TEST_GCS_PY_FILE) class TestBeamJavaPipelineTrigger: @@ -210,13 +211,15 @@ async def test_beam_trigger_exception_should_execute_successfully( assert TriggerEvent({"status": "error", "message": "Test exception"}) == actual @pytest.mark.asyncio - @mock.patch("airflow.providers.apache.beam.triggers.beam.GCSHook") - async def test_beam_trigger_gcs_provide_file_should_execute_successfully(self, gcs_hook, java_trigger): + async def test_beam_trigger_gcs_provide_file_should_execute_successfully(self, java_trigger): """ Test that BeamJavaPipelineTrigger downloads GCS provide file correct. """ - gcs_provide_file = gcs_hook.return_value.provide_file java_trigger.jar = TEST_GCS_JAR_FILE - generator = java_trigger.run() - await generator.asend(None) - gcs_provide_file.assert_called_once_with(object_url=TEST_GCS_JAR_FILE) + + with mock.patch("airflow.providers.google.cloud.hooks.gcs.GCSHook") as mock_gcs_hook: + mock_gcs_hook.return_value.provide_file.return_value = "mocked_temp_file" + generator = java_trigger.run() + await generator.asend(None) + mock_gcs_hook.assert_called_once_with(gcp_conn_id=java_trigger.gcp_conn_id) + mock_gcs_hook.return_value.provide_file.assert_called_once_with(object_url=TEST_GCS_JAR_FILE) diff --git a/providers/google/docs/operators/cloud/dataflow.rst b/providers/google/docs/operators/cloud/dataflow.rst index cd94b2ed2aa15..c8046eab49b21 100644 --- a/providers/google/docs/operators/cloud/dataflow.rst +++ b/providers/google/docs/operators/cloud/dataflow.rst @@ -146,6 +146,15 @@ Here is an example of creating and running a streaming pipeline in Java with jar :start-after: [START howto_operator_start_java_streaming] :end-before: [END howto_operator_start_java_streaming] +Here is an Java dataflow streaming pipeline example in deferrable_mode : + +.. exampleinclude:: /../../google/tests/system/google/cloud/dataflow/example_dataflow_java_streaming.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_start_java_streaming_deferrable] + :end-before: [END howto_operator_start_java_streaming_deferrable] + + .. _howto/operator:PythonSDKPipelines: Python SDK pipelines @@ -232,6 +241,15 @@ source, such as Pub/Sub, in your pipeline (for Java). :start-after: [START howto_operator_start_streaming_python_job] :end-before: [END howto_operator_start_streaming_python_job] +Deferrable mode: + +.. exampleinclude:: /../../google/tests/system/google/cloud/dataflow/example_dataflow_streaming_python.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_start_streaming_python_job_deferrable] + :end-before: [END howto_operator_start_streaming_python_job_deferrable] + + Setting argument ``drain_pipeline`` to ``True`` allows to stop streaming job by draining it instead of canceling during killing task instance. diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/dataflow.py b/providers/google/src/airflow/providers/google/cloud/hooks/dataflow.py index 2493bb91d9b0a..a9b7e07977a7b 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/dataflow.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/dataflow.py @@ -185,7 +185,67 @@ class DataflowJobType: JOB_TYPE_STREAMING = "JOB_TYPE_STREAMING" -class _DataflowJobsController(LoggingMixin): +class DataflowJobTerminalStateHelper(LoggingMixin): + """Helper to define and validate the dataflow job terminal state.""" + + @staticmethod + def expected_terminal_state_is_allowed(expected_terminal_state): + job_allowed_terminal_states = DataflowJobStatus.TERMINAL_STATES | { + DataflowJobStatus.JOB_STATE_RUNNING + } + if expected_terminal_state not in job_allowed_terminal_states: + raise AirflowException( + f"Google Cloud Dataflow job's expected terminal state " + f"'{expected_terminal_state}' is invalid." + f" The value should be any of the following: {job_allowed_terminal_states}" + ) + return True + + @staticmethod + def expected_terminal_state_is_valid_for_job_type(expected_terminal_state, is_streaming: bool): + if is_streaming: + invalid_terminal_state = DataflowJobStatus.JOB_STATE_DONE + job_type = "streaming" + else: + invalid_terminal_state = DataflowJobStatus.JOB_STATE_DRAINED + job_type = "batch" + + if expected_terminal_state == invalid_terminal_state: + raise AirflowException( + f"Google Cloud Dataflow job's expected terminal state cannot be {invalid_terminal_state} while it is a {job_type} job" + ) + return True + + def job_reached_terminal_state(self, job, wait_until_finished=None, custom_terminal_state=None) -> bool: + """ + Check the job reached terminal state, if job failed raise exception. + + :return: True if job is done. + :raise: Exception + """ + current_state = job["currentState"] + is_streaming = job.get("type") == DataflowJobType.JOB_TYPE_STREAMING + expected_terminal_state = ( + DataflowJobStatus.JOB_STATE_RUNNING if is_streaming else DataflowJobStatus.JOB_STATE_DONE + ) + if custom_terminal_state is not None: + expected_terminal_state = custom_terminal_state + self.expected_terminal_state_is_allowed(expected_terminal_state) + self.expected_terminal_state_is_valid_for_job_type(expected_terminal_state, is_streaming=is_streaming) + if current_state == expected_terminal_state: + if expected_terminal_state == DataflowJobStatus.JOB_STATE_RUNNING and wait_until_finished: + return False + return True + if current_state in DataflowJobStatus.AWAITING_STATES: + return wait_until_finished is False + self.log.debug("Current job: %s", job) + raise AirflowException( + f"Google Cloud Dataflow job {job['name']} is in an unexpected terminal state: {current_state}, " + f"expected terminal state: {expected_terminal_state}" + ) + + +class _DataflowJobsController(DataflowJobTerminalStateHelper): """ Interface for communication with Google Cloud Dataflow API. @@ -462,7 +522,10 @@ def wait_for_done(self) -> None: """Wait for result of submitted job.""" self.log.info("Start waiting for done.") self._refresh_jobs() - while self._jobs and not all(self._check_dataflow_job_state(job) for job in self._jobs): + while self._jobs and not all( + self.job_reached_terminal_state(job, self._wait_until_finished, self._expected_terminal_state) + for job in self._jobs + ): self.log.info("Waiting for done. Sleep %s s", self._poll_sleep) time.sleep(self._poll_sleep) self._refresh_jobs() @@ -1295,8 +1358,7 @@ def is_job_done(self, location: str, project_id: str, job_id: str) -> bool: location=location, ) job = job_controller.fetch_job_by_id(job_id) - - return job_controller._check_dataflow_job_state(job) + return job_controller.job_reached_terminal_state(job) @GoogleBaseHook.fallback_to_default_project_id def create_data_pipeline( @@ -1425,7 +1487,7 @@ def build_parent_name(project_id: str, location: str): return f"projects/{project_id}/locations/{location}" -class AsyncDataflowHook(GoogleBaseAsyncHook): +class AsyncDataflowHook(GoogleBaseAsyncHook, DataflowJobTerminalStateHelper): """Async hook class for dataflow service.""" sync_hook_class = DataflowHook diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/dataflow.py b/providers/google/src/airflow/providers/google/cloud/triggers/dataflow.py index e368949ce3d51..857e9bea3f8b9 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/dataflow.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/dataflow.py @@ -788,3 +788,125 @@ def async_hook(self) -> AsyncDataflowHook: poll_sleep=self.poll_sleep, impersonation_chain=self.impersonation_chain, ) + + +class DataflowJobStateCompleteTrigger(BaseTrigger): + """ + Trigger that monitors if a Dataflow job has reached any of successful terminal state meant for that job. + + :param job_id: Required. ID of the job. + :param project_id: Required. The Google Cloud project ID in which the job was started. + :param location: Optional. The location where the job is executed. If set to None then + the value of DEFAULT_DATAFLOW_LOCATION will be used. + :param wait_until_finished: Optional. Dataflow option to block pipeline until completion. + :param gcp_conn_id: The connection ID to use for connecting to Google Cloud. + :param poll_sleep: Time (seconds) to wait between two consecutive calls to check the job. + :param impersonation_chain: Optional. Service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + def __init__( + self, + job_id: str, + project_id: str | None, + location: str = DEFAULT_DATAFLOW_LOCATION, + wait_until_finished: bool | None = None, + gcp_conn_id: str = "google_cloud_default", + poll_sleep: int = 10, + impersonation_chain: str | Sequence[str] | None = None, + ): + super().__init__() + self.job_id = job_id + self.project_id = project_id + self.location = location + self.wait_until_finished = wait_until_finished + self.gcp_conn_id = gcp_conn_id + self.poll_sleep = poll_sleep + self.impersonation_chain = impersonation_chain + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize class arguments and classpath.""" + return ( + "airflow.providers.google.cloud.triggers.dataflow.DataflowJobStateCompleteTrigger", + { + "job_id": self.job_id, + "project_id": self.project_id, + "location": self.location, + "wait_until_finished": self.wait_until_finished, + "gcp_conn_id": self.gcp_conn_id, + "poll_sleep": self.poll_sleep, + "impersonation_chain": self.impersonation_chain, + }, + ) + + async def run(self): + """ + Loop until the job reaches successful final or error state. + + Yields a TriggerEvent with success status, if the job reaches successful state for own type. + + Yields a TriggerEvent with error status, if the client returns an unexpected terminal + job status or any exception is raised while looping. + + In any other case the Trigger will wait for a specified amount of time + stored in self.poll_sleep variable. + """ + try: + while True: + job = await self.async_hook.get_job( + project_id=self.project_id, + job_id=self.job_id, + location=self.location, + ) + job_state = job.current_state.name + job_type_name = job.type_.name + + FAILED_STATES = DataflowJobStatus.FAILED_END_STATES | {DataflowJobStatus.JOB_STATE_DRAINED} + if job_state in FAILED_STATES: + yield TriggerEvent( + { + "status": "error", + "message": ( + f"Job with id '{self.job_id}' is in failed terminal state: {job_state}" + ), + } + ) + return + + if self.async_hook.job_reached_terminal_state( + job={"id": self.job_id, "currentState": job_state, "type": job_type_name}, + wait_until_finished=self.wait_until_finished, + ): + yield TriggerEvent( + { + "status": "success", + "message": ( + f"Job with id '{self.job_id}' has reached successful final state: {job_state}" + ), + } + ) + return + self.log.info("Sleeping for %s seconds.", self.poll_sleep) + await asyncio.sleep(self.poll_sleep) + except Exception as e: + self.log.error("Exception occurred while checking for job state!") + yield TriggerEvent( + { + "status": "error", + "message": str(e), + } + ) + + @cached_property + def async_hook(self) -> AsyncDataflowHook: + return AsyncDataflowHook( + gcp_conn_id=self.gcp_conn_id, + poll_sleep=self.poll_sleep, + impersonation_chain=self.impersonation_chain, + ) diff --git a/providers/google/tests/system/google/cloud/dataflow/example_dataflow_java_streaming.py b/providers/google/tests/system/google/cloud/dataflow/example_dataflow_java_streaming.py index cf6c3228fc71e..45900bc03e693 100644 --- a/providers/google/tests/system/google/cloud/dataflow/example_dataflow_java_streaming.py +++ b/providers/google/tests/system/google/cloud/dataflow/example_dataflow_java_streaming.py @@ -68,8 +68,10 @@ REMOTE_JAR_FILE_PATH = f"dataflow/java/{JAR_FILE_NAME}" OUTPUT_TOPIC_ID = f"tp-{ENV_ID}-out" +OUTPUT_TOPIC_ID_2 = f"tp-2-{ENV_ID}-out" INPUT_TOPIC = "projects/pubsub-public-data/topics/taxirides-realtime" OUTPUT_TOPIC = f"projects/{PROJECT_ID}/topics/{OUTPUT_TOPIC_ID}" +OUTPUT_TOPIC_2 = f"projects/{PROJECT_ID}/topics/{OUTPUT_TOPIC_ID_2}" with DAG( @@ -89,8 +91,10 @@ create_output_pub_sub_topic = PubSubCreateTopicOperator( task_id="create_topic", topic=OUTPUT_TOPIC_ID, project_id=PROJECT_ID, fail_if_exists=False ) + create_output_pub_sub_topic_2 = PubSubCreateTopicOperator( + task_id="create_topic_2", topic=OUTPUT_TOPIC_ID_2, project_id=PROJECT_ID, fail_if_exists=False + ) # [START howto_operator_start_java_streaming] - start_java_streaming_job_dataflow = BeamRunJavaPipelineOperator( runner=BeamRunnerType.DataflowRunner, task_id="start_java_streaming_dataflow_job", @@ -107,15 +111,46 @@ }, ) # [END howto_operator_start_java_streaming] + + # [START howto_operator_start_java_streaming_deferrable] + start_java_streaming_job_dataflow_def = BeamRunJavaPipelineOperator( + runner=BeamRunnerType.DataflowRunner, + task_id="start_java_streaming_dataflow_job_def", + jar=LOCAL_JAR, + pipeline_options={ + "tempLocation": GCS_TMP, + "input_topic": INPUT_TOPIC, + "output_topic": OUTPUT_TOPIC_2, + "streaming": True, + }, + dataflow_config={ + "job_name": f"java-streaming-job-{ENV_ID}", + "location": LOCATION, + }, + deferrable=True, + ) + # [END howto_operator_start_java_streaming_deferrable] + stop_dataflow_job = DataflowStopJobOperator( task_id="stop_dataflow_job", location=LOCATION, job_id="{{ task_instance.xcom_pull(task_ids='start_java_streaming_dataflow_job')['dataflow_job_id'] }}", ) + stop_dataflow_job_deferrable = DataflowStopJobOperator( + task_id="stop_dataflow_job_deferrable", + location=LOCATION, + job_id="{{ task_instance.xcom_pull(task_ids='start_java_streaming_dataflow_job_def', key='dataflow_job_id') }}", + ) delete_topic = PubSubDeleteTopicOperator( task_id="delete_topic", topic=OUTPUT_TOPIC_ID, project_id=PROJECT_ID ) delete_topic.trigger_rule = TriggerRule.ALL_DONE + + delete_topic_2 = PubSubDeleteTopicOperator( + task_id="delete_topic_2", topic=OUTPUT_TOPIC_ID_2, project_id=PROJECT_ID + ) + delete_topic_2.trigger_rule = TriggerRule.ALL_DONE + delete_bucket = GCSDeleteBucketOperator( task_id="delete_bucket", bucket_name=BUCKET_NAME, trigger_rule=TriggerRule.ALL_DONE ) @@ -125,11 +160,14 @@ create_bucket >> download_file >> create_output_pub_sub_topic + >> create_output_pub_sub_topic_2 # TEST BODY >> start_java_streaming_job_dataflow + >> start_java_streaming_job_dataflow_def # TEST TEARDOWN - >> stop_dataflow_job + >> [stop_dataflow_job, stop_dataflow_job_deferrable] >> delete_topic + >> delete_topic_2 >> delete_bucket ) diff --git a/providers/google/tests/system/google/cloud/dataflow/example_dataflow_streaming_python.py b/providers/google/tests/system/google/cloud/dataflow/example_dataflow_streaming_python.py index e87ee258bb881..3df070f19a8a7 100644 --- a/providers/google/tests/system/google/cloud/dataflow/example_dataflow_streaming_python.py +++ b/providers/google/tests/system/google/cloud/dataflow/example_dataflow_streaming_python.py @@ -48,6 +48,7 @@ GCS_PYTHON_SCRIPT = f"gs://{RESOURCE_DATA_BUCKET}/dataflow/python/streaming_wordcount.py" LOCATION = "europe-west3" TOPIC_ID = f"topic-{DAG_ID}" +TOPIC_ID_2 = f"topic-2-{DAG_ID}" default_args = { "dataflow_default_options": { @@ -60,7 +61,7 @@ DAG_ID, default_args=default_args, schedule="@once", - start_date=datetime(2021, 1, 1), + start_date=datetime(2025, 1, 1), catchup=False, tags=["example", "dataflow"], ) as dag: @@ -69,6 +70,9 @@ create_pub_sub_topic = PubSubCreateTopicOperator( task_id="create_topic", topic=TOPIC_ID, project_id=PROJECT_ID, fail_if_exists=False ) + create_pub_sub_topic_2 = PubSubCreateTopicOperator( + task_id="create_topic_2", topic=TOPIC_ID_2, project_id=PROJECT_ID, fail_if_exists=False + ) # [START howto_operator_start_streaming_python_job] start_streaming_python_job = BeamRunPythonPipelineOperator( @@ -82,22 +86,53 @@ "output_topic": f"projects/{PROJECT_ID}/topics/{TOPIC_ID}", "streaming": True, }, - py_requirements=["apache-beam[gcp]==2.59.0"], + py_requirements=["apache-beam[gcp]==2.63.0"], py_interpreter="python3", py_system_site_packages=False, dataflow_config={"location": LOCATION, "job_name": "start_python_job_streaming"}, ) # [END howto_operator_start_streaming_python_job] + # [START howto_operator_start_streaming_python_job_deferrable] + start_streaming_python_job_def = BeamRunPythonPipelineOperator( + runner=BeamRunnerType.DataflowRunner, + task_id="start_def_streaming_python_job", + py_file=GCS_PYTHON_SCRIPT, + py_options=[], + pipeline_options={ + "temp_location": GCS_TMP, + "input_topic": "projects/pubsub-public-data/topics/taxirides-realtime", + "output_topic": f"projects/{PROJECT_ID}/topics/{TOPIC_ID_2}", + "streaming": True, + }, + py_requirements=["apache-beam[gcp]==2.63.0"], + py_interpreter="python3", + py_system_site_packages=False, + dataflow_config={"location": LOCATION, "job_name": "start_python_job_streaming"}, + deferrable=True, + ) + # [END howto_operator_start_streaming_python_job_deferrable] + stop_dataflow_job = DataflowStopJobOperator( task_id="stop_dataflow_job", location=LOCATION, job_id="{{ task_instance.xcom_pull(task_ids='start_streaming_python_job')['dataflow_job_id'] }}", ) + stop_dataflow_job_deferrable = DataflowStopJobOperator( + task_id="stop_dataflow_job_deferrable", + location=LOCATION, + job_id="{{ task_instance.xcom_pull(task_ids='start_def_streaming_python_job', key='dataflow_job_id') }}", + ) + delete_topic = PubSubDeleteTopicOperator(task_id="delete_topic", topic=TOPIC_ID, project_id=PROJECT_ID) delete_topic.trigger_rule = TriggerRule.ALL_DONE + delete_topic_2 = PubSubDeleteTopicOperator( + task_id="delete_topic2", topic=TOPIC_ID_2, project_id=PROJECT_ID + ) + delete_topic.trigger_rule = TriggerRule.ALL_DONE + delete_bucket = GCSDeleteBucketOperator( task_id="delete_bucket", bucket_name=BUCKET_NAME, trigger_rule=TriggerRule.ALL_DONE ) @@ -106,11 +141,14 @@ # TEST SETUP create_bucket >> create_pub_sub_topic + >> create_pub_sub_topic_2 # TEST BODY >> start_streaming_python_job + >> start_streaming_python_job_def # TEST TEARDOWN - >> stop_dataflow_job + >> [stop_dataflow_job, stop_dataflow_job_deferrable] >> delete_topic + >> delete_topic_2 >> delete_bucket ) @@ -120,7 +158,6 @@ # when "teardown" task with trigger rule is part of the DAG list(dag.tasks) >> watcher() - from tests_common.test_utils.system_tests import get_test_run # noqa: E402 # Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) diff --git a/providers/google/tests/unit/google/cloud/hooks/test_dataflow.py b/providers/google/tests/unit/google/cloud/hooks/test_dataflow.py index 2ecb4fb284421..8a13b1d8485c0 100644 --- a/providers/google/tests/unit/google/cloud/hooks/test_dataflow.py +++ b/providers/google/tests/unit/google/cloud/hooks/test_dataflow.py @@ -983,7 +983,7 @@ def test_check_dataflow_job_state_wait_until_finished( multiple_jobs=True, wait_until_finished=wait_until_finished, ) - result = dataflow_job._check_dataflow_job_state(job) + result = dataflow_job.job_reached_terminal_state(job, wait_until_finished) assert result == expected_result @pytest.mark.parametrize( @@ -1055,7 +1055,7 @@ def test_check_dataflow_job_state_without_job_type_changed_on_terminal_state( result = False for current_job in jobs: job = {"id": "id-2", "name": "name-2", "type": current_job[0], "currentState": current_job[1]} - result = dataflow_job._check_dataflow_job_state(job) + result = dataflow_job.job_reached_terminal_state(job, wait_until_finished) assert result == expected_result @pytest.mark.parametrize( @@ -1088,7 +1088,7 @@ def test_check_dataflow_job_state_without_job_type(self, job_state, wait_until_f multiple_jobs=True, wait_until_finished=wait_until_finished, ) - result = dataflow_job._check_dataflow_job_state(job) + result = dataflow_job.job_reached_terminal_state(job, wait_until_finished) assert result == expected_result @pytest.mark.parametrize( @@ -1159,7 +1159,7 @@ def test_check_dataflow_job_state_terminal_state(self, job_type, job_state, exce multiple_jobs=True, ) with pytest.raises(AirflowException, match=exception_regex): - dataflow_job._check_dataflow_job_state(job) + dataflow_job.job_reached_terminal_state(job) @pytest.mark.parametrize( "job_type, expected_terminal_state, match", @@ -1200,7 +1200,7 @@ def test_check_dataflow_job_state__invalid_expected_state(self, job_type, expect expected_terminal_state=expected_terminal_state, ) with pytest.raises(AirflowException, match=match): - dataflow_job._check_dataflow_job_state(job) + dataflow_job.job_reached_terminal_state(job, custom_terminal_state=expected_terminal_state) def test_dataflow_job_cancel_job(self): mock_jobs = self.mock_dataflow.projects.return_value.locations.return_value.jobs diff --git a/providers/google/tests/unit/google/cloud/triggers/test_dataflow.py b/providers/google/tests/unit/google/cloud/triggers/test_dataflow.py index 67b2f41009c25..1312a1dfb2aca 100644 --- a/providers/google/tests/unit/google/cloud/triggers/test_dataflow.py +++ b/providers/google/tests/unit/google/cloud/triggers/test_dataflow.py @@ -29,6 +29,7 @@ DataflowJobAutoScalingEventTrigger, DataflowJobMessagesTrigger, DataflowJobMetricsTrigger, + DataflowJobStateCompleteTrigger, DataflowJobStatusTrigger, DataflowStartYamlJobTrigger, TemplateJobStartTrigger, @@ -42,6 +43,7 @@ POLL_SLEEP = 20 IMPERSONATION_CHAIN = ["impersonate", "this"] CANCEL_TIMEOUT = 10 * 420 +WAIT_UNTIL_FINISHED = None @pytest.fixture @@ -122,6 +124,19 @@ def dataflow_start_yaml_job_trigger(): ) +@pytest.fixture +def dataflow_job_state_complete_trigger(): + return DataflowJobStateCompleteTrigger( + project_id=PROJECT_ID, + job_id=JOB_ID, + location=LOCATION, + gcp_conn_id=GCP_CONN_ID, + poll_sleep=POLL_SLEEP, + impersonation_chain=IMPERSONATION_CHAIN, + wait_until_finished=WAIT_UNTIL_FINISHED, + ) + + @pytest.fixture def test_dataflow_batch_job(): return Job(id=JOB_ID, current_state=JobState.JOB_STATE_DONE, type_=JobType.JOB_TYPE_BATCH) @@ -860,3 +875,132 @@ async def test_run_loop_is_still_running( await asyncio.sleep(0.5) assert task.done() is False task.cancel() + + +class TestDataflowJobStateCompleteTrigger: + """Test case for DataflowJobStatusTrigger""" + + def test_serialize(self, dataflow_job_state_complete_trigger): + expected_data = ( + "airflow.providers.google.cloud.triggers.dataflow.DataflowJobStateCompleteTrigger", + { + "project_id": PROJECT_ID, + "job_id": JOB_ID, + "location": LOCATION, + "wait_until_finished": WAIT_UNTIL_FINISHED, + "gcp_conn_id": GCP_CONN_ID, + "poll_sleep": POLL_SLEEP, + "impersonation_chain": IMPERSONATION_CHAIN, + }, + ) + actual_data = dataflow_job_state_complete_trigger.serialize() + assert actual_data == expected_data + + @pytest.mark.parametrize( + "attr, expected", + [ + ("gcp_conn_id", GCP_CONN_ID), + ("poll_sleep", POLL_SLEEP), + ("impersonation_chain", IMPERSONATION_CHAIN), + ], + ) + def test_async_hook(self, dataflow_job_state_complete_trigger, attr, expected): + hook = dataflow_job_state_complete_trigger.async_hook + actual = hook._hook_kwargs.get(attr) + assert actual == expected + + @pytest.mark.parametrize( + "job_status_value", + [ + JobState.JOB_STATE_FAILED, + JobState.JOB_STATE_CANCELLED, + JobState.JOB_STATE_DRAINED, + ], + ) + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job") + async def test_run_yields_error_on_failed_state( + self, + mock_job, + job_status_value, + dataflow_job_state_complete_trigger, + ): + mock_job.return_value = Job( + id=JOB_ID, current_state=job_status_value, type_=JobType.JOB_TYPE_STREAMING + ) + expected_event = TriggerEvent( + { + "status": "error", + "message": f"Job with id '{JOB_ID}' is in failed terminal state: {job_status_value.name}", + } + ) + actual_event = await dataflow_job_state_complete_trigger.run().asend(None) + assert actual_event == expected_event + + @pytest.mark.parametrize( + "job_item", + [ + Job(id=JOB_ID, current_state=JobState.JOB_STATE_DONE, type_=JobType.JOB_TYPE_BATCH), + Job(id=JOB_ID, current_state=JobState.JOB_STATE_RUNNING, type_=JobType.JOB_TYPE_STREAMING), + ], + ) + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job") + async def test_run_yields_success_event_if_expected_job_status( + self, + mock_job, + job_item, + dataflow_job_state_complete_trigger, + ): + dataflow_job_state_complete_trigger.expected_statuses = { + DataflowJobStatus.JOB_STATE_DONE, + DataflowJobStatus.JOB_STATE_RUNNING, + } + mock_job.return_value = mock_job.return_value = job_item + + expected_event = TriggerEvent( + { + "status": "success", + "message": f"Job with id '{JOB_ID}' has reached successful final state:" + f" {job_item.current_state.name}", + } + ) + actual_event = await dataflow_job_state_complete_trigger.run().asend(None) + assert actual_event == expected_event + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job") + async def test_run_loop_is_still_running_if_state_is_not_terminal_or_expected( + self, + mock_job, + dataflow_job_state_complete_trigger, + ): + """ + Test that DataflowJobStateCompleteTrigger is still in loop if the job status neither + terminal nor expected. + """ + mock_job.return_value = Job( + id=JOB_ID, + current_state=JobState.JOB_STATE_PENDING, + type_=JobType.JOB_TYPE_STREAMING, + ) + task = asyncio.create_task(dataflow_job_state_complete_trigger.run().__anext__()) + await asyncio.sleep(0.5) + assert task.done() is False + task.cancel() + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job") + async def test_run_raises_exception(self, mock_job, dataflow_job_state_complete_trigger): + """ + Tests the DataflowJobStateCompleteTrigger does trigger error if there is an exception. + """ + mock_job.side_effect = mock.AsyncMock(side_effect=Exception("Test exception")) + expected_event = TriggerEvent( + { + "status": "error", + "message": "Test exception", + } + ) + actual_event = await dataflow_job_state_complete_trigger.run().asend(None) + assert expected_event == actual_event