-
Notifications
You must be signed in to change notification settings - Fork 16.4k
Add direct-to-triggerer for DataprocSubmitJobOperator #52005
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
741dafb
fcc4d93
df9359b
01d686a
e2fbfd4
a6de84a
826f4fa
498cbbb
961eb3c
1a2fdc3
ed2fd80
f2b84e2
8109f29
8954e74
d5da7e8
b1321c6
3bc8823
8f7c711
de092b4
cd45912
fc9b963
a93d55b
8f60603
a8e6e14
c464d69
56e4c43
0f67393
7e7c2a5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,6 +27,7 @@ | |
|
|
||
| from asgiref.sync import sync_to_async | ||
| from google.api_core.exceptions import NotFound | ||
| from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault | ||
| from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus | ||
|
|
||
| from airflow.exceptions import AirflowException | ||
|
|
@@ -40,6 +41,7 @@ | |
| from airflow.utils.state import TaskInstanceState | ||
|
|
||
| if TYPE_CHECKING: | ||
| from google.api_core.retry import Retry | ||
| from sqlalchemy.orm.session import Session | ||
|
|
||
|
|
||
|
|
@@ -214,6 +216,167 @@ async def run(self): | |
| raise e | ||
|
|
||
|
|
||
| class DataprocSubmitJobTrigger(DataprocBaseTrigger): | ||
| """DataprocSubmitJobTrigger runs on the trigger worker to perform Build operation.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| job: dict, | ||
| request_id: str | None = None, | ||
| retry: Retry | _MethodDefault = DEFAULT, | ||
| timeout: float | None = None, | ||
| metadata: Sequence[tuple[str, str]] = (), | ||
| **kwargs, | ||
| ): | ||
| super().__init__(**kwargs) | ||
| self.job = job | ||
| self.request_id = request_id | ||
| self.retry = retry | ||
| self.timeout = timeout | ||
| self.metadata = metadata | ||
| self.job_id = None # Initialize job_id to None | ||
|
|
||
| def _normalize_retry_value(self, retry_value): | ||
| """ | ||
| Normalize retry value for serialization and API calls. | ||
|
|
||
| Since DEFAULT and Retry objects don't serialize well, we convert them to None. | ||
| """ | ||
| if retry_value is DEFAULT or retry_value is None: | ||
| return None | ||
| # For other retry objects (like Retry instances), use None as fallback | ||
| # since they are complex objects that don't serialize well | ||
| return None | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think that is what users expect providing the retries. There should be a way to pass it, in worst case I'd prefer to have an exception that such params are not supported, so please provide None explicitly.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The docs mentions this limitation as |
||
|
|
||
| def serialize(self): | ||
| return ( | ||
| "airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger", | ||
| { | ||
| "project_id": self.project_id, | ||
| "region": self.region, | ||
| "job": self.job, | ||
| "request_id": self.request_id, | ||
| "retry": self._normalize_retry_value(self.retry), | ||
| "timeout": self.timeout, | ||
| "metadata": self.metadata, | ||
| "gcp_conn_id": self.gcp_conn_id, | ||
| "impersonation_chain": self.impersonation_chain, | ||
| "polling_interval_seconds": self.polling_interval_seconds, | ||
| "cancel_on_kill": self.cancel_on_kill, | ||
| }, | ||
| ) | ||
|
|
||
| @provide_session | ||
| def get_task_instance(self, session: Session) -> TaskInstance: | ||
| """ | ||
| Get the task instance for the current task. | ||
|
|
||
| :param session: Sqlalchemy session | ||
| """ | ||
| query = session.query(TaskInstance).filter( | ||
| TaskInstance.dag_id == self.task_instance.dag_id, | ||
| TaskInstance.task_id == self.task_instance.task_id, | ||
| TaskInstance.run_id == self.task_instance.run_id, | ||
| TaskInstance.map_index == self.task_instance.map_index, | ||
| ) | ||
| task_instance = query.one_or_none() | ||
|
Comment on lines
+276
to
+282
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This code will not be compatible with AF3. If you were looking for solutions for this functionality, I think it is better to use already existing trigger. If I understand correctly. Can you please show also your system tests results for this code, both for AF2 and AF3 as screenshots? thanks
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for reviewing.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @VladaZakharova Hi i wrote system test for new feature and all passed. but when i ran that in real airflow env using breeze i got error below and it's not only raised in start_from_trigger, but also just deferrable operator. is there something wrong with the af3 deferrable?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As for the system test I managed to run it (example_dataproc_start_from_trigger.py) sucessfully (AF3). |
||
| if task_instance is None: | ||
| raise AirflowException( | ||
| "TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found", | ||
| self.task_instance.dag_id, | ||
| self.task_instance.task_id, | ||
| self.task_instance.run_id, | ||
| self.task_instance.map_index, | ||
| ) | ||
| return task_instance | ||
|
|
||
| async def get_task_state(self): | ||
| from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance | ||
|
|
||
| task_states_response = await sync_to_async(RuntimeTaskInstance.get_task_states)( | ||
| dag_id=self.task_instance.dag_id, | ||
| task_ids=[self.task_instance.task_id], | ||
| run_ids=[self.task_instance.run_id], | ||
| map_index=self.task_instance.map_index, | ||
| ) | ||
| try: | ||
| task_state = task_states_response[self.task_instance.run_id][self.task_instance.task_id] | ||
| except Exception: | ||
| raise AirflowException( | ||
| "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found", | ||
| self.task_instance.dag_id, | ||
| self.task_instance.task_id, | ||
| self.task_instance.run_id, | ||
| self.task_instance.map_index, | ||
| ) | ||
| return task_state | ||
|
|
||
| async def safe_to_cancel(self) -> bool: | ||
| """ | ||
| Whether it is safe to cancel the external job which is being executed by this trigger. | ||
|
|
||
| This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped. | ||
| Because in those cases, we should NOT cancel the external job. | ||
| """ | ||
| # Database query is needed to get the latest state of the task instance. | ||
| if AIRFLOW_V_3_0_PLUS: | ||
| task_state = await self.get_task_state() | ||
| else: | ||
| # Database query is needed to get the latest state of the task instance. | ||
| task_instance = self.get_task_instance() # type: ignore[call-arg] | ||
| task_state = task_instance.state | ||
| return task_state != TaskInstanceState.DEFERRED | ||
|
|
||
| async def run(self): | ||
| try: | ||
| # Create a new Dataproc job | ||
| job = await self.get_async_hook().submit_job( | ||
| project_id=self.project_id, | ||
| region=self.region, | ||
| job=self.job, | ||
| request_id=self.request_id, | ||
| retry=self._normalize_retry_value(self.retry), | ||
| timeout=self.timeout, | ||
| metadata=self.metadata, | ||
| ) | ||
| self.job_id = job.reference.job_id | ||
| while True: | ||
| job = await self.get_async_hook().get_job( | ||
| project_id=self.project_id, region=self.region, job_id=self.job_id | ||
| ) | ||
| state = job.status.state | ||
| self.log.info("Dataproc job: %s is in state: %s", self.job_id, state) | ||
| if state in (JobStatus.State.DONE, JobStatus.State.CANCELLED, JobStatus.State.ERROR): | ||
| break | ||
| await asyncio.sleep(self.polling_interval_seconds) | ||
| yield TriggerEvent({"job_id": self.job_id, "job_state": str(state), "job": str(job)}) | ||
| except asyncio.CancelledError: | ||
| self.log.info("Task got cancelled.") | ||
| try: | ||
| if ( | ||
| hasattr(self, "job_id") | ||
| and self.job_id | ||
| and self.cancel_on_kill | ||
| and await self.safe_to_cancel() | ||
| ): | ||
| self.log.info( | ||
| "Cancelling the job as it is safe to do so. Note that the airflow TaskInstance is not" | ||
| " in deferred state." | ||
| ) | ||
| self.log.info("Cancelling the job: %s", self.job_id) | ||
| self.get_sync_hook().cancel_job( | ||
| job_id=self.job_id, project_id=self.project_id, region=self.region | ||
| ) | ||
| self.log.info("Job: %s is cancelled", self.job_id) | ||
| yield TriggerEvent({"job_id": self.job_id, "job_state": ClusterStatus.State.DELETING}) | ||
| except Exception as e: | ||
| if hasattr(self, "job_id") and self.job_id: | ||
| self.log.error("Failed to cancel the job: %s with error : %s", self.job_id, str(e)) | ||
| else: | ||
| self.log.error("Failed to cancel the job (no job_id available) with error : %s", str(e)) | ||
| raise e | ||
|
|
||
|
|
||
| class DataprocClusterTrigger(DataprocBaseTrigger): | ||
| """ | ||
| DataprocClusterTrigger run on the trigger worker to perform create Build operation. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,127 @@ | ||
| # | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
| """ | ||
| Example Airflow DAG for DataprocSubmitJobOperator with start_from_trigger. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import os | ||
| from datetime import datetime | ||
|
|
||
| from google.api_core.retry import Retry | ||
|
|
||
| from airflow import DAG | ||
| from airflow.providers.google.cloud.operators.dataproc import ( | ||
| DataprocCreateClusterOperator, | ||
| DataprocDeleteClusterOperator, | ||
| DataprocSubmitJobOperator, | ||
| ) | ||
| from airflow.utils.trigger_rule import TriggerRule | ||
|
|
||
| from system.google import DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID | ||
|
|
||
| ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default") | ||
| DAG_ID = "dataproc_start_from_trigger" | ||
| PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") or DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID | ||
|
|
||
| CLUSTER_NAME_BASE = f"cluster-{DAG_ID}".replace("_", "-") | ||
| CLUSTER_NAME_FULL = CLUSTER_NAME_BASE + f"-{ENV_ID}".replace("_", "-") | ||
| CLUSTER_NAME = CLUSTER_NAME_BASE if len(CLUSTER_NAME_FULL) >= 33 else CLUSTER_NAME_FULL | ||
|
|
||
| REGION = "europe-west1" | ||
|
|
||
| # Cluster definition | ||
| CLUSTER_CONFIG = { | ||
| "master_config": { | ||
| "num_instances": 1, | ||
| "machine_type_uri": "n1-standard-4", | ||
| "disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb": 32}, | ||
| }, | ||
| "worker_config": { | ||
| "num_instances": 2, | ||
| "machine_type_uri": "n1-standard-4", | ||
| "disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb": 32}, | ||
| }, | ||
| } | ||
|
|
||
| # Jobs definitions | ||
| SPARK_JOB = { | ||
| "reference": {"project_id": PROJECT_ID}, | ||
| "placement": {"cluster_name": CLUSTER_NAME}, | ||
| "spark_job": { | ||
| "jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"], | ||
| "main_class": "org.apache.spark.examples.SparkPi", | ||
| }, | ||
| } | ||
|
|
||
| # Create DAG | ||
| with DAG( | ||
| dag_id=DAG_ID, | ||
| schedule="@once", | ||
| start_date=datetime(2023, 1, 1), | ||
| catchup=False, | ||
| tags=["dataproc", "start_from_trigger"], | ||
| ) as dag: | ||
| create_cluster = DataprocCreateClusterOperator( | ||
| task_id="create_cluster", | ||
| project_id=PROJECT_ID, | ||
| cluster_config=CLUSTER_CONFIG, | ||
| region=REGION, | ||
| cluster_name=CLUSTER_NAME, | ||
| retry=Retry(maximum=100.0, initial=10.0, multiplier=1.0), | ||
| num_retries_if_resource_is_not_ready=3, | ||
| ) | ||
|
|
||
| spark_job_with_start_from_trigger = DataprocSubmitJobOperator( | ||
| task_id="spark_job_with_start_from_trigger", | ||
| job=SPARK_JOB, | ||
| region=REGION, | ||
| project_id=PROJECT_ID, | ||
| start_from_trigger=True, | ||
| ) | ||
|
|
||
| delete_cluster = DataprocDeleteClusterOperator( | ||
| task_id="delete_cluster", | ||
| project_id=PROJECT_ID, | ||
| region=REGION, | ||
| cluster_name=CLUSTER_NAME, | ||
| trigger_rule=TriggerRule.ALL_DONE, | ||
| ) | ||
|
|
||
| # Define task dependencies | ||
| ( | ||
| # TEST SETUP | ||
| create_cluster | ||
| # TEST BODY | ||
| >> spark_job_with_start_from_trigger | ||
| # TEST TEARDOWN | ||
| >> delete_cluster | ||
| ) | ||
|
|
||
| from tests_common.test_utils.watcher import watcher | ||
|
|
||
| # This test needs watcher in order to properly mark success/failure | ||
| # when "teardown" task with trigger rule is part of the DAG | ||
| list(dag.tasks) >> watcher() | ||
|
|
||
|
|
||
| from tests_common.test_utils.system_tests import get_test_run # noqa: E402 | ||
|
|
||
| # Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) | ||
| test_run = get_test_run(dag) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we have to use this one exclusively for
start_from_triggerI suggest to highlight it in the class name and docstring.