Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
741dafb
add: DataprocSubmitJobTrigger
kgw7401 Jun 21, 2025
fcc4d93
add: unit test
kgw7401 Jun 22, 2025
df9359b
Merge branch 'main' into feat/dataprocsubmitjoboperator-direct-to-tri…
kgw7401 Jun 22, 2025
01d686a
add: mock_log
kgw7401 Jun 22, 2025
e2fbfd4
fix: TestDataprocCreateBatchOperator test_execute_openlineage_all_inf…
kgw7401 Jun 22, 2025
a6de84a
fix: test_dataproc pre-commit
kgw7401 Jun 22, 2025
826f4fa
fix: Move Retry into type-checking block
kgw7401 Jun 22, 2025
498cbbb
Merge branch 'main' into feat/dataprocsubmitjoboperator-direct-to-tri…
kgw7401 Jun 22, 2025
961eb3c
fix: add task_state for af3
kgw7401 Jun 29, 2025
1a2fdc3
add: start_from_trigger system test
kgw7401 Jun 29, 2025
ed2fd80
fix: system-test
kgw7401 Jun 29, 2025
f2b84e2
fix: add await to safe_to_cancel
kgw7401 Jun 29, 2025
8109f29
Merge branch 'main' into feat/dataprocsubmitjoboperator-direct-to-tri…
kgw7401 Jun 29, 2025
8954e74
Merge branch 'main' into feat/dataprocsubmitjoboperator-direct-to-tri…
kgw7401 Jun 30, 2025
d5da7e8
fix: test_execute_openlineage_transport_info_injection
kgw7401 Jun 30, 2025
b1321c6
Merge branch 'main' into feat/dataprocsubmitjoboperator-direct-to-tri…
kgw7401 Jun 30, 2025
3bc8823
Merge branch 'main' into feat/dataprocsubmitjoboperator-direct-to-tri…
kgw7401 Jul 1, 2025
8f7c711
fix: test
kgw7401 Jul 1, 2025
de092b4
fix
kgw7401 Jul 2, 2025
cd45912
fix: example_dataproc_start_from_trigger pre-commit
kgw7401 Jul 2, 2025
fc9b963
Merge branch 'main' into feat/dataprocsubmitjoboperator-direct-to-tri…
kgw7401 Jul 20, 2025
a93d55b
change to async hook
kgw7401 Jul 20, 2025
8f60603
add: normalize_retry_value
kgw7401 Jul 20, 2025
a8e6e14
fix: test code
kgw7401 Jul 20, 2025
c464d69
Merge branch 'main' into feat/dataprocsubmitjoboperator-direct-to-tri…
kgw7401 Jul 22, 2025
56e4c43
add: TEST_RUNNING_CLUSTER, TEST_ERROR_CLUSTER
kgw7401 Jul 22, 2025
0f67393
Merge branch 'feat/dataprocsubmitjoboperator-direct-to-triggerer' of …
kgw7401 Jul 22, 2025
7e7c2a5
fix: mock_log
kgw7401 Jul 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,20 @@

from airflow.utils.context import Context

try:
from airflow.triggers.base import StartTriggerArgs
except ImportError:
# TODO: Remove this when min airflow version is 2.10.0 for standard provider
@dataclass
class StartTriggerArgs: # type: ignore[no-redef]
"""Arguments required for start task execution from triggerer."""

trigger_cls: str
next_method: str
trigger_kwargs: dict[str, Any] | None = None
next_kwargs: dict[str, Any] | None = None
timeout: timedelta | None = None


class PreemptibilityType(Enum):
"""Contains possible Type values of Preemptibility applicable for every secondary worker of Cluster."""
Expand Down Expand Up @@ -1830,6 +1844,15 @@ class DataprocSubmitJobOperator(GoogleCloudBaseOperator):

operator_extra_links = (DataprocJobLink(),)

start_trigger_args = StartTriggerArgs(
trigger_cls="airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger",
trigger_kwargs={},
next_method="execute_complete",
next_kwargs=None,
timeout=None,
)
start_from_trigger = False

def __init__(
self,
*,
Expand All @@ -1844,6 +1867,7 @@ def __init__(
impersonation_chain: str | Sequence[str] | None = None,
asynchronous: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
start_from_trigger: bool = False,
polling_interval_seconds: int = 10,
cancel_on_kill: bool = True,
wait_timeout: int | None = None,
Expand Down Expand Up @@ -1876,6 +1900,16 @@ def __init__(
self.wait_timeout = wait_timeout
self.openlineage_inject_parent_job_info = openlineage_inject_parent_job_info
self.openlineage_inject_transport_info = openlineage_inject_transport_info
self.start_trigger_args.trigger_kwargs = {
"project_id": self.project_id,
"region": self.region,
"job": self.job,
"request_id": self.request_id,
"retry": self.retry,
"timeout": self.timeout,
"metadata": self.metadata,
}
self.start_from_trigger = start_from_trigger

def execute(self, context: Context):
self.log.info("Submitting job")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from asgiref.sync import sync_to_async
from google.api_core.exceptions import NotFound
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus

from airflow.exceptions import AirflowException
Expand All @@ -40,6 +41,7 @@
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from google.api_core.retry import Retry
from sqlalchemy.orm.session import Session


Expand Down Expand Up @@ -214,6 +216,167 @@ async def run(self):
raise e


class DataprocSubmitJobTrigger(DataprocBaseTrigger):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have to use this one exclusively for start_from_trigger I suggest to highlight it in the class name and docstring.

"""DataprocSubmitJobTrigger runs on the trigger worker to perform Build operation."""

def __init__(
self,
job: dict,
request_id: str | None = None,
retry: Retry | _MethodDefault = DEFAULT,
timeout: float | None = None,
metadata: Sequence[tuple[str, str]] = (),
**kwargs,
):
super().__init__(**kwargs)
self.job = job
self.request_id = request_id
self.retry = retry
self.timeout = timeout
self.metadata = metadata
self.job_id = None # Initialize job_id to None

def _normalize_retry_value(self, retry_value):
"""
Normalize retry value for serialization and API calls.

Since DEFAULT and Retry objects don't serialize well, we convert them to None.
"""
if retry_value is DEFAULT or retry_value is None:
return None
# For other retry objects (like Retry instances), use None as fallback
# since they are complex objects that don't serialize well
return None
Copy link
Contributor

@olegkachur-e olegkachur-e Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that is what users expect providing the retries.
Often None - means to wait forever.

There should be a way to pass it, in worst case I'd prefer to have an exception that such params are not supported, so please provide None explicitly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docs mentions this limitation as trigger_kwargs: Keyword arguments to pass to the trigger_cls when it’s initialized. Note that all the arguments need to be serializable by Airflow. It’s the main limitation of this feature.
If we don't need it, maybe then just avoid passing it?


def serialize(self):
return (
"airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger",
{
"project_id": self.project_id,
"region": self.region,
"job": self.job,
"request_id": self.request_id,
"retry": self._normalize_retry_value(self.retry),
"timeout": self.timeout,
"metadata": self.metadata,
"gcp_conn_id": self.gcp_conn_id,
"impersonation_chain": self.impersonation_chain,
"polling_interval_seconds": self.polling_interval_seconds,
"cancel_on_kill": self.cancel_on_kill,
},
)

@provide_session
def get_task_instance(self, session: Session) -> TaskInstance:
"""
Get the task instance for the current task.

:param session: Sqlalchemy session
"""
query = session.query(TaskInstance).filter(
TaskInstance.dag_id == self.task_instance.dag_id,
TaskInstance.task_id == self.task_instance.task_id,
TaskInstance.run_id == self.task_instance.run_id,
TaskInstance.map_index == self.task_instance.map_index,
)
task_instance = query.one_or_none()
Comment on lines +276 to +282
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code will not be compatible with AF3. If you were looking for solutions for this functionality, I think it is better to use already existing trigger. If I understand correctly. Can you please show also your system tests results for this code, both for AF2 and AF3 as screenshots? thanks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for reviewing.
Actually when i ran system test, it failed so that i should fix the code. And i can't find ref about why the code that you mentioned couldn't be compatible with AF3. Could you give me ref for that?

Copy link
Contributor Author

@kgw7401 kgw7401 Jul 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@VladaZakharova Hi i wrote system test for new feature and all passed. but when i ran that in real airflow env using breeze i got error below and it's not only raised in start_from_trigger, but also just deferrable operator. is there something wrong with the af3 deferrable?

Traceback (most recent call last):

  File "/opt/airflow/airflow-core/src/airflow/jobs/triggerer_job_runner.py", line 923, in cleanup_finished_triggers
    result = details["task"].result()
             ^^^^^^^^^^^^^^^^^^^^^^^^

  File "/opt/airflow/airflow-core/src/airflow/jobs/triggerer_job_runner.py", line 1032, in run_trigger
    async for event in trigger.run():

  File "/opt/airflow/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py", line 187, in run
    job = await self.get_async_hook().get_job(
                ^^^^^^^^^^^^^^^^^^^^^

  File "/opt/airflow/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py", line 71, in get_async_hook
    return DataprocAsyncHook(
           ^^^^^^^^^^^^^^^^^^

  File "/opt/airflow/providers/google/src/airflow/providers/google/cloud/hooks/dataproc.py", line 1286, in __init__
    super().__init__(gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain, **kwargs)

  File "/opt/airflow/providers/google/src/airflow/providers/google/common/hooks/base_google.py", line 280, in __init__
    self.extras: dict = self.get_connection(self.gcp_conn_id).extra_dejson
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/opt/airflow/airflow-core/src/airflow/hooks/base.py", line 64, in get_connection
    conn = Connection.get_connection_from_secrets(conn_id)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/opt/airflow/airflow-core/src/airflow/models/connection.py", line 481, in get_connection_from_secrets
    conn = TaskSDKConnection.get(conn_id=conn_id)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/opt/airflow/task-sdk/src/airflow/sdk/definitions/connection.py", line 152, in get
    return _get_connection(conn_id)
           ^^^^^^^^^^^^^^^^^^^^^^^^

  File "/opt/airflow/task-sdk/src/airflow/sdk/execution_time/context.py", line 155, in _get_connection
    msg = SUPERVISOR_COMMS.send(GetConnection(conn_id=conn_id))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/opt/airflow/airflow-core/src/airflow/jobs/triggerer_job_runner.py", line 708, in send
    return async_to_sync(self.asend)(msg)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/usr/local/lib/python3.11/site-packages/asgiref/sync.py", line 186, in __call__
    raise RuntimeError(

RuntimeError: You cannot use AsyncToSync in the same thread as an async event loop - just await the async function directly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As for the system test I managed to run it (example_dataproc_start_from_trigger.py) sucessfully (AF3).

if task_instance is None:
raise AirflowException(
"TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found",
self.task_instance.dag_id,
self.task_instance.task_id,
self.task_instance.run_id,
self.task_instance.map_index,
)
return task_instance

async def get_task_state(self):
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance

task_states_response = await sync_to_async(RuntimeTaskInstance.get_task_states)(
dag_id=self.task_instance.dag_id,
task_ids=[self.task_instance.task_id],
run_ids=[self.task_instance.run_id],
map_index=self.task_instance.map_index,
)
try:
task_state = task_states_response[self.task_instance.run_id][self.task_instance.task_id]
except Exception:
raise AirflowException(
"TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found",
self.task_instance.dag_id,
self.task_instance.task_id,
self.task_instance.run_id,
self.task_instance.map_index,
)
return task_state

async def safe_to_cancel(self) -> bool:
"""
Whether it is safe to cancel the external job which is being executed by this trigger.

This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped.
Because in those cases, we should NOT cancel the external job.
"""
# Database query is needed to get the latest state of the task instance.
if AIRFLOW_V_3_0_PLUS:
task_state = await self.get_task_state()
else:
# Database query is needed to get the latest state of the task instance.
task_instance = self.get_task_instance() # type: ignore[call-arg]
task_state = task_instance.state
return task_state != TaskInstanceState.DEFERRED

async def run(self):
try:
# Create a new Dataproc job
job = await self.get_async_hook().submit_job(
project_id=self.project_id,
region=self.region,
job=self.job,
request_id=self.request_id,
retry=self._normalize_retry_value(self.retry),
timeout=self.timeout,
metadata=self.metadata,
)
self.job_id = job.reference.job_id
while True:
job = await self.get_async_hook().get_job(
project_id=self.project_id, region=self.region, job_id=self.job_id
)
state = job.status.state
self.log.info("Dataproc job: %s is in state: %s", self.job_id, state)
if state in (JobStatus.State.DONE, JobStatus.State.CANCELLED, JobStatus.State.ERROR):
break
await asyncio.sleep(self.polling_interval_seconds)
yield TriggerEvent({"job_id": self.job_id, "job_state": str(state), "job": str(job)})
except asyncio.CancelledError:
self.log.info("Task got cancelled.")
try:
if (
hasattr(self, "job_id")
and self.job_id
and self.cancel_on_kill
and await self.safe_to_cancel()
):
self.log.info(
"Cancelling the job as it is safe to do so. Note that the airflow TaskInstance is not"
" in deferred state."
)
self.log.info("Cancelling the job: %s", self.job_id)
self.get_sync_hook().cancel_job(
job_id=self.job_id, project_id=self.project_id, region=self.region
)
self.log.info("Job: %s is cancelled", self.job_id)
yield TriggerEvent({"job_id": self.job_id, "job_state": ClusterStatus.State.DELETING})
except Exception as e:
if hasattr(self, "job_id") and self.job_id:
self.log.error("Failed to cancel the job: %s with error : %s", self.job_id, str(e))
else:
self.log.error("Failed to cancel the job (no job_id available) with error : %s", str(e))
raise e


class DataprocClusterTrigger(DataprocBaseTrigger):
"""
DataprocClusterTrigger run on the trigger worker to perform create Build operation.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Example Airflow DAG for DataprocSubmitJobOperator with start_from_trigger.
"""

from __future__ import annotations

import os
from datetime import datetime

from google.api_core.retry import Retry

from airflow import DAG
from airflow.providers.google.cloud.operators.dataproc import (
DataprocCreateClusterOperator,
DataprocDeleteClusterOperator,
DataprocSubmitJobOperator,
)
from airflow.utils.trigger_rule import TriggerRule

from system.google import DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID

ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default")
DAG_ID = "dataproc_start_from_trigger"
PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") or DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID

CLUSTER_NAME_BASE = f"cluster-{DAG_ID}".replace("_", "-")
CLUSTER_NAME_FULL = CLUSTER_NAME_BASE + f"-{ENV_ID}".replace("_", "-")
CLUSTER_NAME = CLUSTER_NAME_BASE if len(CLUSTER_NAME_FULL) >= 33 else CLUSTER_NAME_FULL

REGION = "europe-west1"

# Cluster definition
CLUSTER_CONFIG = {
"master_config": {
"num_instances": 1,
"machine_type_uri": "n1-standard-4",
"disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb": 32},
},
"worker_config": {
"num_instances": 2,
"machine_type_uri": "n1-standard-4",
"disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb": 32},
},
}

# Jobs definitions
SPARK_JOB = {
"reference": {"project_id": PROJECT_ID},
"placement": {"cluster_name": CLUSTER_NAME},
"spark_job": {
"jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"],
"main_class": "org.apache.spark.examples.SparkPi",
},
}

# Create DAG
with DAG(
dag_id=DAG_ID,
schedule="@once",
start_date=datetime(2023, 1, 1),
catchup=False,
tags=["dataproc", "start_from_trigger"],
) as dag:
create_cluster = DataprocCreateClusterOperator(
task_id="create_cluster",
project_id=PROJECT_ID,
cluster_config=CLUSTER_CONFIG,
region=REGION,
cluster_name=CLUSTER_NAME,
retry=Retry(maximum=100.0, initial=10.0, multiplier=1.0),
num_retries_if_resource_is_not_ready=3,
)

spark_job_with_start_from_trigger = DataprocSubmitJobOperator(
task_id="spark_job_with_start_from_trigger",
job=SPARK_JOB,
region=REGION,
project_id=PROJECT_ID,
start_from_trigger=True,
)

delete_cluster = DataprocDeleteClusterOperator(
task_id="delete_cluster",
project_id=PROJECT_ID,
region=REGION,
cluster_name=CLUSTER_NAME,
trigger_rule=TriggerRule.ALL_DONE,
)

# Define task dependencies
(
# TEST SETUP
create_cluster
# TEST BODY
>> spark_job_with_start_from_trigger
# TEST TEARDOWN
>> delete_cluster
)

from tests_common.test_utils.watcher import watcher

# This test needs watcher in order to properly mark success/failure
# when "teardown" task with trigger rule is part of the DAG
list(dag.tasks) >> watcher()


from tests_common.test_utils.system_tests import get_test_run # noqa: E402

# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
test_run = get_test_run(dag)
Loading