Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Dataflow operators #172

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 120 additions & 4 deletions airflow/providers/google/cloud/hooks/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@
from copy import deepcopy
from typing import Any, Callable, Generator, Sequence, TypeVar, cast

from google.cloud.dataflow_v1beta3 import GetJobRequest, Job, JobState, JobsV1Beta3AsyncClient, JobView
from google.cloud.dataflow_v1beta3 import (
GetJobRequest,
Job,
JobState,
JobsV1Beta3AsyncClient,
JobView,
)
from googleapiclient.discovery import build

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
Expand Down Expand Up @@ -139,6 +145,7 @@ class DataflowJobStatus:
JOB_STATE_PENDING = "JOB_STATE_PENDING"
JOB_STATE_CANCELLING = "JOB_STATE_CANCELLING"
JOB_STATE_QUEUED = "JOB_STATE_QUEUED"
TURN_UP_STATES = {JOB_STATE_PENDING, JOB_STATE_QUEUED, JOB_STATE_RUNNING}
FAILED_END_STATES = {JOB_STATE_FAILED, JOB_STATE_CANCELLED}
SUCCEEDED_END_STATES = {JOB_STATE_DONE, JOB_STATE_UPDATED, JOB_STATE_DRAINED}
TERMINAL_STATES = SUCCEEDED_END_STATES | FAILED_END_STATES
Expand Down Expand Up @@ -285,6 +292,54 @@ def fetch_job_metrics_by_id(self, job_id: str) -> dict:
self.log.debug("fetch_job_metrics_by_id %s:\n%s", job_id, result)
return result

def update_job(self, job_id, update_mask, body):
result = (
self._dataflow.projects()
.locations()
.jobs()
.update(
projectId=self._project_number,
location=self._job_location,
jobId=job_id,
body=body,
updateMask=update_mask,
)
.execute(num_retries=self._num_retries)
)
self.log.info("result: %s", result)
return result

def create_job_snapshot(
self,
job_id,
snapshot_ttl: str = "604800s",
snapshot_sources: bool = False,
description: str | None = None,
):
result = (
self._dataflow.projects()
.locations()
.jobs()
.snapshot(
projectId=self._project_number,
location=self._job_location,
jobId=job_id,
body={
"ttl": snapshot_ttl,
"snapshotSources": snapshot_sources,
"description": description,
},
)
.execute(num_retries=self._num_retries)
)
self.log.info("result: %s", result)
return result

def fetch_active_jobs(self):
return [
job for job in self._fetch_all_jobs() if job["currentState"] in DataflowJobStatus.TURN_UP_STATES
]

def _fetch_list_job_messages_responses(self, job_id: str) -> Generator[dict, None, None]:
"""
Helper method to fetch ListJobMessagesResponse with the specified Job ID.
Expand Down Expand Up @@ -419,9 +474,7 @@ def _check_dataflow_job_state(self, job) -> bool:
"JOB_STATE_DRAINED while it is a batch job"
)

if current_state == self._expected_terminal_state:
if self._expected_terminal_state == DataflowJobStatus.JOB_STATE_RUNNING:
return not self._wait_until_finished
if not self._wait_until_finished and current_state == self._expected_terminal_state:
return True

if current_state in DataflowJobStatus.AWAITING_STATES:
Expand Down Expand Up @@ -559,6 +612,17 @@ def get_conn(self) -> build:
http_authorized = self._authorize()
return build("dataflow", "v1b3", http=http_authorized, cache_discovery=False)

def initialize_client(self, client_class):
# this method doesn't work for JobClient, requires some user "dataflow-cloud-router"
# for other clients like SnapshotClient works fine
# client_options = ClientOptions(api_endpoint=f"dataflow.googleapis.com:443")
credentials = self.get_credentials()
return client_class(
credentials=credentials,
# client_info=CLIENT_INFO,
# client_options=client_options
)

@_fallback_to_location_from_variables
@_fallback_to_project_id_from_variables
@GoogleBaseHook.fallback_to_default_project_id
Expand Down Expand Up @@ -1094,6 +1158,21 @@ def get_job(
)
return jobs_controller.fetch_job_by_id(job_id)

def update_job(
self,
job_id,
update_mask,
body,
project_id: str = PROVIDE_PROJECT_ID,
location: str = DEFAULT_DATAFLOW_LOCATION,
):
jobs_controller = _DataflowJobsController(
dataflow=self.get_conn(),
project_number=project_id,
location=location,
)
return jobs_controller.update_job(job_id=job_id, update_mask=update_mask, body=body)

@GoogleBaseHook.fallback_to_default_project_id
def fetch_job_metrics_by_id(
self,
Expand Down Expand Up @@ -1167,6 +1246,43 @@ def fetch_job_autoscaling_events_by_id(
)
return jobs_controller.fetch_job_autoscaling_events_by_id(job_id)

@GoogleBaseHook.fallback_to_default_project_id
def check_active_jobs(
self,
project_id: str,
location: str = DEFAULT_DATAFLOW_LOCATION,
) -> list[dict]:
jobs_controller = _DataflowJobsController(
dataflow=self.get_conn(),
project_number=project_id,
location=location,
)
result = jobs_controller.fetch_active_jobs()
return result

@GoogleBaseHook.fallback_to_default_project_id
def create_job_snapshot(
self,
job_id: str,
project_id: str,
location: str = DEFAULT_DATAFLOW_LOCATION,
snapshot_ttl: str = "604800s",
snapshot_sources: bool = False,
description: str | None = None,
) -> dict:
jobs_controller = _DataflowJobsController(
dataflow=self.get_conn(),
project_number=project_id,
location=location,
)
result = jobs_controller.create_job_snapshot(
job_id=job_id,
snapshot_ttl=snapshot_ttl,
snapshot_sources=snapshot_sources,
description=description,
)
return result

@GoogleBaseHook.fallback_to_default_project_id
def wait_for_done(
self,
Expand Down
164 changes: 163 additions & 1 deletion airflow/providers/google/cloud/operators/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,7 +1016,7 @@ def set_current_job(current_job):
project_id=self.project_id,
on_new_job_callback=set_current_job,
)

self.xcom_push(context, key="job_id", value=job["id"])
return job

def on_kill(self) -> None:
Expand Down Expand Up @@ -1334,3 +1334,165 @@ def execute(self, context: Context) -> None:
self.log.info("No jobs to stop")

return None


class DataflowListActiveJobsOperator(GoogleCloudBaseOperator):
"""
Check for existence of active jobs in the given project across the given region.

:param job_name_prefix: Name prefix specifying which jobs are to be stopped.
:param job_id: Job ID specifying which jobs are to be stopped.
:param project_id: Optional, the Google Cloud project ID in which to start a job.
If set to None or missing, the default project_id from the Google Cloud connection is used.
:param location: Optional, Job location. If set to None or missing, "us-central1" will be used.
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
:param poll_sleep: The time in seconds to sleep between polling Google
Cloud Platform for the dataflow job status to confirm it's stopped.
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param drain_pipeline: Optional, set to False if want to stop streaming job by canceling it
instead of draining. See: https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline
:param stop_timeout: wait time in seconds for successful job canceling/draining
"""

def __init__(
self,
project_id: str | None = None,
location: str = DEFAULT_DATAFLOW_LOCATION,
gcp_conn_id: str = "google_cloud_default",
poll_sleep: int = 10,
impersonation_chain: str | Sequence[str] | None = None,
drain_pipeline: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.poll_sleep = poll_sleep
self.project_id = project_id
self.location = location
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.hook: DataflowHook | None = None
self.drain_pipeline = drain_pipeline

def execute(self, context: Context) -> None:
self.dataflow_hook = DataflowHook(
gcp_conn_id=self.gcp_conn_id,
poll_sleep=self.poll_sleep,
impersonation_chain=self.impersonation_chain,
drain_pipeline=self.drain_pipeline,
)
active_jobs = self.dataflow_hook.check_active_jobs(project_id=self.project_id, location=self.location)
self.log.info("Active jobs in %s: %s", active_jobs, self.location)
return None


class DataflowUpdateJobOperator(GoogleCloudBaseOperator):
template_fields: Sequence[str] = ("job_id",)

def __init__(
self,
job_id: str,
update_mask: str,
updated_body: dict[str, Any],
project_id: str | None = None,
location: str = DEFAULT_DATAFLOW_LOCATION,
gcp_conn_id: str = "google_cloud_default",
poll_sleep: int = 10,
impersonation_chain: str | Sequence[str] | None = None,
stop_timeout: int | None = 10 * 60,
drain_pipeline: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.update_mask = update_mask
self.updated_body = updated_body
self.job_id = job_id
self.project_id = project_id
self.location = location
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.hook: DataflowHook | None = None
self.drain_pipeline = drain_pipeline
self.poll_sleep = poll_sleep
self.stop_timeout = stop_timeout

def execute(self, context: Context) -> None:
self.dataflow_hook = DataflowHook(
gcp_conn_id=self.gcp_conn_id,
poll_sleep=self.poll_sleep,
impersonation_chain=self.impersonation_chain,
cancel_timeout=self.stop_timeout,
drain_pipeline=self.drain_pipeline,
)

# streaming jobs can be updated, but batch jobs? Failed to provide correct Field Mask while testing
updated_job = self.dataflow_hook.update_job(
project_id=self.project_id,
location=self.location,
job_id=self.job_id,
update_mask=self.update_mask,
body=self.updated_body,
)

return updated_job


class DataflowCreateJobSnapshotOperator(GoogleCloudBaseOperator):
template_fields: Sequence[str] = ("job_id",)

def __init__(
self,
job_id: str,
project_id: str | None = None,
location: str = DEFAULT_DATAFLOW_LOCATION,
snapshot_ttl: str = "604800s",
snapshot_sources: bool = False,
description: str | None = None,
gcp_conn_id: str = "google_cloud_default",
poll_sleep: int = 10,
impersonation_chain: str | Sequence[str] | None = None,
stop_timeout: int | None = 10 * 60,
drain_pipeline: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.poll_sleep = poll_sleep
self.stop_timeout = stop_timeout
self.job_id = job_id
self.project_id = project_id
self.location = location
self.snapshot_ttl = snapshot_ttl
self.snapshot_sources = snapshot_sources
self.description = description
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.hook: DataflowHook | None = None
self.drain_pipeline = drain_pipeline

def execute(self, context: Context):
self.dataflow_hook = DataflowHook(
gcp_conn_id=self.gcp_conn_id,
poll_sleep=self.poll_sleep,
impersonation_chain=self.impersonation_chain,
cancel_timeout=self.stop_timeout,
drain_pipeline=self.drain_pipeline,
)
# if streaming job exists, then..

snapshot = self.dataflow_hook.create_job_snapshot(
project_id=self.project_id,
location=self.location,
job_id=self.job_id,
snapshot_ttl=self.snapshot_ttl,
snapshot_sources=self.snapshot_sources,
description=self.description,
)
# else "Can't create snapshot of batch job

return snapshot
Loading
Loading