Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions providers/google/docs/operators/cloud/cloud_composer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,15 @@ or you can define the same sensor in the deferrable mode:
:dedent: 4
:start-after: [START howto_sensor_dag_run_deferrable_mode]
:end-before: [END howto_sensor_dag_run_deferrable_mode]

Trigger a DAG run
-----------------

You can trigger a DAG in another Composer environment, use:
:class:`~airflow.providers.google.cloud.operators.cloud_composer.CloudComposerTriggerDAGRunOperator`

.. exampleinclude:: /../../google/tests/system/google/cloud/composer/example_cloud_composer.py
:language: python
:dedent: 4
:start-after: [START howto_operator_trigger_dag_run]
:end-before: [END howto_operator_trigger_dag_run]
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
from __future__ import annotations

import asyncio
import json
import time
from collections.abc import MutableSequence, Sequence
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from urllib.parse import urljoin

from google.api_core.client_options import ClientOptions
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.auth.transport.requests import AuthorizedSession
from google.cloud.orchestration.airflow.service_v1 import (
EnvironmentsAsyncClient,
EnvironmentsClient,
Expand Down Expand Up @@ -76,6 +79,34 @@ def get_image_versions_client(self) -> ImageVersionsClient:
client_options=self.client_options,
)

def make_composer_airflow_api_request(
self,
method: str,
airflow_uri: str,
path: str,
data: Any | None = None,
timeout: float | None = None,
):
"""
Make a request to Cloud Composer environment's web server.

:param method: The request method to use ('GET', 'OPTIONS', 'HEAD', 'POST', 'PUT', 'PATCH', 'DELETE').
:param airflow_uri: The URI of the Apache Airflow Web UI hosted within this environment.
:param path: The path to send the request.
:param data: Dictionary, list of tuples, bytes, or file-like object to send in the body of the request.
:param timeout: The timeout for this request.
"""
authed_session = AuthorizedSession(self.get_credentials())

resp = authed_session.request(
method=method,
url=urljoin(airflow_uri, path),
data=data,
headers={"Content-Type": "application/json"},
timeout=timeout,
)
return resp

def get_operation(self, operation_name):
return self.get_environment_client().transport.operations_client.get_operation(name=operation_name)

Expand Down Expand Up @@ -408,6 +439,39 @@ def wait_command_execution_result(
self.log.info("Waiting for result...")
time.sleep(poll_interval)

def trigger_dag_run(
self,
composer_airflow_uri: str,
composer_dag_id: str,
composer_dag_conf: dict | None = None,
timeout: float | None = None,
) -> dict:
"""
Trigger DAG run for provided Apache Airflow Web UI hosted within Composer environment.

:param composer_airflow_uri: The URI of the Apache Airflow Web UI hosted within Composer environment.
:param composer_dag_id: The ID of DAG which will be triggered.
:param composer_dag_conf: Configuration parameters for the DAG run.
:param timeout: The timeout for this request.
"""
response = self.make_composer_airflow_api_request(
method="POST",
airflow_uri=composer_airflow_uri,
path=f"/api/v1/dags/{composer_dag_id}/dagRuns",
data=json.dumps(
{
"conf": composer_dag_conf or {},
}
),
timeout=timeout,
)

if response.status_code != 200:
self.log.error(response.text)
response.raise_for_status()

return response.json()


class CloudComposerAsyncHook(GoogleBaseHook):
"""Hook for Google Cloud Composer async APIs."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any

from google.api_core.exceptions import AlreadyExists
from google.api_core.exceptions import AlreadyExists, NotFound
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.cloud.orchestration.airflow.service_v1 import ImageVersion
from google.cloud.orchestration.airflow.service_v1.types import Environment, ExecuteAirflowCommandResponse
Expand Down Expand Up @@ -798,3 +798,86 @@ def _merge_cmd_output_result(self, result) -> str:
"""Merge output to one string."""
result_str = "\n".join(line_dict["content"] for line_dict in result["output"])
return result_str


class CloudComposerTriggerDAGRunOperator(GoogleCloudBaseOperator):
"""
Trigger DAG run for provided Composer environment.

:param project_id: The ID of the Google Cloud project that the service belongs to.
:param region: The ID of the Google Cloud region that the service belongs to.
:param environment_id: The ID of the Google Cloud environment that the service belongs to.
:param composer_dag_id: The ID of DAG which will be triggered.
:param composer_dag_conf: Configuration parameters for the DAG run.
:param timeout: The timeout for this request.
:param gcp_conn_id: The connection ID used to connect to Google Cloud Platform.
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
"""

template_fields = (
"project_id",
"region",
"environment_id",
"composer_dag_id",
"impersonation_chain",
)

def __init__(
self,
*,
project_id: str,
region: str,
environment_id: str,
composer_dag_id: str,
composer_dag_conf: dict | None = None,
timeout: float | None = None,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.project_id = project_id
self.region = region
self.environment_id = environment_id
self.composer_dag_id = composer_dag_id
self.composer_dag_conf = composer_dag_conf or {}
self.timeout = timeout
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain

def execute(self, context: Context):
hook = CloudComposerHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
try:
environment = hook.get_environment(
project_id=self.project_id,
region=self.region,
environment_id=self.environment_id,
timeout=self.timeout,
)
except NotFound as not_found_err:
self.log.info("The Composer environment %s does not exist.", self.environment_id)
raise AirflowException(not_found_err)
composer_airflow_uri = environment.config.airflow_uri

self.log.info(
"Triggering the DAG %s on the %s environment...", self.composer_dag_id, self.environment_id
)
dag_run = hook.trigger_dag_run(
composer_airflow_uri=composer_airflow_uri,
composer_dag_id=self.composer_dag_id,
composer_dag_conf=self.composer_dag_conf,
timeout=self.timeout,
)
self.log.info("The DAG %s was triggered with Run ID: %s", self.composer_dag_id, dag_run["dag_run_id"])

return dag_run
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
CloudComposerListEnvironmentsOperator,
CloudComposerListImageVersionsOperator,
CloudComposerRunAirflowCLICommandOperator,
CloudComposerTriggerDAGRunOperator,
CloudComposerUpdateEnvironmentOperator,
)
from airflow.providers.google.cloud.sensors.cloud_composer import CloudComposerDAGRunSensor
Expand Down Expand Up @@ -218,6 +219,16 @@ def get_project_number():
)
# [END howto_sensor_dag_run_deferrable_mode]

# [START howto_operator_trigger_dag_run]
trigger_dag_run = CloudComposerTriggerDAGRunOperator(
task_id="trigger_dag_run",
project_id=PROJECT_ID,
region=REGION,
environment_id=ENVIRONMENT_ID,
composer_dag_id="airflow_monitoring",
)
# [END howto_operator_trigger_dag_run]

# [START howto_operator_delete_composer_environment]
delete_env = CloudComposerDeleteEnvironmentOperator(
task_id="delete_env",
Expand Down Expand Up @@ -250,6 +261,7 @@ def get_project_number():
[update_env, defer_update_env],
[run_airflow_cli_cmd, defer_run_airflow_cli_cmd],
[dag_run_sensor, defer_dag_run_sensor],
trigger_dag_run,
# TEST TEARDOWN
[delete_env, defer_delete_env],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

import json
from unittest import mock
from unittest.mock import AsyncMock

Expand Down Expand Up @@ -56,6 +57,10 @@
TEST_PARENT = "test-parent"
TEST_NAME = "test-name"

TEST_COMPOSER_AIRFLOW_URI = "test-composer-airflow-uri"
TEST_COMPOSER_DAG_ID = "test-composer-dag-id"
TEST_COMPOSER_DAG_CONF = {"test-key": "test-value"}

BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}"
COMPOSER_STRING = "airflow.providers.google.cloud.hooks.cloud_composer.{}"

Expand Down Expand Up @@ -257,6 +262,27 @@ def test_poll_airflow_command(self, mock_client) -> None:
metadata=TEST_METADATA,
)

@mock.patch(COMPOSER_STRING.format("CloudComposerHook.make_composer_airflow_api_request"))
def test_trigger_dag_run(self, mock_composer_airflow_api_request) -> None:
self.hook.get_credentials = mock.MagicMock()
self.hook.trigger_dag_run(
composer_airflow_uri=TEST_COMPOSER_AIRFLOW_URI,
composer_dag_id=TEST_COMPOSER_DAG_ID,
composer_dag_conf=TEST_COMPOSER_DAG_CONF,
timeout=TEST_TIMEOUT,
)
mock_composer_airflow_api_request.assert_called_once_with(
method="POST",
airflow_uri=TEST_COMPOSER_AIRFLOW_URI,
path=f"/api/v1/dags/{TEST_COMPOSER_DAG_ID}/dagRuns",
data=json.dumps(
{
"conf": TEST_COMPOSER_DAG_CONF,
}
),
timeout=TEST_TIMEOUT,
)


class TestCloudComposerAsyncHook:
def setup_method(self, method):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
CloudComposerListEnvironmentsOperator,
CloudComposerListImageVersionsOperator,
CloudComposerRunAirflowCLICommandOperator,
CloudComposerTriggerDAGRunOperator,
CloudComposerUpdateEnvironmentOperator,
)
from airflow.providers.google.cloud.triggers.cloud_composer import (
Expand Down Expand Up @@ -67,6 +68,9 @@
TEST_PARENT = "test-parent"
TEST_NAME = "test-name"

TEST_COMPOSER_DAG_ID = "test-composer-dag-id"
TEST_COMPOSER_DAG_CONF = {"test-key": "test-value"}

COMPOSER_STRING = "airflow.providers.google.cloud.operators.cloud_composer.{}"
COMPOSER_TRIGGERS_STRING = "airflow.providers.google.cloud.triggers.cloud_composer.{}"

Expand Down Expand Up @@ -375,3 +379,35 @@ def test_execute_deferrable(self, mock_trigger_hook, mock_hook, to_dict_mode):

assert isinstance(exc.value.trigger, CloudComposerAirflowCLICommandTrigger)
assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME


class TestCloudComposerTriggerDAGRunOperator:
@mock.patch(COMPOSER_STRING.format("CloudComposerHook"))
def test_execute(self, mock_hook) -> None:
op = CloudComposerTriggerDAGRunOperator(
task_id=TASK_ID,
project_id=TEST_GCP_PROJECT,
region=TEST_GCP_REGION,
environment_id=TEST_ENVIRONMENT_ID,
composer_dag_id=TEST_COMPOSER_DAG_ID,
composer_dag_conf=TEST_COMPOSER_DAG_CONF,
gcp_conn_id=TEST_GCP_CONN_ID,
timeout=TEST_TIMEOUT,
)
op.execute(mock.MagicMock())
mock_hook.assert_called_once_with(
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=TEST_IMPERSONATION_CHAIN,
)
mock_hook.return_value.get_environment.assert_called_once_with(
project_id=TEST_GCP_PROJECT,
region=TEST_GCP_REGION,
environment_id=TEST_ENVIRONMENT_ID,
timeout=TEST_TIMEOUT,
)
mock_hook.return_value.trigger_dag_run.assert_called_once_with(
composer_airflow_uri=mock_hook.return_value.get_environment.return_value.config.airflow_uri,
composer_dag_id=TEST_COMPOSER_DAG_ID,
composer_dag_conf=TEST_COMPOSER_DAG_CONF,
timeout=TEST_TIMEOUT,
)