From 3b02a56dab5a8596c43981fab82bdab32828aa3a Mon Sep 17 00:00:00 2001 From: Ulada Zakharava Date: Mon, 13 Nov 2023 15:12:58 +0000 Subject: [PATCH 1/2] Implement Dataflow operators to update job, create snapshot and list active jobs --- .../providers/google/cloud/hooks/dataflow.py | 106 +++++++++++- .../google/cloud/operators/dataflow.py | 158 +++++++++++++++++- 2 files changed, 259 insertions(+), 5 deletions(-) diff --git a/airflow/providers/google/cloud/hooks/dataflow.py b/airflow/providers/google/cloud/hooks/dataflow.py index e1eb1d048ba37..009464d2a1788 100644 --- a/airflow/providers/google/cloud/hooks/dataflow.py +++ b/airflow/providers/google/cloud/hooks/dataflow.py @@ -28,9 +28,12 @@ import warnings from copy import deepcopy from typing import Any, Callable, Generator, Sequence, TypeVar, cast +from google.api_core.client_options import ClientOptions -from google.cloud.dataflow_v1beta3 import GetJobRequest, Job, JobState, JobsV1Beta3AsyncClient, JobView +from google.cloud.dataflow_v1beta3 import GetJobRequest, Job, JobState, JobsV1Beta3AsyncClient, JobsV1Beta3Client, JobView, SnapshotsV1Beta3Client from googleapiclient.discovery import build +from google.cloud.dataflow_v1beta3.types import CheckActiveJobsRequest, CheckActiveJobsResponse +from airflow.providers.google.common.consts import CLIENT_INFO from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.apache.beam.hooks.beam import BeamHook, BeamRunnerType, beam_options_to_args @@ -139,6 +142,7 @@ class DataflowJobStatus: JOB_STATE_PENDING = "JOB_STATE_PENDING" JOB_STATE_CANCELLING = "JOB_STATE_CANCELLING" JOB_STATE_QUEUED = "JOB_STATE_QUEUED" + TURN_UP_STATES = {JOB_STATE_PENDING, JOB_STATE_QUEUED, JOB_STATE_RUNNING} FAILED_END_STATES = {JOB_STATE_FAILED, JOB_STATE_CANCELLED} SUCCEEDED_END_STATES = {JOB_STATE_DONE, JOB_STATE_UPDATED, JOB_STATE_DRAINED} TERMINAL_STATES = SUCCEEDED_END_STATES | FAILED_END_STATES @@ -285,6 +289,47 @@ def fetch_job_metrics_by_id(self, job_id: str) -> dict: self.log.debug("fetch_job_metrics_by_id %s:\n%s", job_id, result) return result + def update_job(self, job_id, update_mask, body): + result = ( + self._dataflow.projects() + .locations() + .jobs() + .update( + projectId=self._project_number, + location=self._job_location, + jobId=job_id, + body=body, + updateMask=update_mask + ) + .execute(num_retries=self._num_retries) + ) + self.log.info("result: %s", result) + return result + + def create_job_snapshot(self, job_id): + result = ( + self._dataflow.projects() + .locations() + .jobs() + .snapshot( + projectId=self._project_number, + location=self._job_location, + jobId=job_id, + body={ + # pass correct parameters here + "ttl": "3.5s", + "snapshotSources": True, + # "description": environment, + }, + ) + .execute(num_retries=self._num_retries) + ) + self.log.info("result: %s", result) + return result + + def fetch_active_jobs(self): + return [job for job in self._fetch_all_jobs() if job["currentState"] in DataflowJobStatus.TURN_UP_STATES] + def _fetch_list_job_messages_responses(self, job_id: str) -> Generator[dict, None, None]: """ Helper method to fetch ListJobMessagesResponse with the specified Job ID. @@ -419,9 +464,7 @@ def _check_dataflow_job_state(self, job) -> bool: "JOB_STATE_DRAINED while it is a batch job" ) - if current_state == self._expected_terminal_state: - if self._expected_terminal_state == DataflowJobStatus.JOB_STATE_RUNNING: - return not self._wait_until_finished + if not self._wait_until_finished and current_state == self._expected_terminal_state: return True if current_state in DataflowJobStatus.AWAITING_STATES: @@ -559,6 +602,17 @@ def get_conn(self) -> build: http_authorized = self._authorize() return build("dataflow", "v1b3", http=http_authorized, cache_discovery=False) + def initialize_client(self, client_class): + # this method doesn't work for JobClient, requires some user "dataflow-cloud-router" + # for other clients like SnapshotClient works fine + # client_options = ClientOptions(api_endpoint=f"dataflow.googleapis.com:443") + credentials = self.get_credentials() + return client_class( + credentials=credentials, + # client_info=CLIENT_INFO, + # client_options=client_options + ) + @_fallback_to_location_from_variables @_fallback_to_project_id_from_variables @GoogleBaseHook.fallback_to_default_project_id @@ -1094,6 +1148,21 @@ def get_job( ) return jobs_controller.fetch_job_by_id(job_id) + def update_job( + self, + job_id, + update_mask, + body, + project_id: str = PROVIDE_PROJECT_ID, + location: str = DEFAULT_DATAFLOW_LOCATION, + ): + jobs_controller = _DataflowJobsController( + dataflow=self.get_conn(), + project_number=project_id, + location=location, + ) + return jobs_controller.update_job(job_id=job_id, update_mask=update_mask, body=body) + @GoogleBaseHook.fallback_to_default_project_id def fetch_job_metrics_by_id( self, @@ -1167,6 +1236,35 @@ def fetch_job_autoscaling_events_by_id( ) return jobs_controller.fetch_job_autoscaling_events_by_id(job_id) + @GoogleBaseHook.fallback_to_default_project_id + def check_active_jobs( + self, + project_id: str, + location: str = DEFAULT_DATAFLOW_LOCATION, + ) -> list[dict]: + jobs_controller = _DataflowJobsController( + dataflow=self.get_conn(), + project_number=project_id, + location=location, + ) + result = jobs_controller.fetch_active_jobs() + return result + + @GoogleBaseHook.fallback_to_default_project_id + def create_job_snapshot( + self, + job_id: str, + project_id: str, + location: str = DEFAULT_DATAFLOW_LOCATION, + ) -> dict: + jobs_controller = _DataflowJobsController( + dataflow=self.get_conn(), + project_number=project_id, + location=location, + ) + result = jobs_controller.create_job_snapshot(job_id=job_id) + return result + @GoogleBaseHook.fallback_to_default_project_id def wait_for_done( self, diff --git a/airflow/providers/google/cloud/operators/dataflow.py b/airflow/providers/google/cloud/operators/dataflow.py index 59b27158862b3..8ed9f8ab7922f 100644 --- a/airflow/providers/google/cloud/operators/dataflow.py +++ b/airflow/providers/google/cloud/operators/dataflow.py @@ -1016,7 +1016,7 @@ def set_current_job(current_job): project_id=self.project_id, on_new_job_callback=set_current_job, ) - + self.xcom_push(context, key="job_id", value=job["id"]) return job def on_kill(self) -> None: @@ -1334,3 +1334,159 @@ def execute(self, context: Context) -> None: self.log.info("No jobs to stop") return None + + +class DataflowListActiveJobsOperator(GoogleCloudBaseOperator): + """ + Check for existence of active jobs in the given project across the given region. + + :param job_name_prefix: Name prefix specifying which jobs are to be stopped. + :param job_id: Job ID specifying which jobs are to be stopped. + :param project_id: Optional, the Google Cloud project ID in which to start a job. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :param location: Optional, Job location. If set to None or missing, "us-central1" will be used. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param poll_sleep: The time in seconds to sleep between polling Google + Cloud Platform for the dataflow job status to confirm it's stopped. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :param drain_pipeline: Optional, set to False if want to stop streaming job by canceling it + instead of draining. See: https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline + :param stop_timeout: wait time in seconds for successful job canceling/draining + """ + + def __init__( + self, + project_id: str | None = None, + location: str = DEFAULT_DATAFLOW_LOCATION, + gcp_conn_id: str = "google_cloud_default", + poll_sleep: int = 10, + impersonation_chain: str | Sequence[str] | None = None, + drain_pipeline: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.poll_sleep = poll_sleep + self.project_id = project_id + self.location = location + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + self.hook: DataflowHook | None = None + self.drain_pipeline = drain_pipeline + + def execute(self, context: Context) -> None: + self.dataflow_hook = DataflowHook( + gcp_conn_id=self.gcp_conn_id, + poll_sleep=self.poll_sleep, + impersonation_chain=self.impersonation_chain, + drain_pipeline=self.drain_pipeline, + ) + active_jobs = self.dataflow_hook.check_active_jobs(project_id=self.project_id, location=self.location) + self.log.info("Active jobs in %s: %s", active_jobs, self.location) + return None + +class DataflowUpdateJobOperator(GoogleCloudBaseOperator): + template_fields: Sequence[str] = ( + "job_id", + ) + + def __init__( + self, + job_id: str, + update_mask: str, + updated_body: dict[str, Any], + project_id: str | None = None, + location: str = DEFAULT_DATAFLOW_LOCATION, + gcp_conn_id: str = "google_cloud_default", + poll_sleep: int = 10, + impersonation_chain: str | Sequence[str] | None = None, + stop_timeout: int | None = 10 * 60, + drain_pipeline: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.update_mask = update_mask + self.updated_body = updated_body + self.job_id = job_id + self.project_id = project_id + self.location = location + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + self.hook: DataflowHook | None = None + self.drain_pipeline = drain_pipeline + self.poll_sleep = poll_sleep + self.stop_timeout = stop_timeout + + def execute(self, context: Context) -> None: + self.dataflow_hook = DataflowHook( + gcp_conn_id=self.gcp_conn_id, + poll_sleep=self.poll_sleep, + impersonation_chain=self.impersonation_chain, + cancel_timeout=self.stop_timeout, + drain_pipeline=self.drain_pipeline, + ) + + # streaming jobs can be updated, but batch jobs? Failed to provide correct Field Mask while testing + updated_job = self.dataflow_hook.update_job( + project_id=self.project_id, + location=self.location, + job_id=self.job_id, + update_mask=self.update_mask, + body=self.updated_body, + ) + + return updated_job + + +class DataflowCreateJobSnapshotOperator(GoogleCloudBaseOperator): + template_fields: Sequence[str] = ( + "job_id", + ) + + def __init__( + self, + job_id: str, + project_id: str | None = None, + location: str = DEFAULT_DATAFLOW_LOCATION, + gcp_conn_id: str = "google_cloud_default", + poll_sleep: int = 10, + impersonation_chain: str | Sequence[str] | None = None, + stop_timeout: int | None = 10 * 60, + drain_pipeline: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.poll_sleep = poll_sleep + self.stop_timeout = stop_timeout + self.job_id = job_id + self.project_id = project_id + self.location = location + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + self.hook: DataflowHook | None = None + self.drain_pipeline = drain_pipeline + + def execute(self, context: Context): + self.dataflow_hook = DataflowHook( + gcp_conn_id=self.gcp_conn_id, + poll_sleep=self.poll_sleep, + impersonation_chain=self.impersonation_chain, + cancel_timeout=self.stop_timeout, + drain_pipeline=self.drain_pipeline, + ) + # if streaming job exists, then.. + + snapshot = self.dataflow_hook.create_job_snapshot( + project_id=self.project_id, + location=self.location, + job_id=self.job_id, + ) + # else "Can't create snapshot of batch job + + return snapshot From a7dc17141696caa4899d7af1ec72fb08ec070535 Mon Sep 17 00:00:00 2001 From: Maksim Yermakou Date: Wed, 15 Nov 2023 16:36:58 +0000 Subject: [PATCH 2/2] Add system test for Dataflow Job Snapshots service --- .../providers/google/cloud/hooks/dataflow.py | 42 +++-- .../google/cloud/operators/dataflow.py | 18 +- .../dataflow/example_dataflow_job_snapshot.py | 155 ++++++++++++++++++ 3 files changed, 197 insertions(+), 18 deletions(-) create mode 100644 tests/system/providers/google/cloud/dataflow/example_dataflow_job_snapshot.py diff --git a/airflow/providers/google/cloud/hooks/dataflow.py b/airflow/providers/google/cloud/hooks/dataflow.py index 009464d2a1788..04cd0042e9384 100644 --- a/airflow/providers/google/cloud/hooks/dataflow.py +++ b/airflow/providers/google/cloud/hooks/dataflow.py @@ -28,12 +28,15 @@ import warnings from copy import deepcopy from typing import Any, Callable, Generator, Sequence, TypeVar, cast -from google.api_core.client_options import ClientOptions -from google.cloud.dataflow_v1beta3 import GetJobRequest, Job, JobState, JobsV1Beta3AsyncClient, JobsV1Beta3Client, JobView, SnapshotsV1Beta3Client +from google.cloud.dataflow_v1beta3 import ( + GetJobRequest, + Job, + JobState, + JobsV1Beta3AsyncClient, + JobView, +) from googleapiclient.discovery import build -from google.cloud.dataflow_v1beta3.types import CheckActiveJobsRequest, CheckActiveJobsResponse -from airflow.providers.google.common.consts import CLIENT_INFO from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.apache.beam.hooks.beam import BeamHook, BeamRunnerType, beam_options_to_args @@ -299,14 +302,20 @@ def update_job(self, job_id, update_mask, body): location=self._job_location, jobId=job_id, body=body, - updateMask=update_mask + updateMask=update_mask, ) .execute(num_retries=self._num_retries) ) self.log.info("result: %s", result) return result - def create_job_snapshot(self, job_id): + def create_job_snapshot( + self, + job_id, + snapshot_ttl: str = "604800s", + snapshot_sources: bool = False, + description: str | None = None, + ): result = ( self._dataflow.projects() .locations() @@ -316,10 +325,9 @@ def create_job_snapshot(self, job_id): location=self._job_location, jobId=job_id, body={ - # pass correct parameters here - "ttl": "3.5s", - "snapshotSources": True, - # "description": environment, + "ttl": snapshot_ttl, + "snapshotSources": snapshot_sources, + "description": description, }, ) .execute(num_retries=self._num_retries) @@ -328,7 +336,9 @@ def create_job_snapshot(self, job_id): return result def fetch_active_jobs(self): - return [job for job in self._fetch_all_jobs() if job["currentState"] in DataflowJobStatus.TURN_UP_STATES] + return [ + job for job in self._fetch_all_jobs() if job["currentState"] in DataflowJobStatus.TURN_UP_STATES + ] def _fetch_list_job_messages_responses(self, job_id: str) -> Generator[dict, None, None]: """ @@ -1256,13 +1266,21 @@ def create_job_snapshot( job_id: str, project_id: str, location: str = DEFAULT_DATAFLOW_LOCATION, + snapshot_ttl: str = "604800s", + snapshot_sources: bool = False, + description: str | None = None, ) -> dict: jobs_controller = _DataflowJobsController( dataflow=self.get_conn(), project_number=project_id, location=location, ) - result = jobs_controller.create_job_snapshot(job_id=job_id) + result = jobs_controller.create_job_snapshot( + job_id=job_id, + snapshot_ttl=snapshot_ttl, + snapshot_sources=snapshot_sources, + description=description, + ) return result @GoogleBaseHook.fallback_to_default_project_id diff --git a/airflow/providers/google/cloud/operators/dataflow.py b/airflow/providers/google/cloud/operators/dataflow.py index 8ed9f8ab7922f..b87ffcc409047 100644 --- a/airflow/providers/google/cloud/operators/dataflow.py +++ b/airflow/providers/google/cloud/operators/dataflow.py @@ -1391,10 +1391,9 @@ def execute(self, context: Context) -> None: self.log.info("Active jobs in %s: %s", active_jobs, self.location) return None + class DataflowUpdateJobOperator(GoogleCloudBaseOperator): - template_fields: Sequence[str] = ( - "job_id", - ) + template_fields: Sequence[str] = ("job_id",) def __init__( self, @@ -1445,15 +1444,16 @@ def execute(self, context: Context) -> None: class DataflowCreateJobSnapshotOperator(GoogleCloudBaseOperator): - template_fields: Sequence[str] = ( - "job_id", - ) + template_fields: Sequence[str] = ("job_id",) def __init__( self, job_id: str, project_id: str | None = None, location: str = DEFAULT_DATAFLOW_LOCATION, + snapshot_ttl: str = "604800s", + snapshot_sources: bool = False, + description: str | None = None, gcp_conn_id: str = "google_cloud_default", poll_sleep: int = 10, impersonation_chain: str | Sequence[str] | None = None, @@ -1467,6 +1467,9 @@ def __init__( self.job_id = job_id self.project_id = project_id self.location = location + self.snapshot_ttl = snapshot_ttl + self.snapshot_sources = snapshot_sources + self.description = description self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain self.hook: DataflowHook | None = None @@ -1486,6 +1489,9 @@ def execute(self, context: Context): project_id=self.project_id, location=self.location, job_id=self.job_id, + snapshot_ttl=self.snapshot_ttl, + snapshot_sources=self.snapshot_sources, + description=self.description, ) # else "Can't create snapshot of batch job diff --git a/tests/system/providers/google/cloud/dataflow/example_dataflow_job_snapshot.py b/tests/system/providers/google/cloud/dataflow/example_dataflow_job_snapshot.py new file mode 100644 index 0000000000000..cbd18e59545a7 --- /dev/null +++ b/tests/system/providers/google/cloud/dataflow/example_dataflow_job_snapshot.py @@ -0,0 +1,155 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG for testing Google Dataflow +""" +from __future__ import annotations + +import os +from datetime import datetime + +from airflow.models.dag import DAG +from airflow.providers.google.cloud.operators.bigquery import ( + BigQueryCreateEmptyDatasetOperator, + BigQueryCreateEmptyTableOperator, + BigQueryDeleteDatasetOperator, + BigQueryDeleteTableOperator, +) +from airflow.providers.google.cloud.operators.dataflow import ( + DataflowCreateJobSnapshotOperator, + DataflowListActiveJobsOperator, + DataflowTemplatedJobStartOperator, + DataflowUpdateJobOperator, +) +from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator +from airflow.utils.trigger_rule import TriggerRule + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default") +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default") +DAG_ID = "dataflow_job_snapshot" + +BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}".replace("_", "-") +DATASET_NAME = f"dataset_{DAG_ID}_{ENV_ID}" +TABLE_NAME = "realtime" +GCS_TMP = f"gs://{BUCKET_NAME}/temp/" + +LOCATION = "us-central1" + +default_args = { + "dataflow_default_options": { + "tempLocation": GCS_TMP, + } +} + + +with DAG( + DAG_ID, + default_args=default_args, + schedule="@once", + start_date=datetime(2021, 1, 1), + catchup=False, + tags=["example", "dataflow"], +) as dag: + create_bucket = GCSCreateBucketOperator(task_id="create_bucket", bucket_name=BUCKET_NAME) + + create_dataset = BigQueryCreateEmptyDatasetOperator(task_id="create_dataset", dataset_id=DATASET_NAME) + + create_table = BigQueryCreateEmptyTableOperator( + task_id="create_table", + dataset_id=DATASET_NAME, + table_id=TABLE_NAME, + schema_fields=[ + {"name": "ride_id", "type": "STRING", "mode": "NULLABLE"}, + {"name": "point_idx", "type": "INTEGER", "mode": "NULLABLE"}, + {"name": "latitude", "type": "FLOAT", "mode": "NULLABLE"}, + {"name": "longitude", "type": "FLOAT", "mode": "NULLABLE"}, + {"name": "timestamp", "type": "TIMESTAMP", "mode": "NULLABLE"}, + {"name": "meter_reading", "type": "FLOAT", "mode": "NULLABLE"}, + {"name": "meter_increment", "type": "FLOAT", "mode": "NULLABLE"}, + {"name": "ride_status", "type": "STRING", "mode": "NULLABLE"}, + {"name": "passenger_count", "type": "INTEGER", "mode": "NULLABLE"}, + ], + time_partitioning={"type": "DAY", "field": "timestamp"}, + ) + + create_streaming_job = DataflowTemplatedJobStartOperator( + task_id="create_streaming_job", + project_id=PROJECT_ID, + template=f"gs://dataflow-templates-{LOCATION}/latest/PubSub_to_BigQuery", + parameters={ + "inputTopic": "projects/pubsub-public-data/topics/taxirides-realtime", + "outputTableSpec": f"{PROJECT_ID}:{DATASET_NAME}.{TABLE_NAME}", + }, + location=LOCATION, + ) + + create_job_snapshot = DataflowCreateJobSnapshotOperator( + task_id="create_job_snapshot", + job_id="{{ task_instance.xcom_pull(task_ids='create_streaming_job') }}", + project_id=PROJECT_ID, + location=LOCATION, + ) + + update_job = DataflowUpdateJobOperator( + task_id="update_job", + job_id="test_job", + update_mask="", + updated_body={}, + ) + + list_job = DataflowListActiveJobsOperator(task_id="list_job") + + delete_table = BigQueryDeleteTableOperator( + task_id="delete_table", + deletion_dataset_table=f"{PROJECT_ID}.{DATASET_NAME}.{TABLE_NAME}", + ) + + delete_dataset = BigQueryDeleteDatasetOperator( + task_id="delete_dataset", dataset_id=DATASET_NAME, delete_contents=True + ) + delete_dataset.trigger_rule = TriggerRule.ALL_DONE + + delete_bucket = GCSDeleteBucketOperator( + task_id="delete_bucket", bucket_name=BUCKET_NAME, trigger_rule=TriggerRule.ALL_DONE + ) + + ( + create_bucket + >> create_dataset + >> create_table + >> create_streaming_job + >> create_job_snapshot + >> update_job + >> list_job + >> delete_table + >> delete_dataset + >> delete_bucket + ) + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "teardown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)