From 1d9b127ee726ba0e21c3629eb45f2336b8391994 Mon Sep 17 00:00:00 2001 From: Lennox Stevenson Date: Wed, 23 Oct 2024 09:50:50 -0400 Subject: [PATCH] feat: sensor to check status of Dataform action (#43055) Adds a new sensor to check the status of a WorkflowInvocationAction in Google Cloud Dataform. Heavily based on theDataformWorkflowInvocationStateSensor which already exists. Useful for checking the status of a specific target within a Dataform workflow invocation and taking action based on the status. --- .../operators/cloud/dataform.rst | 12 ++ .../google/cloud/sensors/dataform.py | 75 +++++++++ .../google/cloud/sensors/test_dataform.py | 150 ++++++++++++++++++ .../google/cloud/dataform/example_dataform.py | 40 ++++- 4 files changed, 275 insertions(+), 2 deletions(-) create mode 100644 providers/tests/google/cloud/sensors/test_dataform.py diff --git a/docs/apache-airflow-providers-google/operators/cloud/dataform.rst b/docs/apache-airflow-providers-google/operators/cloud/dataform.rst index 09d8a6e6b8f93..208035af53c3c 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/dataform.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/dataform.rst @@ -95,6 +95,12 @@ We have possibility to run this operation in the sync mode and async, for async a sensor: :class:`~airflow.providers.google.cloud.operators.dataform.DataformWorkflowInvocationStateSensor` +We also have a sensor to check the status of a particular action for a workflow invocation triggered +asynchronously. + +:class:`~airflow.providers.google.cloud.operators.dataform.DataformWorkflowInvocationActionStateSensor` + + .. exampleinclude:: /../../providers/tests/system/google/cloud/dataform/example_dataform.py :language: python :dedent: 4 @@ -107,6 +113,12 @@ a sensor: :start-after: [START howto_operator_create_workflow_invocation_async] :end-before: [END howto_operator_create_workflow_invocation_async] +.. exampleinclude:: /../../providers/tests/system/google/cloud/dataform/example_dataform.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_create_workflow_invocation_action_async] + :end-before: [END howto_operator_create_workflow_invocation_action_async] + Get Workflow Invocation ----------------------- diff --git a/providers/src/airflow/providers/google/cloud/sensors/dataform.py b/providers/src/airflow/providers/google/cloud/sensors/dataform.py index 0e4676749eb47..17d1351404fa4 100644 --- a/providers/src/airflow/providers/google/cloud/sensors/dataform.py +++ b/providers/src/airflow/providers/google/cloud/sensors/dataform.py @@ -103,3 +103,78 @@ def poke(self, context: Context) -> bool: raise AirflowException(message) return workflow_status in self.expected_statuses + + +class DataformWorkflowInvocationActionStateSensor(BaseSensorOperator): + """ + Checks for the status of a Workflow Invocation Action in Google Cloud Dataform. + + :param project_id: Required, the Google Cloud project ID in which to start a job. + If set to None or missing, the default project_id from the Google Cloud connection is used. + :param region: Required, The location of the Dataform workflow invocation (for example europe-west1). + :param repository_id: Required. The ID of the Dataform repository that the task belongs to. + :param workflow_invocation_id: Required, ID of the workflow invocation to be checked. + :param target_name: Required. The name of the target to be checked in the workflow. + :param expected_statuses: The expected state of the action. + See: + https://cloud.google.com/python/docs/reference/dataform/latest/google.cloud.dataform_v1beta1.types.WorkflowInvocationAction.State + :param failure_statuses: State that will terminate the sensor with an exception + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ("workflow_invocation_id",) + + def __init__( + self, + *, + project_id: str, + region: str, + repository_id: str, + workflow_invocation_id: str, + target_name: str, + expected_statuses: Iterable[int], + failure_statuses: Iterable[int], + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.repository_id = repository_id + self.workflow_invocation_id = workflow_invocation_id + self.project_id = project_id + self.region = region + self.target_name = target_name + self.expected_statuses = expected_statuses + self.failure_statuses = failure_statuses + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + self.hook: DataformHook | None = None + + def poke(self, context: Context) -> bool: + self.hook = DataformHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + + workflow_invocation_actions = self.hook.query_workflow_invocation_actions( + project_id=self.project_id, + region=self.region, + repository_id=self.repository_id, + workflow_invocation_id=self.workflow_invocation_id, + ) + + for workflow_invocation_action in workflow_invocation_actions: + if workflow_invocation_action.target.name == self.target_name: + state = workflow_invocation_action.state + if state in self.failure_statuses: + raise AirflowException( + f"Workflow Invocation Action target {self.target_name} state is: {state}." + ) + return state in self.expected_statuses + + raise AirflowException(f"Workflow Invocation Action target {self.target_name} not found.") diff --git a/providers/tests/google/cloud/sensors/test_dataform.py b/providers/tests/google/cloud/sensors/test_dataform.py new file mode 100644 index 0000000000000..d3bcd8b6c9aad --- /dev/null +++ b/providers/tests/google/cloud/sensors/test_dataform.py @@ -0,0 +1,150 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from unittest import mock + +import pytest +from google.cloud.dataform_v1beta1.types import Target, WorkflowInvocationAction + +from airflow.exceptions import AirflowException +from airflow.providers.google.cloud.sensors.dataform import DataformWorkflowInvocationActionStateSensor + +TEST_TASK_ID = "task_id" +TEST_PROJECT_ID = "test_project" +TEST_REGION = "us-central1" +TEST_REPOSITORY_ID = "test_repository_id" +TEST_WORKFLOW_INVOCATION_ID = "test_workflow_invocation_id" +TEST_TARGET_NAME = "test_target_name" +TEST_GCP_CONN_ID = "test_gcp_conn_id" +TEST_IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"] + + +class TestDataformWorkflowInvocationActionStateSensor: + @pytest.mark.parametrize( + "expected_status, current_status, sensor_return", + [ + (WorkflowInvocationAction.State.SUCCEEDED, WorkflowInvocationAction.State.SUCCEEDED, True), + (WorkflowInvocationAction.State.SUCCEEDED, WorkflowInvocationAction.State.RUNNING, False), + ], + ) + @mock.patch("airflow.providers.google.cloud.sensors.dataform.DataformHook") + def test_poke( + self, + mock_hook: mock.MagicMock, + expected_status: WorkflowInvocationAction.State, + current_status: WorkflowInvocationAction.State, + sensor_return: bool, + ): + target = Target(database="", schema="", name=TEST_TARGET_NAME) + workflow_invocation_action = WorkflowInvocationAction(target=target, state=current_status) + mock_query_workflow_invocation_actions = mock_hook.return_value.query_workflow_invocation_actions + mock_query_workflow_invocation_actions.return_value = [workflow_invocation_action] + + task = DataformWorkflowInvocationActionStateSensor( + task_id=TEST_TASK_ID, + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + repository_id=TEST_REPOSITORY_ID, + workflow_invocation_id=TEST_WORKFLOW_INVOCATION_ID, + target_name=TEST_TARGET_NAME, + expected_statuses=[expected_status], + failure_statuses=[], + gcp_conn_id=TEST_GCP_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + ) + results = task.poke(mock.MagicMock()) + + assert sensor_return == results + + mock_hook.assert_called_once_with( + gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN + ) + mock_query_workflow_invocation_actions.assert_called_once_with( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + repository_id=TEST_REPOSITORY_ID, + workflow_invocation_id=TEST_WORKFLOW_INVOCATION_ID, + ) + + @mock.patch("airflow.providers.google.cloud.sensors.dataform.DataformHook") + def test_target_state_failure_raises_exception(self, mock_hook: mock.MagicMock): + target = Target(database="", schema="", name=TEST_TARGET_NAME) + workflow_invocation_action = WorkflowInvocationAction( + target=target, state=WorkflowInvocationAction.State.FAILED + ) + mock_query_workflow_invocation_actions = mock_hook.return_value.query_workflow_invocation_actions + mock_query_workflow_invocation_actions.return_value = [workflow_invocation_action] + + task = DataformWorkflowInvocationActionStateSensor( + task_id=TEST_TASK_ID, + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + repository_id=TEST_REPOSITORY_ID, + workflow_invocation_id=TEST_WORKFLOW_INVOCATION_ID, + target_name=TEST_TARGET_NAME, + expected_statuses=[WorkflowInvocationAction.State.SUCCEEDED], + failure_statuses=[WorkflowInvocationAction.State.FAILED], + gcp_conn_id=TEST_GCP_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + ) + + with pytest.raises(AirflowException): + task.poke(mock.MagicMock()) + + mock_hook.assert_called_once_with( + gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN + ) + mock_query_workflow_invocation_actions.assert_called_once_with( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + repository_id=TEST_REPOSITORY_ID, + workflow_invocation_id=TEST_WORKFLOW_INVOCATION_ID, + ) + + @mock.patch("airflow.providers.google.cloud.sensors.dataform.DataformHook") + def test_target_not_found_raises_exception(self, mock_hook: mock.MagicMock): + mock_query_workflow_invocation_actions = mock_hook.return_value.query_workflow_invocation_actions + mock_query_workflow_invocation_actions.return_value = [] + + task = DataformWorkflowInvocationActionStateSensor( + task_id=TEST_TASK_ID, + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + repository_id=TEST_REPOSITORY_ID, + workflow_invocation_id=TEST_WORKFLOW_INVOCATION_ID, + target_name=TEST_TARGET_NAME, + expected_statuses=[WorkflowInvocationAction.State.SUCCEEDED], + failure_statuses=[WorkflowInvocationAction.State.FAILED], + gcp_conn_id=TEST_GCP_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + ) + + with pytest.raises(AirflowException): + task.poke(mock.MagicMock()) + + mock_hook.assert_called_once_with( + gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN + ) + mock_query_workflow_invocation_actions.assert_called_once_with( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + repository_id=TEST_REPOSITORY_ID, + workflow_invocation_id=TEST_WORKFLOW_INVOCATION_ID, + ) diff --git a/providers/tests/system/google/cloud/dataform/example_dataform.py b/providers/tests/system/google/cloud/dataform/example_dataform.py index e88a37caef2e6..f15e629b0f4b6 100644 --- a/providers/tests/system/google/cloud/dataform/example_dataform.py +++ b/providers/tests/system/google/cloud/dataform/example_dataform.py @@ -24,7 +24,7 @@ import os from datetime import datetime -from google.cloud.dataform_v1beta1 import WorkflowInvocation +from google.cloud.dataform_v1beta1 import WorkflowInvocation, WorkflowInvocationAction from airflow.models.dag import DAG from airflow.providers.google.cloud.operators.bigquery import BigQueryDeleteDatasetOperator @@ -45,7 +45,10 @@ DataformRemoveFileOperator, DataformWriteFileOperator, ) -from airflow.providers.google.cloud.sensors.dataform import DataformWorkflowInvocationStateSensor +from airflow.providers.google.cloud.sensors.dataform import ( + DataformWorkflowInvocationActionStateSensor, + DataformWorkflowInvocationStateSensor, +) from airflow.providers.google.cloud.utils.dataform import make_initialization_workspace_flow from airflow.utils.trigger_rule import TriggerRule @@ -174,6 +177,37 @@ ) # [END howto_operator_create_workflow_invocation_async] + # [START howto_operator_create_workflow_invocation_action_async] + create_workflow_invocation_async_action = DataformCreateWorkflowInvocationOperator( + task_id="create-workflow-invocation-async", + project_id=PROJECT_ID, + region=REGION, + repository_id=REPOSITORY_ID, + asynchronous=True, + workflow_invocation={ + "compilation_result": "{{ task_instance.xcom_pull('create-compilation-result')['name'] }}" + }, + ) + + is_workflow_invocation_action_done = DataformWorkflowInvocationActionStateSensor( + task_id="is-workflow-invocation-done", + project_id=PROJECT_ID, + region=REGION, + repository_id=REPOSITORY_ID, + workflow_invocation_id=( + "{{ task_instance.xcom_pull('create-workflow-invocation')['name'].split('/')[-1] }}" + ), + target_name="YOUR_TARGET_HERE", + expected_statuses={WorkflowInvocationAction.State.SUCCEEDED}, + failure_statuses={ + WorkflowInvocationAction.State.SKIPPED, + WorkflowInvocationAction.State.DISABLED, + WorkflowInvocationAction.State.CANCELLED, + WorkflowInvocationAction.State.FAILED, + }, + ) + # [END howto_operator_create_workflow_invocation_action_async] + # [START howto_operator_get_workflow_invocation] get_workflow_invocation = DataformGetWorkflowInvocationOperator( task_id="get-workflow-invocation", @@ -314,6 +348,8 @@ >> query_workflow_invocation_actions >> create_workflow_invocation_async >> is_workflow_invocation_done + >> create_workflow_invocation_async_action + >> is_workflow_invocation_action_done >> create_workflow_invocation_for_cancel >> cancel_workflow_invocation >> make_test_directory