diff --git a/providers/apache/beam/src/airflow/providers/apache/beam/operators/beam.py b/providers/apache/beam/src/airflow/providers/apache/beam/operators/beam.py index fccb4aec3c994..073f7bb750c30 100644 --- a/providers/apache/beam/src/airflow/providers/apache/beam/operators/beam.py +++ b/providers/apache/beam/src/airflow/providers/apache/beam/operators/beam.py @@ -276,6 +276,14 @@ def _init_pipeline_options( is_dataflow_job_id_exist_callback, ) + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "project_id": self.dataflow_config.project_id, + "region": self.dataflow_config.location, + "job_id": self.dataflow_job_id, + } + def execute_complete(self, context: Context, event: dict[str, Any]): """ Execute when the trigger fires - returns immediately. @@ -443,13 +451,7 @@ def execute_on_dataflow(self, context: Context): ) location = self.dataflow_config.location or DEFAULT_DATAFLOW_LOCATION - DataflowJobLink.persist( - self, - context, - self.dataflow_config.project_id, - location, - self.dataflow_job_id, - ) + DataflowJobLink.persist(context=context, region=location) if self.deferrable: trigger_args = { @@ -626,13 +628,7 @@ def execute_on_dataflow(self, context: Context): is_dataflow_job_id_exist_callback=self.is_dataflow_job_id_exist_callback, ) if self.dataflow_job_name and self.dataflow_config.location: - DataflowJobLink.persist( - self, - context, - self.dataflow_config.project_id, - self.dataflow_config.location, - self.dataflow_job_id, - ) + DataflowJobLink.persist(context=context) if self.deferrable: trigger_args = { "job_id": self.dataflow_job_id, @@ -795,14 +791,7 @@ def execute(self, context: Context): variables=snake_case_pipeline_options, process_line_callback=process_line_callback, ) - - DataflowJobLink.persist( - self, - context, - self.dataflow_config.project_id, - self.dataflow_config.location, - self.dataflow_job_id, - ) + DataflowJobLink.persist(context=context) if dataflow_job_name and self.dataflow_config.location: self.dataflow_hook.wait_for_done( job_name=dataflow_job_name, diff --git a/providers/apache/beam/tests/unit/apache/beam/operators/test_beam.py b/providers/apache/beam/tests/unit/apache/beam/operators/test_beam.py index dcb766373bfd9..dffce7778ee8b 100644 --- a/providers/apache/beam/tests/unit/apache/beam/operators/test_beam.py +++ b/providers/apache/beam/tests/unit/apache/beam/operators/test_beam.py @@ -23,7 +23,7 @@ import pytest -from airflow.exceptions import AirflowException, TaskDeferred +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, TaskDeferred from airflow.providers.apache.beam.operators.beam import ( BeamBasePipelineOperator, BeamRunGoPipelineOperator, @@ -32,6 +32,7 @@ ) from airflow.providers.apache.beam.triggers.beam import BeamJavaPipelineTrigger, BeamPythonPipelineTrigger from airflow.providers.google.cloud.operators.dataflow import DataflowConfiguration +from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS from airflow.version import version TASK_ID = "test-beam-operator" @@ -233,13 +234,7 @@ def test_exec_dataflow_runner( } gcs_provide_file.assert_any_call(object_url=PY_FILE) gcs_provide_file.assert_any_call(object_url=REQURIEMENTS_FILE) - persist_link_mock.assert_called_once_with( - op, - {}, - expected_options["project"], - expected_options["region"], - op.dataflow_job_id, - ) + persist_link_mock.assert_called_once_with(context={}, region="us-central1") beam_hook_mock.return_value.start_python_pipeline.assert_called_once_with( variables=expected_options, py_file=gcs_provide_file.return_value.__enter__.return_value.name, @@ -446,13 +441,7 @@ def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock "output": "gs://test/output", "impersonateServiceAccount": TEST_IMPERSONATION_ACCOUNT, } - persist_link_mock.assert_called_once_with( - op, - {}, - expected_options["project"], - expected_options["region"], - op.dataflow_job_id, - ) + persist_link_mock.assert_called_once_with(context={}) beam_hook_mock.return_value.start_java_pipeline.assert_called_once_with( variables=expected_options, jar=gcs_provide_file.return_value.__enter__.return_value.name, @@ -753,13 +742,7 @@ def test_exec_dataflow_runner_with_go_file( "labels": {"foo": "bar", "airflow-version": TEST_VERSION}, "region": "us-central1", } - persist_link_mock.assert_called_once_with( - op, - {}, - expected_options["project"], - expected_options["region"], - op.dataflow_job_id, - ) + persist_link_mock.assert_called_once_with(context={}) expected_go_file = "/tmp/apache-beam-go/main.go" gcs_download_method.assert_called_once_with( bucket_name="my-bucket", object_name="example/main.go", filename=expected_go_file @@ -859,13 +842,7 @@ def gcs_download_side_effect(bucket_name: str, object_name: str, filename: str) worker_binary=expected_worker_binary, process_line_callback=mock.ANY, ) - mock_persist_link.assert_called_once_with( - operator, - {}, - dataflow_config.project_id, - dataflow_config.location, - operator.dataflow_job_id, - ) + mock_persist_link.assert_called_once_with(context={}) wait_for_done_method.assert_called_once_with( job_name=expected_job_name, location=dataflow_config.location, @@ -970,8 +947,20 @@ def test_exec_dataflow_runner(self, gcs_hook_mock, dataflow_hook_mock, beam_hook **self.default_op_kwargs, ) magic_mock = mock.MagicMock() - with pytest.raises(TaskDeferred): - op.execute(context=magic_mock) + if AIRFLOW_V_3_0_PLUS: + with pytest.raises(TaskDeferred): + op.execute(context=magic_mock) + else: + exception_msg = ( + "GoogleBaseLink.persist method call with no extra value is Deprecated for Airflow 3." + " The method calls (only with context) needs to be removed after the Airflow 3 Migration" + " completed!" + ) + with ( + pytest.raises(TaskDeferred), + pytest.warns(AirflowProviderDeprecationWarning, match=exception_msg), + ): + op.execute(context=magic_mock) dataflow_hook_mock.assert_called_once_with( gcp_conn_id=dataflow_config.gcp_conn_id, @@ -1005,8 +994,20 @@ def test_on_kill_dataflow_runner(self, dataflow_hook_mock, _, __, ___): def test_on_kill_direct_runner(self, _, dataflow_mock, __): dataflow_cancel_job = dataflow_mock.return_value.cancel_job op = BeamRunPythonPipelineOperator(runner="DataflowRunner", **self.default_op_kwargs) - with pytest.raises(TaskDeferred): - op.execute(mock.MagicMock()) + if AIRFLOW_V_3_0_PLUS: + with pytest.raises(TaskDeferred): + op.execute(mock.MagicMock()) + else: + exception_msg = ( + "GoogleBaseLink.persist method call with no extra value is Deprecated for Airflow 3." + " The method calls (only with context) needs to be removed after the Airflow 3 Migration" + " completed!" + ) + with ( + pytest.raises(TaskDeferred), + pytest.warns(AirflowProviderDeprecationWarning, match=exception_msg), + ): + op.execute(mock.MagicMock()) op.on_kill() dataflow_cancel_job.assert_not_called() @@ -1075,8 +1076,20 @@ def test_exec_dataflow_runner(self, gcs_hook_mock, dataflow_hook_mock, beam_hook ) dataflow_hook_mock.return_value.is_job_dataflow_running.return_value = False magic_mock = mock.MagicMock() - with pytest.raises(TaskDeferred): - op.execute(context=magic_mock) + if AIRFLOW_V_3_0_PLUS: + with pytest.raises(TaskDeferred): + op.execute(context=magic_mock) + else: + exception_msg = ( + "GoogleBaseLink.persist method call with no extra value is Deprecated for Airflow 3." + " The method calls (only with context) needs to be removed after the Airflow 3 Migration" + " completed!" + ) + with ( + pytest.raises(TaskDeferred), + pytest.warns(AirflowProviderDeprecationWarning, match=exception_msg), + ): + op.execute(context=magic_mock) dataflow_hook_mock.assert_called_once_with( gcp_conn_id=dataflow_config.gcp_conn_id, diff --git a/providers/google/src/airflow/providers/google/cloud/links/alloy_db.py b/providers/google/src/airflow/providers/google/cloud/links/alloy_db.py index 05d9e9fa5a873..5321bad431dd7 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/alloy_db.py +++ b/providers/google/src/airflow/providers/google/cloud/links/alloy_db.py @@ -19,14 +19,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING - from airflow.providers.google.cloud.links.base import BaseGoogleLink -if TYPE_CHECKING: - from airflow.models import BaseOperator - from airflow.utils.context import Context - ALLOY_DB_BASE_LINK = "/alloydb" ALLOY_DB_CLUSTER_LINK = ( ALLOY_DB_BASE_LINK + "/locations/{location_id}/clusters/{cluster_id}?project={project_id}" @@ -44,20 +38,6 @@ class AlloyDBClusterLink(BaseGoogleLink): key = "alloy_db_cluster" format_str = ALLOY_DB_CLUSTER_LINK - @staticmethod - def persist( - context: Context, - task_instance: BaseOperator, - location_id: str, - cluster_id: str, - project_id: str | None, - ): - task_instance.xcom_push( - context, - key=AlloyDBClusterLink.key, - value={"location_id": location_id, "cluster_id": cluster_id, "project_id": project_id}, - ) - class AlloyDBUsersLink(BaseGoogleLink): """Helper class for constructing AlloyDB users Link.""" @@ -66,20 +46,6 @@ class AlloyDBUsersLink(BaseGoogleLink): key = "alloy_db_users" format_str = ALLOY_DB_USERS_LINK - @staticmethod - def persist( - context: Context, - task_instance: BaseOperator, - location_id: str, - cluster_id: str, - project_id: str | None, - ): - task_instance.xcom_push( - context, - key=AlloyDBUsersLink.key, - value={"location_id": location_id, "cluster_id": cluster_id, "project_id": project_id}, - ) - class AlloyDBBackupsLink(BaseGoogleLink): """Helper class for constructing AlloyDB backups Link.""" @@ -87,15 +53,3 @@ class AlloyDBBackupsLink(BaseGoogleLink): name = "AlloyDB Backups" key = "alloy_db_backups" format_str = ALLOY_DB_BACKUPS_LINK - - @staticmethod - def persist( - context: Context, - task_instance: BaseOperator, - project_id: str | None, - ): - task_instance.xcom_push( - context, - key=AlloyDBBackupsLink.key, - value={"project_id": project_id}, - ) diff --git a/providers/google/src/airflow/providers/google/cloud/links/base.py b/providers/google/src/airflow/providers/google/cloud/links/base.py index 3bf9114cc9b9f..da5d68bfbe98b 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/base.py +++ b/providers/google/src/airflow/providers/google/cloud/links/base.py @@ -24,6 +24,9 @@ if TYPE_CHECKING: from airflow.models import BaseOperator from airflow.models.taskinstancekey import TaskInstanceKey + from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator + from airflow.sdk import BaseSensorOperator + from airflow.utils.context import Context if AIRFLOW_V_3_0_PLUS: from airflow.sdk import BaseOperatorLink @@ -39,6 +42,12 @@ class BaseGoogleLink(BaseOperatorLink): """ Base class for all Google links. + When you inherit this class in a Link class; + - You can call the persist method to push data to the XCom to use it later in the get_link method. + - If you have an operator which inherit the GoogleCloudBaseOperator or BaseSensorOperator + You can define extra_links_params method in the operator to pass the operator properties + to the get_link method. + :meta private: """ @@ -46,15 +55,69 @@ class BaseGoogleLink(BaseOperatorLink): key: ClassVar[str] format_str: ClassVar[str] + @property + def xcom_key(self) -> str: + # NOTE: in Airflow 3 we need to have xcom_key property in the Link class. + # Since we have the key property already, this is just a proxy property method to use same + # key as in Airflow 2. + return self.key + + @classmethod + def persist(cls, context: Context, **value): + """ + Push arguments to the XCom to use later for link formatting at the `get_link` method. + + Note: for Airflow 2 we need to call this function with context variable only + where we have the extra_links_params property method defined + """ + params = {} + # TODO: remove after Airflow v2 support dropped + if not AIRFLOW_V_3_0_PLUS: + common_params = getattr(context["task"], "extra_links_params", None) + if common_params: + params.update(common_params) + + context["ti"].xcom_push( + key=cls.key, + value={ + **params, + **value, + }, + ) + + def get_config(self, operator, ti_key): + conf = {} + conf.update(getattr(operator, "extra_links_params", {})) + conf.update(XCom.get_value(key=self.key, ti_key=ti_key) or {}) + + # if the config did not define, return None to stop URL formatting + if not conf: + return None + + # Add a default value for the 'namespace' parameter for backward compatibility. + # This is for datafusion + conf.setdefault("namespace", "default") + return conf + def get_link( self, operator: BaseOperator, *, ti_key: TaskInstanceKey, ) -> str: - conf = XCom.get_value(key=self.key, ti_key=ti_key) + if TYPE_CHECKING: + assert isinstance(operator, (GoogleCloudBaseOperator, BaseSensorOperator)) + + conf = self.get_config(operator, ti_key) if not conf: return "" - if self.format_str.startswith("http"): - return self.format_str.format(**conf) - return BASE_LINK + self.format_str.format(**conf) + return self._format_link(**conf) + + def _format_link(self, **kwargs): + try: + formatted_str = self.format_str.format(**kwargs) + if formatted_str.startswith("http"): + return formatted_str + return BASE_LINK + formatted_str + except KeyError: + return "" diff --git a/providers/google/src/airflow/providers/google/cloud/links/bigquery.py b/providers/google/src/airflow/providers/google/cloud/links/bigquery.py index 8b3e95a29dea3..c71d2a2174497 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/bigquery.py +++ b/providers/google/src/airflow/providers/google/cloud/links/bigquery.py @@ -19,14 +19,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING - from airflow.providers.google.cloud.links.base import BaseGoogleLink -if TYPE_CHECKING: - from airflow.models import BaseOperator - from airflow.utils.context import Context - BIGQUERY_BASE_LINK = "/bigquery" BIGQUERY_DATASET_LINK = ( BIGQUERY_BASE_LINK + "?referrer=search&project={project_id}&d={dataset_id}&p={project_id}&page=dataset" @@ -47,19 +41,6 @@ class BigQueryDatasetLink(BaseGoogleLink): key = "bigquery_dataset" format_str = BIGQUERY_DATASET_LINK - @staticmethod - def persist( - context: Context, - task_instance: BaseOperator, - dataset_id: str, - project_id: str, - ): - task_instance.xcom_push( - context, - key=BigQueryDatasetLink.key, - value={"dataset_id": dataset_id, "project_id": project_id}, - ) - class BigQueryTableLink(BaseGoogleLink): """Helper class for constructing BigQuery Table Link.""" @@ -68,20 +49,6 @@ class BigQueryTableLink(BaseGoogleLink): key = "bigquery_table" format_str = BIGQUERY_TABLE_LINK - @staticmethod - def persist( - context: Context, - task_instance: BaseOperator, - project_id: str, - table_id: str, - dataset_id: str | None = None, - ): - task_instance.xcom_push( - context, - key=BigQueryTableLink.key, - value={"dataset_id": dataset_id, "project_id": project_id, "table_id": table_id}, - ) - class BigQueryJobDetailLink(BaseGoogleLink): """Helper class for constructing BigQuery Job Detail Link.""" @@ -89,17 +56,3 @@ class BigQueryJobDetailLink(BaseGoogleLink): name = "BigQuery Job Detail" key = "bigquery_job_detail" format_str = BIGQUERY_JOB_DETAIL_LINK - - @staticmethod - def persist( - context: Context, - task_instance: BaseOperator, - project_id: str, - location: str, - job_id: str, - ): - task_instance.xcom_push( - context, - key=BigQueryJobDetailLink.key, - value={"project_id": project_id, "location": location, "job_id": job_id}, - ) diff --git a/providers/google/src/airflow/providers/google/cloud/links/bigquery_dts.py b/providers/google/src/airflow/providers/google/cloud/links/bigquery_dts.py index e8c57543da17b..db878effccb9b 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/bigquery_dts.py +++ b/providers/google/src/airflow/providers/google/cloud/links/bigquery_dts.py @@ -19,14 +19,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING - from airflow.providers.google.cloud.links.base import BaseGoogleLink -if TYPE_CHECKING: - from airflow.models import BaseOperator - from airflow.utils.context import Context - BIGQUERY_BASE_LINK = "/bigquery/transfers" BIGQUERY_DTS_LINK = BIGQUERY_BASE_LINK + "/locations/{region}/configs/{config_id}/runs?project={project_id}" @@ -37,17 +31,3 @@ class BigQueryDataTransferConfigLink(BaseGoogleLink): name = "BigQuery Data Transfer Config" key = "bigquery_dts_config" format_str = BIGQUERY_DTS_LINK - - @staticmethod - def persist( - context: Context, - task_instance: BaseOperator, - region: str, - config_id: str, - project_id: str, - ): - task_instance.xcom_push( - context, - key=BigQueryDataTransferConfigLink.key, - value={"project_id": project_id, "region": region, "config_id": config_id}, - ) diff --git a/providers/google/src/airflow/providers/google/cloud/links/bigtable.py b/providers/google/src/airflow/providers/google/cloud/links/bigtable.py index 47805ba348896..716b1a466c01d 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/bigtable.py +++ b/providers/google/src/airflow/providers/google/cloud/links/bigtable.py @@ -16,13 +16,8 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING - from airflow.providers.google.cloud.links.base import BaseGoogleLink -if TYPE_CHECKING: - from airflow.utils.context import Context - BIGTABLE_BASE_LINK = "/bigtable" BIGTABLE_INSTANCE_LINK = BIGTABLE_BASE_LINK + "/instances/{instance_id}/overview?project={project_id}" BIGTABLE_CLUSTER_LINK = ( @@ -38,20 +33,6 @@ class BigtableInstanceLink(BaseGoogleLink): key = "instance_key" format_str = BIGTABLE_INSTANCE_LINK - @staticmethod - def persist( - context: Context, - task_instance, - ): - task_instance.xcom_push( - context=context, - key=BigtableInstanceLink.key, - value={ - "instance_id": task_instance.instance_id, - "project_id": task_instance.project_id, - }, - ) - class BigtableClusterLink(BaseGoogleLink): """Helper class for constructing Bigtable Cluster link.""" @@ -60,21 +41,6 @@ class BigtableClusterLink(BaseGoogleLink): key = "cluster_key" format_str = BIGTABLE_CLUSTER_LINK - @staticmethod - def persist( - context: Context, - task_instance, - ): - task_instance.xcom_push( - context=context, - key=BigtableClusterLink.key, - value={ - "instance_id": task_instance.instance_id, - "cluster_id": task_instance.cluster_id, - "project_id": task_instance.project_id, - }, - ) - class BigtableTablesLink(BaseGoogleLink): """Helper class for constructing Bigtable Tables link.""" @@ -82,17 +48,3 @@ class BigtableTablesLink(BaseGoogleLink): name = "Bigtable Tables" key = "tables_key" format_str = BIGTABLE_TABLES_LINK - - @staticmethod - def persist( - context: Context, - task_instance, - ): - task_instance.xcom_push( - context=context, - key=BigtableTablesLink.key, - value={ - "instance_id": task_instance.instance_id, - "project_id": task_instance.project_id, - }, - ) diff --git a/providers/google/src/airflow/providers/google/cloud/links/cloud_build.py b/providers/google/src/airflow/providers/google/cloud/links/cloud_build.py index b855dba89732b..723f9e2faf671 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/cloud_build.py +++ b/providers/google/src/airflow/providers/google/cloud/links/cloud_build.py @@ -16,13 +16,8 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING - from airflow.providers.google.cloud.links.base import BaseGoogleLink -if TYPE_CHECKING: - from airflow.utils.context import Context - BUILD_BASE_LINK = "/cloud-build" BUILD_LINK = BUILD_BASE_LINK + "/builds;region={region}/{build_id}?project={project_id}" @@ -43,24 +38,6 @@ class CloudBuildLink(BaseGoogleLink): key = "cloud_build_key" format_str = BUILD_LINK - @staticmethod - def persist( - context: Context, - task_instance, - build_id: str, - project_id: str, - region: str, - ): - task_instance.xcom_push( - context=context, - key=CloudBuildLink.key, - value={ - "project_id": project_id, - "region": region, - "build_id": build_id, - }, - ) - class CloudBuildListLink(BaseGoogleLink): """Helper class for constructing Cloud Build List link.""" @@ -69,22 +46,6 @@ class CloudBuildListLink(BaseGoogleLink): key = "cloud_build_list_key" format_str = BUILD_LIST_LINK - @staticmethod - def persist( - context: Context, - task_instance, - project_id: str, - region: str, - ): - task_instance.xcom_push( - context=context, - key=CloudBuildListLink.key, - value={ - "project_id": project_id, - "region": region, - }, - ) - class CloudBuildTriggersListLink(BaseGoogleLink): """Helper class for constructing Cloud Build Triggers List link.""" @@ -93,22 +54,6 @@ class CloudBuildTriggersListLink(BaseGoogleLink): key = "cloud_build_triggers_list_key" format_str = BUILD_TRIGGERS_LIST_LINK - @staticmethod - def persist( - context: Context, - task_instance, - project_id: str, - region: str, - ): - task_instance.xcom_push( - context=context, - key=CloudBuildTriggersListLink.key, - value={ - "project_id": project_id, - "region": region, - }, - ) - class CloudBuildTriggerDetailsLink(BaseGoogleLink): """Helper class for constructing Cloud Build Trigger Details link.""" @@ -116,21 +61,3 @@ class CloudBuildTriggerDetailsLink(BaseGoogleLink): name = "Cloud Build Triggers Details" key = "cloud_build_triggers_details_key" format_str = BUILD_TRIGGER_DETAILS_LINK - - @staticmethod - def persist( - context: Context, - task_instance, - project_id: str, - region: str, - trigger_id: str, - ): - task_instance.xcom_push( - context=context, - key=CloudBuildTriggerDetailsLink.key, - value={ - "project_id": project_id, - "region": region, - "trigger_id": trigger_id, - }, - ) diff --git a/providers/google/src/airflow/providers/google/cloud/links/cloud_functions.py b/providers/google/src/airflow/providers/google/cloud/links/cloud_functions.py index b0d2f92f9a1db..a0e5373bedc1c 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/cloud_functions.py +++ b/providers/google/src/airflow/providers/google/cloud/links/cloud_functions.py @@ -19,15 +19,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING - from airflow.providers.google.cloud.links.base import BaseGoogleLink -if TYPE_CHECKING: - from airflow.models import BaseOperator - from airflow.utils.context import Context - - CLOUD_FUNCTIONS_BASE_LINK = "https://console.cloud.google.com/functions" CLOUD_FUNCTIONS_DETAILS_LINK = ( @@ -44,20 +37,6 @@ class CloudFunctionsDetailsLink(BaseGoogleLink): key = "cloud_functions_details" format_str = CLOUD_FUNCTIONS_DETAILS_LINK - @staticmethod - def persist( - context: Context, - task_instance: BaseOperator, - function_name: str, - location: str, - project_id: str, - ): - task_instance.xcom_push( - context, - key=CloudFunctionsDetailsLink.key, - value={"function_name": function_name, "location": location, "project_id": project_id}, - ) - class CloudFunctionsListLink(BaseGoogleLink): """Helper class for constructing Cloud Functions Details Link.""" @@ -65,15 +44,3 @@ class CloudFunctionsListLink(BaseGoogleLink): name = "Cloud Functions List" key = "cloud_functions_list" format_str = CLOUD_FUNCTIONS_LIST_LINK - - @staticmethod - def persist( - context: Context, - task_instance: BaseOperator, - project_id: str, - ): - task_instance.xcom_push( - context, - key=CloudFunctionsDetailsLink.key, - value={"project_id": project_id}, - ) diff --git a/providers/google/src/airflow/providers/google/cloud/links/cloud_memorystore.py b/providers/google/src/airflow/providers/google/cloud/links/cloud_memorystore.py index c9c9d468dd306..5bf10c6af32ad 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/cloud_memorystore.py +++ b/providers/google/src/airflow/providers/google/cloud/links/cloud_memorystore.py @@ -19,14 +19,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING - from airflow.providers.google.cloud.links.base import BaseGoogleLink -if TYPE_CHECKING: - from airflow.models import BaseOperator - from airflow.utils.context import Context - BASE_LINK = "/memorystore" MEMCACHED_LINK = ( BASE_LINK + "/memcached/locations/{location_id}/instances/{instance_id}/details?project={project_id}" @@ -45,20 +39,6 @@ class MemcachedInstanceDetailsLink(BaseGoogleLink): key = "memcached_instance" format_str = MEMCACHED_LINK - @staticmethod - def persist( - context: Context, - task_instance: BaseOperator, - instance_id: str, - location_id: str, - project_id: str | None, - ): - task_instance.xcom_push( - context, - key=MemcachedInstanceDetailsLink.key, - value={"instance_id": instance_id, "location_id": location_id, "project_id": project_id}, - ) - class MemcachedInstanceListLink(BaseGoogleLink): """Helper class for constructing Memorystore Memcached List of Instances Link.""" @@ -67,18 +47,6 @@ class MemcachedInstanceListLink(BaseGoogleLink): key = "memcached_instances" format_str = MEMCACHED_LIST_LINK - @staticmethod - def persist( - context: Context, - task_instance: BaseOperator, - project_id: str | None, - ): - task_instance.xcom_push( - context, - key=MemcachedInstanceListLink.key, - value={"project_id": project_id}, - ) - class RedisInstanceDetailsLink(BaseGoogleLink): """Helper class for constructing Memorystore Redis Instance Link.""" @@ -87,20 +55,6 @@ class RedisInstanceDetailsLink(BaseGoogleLink): key = "redis_instance" format_str = REDIS_LINK - @staticmethod - def persist( - context: Context, - task_instance: BaseOperator, - instance_id: str, - location_id: str, - project_id: str | None, - ): - task_instance.xcom_push( - context, - key=RedisInstanceDetailsLink.key, - value={"instance_id": instance_id, "location_id": location_id, "project_id": project_id}, - ) - class RedisInstanceListLink(BaseGoogleLink): """Helper class for constructing Memorystore Redis List of Instances Link.""" @@ -108,15 +62,3 @@ class RedisInstanceListLink(BaseGoogleLink): name = "Memorystore Redis List of Instances" key = "redis_instances" format_str = REDIS_LIST_LINK - - @staticmethod - def persist( - context: Context, - task_instance: BaseOperator, - project_id: str | None, - ): - task_instance.xcom_push( - context, - key=RedisInstanceListLink.key, - value={"project_id": project_id}, - ) diff --git a/providers/google/src/airflow/providers/google/cloud/links/cloud_run.py b/providers/google/src/airflow/providers/google/cloud/links/cloud_run.py index 55283eef650e7..11efa9d80ba14 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/cloud_run.py +++ b/providers/google/src/airflow/providers/google/cloud/links/cloud_run.py @@ -16,20 +16,7 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING - from airflow.providers.google.cloud.links.base import BaseGoogleLink -from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS - -if TYPE_CHECKING: - from airflow.models import BaseOperator - from airflow.models.taskinstancekey import TaskInstanceKey - from airflow.utils.context import Context - -if AIRFLOW_V_3_0_PLUS: - from airflow.sdk.execution_time.xcom import XCom -else: - from airflow.models.xcom import XCom # type: ignore[no-redef] class CloudRunJobLoggingLink(BaseGoogleLink): @@ -37,23 +24,4 @@ class CloudRunJobLoggingLink(BaseGoogleLink): name = "Cloud Run Job Logging" key = "log_uri" - - @staticmethod - def persist( - context: Context, - task_instance: BaseOperator, - log_uri: str, - ): - task_instance.xcom_push( - context, - key=CloudRunJobLoggingLink.key, - value=log_uri, - ) - - def get_link( - self, - operator: BaseOperator, - *, - ti_key: TaskInstanceKey, - ) -> str: - return XCom.get_value(key=self.key, ti_key=ti_key) + format_str = "{log_uri}" diff --git a/providers/google/src/airflow/providers/google/cloud/links/cloud_sql.py b/providers/google/src/airflow/providers/google/cloud/links/cloud_sql.py index a2373e0d1619b..33a16fb1987d3 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/cloud_sql.py +++ b/providers/google/src/airflow/providers/google/cloud/links/cloud_sql.py @@ -19,15 +19,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING - from airflow.providers.google.cloud.links.base import BaseGoogleLink -if TYPE_CHECKING: - from airflow.models import BaseOperator - from airflow.utils.context import Context - - CLOUD_SQL_BASE_LINK = "/sql" CLOUD_SQL_INSTANCE_LINK = CLOUD_SQL_BASE_LINK + "/instances/{instance}/overview?project={project_id}" CLOUD_SQL_INSTANCE_DATABASE_LINK = ( @@ -42,19 +35,6 @@ class CloudSQLInstanceLink(BaseGoogleLink): key = "cloud_sql_instance" format_str = CLOUD_SQL_INSTANCE_LINK - @staticmethod - def persist( - context: Context, - task_instance: BaseOperator, - cloud_sql_instance: str, - project_id: str | None, - ): - task_instance.xcom_push( - context, - key=CloudSQLInstanceLink.key, - value={"instance": cloud_sql_instance, "project_id": project_id}, - ) - class CloudSQLInstanceDatabaseLink(BaseGoogleLink): """Helper class for constructing Cloud SQL Instance Database Link.""" @@ -62,16 +42,3 @@ class CloudSQLInstanceDatabaseLink(BaseGoogleLink): name = "Cloud SQL Instance Database" key = "cloud_sql_instance_database" format_str = CLOUD_SQL_INSTANCE_DATABASE_LINK - - @staticmethod - def persist( - context: Context, - task_instance: BaseOperator, - cloud_sql_instance: str, - project_id: str | None, - ): - task_instance.xcom_push( - context, - key=CloudSQLInstanceDatabaseLink.key, - value={"instance": cloud_sql_instance, "project_id": project_id}, - ) diff --git a/providers/google/src/airflow/providers/google/cloud/links/cloud_storage_transfer.py b/providers/google/src/airflow/providers/google/cloud/links/cloud_storage_transfer.py index 159f4a0dfe6c1..0e30cb7c8dfc4 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/cloud_storage_transfer.py +++ b/providers/google/src/airflow/providers/google/cloud/links/cloud_storage_transfer.py @@ -60,18 +60,6 @@ class CloudStorageTransferListLink(BaseGoogleLink): key = "cloud_storage_transfer" format_str = CLOUD_STORAGE_TRANSFER_LIST_LINK - @staticmethod - def persist( - context: Context, - task_instance, - project_id: str, - ): - task_instance.xcom_push( - context, - key=CloudStorageTransferListLink.key, - value={"project_id": project_id}, - ) - class CloudStorageTransferJobLink(BaseGoogleLink): """Helper class for constructing Storage Transfer Job Link.""" @@ -80,22 +68,6 @@ class CloudStorageTransferJobLink(BaseGoogleLink): key = "cloud_storage_transfer_job" format_str = CLOUD_STORAGE_TRANSFER_JOB_LINK - @staticmethod - def persist( - context: Context, - project_id: str, - job_name: str, - ): - job_name = job_name.split("/")[1] if job_name else "" - - context["ti"].xcom_push( - key=CloudStorageTransferJobLink.key, - value={ - "project_id": project_id, - "transfer_job": job_name, - }, - ) - class CloudStorageTransferDetailsLink(BaseGoogleLink): """Helper class for constructing Cloud Storage Transfer Operation Link.""" @@ -105,20 +77,21 @@ class CloudStorageTransferDetailsLink(BaseGoogleLink): format_str = CLOUD_STORAGE_TRANSFER_OPERATION_LINK @staticmethod - def persist( - task_instance, - context: Context, - project_id: str, - operation_name: str, - ): - transfer_operation, transfer_job = CloudStorageTransferLinkHelper.extract_parts(operation_name) - - task_instance.xcom_push( + def extract_parts(operation_name: str | None): + if not operation_name: + return "", "" + transfer_operation = operation_name.split("/")[1] + transfer_job = operation_name.split("-")[1] + return transfer_operation, transfer_job + + @classmethod + def persist(cls, context: Context, **value): + operation_name = value.get("operation_name") + transfer_operation, transfer_job = cls.extract_parts(operation_name) + + super().persist( context, - key=CloudStorageTransferDetailsLink.key, - value={ - "project_id": project_id, - "transfer_job": transfer_job, - "transfer_operation": transfer_operation, - }, + project_id=value.get("project_id"), + transfer_job=transfer_job, + transfer_operation=transfer_operation, ) diff --git a/providers/google/src/airflow/providers/google/cloud/links/cloud_tasks.py b/providers/google/src/airflow/providers/google/cloud/links/cloud_tasks.py index 10a37b27e64c0..44f8330425381 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/cloud_tasks.py +++ b/providers/google/src/airflow/providers/google/cloud/links/cloud_tasks.py @@ -24,7 +24,6 @@ from airflow.providers.google.cloud.links.base import BaseGoogleLink if TYPE_CHECKING: - from airflow.models import BaseOperator from airflow.utils.context import Context CLOUD_TASKS_BASE_LINK = "/cloudtasks" @@ -51,18 +50,12 @@ def extract_parts(queue_name: str | None): parts = queue_name.split("/") return parts[1], parts[3], parts[5] - @staticmethod - def persist( - operator_instance: BaseOperator, - context: Context, - queue_name: str | None, - ): - project_id, location, queue_id = CloudTasksQueueLink.extract_parts(queue_name) - operator_instance.xcom_push( - context, - key=CloudTasksQueueLink.key, - value={"project_id": project_id, "location": location, "queue_id": queue_id}, - ) + @classmethod + def persist(cls, context: Context, **value): + queue_name = value.get("queue_name") + project_id, location, queue_id = cls.extract_parts(queue_name) + + super().persist(context, project_id=project_id, location=location, queue_id=queue_id) class CloudTasksLink(BaseGoogleLink): @@ -71,15 +64,3 @@ class CloudTasksLink(BaseGoogleLink): name = "Cloud Tasks" key = "cloud_task" format_str = CLOUD_TASKS_LINK - - @staticmethod - def persist( - operator_instance: BaseOperator, - context: Context, - project_id: str | None, - ): - operator_instance.xcom_push( - context, - key=CloudTasksLink.key, - value={"project_id": project_id}, - ) diff --git a/providers/google/src/airflow/providers/google/cloud/links/compute.py b/providers/google/src/airflow/providers/google/cloud/links/compute.py index 12680baf816e6..6989bc44f77c8 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/compute.py +++ b/providers/google/src/airflow/providers/google/cloud/links/compute.py @@ -19,14 +19,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING - from airflow.providers.google.cloud.links.base import BaseGoogleLink -if TYPE_CHECKING: - from airflow.models import BaseOperator - from airflow.utils.context import Context - COMPUTE_BASE_LINK = "https://console.cloud.google.com/compute" COMPUTE_LINK = ( COMPUTE_BASE_LINK + "/instancesDetail/zones/{location_id}/instances/{resource_id}?project={project_id}" @@ -44,24 +38,6 @@ class ComputeInstanceDetailsLink(BaseGoogleLink): key = "compute_instance_details" format_str = COMPUTE_LINK - @staticmethod - def persist( - context: Context, - task_instance: BaseOperator, - location_id: str, - resource_id: str | None, - project_id: str | None, - ): - task_instance.xcom_push( - context, - key=ComputeInstanceDetailsLink.key, - value={ - "location_id": location_id, - "resource_id": resource_id, - "project_id": project_id, - }, - ) - class ComputeInstanceTemplateDetailsLink(BaseGoogleLink): """Helper class for constructing Compute Instance Template details Link.""" @@ -70,22 +46,6 @@ class ComputeInstanceTemplateDetailsLink(BaseGoogleLink): key = "compute_instance_template_details" format_str = COMPUTE_TEMPLATE_LINK - @staticmethod - def persist( - context: Context, - task_instance: BaseOperator, - resource_id: str | None, - project_id: str | None, - ): - task_instance.xcom_push( - context, - key=ComputeInstanceTemplateDetailsLink.key, - value={ - "resource_id": resource_id, - "project_id": project_id, - }, - ) - class ComputeInstanceGroupManagerDetailsLink(BaseGoogleLink): """Helper class for constructing Compute Instance Group Manager details Link.""" @@ -93,21 +53,3 @@ class ComputeInstanceGroupManagerDetailsLink(BaseGoogleLink): name = "Compute Instance Group Manager" key = "compute_instance_group_manager_details" format_str = COMPUTE_GROUP_MANAGER_LINK - - @staticmethod - def persist( - context: Context, - task_instance: BaseOperator, - location_id: str, - resource_id: str | None, - project_id: str | None, - ): - task_instance.xcom_push( - context, - key=ComputeInstanceGroupManagerDetailsLink.key, - value={ - "location_id": location_id, - "resource_id": resource_id, - "project_id": project_id, - }, - ) diff --git a/providers/google/src/airflow/providers/google/cloud/links/data_loss_prevention.py b/providers/google/src/airflow/providers/google/cloud/links/data_loss_prevention.py index 4bc2f0e23460f..1ea6bef696a11 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/data_loss_prevention.py +++ b/providers/google/src/airflow/providers/google/cloud/links/data_loss_prevention.py @@ -17,13 +17,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING - from airflow.providers.google.cloud.links.base import BaseGoogleLink -if TYPE_CHECKING: - from airflow.utils.context import Context - BASE_LINK = "https://console.cloud.google.com" DLP_BASE_LINK = BASE_LINK + "/security/dlp" @@ -73,20 +68,6 @@ class CloudDLPDeidentifyTemplatesListLink(BaseGoogleLink): key = "cloud_dlp_deidentify_templates_list_key" format_str = DLP_DEIDENTIFY_TEMPLATES_LIST_LINK - @staticmethod - def persist( - context: Context, - task_instance, - project_id: str, - ): - task_instance.xcom_push( - context=context, - key=CloudDLPDeidentifyTemplatesListLink.key, - value={ - "project_id": project_id, - }, - ) - class CloudDLPDeidentifyTemplateDetailsLink(BaseGoogleLink): """Helper class for constructing Cloud Data Loss Prevention link.""" @@ -95,22 +76,6 @@ class CloudDLPDeidentifyTemplateDetailsLink(BaseGoogleLink): key = "cloud_dlp_deidentify_template_details_key" format_str = DLP_DEIDENTIFY_TEMPLATE_DETAILS_LINK - @staticmethod - def persist( - context: Context, - task_instance, - project_id: str, - template_name: str, - ): - task_instance.xcom_push( - context=context, - key=CloudDLPDeidentifyTemplateDetailsLink.key, - value={ - "project_id": project_id, - "template_name": template_name, - }, - ) - class CloudDLPJobTriggersListLink(BaseGoogleLink): """Helper class for constructing Cloud Data Loss Prevention link.""" @@ -119,20 +84,6 @@ class CloudDLPJobTriggersListLink(BaseGoogleLink): key = "cloud_dlp_job_triggers_list_key" format_str = DLP_JOB_TRIGGER_LIST_LINK - @staticmethod - def persist( - context: Context, - task_instance, - project_id: str, - ): - task_instance.xcom_push( - context=context, - key=CloudDLPJobTriggersListLink.key, - value={ - "project_id": project_id, - }, - ) - class CloudDLPJobTriggerDetailsLink(BaseGoogleLink): """Helper class for constructing Cloud Data Loss Prevention link.""" @@ -141,22 +92,6 @@ class CloudDLPJobTriggerDetailsLink(BaseGoogleLink): key = "cloud_dlp_job_trigger_details_key" format_str = DLP_JOB_TRIGGER_DETAILS_LINK - @staticmethod - def persist( - context: Context, - task_instance, - project_id: str, - trigger_name: str, - ): - task_instance.xcom_push( - context=context, - key=CloudDLPJobTriggerDetailsLink.key, - value={ - "project_id": project_id, - "trigger_name": trigger_name, - }, - ) - class CloudDLPJobsListLink(BaseGoogleLink): """Helper class for constructing Cloud Data Loss Prevention link.""" @@ -165,20 +100,6 @@ class CloudDLPJobsListLink(BaseGoogleLink): key = "cloud_dlp_jobs_list_key" format_str = DLP_JOBS_LIST_LINK - @staticmethod - def persist( - context: Context, - task_instance, - project_id: str, - ): - task_instance.xcom_push( - context=context, - key=CloudDLPJobsListLink.key, - value={ - "project_id": project_id, - }, - ) - class CloudDLPJobDetailsLink(BaseGoogleLink): """Helper class for constructing Cloud Data Loss Prevention link.""" @@ -187,22 +108,6 @@ class CloudDLPJobDetailsLink(BaseGoogleLink): key = "cloud_dlp_job_details_key" format_str = DLP_JOB_DETAILS_LINK - @staticmethod - def persist( - context: Context, - task_instance, - project_id: str, - job_name: str, - ): - task_instance.xcom_push( - context=context, - key=CloudDLPJobDetailsLink.key, - value={ - "project_id": project_id, - "job_name": job_name, - }, - ) - class CloudDLPInspectTemplatesListLink(BaseGoogleLink): """Helper class for constructing Cloud Data Loss Prevention link.""" @@ -211,20 +116,6 @@ class CloudDLPInspectTemplatesListLink(BaseGoogleLink): key = "cloud_dlp_inspect_templates_list_key" format_str = DLP_INSPECT_TEMPLATES_LIST_LINK - @staticmethod - def persist( - context: Context, - task_instance, - project_id: str, - ): - task_instance.xcom_push( - context=context, - key=CloudDLPInspectTemplatesListLink.key, - value={ - "project_id": project_id, - }, - ) - class CloudDLPInspectTemplateDetailsLink(BaseGoogleLink): """Helper class for constructing Cloud Data Loss Prevention link.""" @@ -233,22 +124,6 @@ class CloudDLPInspectTemplateDetailsLink(BaseGoogleLink): key = "cloud_dlp_inspect_template_details_key" format_str = DLP_INSPECT_TEMPLATE_DETAILS_LINK - @staticmethod - def persist( - context: Context, - task_instance, - project_id: str, - template_name: str, - ): - task_instance.xcom_push( - context=context, - key=CloudDLPInspectTemplateDetailsLink.key, - value={ - "project_id": project_id, - "template_name": template_name, - }, - ) - class CloudDLPInfoTypesListLink(BaseGoogleLink): """Helper class for constructing Cloud Data Loss Prevention link.""" @@ -257,20 +132,6 @@ class CloudDLPInfoTypesListLink(BaseGoogleLink): key = "cloud_dlp_info_types_list_key" format_str = DLP_INFO_TYPES_LIST_LINK - @staticmethod - def persist( - context: Context, - task_instance, - project_id: str, - ): - task_instance.xcom_push( - context=context, - key=CloudDLPInfoTypesListLink.key, - value={ - "project_id": project_id, - }, - ) - class CloudDLPInfoTypeDetailsLink(BaseGoogleLink): """Helper class for constructing Cloud Data Loss Prevention link.""" @@ -279,22 +140,6 @@ class CloudDLPInfoTypeDetailsLink(BaseGoogleLink): key = "cloud_dlp_info_type_details_key" format_str = DLP_INFO_TYPE_DETAILS_LINK - @staticmethod - def persist( - context: Context, - task_instance, - project_id: str, - info_type_name: str, - ): - task_instance.xcom_push( - context=context, - key=CloudDLPInfoTypeDetailsLink.key, - value={ - "project_id": project_id, - "info_type_name": info_type_name, - }, - ) - class CloudDLPPossibleInfoTypesListLink(BaseGoogleLink): """Helper class for constructing Cloud Data Loss Prevention link.""" @@ -302,17 +147,3 @@ class CloudDLPPossibleInfoTypesListLink(BaseGoogleLink): name = "Cloud DLP Possible Info Types List" key = "cloud_dlp_possible_info_types_list_key" format_str = DLP_POSSIBLE_INFO_TYPES_LIST_LINK - - @staticmethod - def persist( - context: Context, - task_instance, - project_id: str, - ): - task_instance.xcom_push( - context=context, - key=CloudDLPPossibleInfoTypesListLink.key, - value={ - "project_id": project_id, - }, - ) diff --git a/providers/google/src/airflow/providers/google/cloud/links/datacatalog.py b/providers/google/src/airflow/providers/google/cloud/links/datacatalog.py index d55687e92a5f6..00b9071d77dcc 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/datacatalog.py +++ b/providers/google/src/airflow/providers/google/cloud/links/datacatalog.py @@ -19,16 +19,10 @@ from __future__ import annotations -from typing import TYPE_CHECKING - from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.google.cloud.links.base import BaseGoogleLink from airflow.providers.google.common.deprecated import deprecated -if TYPE_CHECKING: - from airflow.models import BaseOperator - from airflow.utils.context import Context - DATACATALOG_BASE_LINK = "/datacatalog" ENTRY_GROUP_LINK = ( DATACATALOG_BASE_LINK @@ -59,20 +53,6 @@ class DataCatalogEntryGroupLink(BaseGoogleLink): key = "data_catalog_entry_group" format_str = ENTRY_GROUP_LINK - @staticmethod - def persist( - context: Context, - task_instance: BaseOperator, - entry_group_id: str, - location_id: str, - project_id: str | None, - ): - task_instance.xcom_push( - context, - key=DataCatalogEntryGroupLink.key, - value={"entry_group_id": entry_group_id, "location_id": location_id, "project_id": project_id}, - ) - @deprecated( planned_removal_date="January 30, 2026", @@ -88,26 +68,6 @@ class DataCatalogEntryLink(BaseGoogleLink): key = "data_catalog_entry" format_str = ENTRY_LINK - @staticmethod - def persist( - context: Context, - task_instance: BaseOperator, - entry_id: str, - entry_group_id: str, - location_id: str, - project_id: str | None, - ): - task_instance.xcom_push( - context, - key=DataCatalogEntryLink.key, - value={ - "entry_id": entry_id, - "entry_group_id": entry_group_id, - "location_id": location_id, - "project_id": project_id, - }, - ) - @deprecated( planned_removal_date="January 30, 2026", @@ -122,17 +82,3 @@ class DataCatalogTagTemplateLink(BaseGoogleLink): name = "Data Catalog Tag Template" key = "data_catalog_tag_template" format_str = TAG_TEMPLATE_LINK - - @staticmethod - def persist( - context: Context, - task_instance: BaseOperator, - tag_template_id: str, - location_id: str, - project_id: str | None, - ): - task_instance.xcom_push( - context, - key=DataCatalogTagTemplateLink.key, - value={"tag_template_id": tag_template_id, "location_id": location_id, "project_id": project_id}, - ) diff --git a/providers/google/src/airflow/providers/google/cloud/links/dataflow.py b/providers/google/src/airflow/providers/google/cloud/links/dataflow.py index 26cc00f8aebe7..3693433992e0e 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/dataflow.py +++ b/providers/google/src/airflow/providers/google/cloud/links/dataflow.py @@ -19,14 +19,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING - from airflow.providers.google.cloud.links.base import BaseGoogleLink -if TYPE_CHECKING: - from airflow.models import BaseOperator - from airflow.utils.context import Context - DATAFLOW_BASE_LINK = "/dataflow/jobs" DATAFLOW_JOB_LINK = DATAFLOW_BASE_LINK + "/{region}/{job_id}?project={project_id}" @@ -41,20 +35,6 @@ class DataflowJobLink(BaseGoogleLink): key = "dataflow_job_config" format_str = DATAFLOW_JOB_LINK - @staticmethod - def persist( - operator_instance: BaseOperator, - context: Context, - project_id: str | None, - region: str | None, - job_id: str | None, - ): - operator_instance.xcom_push( - context, - key=DataflowJobLink.key, - value={"project_id": project_id, "region": region, "job_id": job_id}, - ) - class DataflowPipelineLink(BaseGoogleLink): """Helper class for constructing Dataflow Pipeline Link.""" @@ -62,17 +42,3 @@ class DataflowPipelineLink(BaseGoogleLink): name = "Dataflow Pipeline" key = "dataflow_pipeline_config" format_str = DATAFLOW_PIPELINE_LINK - - @staticmethod - def persist( - operator_instance: BaseOperator, - context: Context, - project_id: str | None, - location: str | None, - pipeline_name: str | None, - ): - operator_instance.xcom_push( - context, - key=DataflowPipelineLink.key, - value={"project_id": project_id, "location": location, "pipeline_name": pipeline_name}, - ) diff --git a/providers/google/src/airflow/providers/google/cloud/links/dataform.py b/providers/google/src/airflow/providers/google/cloud/links/dataform.py index 3d94f4c5b142a..f6dde4f5dcfd4 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/dataform.py +++ b/providers/google/src/airflow/providers/google/cloud/links/dataform.py @@ -19,14 +19,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING - from airflow.providers.google.cloud.links.base import BaseGoogleLink -if TYPE_CHECKING: - from airflow.models import BaseOperator - from airflow.utils.context import Context - DATAFORM_BASE_LINK = "/bigquery/dataform" DATAFORM_WORKFLOW_INVOCATION_LINK = ( DATAFORM_BASE_LINK @@ -53,26 +47,6 @@ class DataformWorkflowInvocationLink(BaseGoogleLink): key = "dataform_workflow_invocation_config" format_str = DATAFORM_WORKFLOW_INVOCATION_LINK - @staticmethod - def persist( - operator_instance: BaseOperator, - context: Context, - project_id: str, - region: str, - repository_id: str, - workflow_invocation_id: str, - ): - operator_instance.xcom_push( - context, - key=DataformWorkflowInvocationLink.key, - value={ - "project_id": project_id, - "region": region, - "repository_id": repository_id, - "workflow_invocation_id": workflow_invocation_id, - }, - ) - class DataformRepositoryLink(BaseGoogleLink): """Helper class for constructing Dataflow repository link.""" @@ -81,24 +55,6 @@ class DataformRepositoryLink(BaseGoogleLink): key = "dataform_repository" format_str = DATAFORM_REPOSITORY_LINK - @staticmethod - def persist( - operator_instance: BaseOperator, - context: Context, - project_id: str, - region: str, - repository_id: str, - ) -> None: - operator_instance.xcom_push( - context=context, - key=DataformRepositoryLink.key, - value={ - "project_id": project_id, - "region": region, - "repository_id": repository_id, - }, - ) - class DataformWorkspaceLink(BaseGoogleLink): """Helper class for constructing Dataform workspace link.""" @@ -106,23 +62,3 @@ class DataformWorkspaceLink(BaseGoogleLink): name = "Dataform Workspace" key = "dataform_workspace" format_str = DATAFORM_WORKSPACE_LINK - - @staticmethod - def persist( - operator_instance: BaseOperator, - context: Context, - project_id: str, - region: str, - repository_id: str, - workspace_id: str, - ) -> None: - operator_instance.xcom_push( - context=context, - key=DataformWorkspaceLink.key, - value={ - "project_id": project_id, - "region": region, - "repository_id": repository_id, - "workspace_id": workspace_id, - }, - ) diff --git a/providers/google/src/airflow/providers/google/cloud/links/datafusion.py b/providers/google/src/airflow/providers/google/cloud/links/datafusion.py index c3e2351e2adc0..f6e996434fc21 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/datafusion.py +++ b/providers/google/src/airflow/providers/google/cloud/links/datafusion.py @@ -19,21 +19,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, ClassVar - -from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS - -if AIRFLOW_V_3_0_PLUS: - from airflow.sdk import BaseOperatorLink - from airflow.sdk.execution_time.xcom import XCom -else: - from airflow.models import XCom # type: ignore[no-redef] - from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef] - -if TYPE_CHECKING: - from airflow.models import BaseOperator - from airflow.models.taskinstancekey import TaskInstanceKey - from airflow.utils.context import Context +from airflow.providers.google.cloud.links.base import BaseGoogleLink BASE_LINK = "https://console.cloud.google.com/data-fusion" DATAFUSION_INSTANCE_LINK = BASE_LINK + "/locations/{region}/instances/{instance_name}?project={project_id}" @@ -41,35 +27,6 @@ DATAFUSION_PIPELINE_LINK = "{uri}/pipelines/ns/{namespace}/view/{pipeline_name}" -class BaseGoogleLink(BaseOperatorLink): - """ - Link for Google operators. - - Prevent adding ``https://console.cloud.google.com`` in front of every link - where URI is used. - """ - - name: ClassVar[str] - key: ClassVar[str] - format_str: ClassVar[str] - - def get_link( - self, - operator: BaseOperator, - *, - ti_key: TaskInstanceKey, - ) -> str: - conf = XCom.get_value(key=self.key, ti_key=ti_key) - - if not conf: - return "" - - # Add a default value for the 'namespace' parameter for backward compatibility. - conf.setdefault("namespace", "default") - - return self.format_str.format(**conf) - - class DataFusionInstanceLink(BaseGoogleLink): """Helper class for constructing Data Fusion Instance link.""" @@ -77,24 +34,6 @@ class DataFusionInstanceLink(BaseGoogleLink): key = "instance_conf" format_str = DATAFUSION_INSTANCE_LINK - @staticmethod - def persist( - context: Context, - task_instance: BaseOperator, - location: str, - instance_name: str, - project_id: str, - ): - task_instance.xcom_push( - context=context, - key=DataFusionInstanceLink.key, - value={ - "region": location, - "instance_name": instance_name, - "project_id": project_id, - }, - ) - class DataFusionPipelineLink(BaseGoogleLink): """Helper class for constructing Data Fusion Pipeline link.""" @@ -103,24 +42,6 @@ class DataFusionPipelineLink(BaseGoogleLink): key = "pipeline_conf" format_str = DATAFUSION_PIPELINE_LINK - @staticmethod - def persist( - context: Context, - task_instance: BaseOperator, - uri: str, - pipeline_name: str, - namespace: str, - ): - task_instance.xcom_push( - context=context, - key=DataFusionPipelineLink.key, - value={ - "uri": uri, - "pipeline_name": pipeline_name, - "namespace": namespace, - }, - ) - class DataFusionPipelinesLink(BaseGoogleLink): """Helper class for constructing list of Data Fusion Pipelines link.""" @@ -128,19 +49,3 @@ class DataFusionPipelinesLink(BaseGoogleLink): name = "Data Fusion Pipelines List" key = "pipelines_conf" format_str = DATAFUSION_PIPELINES_LINK - - @staticmethod - def persist( - context: Context, - task_instance: BaseOperator, - uri: str, - namespace: str, - ): - task_instance.xcom_push( - context=context, - key=DataFusionPipelinesLink.key, - value={ - "uri": uri, - "namespace": namespace, - }, - ) diff --git a/providers/google/src/airflow/providers/google/cloud/links/dataplex.py b/providers/google/src/airflow/providers/google/cloud/links/dataplex.py index 5ef0be922fd8f..0c7fa00a1cc22 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/dataplex.py +++ b/providers/google/src/airflow/providers/google/cloud/links/dataplex.py @@ -19,13 +19,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING - from airflow.providers.google.cloud.links.base import BaseGoogleLink -if TYPE_CHECKING: - from airflow.utils.context import Context - DATAPLEX_BASE_LINK = "/dataplex/process/tasks" DATAPLEX_TASK_LINK = DATAPLEX_BASE_LINK + "/{lake_id}.{task_id};location={region}/jobs?project={project_id}" DATAPLEX_TASKS_LINK = DATAPLEX_BASE_LINK + "?project={project_id}&qLake={lake_id}.{region}" @@ -53,22 +48,6 @@ class DataplexTaskLink(BaseGoogleLink): key = "task_conf" format_str = DATAPLEX_TASK_LINK - @staticmethod - def persist( - context: Context, - task_instance, - ): - task_instance.xcom_push( - context=context, - key=DataplexTaskLink.key, - value={ - "lake_id": task_instance.lake_id, - "task_id": task_instance.dataplex_task_id, - "region": task_instance.region, - "project_id": task_instance.project_id, - }, - ) - class DataplexTasksLink(BaseGoogleLink): """Helper class for constructing Dataplex Tasks link.""" @@ -77,21 +56,6 @@ class DataplexTasksLink(BaseGoogleLink): key = "tasks_conf" format_str = DATAPLEX_TASKS_LINK - @staticmethod - def persist( - context: Context, - task_instance, - ): - task_instance.xcom_push( - context=context, - key=DataplexTasksLink.key, - value={ - "project_id": task_instance.project_id, - "lake_id": task_instance.lake_id, - "region": task_instance.region, - }, - ) - class DataplexLakeLink(BaseGoogleLink): """Helper class for constructing Dataplex Lake link.""" @@ -100,21 +64,6 @@ class DataplexLakeLink(BaseGoogleLink): key = "dataplex_lake_key" format_str = DATAPLEX_LAKE_LINK - @staticmethod - def persist( - context: Context, - task_instance, - ): - task_instance.xcom_push( - context=context, - key=DataplexLakeLink.key, - value={ - "lake_id": task_instance.lake_id, - "region": task_instance.region, - "project_id": task_instance.project_id, - }, - ) - class DataplexCatalogEntryGroupLink(BaseGoogleLink): """Helper class for constructing Dataplex Catalog EntryGroup link.""" @@ -123,21 +72,6 @@ class DataplexCatalogEntryGroupLink(BaseGoogleLink): key = "dataplex_catalog_entry_group_key" format_str = DATAPLEX_CATALOG_ENTRY_GROUP_LINK - @staticmethod - def persist( - context: Context, - task_instance, - ): - task_instance.xcom_push( - context=context, - key=DataplexCatalogEntryGroupLink.key, - value={ - "entry_group_id": task_instance.entry_group_id, - "location": task_instance.location, - "project_id": task_instance.project_id, - }, - ) - class DataplexCatalogEntryGroupsLink(BaseGoogleLink): """Helper class for constructing Dataplex Catalog EntryGroups link.""" @@ -146,20 +80,6 @@ class DataplexCatalogEntryGroupsLink(BaseGoogleLink): key = "dataplex_catalog_entry_groups_key" format_str = DATAPLEX_CATALOG_ENTRY_GROUPS_LINK - @staticmethod - def persist( - context: Context, - task_instance, - ): - task_instance.xcom_push( - context=context, - key=DataplexCatalogEntryGroupsLink.key, - value={ - "location": task_instance.location, - "project_id": task_instance.project_id, - }, - ) - class DataplexCatalogEntryTypeLink(BaseGoogleLink): """Helper class for constructing Dataplex Catalog EntryType link.""" @@ -168,21 +88,6 @@ class DataplexCatalogEntryTypeLink(BaseGoogleLink): key = "dataplex_catalog_entry_type_key" format_str = DATAPLEX_CATALOG_ENTRY_TYPE_LINK - @staticmethod - def persist( - context: Context, - task_instance, - ): - task_instance.xcom_push( - context=context, - key=DataplexCatalogEntryTypeLink.key, - value={ - "entry_type_id": task_instance.entry_type_id, - "location": task_instance.location, - "project_id": task_instance.project_id, - }, - ) - class DataplexCatalogEntryTypesLink(BaseGoogleLink): """Helper class for constructing Dataplex Catalog EntryTypes link.""" @@ -191,20 +96,6 @@ class DataplexCatalogEntryTypesLink(BaseGoogleLink): key = "dataplex_catalog_entry_types_key" format_str = DATAPLEX_CATALOG_ENTRY_TYPES_LINK - @staticmethod - def persist( - context: Context, - task_instance, - ): - task_instance.xcom_push( - context=context, - key=DataplexCatalogEntryTypesLink.key, - value={ - "location": task_instance.location, - "project_id": task_instance.project_id, - }, - ) - class DataplexCatalogAspectTypeLink(BaseGoogleLink): """Helper class for constructing Dataplex Catalog AspectType link.""" @@ -213,21 +104,6 @@ class DataplexCatalogAspectTypeLink(BaseGoogleLink): key = "dataplex_catalog_aspect_type_key" format_str = DATAPLEX_CATALOG_ASPECT_TYPE_LINK - @staticmethod - def persist( - context: Context, - task_instance, - ): - task_instance.xcom_push( - context=context, - key=DataplexCatalogAspectTypeLink.key, - value={ - "aspect_type_id": task_instance.aspect_type_id, - "location": task_instance.location, - "project_id": task_instance.project_id, - }, - ) - class DataplexCatalogAspectTypesLink(BaseGoogleLink): """Helper class for constructing Dataplex Catalog AspectTypes link.""" @@ -236,20 +112,6 @@ class DataplexCatalogAspectTypesLink(BaseGoogleLink): key = "dataplex_catalog_aspect_types_key" format_str = DATAPLEX_CATALOG_ASPECT_TYPES_LINK - @staticmethod - def persist( - context: Context, - task_instance, - ): - task_instance.xcom_push( - context=context, - key=DataplexCatalogAspectTypesLink.key, - value={ - "location": task_instance.location, - "project_id": task_instance.project_id, - }, - ) - class DataplexCatalogEntryLink(BaseGoogleLink): """Helper class for constructing Dataplex Catalog Entry link.""" @@ -257,19 +119,3 @@ class DataplexCatalogEntryLink(BaseGoogleLink): name = "Dataplex Catalog Entry" key = "dataplex_catalog_entry_key" format_str = DATAPLEX_CATALOG_ENTRY_LINK - - @staticmethod - def persist( - context: Context, - task_instance, - ): - task_instance.xcom_push( - context=context, - key=DataplexCatalogEntryLink.key, - value={ - "entry_id": task_instance.entry_id, - "entry_group_id": task_instance.entry_group_id, - "location": task_instance.location, - "project_id": task_instance.project_id, - }, - ) diff --git a/providers/google/src/airflow/providers/google/cloud/links/dataprep.py b/providers/google/src/airflow/providers/google/cloud/links/dataprep.py index 66caf1cfe8933..8714243f86779 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/dataprep.py +++ b/providers/google/src/airflow/providers/google/cloud/links/dataprep.py @@ -16,13 +16,8 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING - from airflow.providers.google.cloud.links.base import BaseGoogleLink -if TYPE_CHECKING: - from airflow.utils.context import Context - BASE_LINK = "https://clouddataprep.com" DATAPREP_FLOW_LINK = BASE_LINK + "/flows/{flow_id}?projectId={project_id}" DATAPREP_JOB_GROUP_LINK = BASE_LINK + "/jobs/{job_group_id}?projectId={project_id}" @@ -35,14 +30,6 @@ class DataprepFlowLink(BaseGoogleLink): key = "dataprep_flow_page" format_str = DATAPREP_FLOW_LINK - @staticmethod - def persist(context: Context, task_instance, project_id: str, flow_id: int): - task_instance.xcom_push( - context=context, - key=DataprepFlowLink.key, - value={"project_id": project_id, "flow_id": flow_id}, - ) - class DataprepJobGroupLink(BaseGoogleLink): """Helper class for constructing Dataprep job group link.""" @@ -50,14 +37,3 @@ class DataprepJobGroupLink(BaseGoogleLink): name = "Job group details page" key = "dataprep_job_group_page" format_str = DATAPREP_JOB_GROUP_LINK - - @staticmethod - def persist(context: Context, task_instance, project_id: str, job_group_id: int): - task_instance.xcom_push( - context=context, - key=DataprepJobGroupLink.key, - value={ - "project_id": project_id, - "job_group_id": job_group_id, - }, - ) diff --git a/providers/google/src/airflow/providers/google/cloud/links/dataproc.py b/providers/google/src/airflow/providers/google/cloud/links/dataproc.py index 4fafd9da90fe5..832c66bf132e0 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/dataproc.py +++ b/providers/google/src/airflow/providers/google/cloud/links/dataproc.py @@ -189,20 +189,6 @@ class DataprocClusterLink(BaseGoogleLink): key = "dataproc_cluster" format_str = DATAPROC_CLUSTER_LINK - @staticmethod - def persist( - context: Context, - operator: BaseOperator, - cluster_id: str, - region: str, - project_id: str, - ): - operator.xcom_push( - context, - key=DataprocClusterLink.key, - value={"cluster_id": cluster_id, "region": region, "project_id": project_id}, - ) - class DataprocJobLink(BaseGoogleLink): """Helper class for constructing Dataproc Job Link.""" @@ -211,20 +197,6 @@ class DataprocJobLink(BaseGoogleLink): key = "dataproc_job" format_str = DATAPROC_JOB_LINK - @staticmethod - def persist( - context: Context, - operator: BaseOperator, - job_id: str, - region: str, - project_id: str, - ): - operator.xcom_push( - context, - key=DataprocJobLink.key, - value={"job_id": job_id, "region": region, "project_id": project_id}, - ) - class DataprocWorkflowLink(BaseGoogleLink): """Helper class for constructing Dataproc Workflow Link.""" @@ -233,14 +205,6 @@ class DataprocWorkflowLink(BaseGoogleLink): key = "dataproc_workflow" format_str = DATAPROC_WORKFLOW_LINK - @staticmethod - def persist(context: Context, operator: BaseOperator, workflow_id: str, project_id: str, region: str): - operator.xcom_push( - context, - key=DataprocWorkflowLink.key, - value={"workflow_id": workflow_id, "region": region, "project_id": project_id}, - ) - class DataprocWorkflowTemplateLink(BaseGoogleLink): """Helper class for constructing Dataproc Workflow Template Link.""" @@ -249,20 +213,6 @@ class DataprocWorkflowTemplateLink(BaseGoogleLink): key = "dataproc_workflow_template" format_str = DATAPROC_WORKFLOW_TEMPLATE_LINK - @staticmethod - def persist( - context: Context, - operator: BaseOperator, - workflow_template_id: str, - project_id: str, - region: str, - ): - operator.xcom_push( - context, - key=DataprocWorkflowTemplateLink.key, - value={"workflow_template_id": workflow_template_id, "region": region, "project_id": project_id}, - ) - class DataprocBatchLink(BaseGoogleLink): """Helper class for constructing Dataproc Batch Link.""" @@ -271,20 +221,6 @@ class DataprocBatchLink(BaseGoogleLink): key = "dataproc_batch" format_str = DATAPROC_BATCH_LINK - @staticmethod - def persist( - context: Context, - operator: BaseOperator, - batch_id: str, - project_id: str, - region: str, - ): - operator.xcom_push( - context, - key=DataprocBatchLink.key, - value={"batch_id": batch_id, "region": region, "project_id": project_id}, - ) - class DataprocBatchesListLink(BaseGoogleLink): """Helper class for constructing Dataproc Batches List Link.""" @@ -292,15 +228,3 @@ class DataprocBatchesListLink(BaseGoogleLink): name = "Dataproc Batches List" key = "dataproc_batches_list" format_str = DATAPROC_BATCHES_LINK - - @staticmethod - def persist( - context: Context, - operator: BaseOperator, - project_id: str, - ): - operator.xcom_push( - context, - key=DataprocBatchesListLink.key, - value={"project_id": project_id}, - ) diff --git a/providers/google/src/airflow/providers/google/cloud/links/datastore.py b/providers/google/src/airflow/providers/google/cloud/links/datastore.py index a75893d793e40..992c5d2a0b020 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/datastore.py +++ b/providers/google/src/airflow/providers/google/cloud/links/datastore.py @@ -16,13 +16,8 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING - from airflow.providers.google.cloud.links.base import BaseGoogleLink -if TYPE_CHECKING: - from airflow.utils.context import Context - DATASTORE_BASE_LINK = "/datastore" DATASTORE_IMPORT_EXPORT_LINK = DATASTORE_BASE_LINK + "/import-export?project={project_id}" DATASTORE_EXPORT_ENTITIES_LINK = "/storage/browser/{bucket_name}/{export_name}?project={project_id}" @@ -36,19 +31,6 @@ class CloudDatastoreImportExportLink(BaseGoogleLink): key = "import_export_conf" format_str = DATASTORE_IMPORT_EXPORT_LINK - @staticmethod - def persist( - context: Context, - task_instance, - ): - task_instance.xcom_push( - context=context, - key=CloudDatastoreImportExportLink.key, - value={ - "project_id": task_instance.project_id, - }, - ) - class CloudDatastoreEntitiesLink(BaseGoogleLink): """Helper class for constructing Cloud Datastore Entities Link.""" @@ -56,16 +38,3 @@ class CloudDatastoreEntitiesLink(BaseGoogleLink): name = "Entities" key = "entities_conf" format_str = DATASTORE_ENTITIES_LINK - - @staticmethod - def persist( - context: Context, - task_instance, - ): - task_instance.xcom_push( - context=context, - key=CloudDatastoreEntitiesLink.key, - value={ - "project_id": task_instance.project_id, - }, - ) diff --git a/providers/google/src/airflow/providers/google/cloud/links/kubernetes_engine.py b/providers/google/src/airflow/providers/google/cloud/links/kubernetes_engine.py index 4169a48133f78..2b94326cbba0a 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/kubernetes_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/links/kubernetes_engine.py @@ -51,19 +51,15 @@ class KubernetesEngineClusterLink(BaseGoogleLink): key = "kubernetes_cluster_conf" format_str = KUBERNETES_CLUSTER_LINK - @staticmethod - def persist(context: Context, task_instance, cluster: dict | Cluster | None): + @classmethod + def persist(cls, context: Context, **value): + cluster = value.get("cluster") if isinstance(cluster, dict): cluster = Cluster.from_json(json.dumps(cluster)) - task_instance.xcom_push( + super().persist( context=context, - key=KubernetesEngineClusterLink.key, - value={ - "location": task_instance.location, - "cluster_name": cluster.name, # type: ignore - "project_id": task_instance.project_id, - }, + cluster_name=cluster.name, ) @@ -74,23 +70,6 @@ class KubernetesEnginePodLink(BaseGoogleLink): key = "kubernetes_pod_conf" format_str = KUBERNETES_POD_LINK - @staticmethod - def persist( - context: Context, - task_instance, - ): - task_instance.xcom_push( - context=context, - key=KubernetesEnginePodLink.key, - value={ - "location": task_instance.location, - "cluster_name": task_instance.cluster_name, - "namespace": task_instance.pod.metadata.namespace, - "pod_name": task_instance.pod.metadata.name, - "project_id": task_instance.project_id, - }, - ) - class KubernetesEngineJobLink(BaseGoogleLink): """Helper class for constructing Kubernetes Engine Job Link.""" @@ -99,23 +78,6 @@ class KubernetesEngineJobLink(BaseGoogleLink): key = "kubernetes_job_conf" format_str = KUBERNETES_JOB_LINK - @staticmethod - def persist( - context: Context, - task_instance, - ): - task_instance.xcom_push( - context=context, - key=KubernetesEngineJobLink.key, - value={ - "location": task_instance.location, - "cluster_name": task_instance.cluster_name, - "namespace": task_instance.job.metadata.namespace, - "job_name": task_instance.job.metadata.name, - "project_id": task_instance.project_id, - }, - ) - class KubernetesEngineWorkloadsLink(BaseGoogleLink): """Helper class for constructing Kubernetes Engine Workloads Link.""" @@ -123,19 +85,3 @@ class KubernetesEngineWorkloadsLink(BaseGoogleLink): name = "Kubernetes Workloads" key = "kubernetes_workloads_conf" format_str = KUBERNETES_WORKLOADS_LINK - - @staticmethod - def persist( - context: Context, - task_instance, - ): - task_instance.xcom_push( - context=context, - key=KubernetesEngineWorkloadsLink.key, - value={ - "location": task_instance.location, - "cluster_name": task_instance.cluster_name, - "namespace": task_instance.namespace, - "project_id": task_instance.project_id, - }, - ) diff --git a/providers/google/src/airflow/providers/google/cloud/links/life_sciences.py b/providers/google/src/airflow/providers/google/cloud/links/life_sciences.py index 948142023777d..e7ac84cd67082 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/life_sciences.py +++ b/providers/google/src/airflow/providers/google/cloud/links/life_sciences.py @@ -16,13 +16,8 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING - from airflow.providers.google.cloud.links.base import BaseGoogleLink -if TYPE_CHECKING: - from airflow.utils.context import Context - BASE_LINK = "https://console.cloud.google.com/lifesciences" LIFESCIENCES_LIST_LINK = BASE_LINK + "/pipelines?project={project_id}" @@ -33,17 +28,3 @@ class LifeSciencesLink(BaseGoogleLink): name = "Life Sciences" key = "lifesciences_key" format_str = LIFESCIENCES_LIST_LINK - - @staticmethod - def persist( - context: Context, - task_instance, - project_id: str, - ): - task_instance.xcom_push( - context=context, - key=LifeSciencesLink.key, - value={ - "project_id": project_id, - }, - ) diff --git a/providers/google/src/airflow/providers/google/cloud/links/managed_kafka.py b/providers/google/src/airflow/providers/google/cloud/links/managed_kafka.py index 45b62901c5515..5d60bcc23af9b 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/managed_kafka.py +++ b/providers/google/src/airflow/providers/google/cloud/links/managed_kafka.py @@ -16,13 +16,8 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING - from airflow.providers.google.cloud.links.base import BaseGoogleLink -if TYPE_CHECKING: - from airflow.utils.context import Context - MANAGED_KAFKA_BASE_LINK = "/managedkafka" MANAGED_KAFKA_CLUSTER_LINK = ( MANAGED_KAFKA_BASE_LINK + "/{location}/clusters/{cluster_id}?project={project_id}" @@ -44,22 +39,6 @@ class ApacheKafkaClusterLink(BaseGoogleLink): key = "cluster_conf" format_str = MANAGED_KAFKA_CLUSTER_LINK - @staticmethod - def persist( - context: Context, - task_instance, - cluster_id: str, - ): - task_instance.xcom_push( - context=context, - key=ApacheKafkaClusterLink.key, - value={ - "location": task_instance.location, - "cluster_id": cluster_id, - "project_id": task_instance.project_id, - }, - ) - class ApacheKafkaClusterListLink(BaseGoogleLink): """Helper class for constructing Apache Kafka Clusters link.""" @@ -68,19 +47,6 @@ class ApacheKafkaClusterListLink(BaseGoogleLink): key = "cluster_list_conf" format_str = MANAGED_KAFKA_CLUSTER_LIST_LINK - @staticmethod - def persist( - context: Context, - task_instance, - ): - task_instance.xcom_push( - context=context, - key=ApacheKafkaClusterListLink.key, - value={ - "project_id": task_instance.project_id, - }, - ) - class ApacheKafkaTopicLink(BaseGoogleLink): """Helper class for constructing Apache Kafka Topic link.""" @@ -89,24 +55,6 @@ class ApacheKafkaTopicLink(BaseGoogleLink): key = "topic_conf" format_str = MANAGED_KAFKA_TOPIC_LINK - @staticmethod - def persist( - context: Context, - task_instance, - cluster_id: str, - topic_id: str, - ): - task_instance.xcom_push( - context=context, - key=ApacheKafkaTopicLink.key, - value={ - "location": task_instance.location, - "cluster_id": cluster_id, - "topic_id": topic_id, - "project_id": task_instance.project_id, - }, - ) - class ApacheKafkaConsumerGroupLink(BaseGoogleLink): """Helper class for constructing Apache Kafka Consumer Group link.""" @@ -114,21 +62,3 @@ class ApacheKafkaConsumerGroupLink(BaseGoogleLink): name = "Apache Kafka Consumer Group" key = "consumer_group_conf" format_str = MANAGED_KAFKA_CONSUMER_GROUP_LINK - - @staticmethod - def persist( - context: Context, - task_instance, - cluster_id: str, - consumer_group_id: str, - ): - task_instance.xcom_push( - context=context, - key=ApacheKafkaConsumerGroupLink.key, - value={ - "location": task_instance.location, - "cluster_id": cluster_id, - "consumer_group_id": consumer_group_id, - "project_id": task_instance.project_id, - }, - ) diff --git a/providers/google/src/airflow/providers/google/cloud/links/mlengine.py b/providers/google/src/airflow/providers/google/cloud/links/mlengine.py index 88a111a819d3a..040a4181020a2 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/mlengine.py +++ b/providers/google/src/airflow/providers/google/cloud/links/mlengine.py @@ -19,14 +19,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING - from airflow.providers.google.cloud.links.base import BaseGoogleLink -if TYPE_CHECKING: - from airflow.utils.context import Context - - MLENGINE_BASE_LINK = "https://console.cloud.google.com/ai-platform" MLENGINE_MODEL_DETAILS_LINK = MLENGINE_BASE_LINK + "/models/{model_id}/versions?project={project_id}" MLENGINE_MODEL_VERSION_DETAILS_LINK = ( @@ -44,19 +38,6 @@ class MLEngineModelLink(BaseGoogleLink): key = "ml_engine_model" format_str = MLENGINE_MODEL_DETAILS_LINK - @staticmethod - def persist( - context: Context, - task_instance, - model_id: str, - project_id: str, - ): - task_instance.xcom_push( - context, - key=MLEngineModelLink.key, - value={"model_id": model_id, "project_id": project_id}, - ) - class MLEngineModelsListLink(BaseGoogleLink): """Helper class for constructing ML Engine link.""" @@ -65,18 +46,6 @@ class MLEngineModelsListLink(BaseGoogleLink): key = "ml_engine_models_list" format_str = MLENGINE_MODELS_LIST_LINK - @staticmethod - def persist( - context: Context, - task_instance, - project_id: str, - ): - task_instance.xcom_push( - context, - key=MLEngineModelsListLink.key, - value={"project_id": project_id}, - ) - class MLEngineJobDetailsLink(BaseGoogleLink): """Helper class for constructing ML Engine link.""" @@ -85,19 +54,6 @@ class MLEngineJobDetailsLink(BaseGoogleLink): key = "ml_engine_job_details" format_str = MLENGINE_JOB_DETAILS_LINK - @staticmethod - def persist( - context: Context, - task_instance, - job_id: str, - project_id: str, - ): - task_instance.xcom_push( - context, - key=MLEngineJobDetailsLink.key, - value={"job_id": job_id, "project_id": project_id}, - ) - class MLEngineModelVersionDetailsLink(BaseGoogleLink): """Helper class for constructing ML Engine link.""" @@ -106,20 +62,6 @@ class MLEngineModelVersionDetailsLink(BaseGoogleLink): key = "ml_engine_version_details" format_str = MLENGINE_MODEL_VERSION_DETAILS_LINK - @staticmethod - def persist( - context: Context, - task_instance, - model_id: str, - project_id: str, - version_id: str, - ): - task_instance.xcom_push( - context, - key=MLEngineModelVersionDetailsLink.key, - value={"model_id": model_id, "project_id": project_id, "version_id": version_id}, - ) - class MLEngineJobSListLink(BaseGoogleLink): """Helper class for constructing ML Engine link.""" @@ -127,15 +69,3 @@ class MLEngineJobSListLink(BaseGoogleLink): name = "MLEngine Jobs List" key = "ml_engine_jobs_list" format_str = MLENGINE_JOBS_LIST_LINK - - @staticmethod - def persist( - context: Context, - task_instance, - project_id: str, - ): - task_instance.xcom_push( - context, - key=MLEngineJobSListLink.key, - value={"project_id": project_id}, - ) diff --git a/providers/google/src/airflow/providers/google/cloud/links/pubsub.py b/providers/google/src/airflow/providers/google/cloud/links/pubsub.py index 0ab8dd97e23e5..d5b3134254c66 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/pubsub.py +++ b/providers/google/src/airflow/providers/google/cloud/links/pubsub.py @@ -19,14 +19,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING - from airflow.providers.google.cloud.links.base import BaseGoogleLink -if TYPE_CHECKING: - from airflow.models import BaseOperator - from airflow.utils.context import Context - PUBSUB_BASE_LINK = "/cloudpubsub" PUBSUB_TOPIC_LINK = PUBSUB_BASE_LINK + "/topic/detail/{topic_id}?project={project_id}" PUBSUB_SUBSCRIPTION_LINK = PUBSUB_BASE_LINK + "/subscription/detail/{subscription_id}?project={project_id}" @@ -39,19 +33,6 @@ class PubSubTopicLink(BaseGoogleLink): key = "pubsub_topic" format_str = PUBSUB_TOPIC_LINK - @staticmethod - def persist( - context: Context, - task_instance: BaseOperator, - topic_id: str, - project_id: str | None, - ): - task_instance.xcom_push( - context, - key=PubSubTopicLink.key, - value={"topic_id": topic_id, "project_id": project_id}, - ) - class PubSubSubscriptionLink(BaseGoogleLink): """Helper class for constructing Pub/Sub Subscription Link.""" @@ -59,16 +40,3 @@ class PubSubSubscriptionLink(BaseGoogleLink): name = "Pub/Sub Subscription" key = "pubsub_subscription" format_str = PUBSUB_SUBSCRIPTION_LINK - - @staticmethod - def persist( - context: Context, - task_instance: BaseOperator, - subscription_id: str | None, - project_id: str | None, - ): - task_instance.xcom_push( - context, - key=PubSubSubscriptionLink.key, - value={"subscription_id": subscription_id, "project_id": project_id}, - ) diff --git a/providers/google/src/airflow/providers/google/cloud/links/spanner.py b/providers/google/src/airflow/providers/google/cloud/links/spanner.py index 2a9d5e09bf155..77f3df4a1a464 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/spanner.py +++ b/providers/google/src/airflow/providers/google/cloud/links/spanner.py @@ -19,14 +19,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING - from airflow.providers.google.cloud.links.base import BaseGoogleLink -if TYPE_CHECKING: - from airflow.models import BaseOperator - from airflow.utils.context import Context - SPANNER_BASE_LINK = "/spanner/instances" SPANNER_INSTANCE_LINK = SPANNER_BASE_LINK + "/{instance_id}/details/databases?project={project_id}" SPANNER_DATABASE_LINK = ( @@ -41,19 +35,6 @@ class SpannerInstanceLink(BaseGoogleLink): key = "spanner_instance" format_str = SPANNER_INSTANCE_LINK - @staticmethod - def persist( - context: Context, - task_instance: BaseOperator, - instance_id: str, - project_id: str | None, - ): - task_instance.xcom_push( - context, - key=SpannerInstanceLink.key, - value={"instance_id": instance_id, "project_id": project_id}, - ) - class SpannerDatabaseLink(BaseGoogleLink): """Helper class for constructing Spanner Database Link.""" @@ -61,17 +42,3 @@ class SpannerDatabaseLink(BaseGoogleLink): name = "Spanner Database" key = "spanner_database" format_str = SPANNER_DATABASE_LINK - - @staticmethod - def persist( - context: Context, - task_instance: BaseOperator, - instance_id: str, - database_id: str, - project_id: str | None, - ): - task_instance.xcom_push( - context, - key=SpannerDatabaseLink.key, - value={"instance_id": instance_id, "database_id": database_id, "project_id": project_id}, - ) diff --git a/providers/google/src/airflow/providers/google/cloud/links/stackdriver.py b/providers/google/src/airflow/providers/google/cloud/links/stackdriver.py index 0dbc27a0d5cf0..44393e979b664 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/stackdriver.py +++ b/providers/google/src/airflow/providers/google/cloud/links/stackdriver.py @@ -19,14 +19,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING - from airflow.providers.google.cloud.links.base import BaseGoogleLink -if TYPE_CHECKING: - from airflow.models import BaseOperator - from airflow.utils.context import Context - STACKDRIVER_BASE_LINK = "/monitoring/alerting" STACKDRIVER_NOTIFICATIONS_LINK = STACKDRIVER_BASE_LINK + "/notifications?project={project_id}" STACKDRIVER_POLICIES_LINK = STACKDRIVER_BASE_LINK + "/policies?project={project_id}" @@ -39,18 +33,6 @@ class StackdriverNotificationsLink(BaseGoogleLink): key = "stackdriver_notifications" format_str = STACKDRIVER_NOTIFICATIONS_LINK - @staticmethod - def persist( - operator_instance: BaseOperator, - context: Context, - project_id: str | None, - ): - operator_instance.xcom_push( - context, - key=StackdriverNotificationsLink.key, - value={"project_id": project_id}, - ) - class StackdriverPoliciesLink(BaseGoogleLink): """Helper class for constructing Stackdriver Policies Link.""" @@ -58,15 +40,3 @@ class StackdriverPoliciesLink(BaseGoogleLink): name = "Cloud Monitoring Policies" key = "stackdriver_policies" format_str = STACKDRIVER_POLICIES_LINK - - @staticmethod - def persist( - operator_instance: BaseOperator, - context: Context, - project_id: str | None, - ): - operator_instance.xcom_push( - context, - key=StackdriverPoliciesLink.key, - value={"project_id": project_id}, - ) diff --git a/providers/google/src/airflow/providers/google/cloud/links/translate.py b/providers/google/src/airflow/providers/google/cloud/links/translate.py index 4415b058829fb..bbd8febe37af7 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/translate.py +++ b/providers/google/src/airflow/providers/google/cloud/links/translate.py @@ -70,19 +70,6 @@ class TranslationLegacyDatasetLink(BaseGoogleLink): key = "translation_legacy_dataset" format_str = TRANSLATION_LEGACY_DATASET_LINK - @staticmethod - def persist( - context: Context, - task_instance, - dataset_id: str, - project_id: str, - ): - task_instance.xcom_push( - context, - key=TranslationLegacyDatasetLink.key, - value={"location": task_instance.location, "dataset_id": dataset_id, "project_id": project_id}, - ) - class TranslationDatasetListLink(BaseGoogleLink): """Helper class for constructing Translation Dataset List link.""" @@ -91,20 +78,6 @@ class TranslationDatasetListLink(BaseGoogleLink): key = "translation_dataset_list" format_str = TRANSLATION_DATASET_LIST_LINK - @staticmethod - def persist( - context: Context, - task_instance, - project_id: str, - ): - task_instance.xcom_push( - context, - key=TranslationDatasetListLink.key, - value={ - "project_id": project_id, - }, - ) - class TranslationLegacyModelLink(BaseGoogleLink): """ @@ -117,25 +90,6 @@ class TranslationLegacyModelLink(BaseGoogleLink): key = "translation_legacy_model" format_str = TRANSLATION_LEGACY_MODEL_LINK - @staticmethod - def persist( - context: Context, - task_instance, - dataset_id: str, - model_id: str, - project_id: str, - ): - task_instance.xcom_push( - context, - key=TranslationLegacyModelLink.key, - value={ - "location": task_instance.location, - "dataset_id": dataset_id, - "model_id": model_id, - "project_id": project_id, - }, - ) - class TranslationLegacyModelTrainLink(BaseGoogleLink): """ @@ -148,22 +102,6 @@ class TranslationLegacyModelTrainLink(BaseGoogleLink): key = "translation_legacy_model_train" format_str = TRANSLATION_LEGACY_MODEL_TRAIN_LINK - @staticmethod - def persist( - context: Context, - task_instance, - project_id: str, - ): - task_instance.xcom_push( - context, - key=TranslationLegacyModelTrainLink.key, - value={ - "location": task_instance.location, - "dataset_id": task_instance.model["dataset_id"], - "project_id": project_id, - }, - ) - class TranslationLegacyModelPredictLink(BaseGoogleLink): """ @@ -176,25 +114,6 @@ class TranslationLegacyModelPredictLink(BaseGoogleLink): key = "translation_legacy_model_predict" format_str = TRANSLATION_LEGACY_MODEL_PREDICT_LINK - @staticmethod - def persist( - context: Context, - task_instance, - model_id: str, - project_id: str, - dataset_id: str, - ): - task_instance.xcom_push( - context, - key=TranslationLegacyModelPredictLink.key, - value={ - "location": task_instance.location, - "dataset_id": dataset_id, - "model_id": model_id, - "project_id": project_id, - }, - ) - class TranslateTextBatchLink(BaseGoogleLink): """ @@ -212,20 +131,13 @@ class TranslateTextBatchLink(BaseGoogleLink): def extract_output_uri_prefix(output_config): return output_config["gcs_destination"]["output_uri_prefix"].rpartition("gs://")[-1] - @staticmethod - def persist( - context: Context, - task_instance, - project_id: str, - output_config: dict, - ): - task_instance.xcom_push( - context, - key=TranslateTextBatchLink.key, - value={ - "project_id": project_id, - "output_uri_prefix": TranslateTextBatchLink.extract_output_uri_prefix(output_config), - }, + @classmethod + def persist(cls, context: Context, **value): + output_config = value.get("output_config") + super().persist( + context=context, + project_id=value.get("project_id"), + output_uri_prefix=cls.extract_output_uri_prefix(output_config), ) @@ -240,19 +152,6 @@ class TranslationNativeDatasetLink(BaseGoogleLink): key = "translation_naive_dataset" format_str = TRANSLATION_NATIVE_DATASET_LINK - @staticmethod - def persist( - context: Context, - task_instance, - dataset_id: str, - project_id: str, - ): - task_instance.xcom_push( - context, - key=TranslationNativeDatasetLink.key, - value={"location": task_instance.location, "dataset_id": dataset_id, "project_id": project_id}, - ) - class TranslationDatasetsListLink(BaseGoogleLink): """ @@ -265,20 +164,6 @@ class TranslationDatasetsListLink(BaseGoogleLink): key = "translation_dataset_list" format_str = TRANSLATION_DATASET_LIST_LINK - @staticmethod - def persist( - context: Context, - task_instance, - project_id: str, - ): - task_instance.xcom_push( - context, - key=TranslationDatasetsListLink.key, - value={ - "project_id": project_id, - }, - ) - class TranslationModelLink(BaseGoogleLink): """ @@ -291,25 +176,6 @@ class TranslationModelLink(BaseGoogleLink): key = "translation_model" format_str = TRANSLATION_NATIVE_MODEL_LINK - @staticmethod - def persist( - context: Context, - task_instance, - dataset_id: str, - model_id: str, - project_id: str, - ): - task_instance.xcom_push( - context, - key=TranslationLegacyModelLink.key, - value={ - "location": task_instance.location, - "dataset_id": dataset_id, - "model_id": model_id, - "project_id": project_id, - }, - ) - class TranslationModelsListLink(BaseGoogleLink): """ @@ -322,20 +188,6 @@ class TranslationModelsListLink(BaseGoogleLink): key = "translation_models_list" format_str = TRANSLATION_MODELS_LIST_LINK - @staticmethod - def persist( - context: Context, - task_instance, - project_id: str, - ): - task_instance.xcom_push( - context, - key=TranslationModelsListLink.key, - value={ - "project_id": project_id, - }, - ) - class TranslateResultByOutputConfigLink(BaseGoogleLink): """ @@ -353,22 +205,14 @@ class TranslateResultByOutputConfigLink(BaseGoogleLink): def extract_output_uri_prefix(output_config): return output_config["gcs_destination"]["output_uri_prefix"].rpartition("gs://")[-1] - @staticmethod - def persist( - context: Context, - task_instance, - project_id: str, - output_config: dict, - ): - task_instance.xcom_push( - context, - key=TranslateResultByOutputConfigLink.key, - value={ - "project_id": project_id, - "output_uri_prefix": TranslateResultByOutputConfigLink.extract_output_uri_prefix( - output_config - ), - }, + @classmethod + def persist(cls, context: Context, **value): + output_config = value.get("output_config") + output_uri_prefix = cls.extract_output_uri_prefix(output_config) + super().persist( + context=context, + project_id=value.get("project_id"), + output_uri_prefix=output_uri_prefix, ) @@ -382,17 +226,3 @@ class TranslationGlossariesListLink(BaseGoogleLink): name = "Translation Glossaries List" key = "translation_glossaries_list" format_str = TRANSLATION_HUB_RESOURCES_LIST_LINK - - @staticmethod - def persist( - context: Context, - task_instance, - project_id: str, - ): - task_instance.xcom_push( - context, - key=TranslationGlossariesListLink.key, - value={ - "project_id": project_id, - }, - ) diff --git a/providers/google/src/airflow/providers/google/cloud/links/vertex_ai.py b/providers/google/src/airflow/providers/google/cloud/links/vertex_ai.py index d1f32d052a884..f735e07f58fe3 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/vertex_ai.py +++ b/providers/google/src/airflow/providers/google/cloud/links/vertex_ai.py @@ -67,22 +67,6 @@ class VertexAIModelLink(BaseGoogleLink): key = "model_conf" format_str = VERTEX_AI_MODEL_LINK - @staticmethod - def persist( - context: Context, - task_instance, - model_id: str, - ): - task_instance.xcom_push( - context=context, - key=VertexAIModelLink.key, - value={ - "model_id": model_id, - "region": task_instance.region, - "project_id": task_instance.project_id, - }, - ) - class VertexAIModelListLink(BaseGoogleLink): """Helper class for constructing Vertex AI Models Link.""" @@ -91,19 +75,6 @@ class VertexAIModelListLink(BaseGoogleLink): key = "models_conf" format_str = VERTEX_AI_MODEL_LIST_LINK - @staticmethod - def persist( - context: Context, - task_instance, - ): - task_instance.xcom_push( - context=context, - key=VertexAIModelListLink.key, - value={ - "project_id": task_instance.project_id, - }, - ) - class VertexAIModelExportLink(BaseGoogleLink): """Helper class for constructing Vertex AI Model Export Link.""" @@ -117,19 +88,15 @@ def extract_bucket_name(config): """Return bucket name from output configuration.""" return config["artifact_destination"]["output_uri_prefix"].rpartition("gs://")[-1] - @staticmethod - def persist( - context: Context, - task_instance, - ): - task_instance.xcom_push( + @classmethod + def persist(cls, context: Context, **value): + output_config = value.get("output_config") + bucket_name = cls.extract_bucket_name(output_config) + super().persist( context=context, - key=VertexAIModelExportLink.key, - value={ - "project_id": task_instance.project_id, - "model_id": task_instance.model_id, - "bucket_name": VertexAIModelExportLink.extract_bucket_name(task_instance.output_config), - }, + project_id=value.get("project_id"), + model_id=value.get("model_id"), + bucket_name=bucket_name, ) @@ -140,22 +107,6 @@ class VertexAITrainingLink(BaseGoogleLink): key = "training_conf" format_str = VERTEX_AI_TRAINING_LINK - @staticmethod - def persist( - context: Context, - task_instance, - training_id: str, - ): - task_instance.xcom_push( - context=context, - key=VertexAITrainingLink.key, - value={ - "training_id": training_id, - "region": task_instance.region, - "project_id": task_instance.project_id, - }, - ) - class VertexAITrainingPipelinesLink(BaseGoogleLink): """Helper class for constructing Vertex AI Training Pipelines link.""" @@ -164,19 +115,6 @@ class VertexAITrainingPipelinesLink(BaseGoogleLink): key = "pipelines_conf" format_str = VERTEX_AI_TRAINING_PIPELINES_LINK - @staticmethod - def persist( - context: Context, - task_instance, - ): - task_instance.xcom_push( - context=context, - key=VertexAITrainingPipelinesLink.key, - value={ - "project_id": task_instance.project_id, - }, - ) - class VertexAIDatasetLink(BaseGoogleLink): """Helper class for constructing Vertex AI Dataset link.""" @@ -185,18 +123,6 @@ class VertexAIDatasetLink(BaseGoogleLink): key = "dataset_conf" format_str = VERTEX_AI_DATASET_LINK - @staticmethod - def persist(context: Context, task_instance, dataset_id: str): - task_instance.xcom_push( - context=context, - key=VertexAIDatasetLink.key, - value={ - "dataset_id": dataset_id, - "region": task_instance.region, - "project_id": task_instance.project_id, - }, - ) - class VertexAIDatasetListLink(BaseGoogleLink): """Helper class for constructing Vertex AI Datasets Link.""" @@ -205,19 +131,6 @@ class VertexAIDatasetListLink(BaseGoogleLink): key = "datasets_conf" format_str = VERTEX_AI_DATASET_LIST_LINK - @staticmethod - def persist( - context: Context, - task_instance, - ): - task_instance.xcom_push( - context=context, - key=VertexAIDatasetListLink.key, - value={ - "project_id": task_instance.project_id, - }, - ) - class VertexAIHyperparameterTuningJobListLink(BaseGoogleLink): """Helper class for constructing Vertex AI HyperparameterTuningJobs Link.""" @@ -226,19 +139,6 @@ class VertexAIHyperparameterTuningJobListLink(BaseGoogleLink): key = "hyperparameter_tuning_jobs_conf" format_str = VERTEX_AI_HYPERPARAMETER_TUNING_JOB_LIST_LINK - @staticmethod - def persist( - context: Context, - task_instance, - ): - task_instance.xcom_push( - context=context, - key=VertexAIHyperparameterTuningJobListLink.key, - value={ - "project_id": task_instance.project_id, - }, - ) - class VertexAIBatchPredictionJobLink(BaseGoogleLink): """Helper class for constructing Vertex AI BatchPredictionJob link.""" @@ -247,22 +147,6 @@ class VertexAIBatchPredictionJobLink(BaseGoogleLink): key = "batch_prediction_job_conf" format_str = VERTEX_AI_BATCH_PREDICTION_JOB_LINK - @staticmethod - def persist( - context: Context, - task_instance, - batch_prediction_job_id: str, - ): - task_instance.xcom_push( - context=context, - key=VertexAIBatchPredictionJobLink.key, - value={ - "batch_prediction_job_id": batch_prediction_job_id, - "region": task_instance.region, - "project_id": task_instance.project_id, - }, - ) - class VertexAIBatchPredictionJobListLink(BaseGoogleLink): """Helper class for constructing Vertex AI BatchPredictionJobList link.""" @@ -271,19 +155,6 @@ class VertexAIBatchPredictionJobListLink(BaseGoogleLink): key = "batch_prediction_jobs_conf" format_str = VERTEX_AI_BATCH_PREDICTION_JOB_LIST_LINK - @staticmethod - def persist( - context: Context, - task_instance, - ): - task_instance.xcom_push( - context=context, - key=VertexAIBatchPredictionJobListLink.key, - value={ - "project_id": task_instance.project_id, - }, - ) - class VertexAIEndpointLink(BaseGoogleLink): """Helper class for constructing Vertex AI Endpoint link.""" @@ -292,22 +163,6 @@ class VertexAIEndpointLink(BaseGoogleLink): key = "endpoint_conf" format_str = VERTEX_AI_ENDPOINT_LINK - @staticmethod - def persist( - context: Context, - task_instance, - endpoint_id: str, - ): - task_instance.xcom_push( - context=context, - key=VertexAIEndpointLink.key, - value={ - "endpoint_id": endpoint_id, - "region": task_instance.region, - "project_id": task_instance.project_id, - }, - ) - class VertexAIEndpointListLink(BaseGoogleLink): """Helper class for constructing Vertex AI EndpointList link.""" @@ -316,19 +171,6 @@ class VertexAIEndpointListLink(BaseGoogleLink): key = "endpoints_conf" format_str = VERTEX_AI_ENDPOINT_LIST_LINK - @staticmethod - def persist( - context: Context, - task_instance, - ): - task_instance.xcom_push( - context=context, - key=VertexAIEndpointListLink.key, - value={ - "project_id": task_instance.project_id, - }, - ) - class VertexAIPipelineJobLink(BaseGoogleLink): """Helper class for constructing Vertex AI PipelineJob link.""" @@ -337,22 +179,6 @@ class VertexAIPipelineJobLink(BaseGoogleLink): key = "pipeline_job_conf" format_str = VERTEX_AI_PIPELINE_JOB_LINK - @staticmethod - def persist( - context: Context, - task_instance, - pipeline_id: str, - ): - task_instance.xcom_push( - context=context, - key=VertexAIPipelineJobLink.key, - value={ - "pipeline_id": pipeline_id, - "region": task_instance.region, - "project_id": task_instance.project_id, - }, - ) - class VertexAIPipelineJobListLink(BaseGoogleLink): """Helper class for constructing Vertex AI PipelineJobList link.""" @@ -361,19 +187,6 @@ class VertexAIPipelineJobListLink(BaseGoogleLink): key = "pipeline_job_list_conf" format_str = VERTEX_AI_PIPELINE_JOB_LIST_LINK - @staticmethod - def persist( - context: Context, - task_instance, - ): - task_instance.xcom_push( - context=context, - key=VertexAIPipelineJobListLink.key, - value={ - "project_id": task_instance.project_id, - }, - ) - class VertexAIRayClusterLink(BaseGoogleLink): """Helper class for constructing Vertex AI Ray Cluster link.""" @@ -382,22 +195,6 @@ class VertexAIRayClusterLink(BaseGoogleLink): key = "ray_cluster_conf" format_str = VERTEX_AI_RAY_CLUSTER_LINK - @staticmethod - def persist( - context: Context, - task_instance, - cluster_id: str, - ): - task_instance.xcom_push( - context=context, - key=VertexAIRayClusterLink.key, - value={ - "location": task_instance.location, - "cluster_id": cluster_id, - "project_id": task_instance.project_id, - }, - ) - class VertexAIRayClusterListLink(BaseGoogleLink): """Helper class for constructing Vertex AI Ray Cluster List link.""" @@ -405,16 +202,3 @@ class VertexAIRayClusterListLink(BaseGoogleLink): name = "Ray Cluster List" key = "ray_cluster_list_conf" format_str = VERTEX_AI_RAY_CLUSTER_LIST_LINK - - @staticmethod - def persist( - context: Context, - task_instance, - ): - task_instance.xcom_push( - context=context, - key=VertexAIRayClusterListLink.key, - value={ - "project_id": task_instance.project_id, - }, - ) diff --git a/providers/google/src/airflow/providers/google/cloud/links/workflows.py b/providers/google/src/airflow/providers/google/cloud/links/workflows.py index f063c7de700e6..c0e3f3a445753 100644 --- a/providers/google/src/airflow/providers/google/cloud/links/workflows.py +++ b/providers/google/src/airflow/providers/google/cloud/links/workflows.py @@ -19,14 +19,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING - from airflow.providers.google.cloud.links.base import BaseGoogleLink -if TYPE_CHECKING: - from airflow.models import BaseOperator - from airflow.utils.context import Context - WORKFLOWS_BASE_LINK = "/workflows" WORKFLOW_LINK = WORKFLOWS_BASE_LINK + "/workflow/{location_id}/{workflow_id}/executions?project={project_id}" WORKFLOWS_LINK = WORKFLOWS_BASE_LINK + "?project={project_id}" @@ -43,20 +37,6 @@ class WorkflowsWorkflowDetailsLink(BaseGoogleLink): key = "workflow_details" format_str = WORKFLOW_LINK - @staticmethod - def persist( - context: Context, - task_instance: BaseOperator, - location_id: str, - workflow_id: str, - project_id: str | None, - ): - task_instance.xcom_push( - context, - key=WorkflowsWorkflowDetailsLink.key, - value={"location_id": location_id, "workflow_id": workflow_id, "project_id": project_id}, - ) - class WorkflowsListOfWorkflowsLink(BaseGoogleLink): """Helper class for constructing list of Workflows Link.""" @@ -65,18 +45,6 @@ class WorkflowsListOfWorkflowsLink(BaseGoogleLink): key = "list_of_workflows" format_str = WORKFLOWS_LINK - @staticmethod - def persist( - context: Context, - task_instance: BaseOperator, - project_id: str | None, - ): - task_instance.xcom_push( - context, - key=WorkflowsListOfWorkflowsLink.key, - value={"project_id": project_id}, - ) - class WorkflowsExecutionLink(BaseGoogleLink): """Helper class for constructing Workflows Execution Link.""" @@ -84,23 +52,3 @@ class WorkflowsExecutionLink(BaseGoogleLink): name = "Workflow Execution" key = "workflow_execution" format_str = EXECUTION_LINK - - @staticmethod - def persist( - context: Context, - task_instance: BaseOperator, - location_id: str, - workflow_id: str, - execution_id: str, - project_id: str | None, - ): - task_instance.xcom_push( - context, - key=WorkflowsExecutionLink.key, - value={ - "location_id": location_id, - "workflow_id": workflow_id, - "execution_id": execution_id, - "project_id": project_id, - }, - ) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/alloy_db.py b/providers/google/src/airflow/providers/google/cloud/operators/alloy_db.py index 5346076699bec..4c157b5245564 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/alloy_db.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/alloy_db.py @@ -21,7 +21,7 @@ from collections.abc import Sequence from functools import cached_property -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from google.api_core.exceptions import NotFound from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault @@ -228,15 +228,16 @@ def _get_cluster(self) -> proto.Message | None: return result return None - def execute(self, context: Context) -> dict | None: - AlloyDBClusterLink.persist( - context=context, - task_instance=self, - location_id=self.location, - cluster_id=self.cluster_id, - project_id=self.project_id, - ) + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "location_id": self.location, + "cluster_id": self.cluster_id, + "project_id": self.project_id, + } + def execute(self, context: Context) -> dict | None: + AlloyDBClusterLink.persist(context=context) if cluster := self._get_cluster(): return cluster @@ -334,14 +335,16 @@ def __init__( self.update_mask = update_mask self.allow_missing = allow_missing + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "location_id": self.location, + "cluster_id": self.cluster_id, + "project_id": self.project_id, + } + def execute(self, context: Context) -> dict | None: - AlloyDBClusterLink.persist( - context=context, - task_instance=self, - location_id=self.location, - cluster_id=self.cluster_id, - project_id=self.project_id, - ) + AlloyDBClusterLink.persist(context=context) if self.validate_request: self.log.info("Validating an Update AlloyDB cluster request.") else: @@ -545,14 +548,16 @@ def _get_instance(self) -> proto.Message | None: return result return None + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "location_id": self.location, + "cluster_id": self.cluster_id, + "project_id": self.project_id, + } + def execute(self, context: Context) -> dict | None: - AlloyDBClusterLink.persist( - context=context, - task_instance=self, - location_id=self.location, - cluster_id=self.cluster_id, - project_id=self.project_id, - ) + AlloyDBClusterLink.persist(context=context) if instance := self._get_instance(): return instance @@ -654,14 +659,16 @@ def __init__( self.update_mask = update_mask self.allow_missing = allow_missing + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "location_id": self.location, + "cluster_id": self.cluster_id, + "project_id": self.project_id, + } + def execute(self, context: Context) -> dict | None: - AlloyDBClusterLink.persist( - context=context, - task_instance=self, - location_id=self.location, - cluster_id=self.cluster_id, - project_id=self.project_id, - ) + AlloyDBClusterLink.persist(context=context) if self.validate_request: self.log.info("Validating an Update AlloyDB instance request.") else: @@ -861,14 +868,16 @@ def _get_user(self) -> proto.Message | None: return result return None + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "location_id": self.location, + "cluster_id": self.cluster_id, + "project_id": self.project_id, + } + def execute(self, context: Context) -> dict | None: - AlloyDBUsersLink.persist( - context=context, - task_instance=self, - location_id=self.location, - cluster_id=self.cluster_id, - project_id=self.project_id, - ) + AlloyDBUsersLink.persist(context=context) if (_user := self._get_user()) is not None: return _user @@ -968,14 +977,16 @@ def __init__( self.update_mask = update_mask self.allow_missing = allow_missing + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "location_id": self.location, + "cluster_id": self.cluster_id, + "project_id": self.project_id, + } + def execute(self, context: Context) -> dict | None: - AlloyDBUsersLink.persist( - context=context, - task_instance=self, - location_id=self.location, - cluster_id=self.cluster_id, - project_id=self.project_id, - ) + AlloyDBUsersLink.persist(context=context) if self.validate_request: self.log.info("Validating an Update AlloyDB user request.") else: @@ -1159,12 +1170,14 @@ def _get_backup(self) -> proto.Message | None: return result return None + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "project_id": self.project_id, + } + def execute(self, context: Context) -> dict | None: - AlloyDBBackupsLink.persist( - context=context, - task_instance=self, - project_id=self.project_id, - ) + AlloyDBBackupsLink.persist(context=context) if backup := self._get_backup(): return backup @@ -1259,12 +1272,14 @@ def __init__( self.update_mask = update_mask self.allow_missing = allow_missing + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "project_id": self.project_id, + } + def execute(self, context: Context) -> dict | None: - AlloyDBBackupsLink.persist( - context=context, - task_instance=self, - project_id=self.project_id, - ) + AlloyDBBackupsLink.persist(context=context) if self.validate_request: self.log.info("Validating an Update AlloyDB backup request.") else: diff --git a/providers/google/src/airflow/providers/google/cloud/operators/automl.py b/providers/google/src/airflow/providers/google/cloud/operators/automl.py index cb106537bf7ea..648cee403be6a 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/automl.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/automl.py @@ -153,7 +153,10 @@ def execute(self, context: Context): project_id = self.project_id or hook.project_id if project_id: TranslationLegacyModelTrainLink.persist( - context=context, task_instance=self, project_id=project_id + context=context, + dataset_id=self.model["dataset_id"], + project_id=project_id, + location=self.location, ) operation_result = hook.wait_for_operation(timeout=self.timeout, operation=operation) result = Model.to_dict(operation_result) @@ -164,10 +167,10 @@ def execute(self, context: Context): if project_id: TranslationLegacyModelLink.persist( context=context, - task_instance=self, dataset_id=self.model["dataset_id"] or "-", model_id=model_id, project_id=project_id, + location=self.location, ) return result @@ -313,10 +316,10 @@ def execute(self, context: Context): if project_id and self.model_id and dataset_id: TranslationLegacyModelPredictLink.persist( context=context, - task_instance=self, model_id=self.model_id, dataset_id=dataset_id, project_id=project_id, + location=self.location, ) return PredictResponse.to_dict(result) @@ -417,9 +420,9 @@ def execute(self, context: Context): if project_id: TranslationLegacyDatasetLink.persist( context=context, - task_instance=self, dataset_id=dataset_id, project_id=project_id, + location=self.location, ) return result @@ -530,9 +533,9 @@ def execute(self, context: Context): if project_id: TranslationLegacyDatasetLink.persist( context=context, - task_instance=self, dataset_id=self.dataset_id, project_id=project_id, + location=self.location, ) @@ -649,9 +652,9 @@ def execute(self, context: Context): if project_id: TranslationLegacyDatasetLink.persist( context=context, - task_instance=self, dataset_id=self.dataset_id, project_id=project_id, + location=self.location, ) return result @@ -749,9 +752,9 @@ def execute(self, context: Context): if project_id: TranslationLegacyDatasetLink.persist( context=context, - task_instance=self, dataset_id=hook.extract_object_id(self.dataset), project_id=project_id, + location=self.location, ) return Dataset.to_dict(result) @@ -845,10 +848,10 @@ def execute(self, context: Context): if project_id: TranslationLegacyModelLink.persist( context=context, - task_instance=self, dataset_id=model["dataset_id"], model_id=self.model_id, project_id=project_id, + location=self.location, ) return model @@ -1154,9 +1157,9 @@ def execute(self, context: Context): if project_id: TranslationLegacyDatasetLink.persist( context=context, - task_instance=self, dataset_id=self.dataset_id, project_id=project_id, + location=self.location, ) return result @@ -1252,7 +1255,7 @@ def execute(self, context: Context): ) project_id = self.project_id or hook.project_id if project_id: - TranslationDatasetListLink.persist(context=context, task_instance=self, project_id=project_id) + TranslationDatasetListLink.persist(context=context, project_id=project_id) return result diff --git a/providers/google/src/airflow/providers/google/cloud/operators/bigquery.py b/providers/google/src/airflow/providers/google/cloud/operators/bigquery.py index f92b1cae6cc84..b245c2d4bf5b5 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/bigquery.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/bigquery.py @@ -1324,7 +1324,6 @@ def execute(self, context: Context) -> None: if self._table: persist_kwargs = { "context": context, - "task_instance": self, "project_id": self._table.to_api_repr()["tableReference"]["projectId"], "dataset_id": self._table.to_api_repr()["tableReference"]["datasetId"], "table_id": self._table.to_api_repr()["tableReference"]["tableId"], @@ -1343,7 +1342,6 @@ def execute(self, context: Context) -> None: self.log.info(error_msg) persist_kwargs = { "context": context, - "task_instance": self, "project_id": self.project_id or bq_hook.project_id, "dataset_id": self.dataset_id, "table_id": self.table_id, @@ -1608,7 +1606,6 @@ def execute(self, context: Context) -> None: if self._table: persist_kwargs = { "context": context, - "task_instance": self, "project_id": self._table.to_api_repr()["tableReference"]["projectId"], "dataset_id": self._table.to_api_repr()["tableReference"]["datasetId"], "table_id": self._table.to_api_repr()["tableReference"]["tableId"], @@ -1627,7 +1624,6 @@ def execute(self, context: Context) -> None: self.log.info(error_msg) persist_kwargs = { "context": context, - "task_instance": self, "project_id": self.project_id or bq_hook.project_id, "dataset_id": self.dataset_id, "table_id": self.table_id, @@ -1898,7 +1894,6 @@ def execute(self, context: Context) -> None: if self._table: BigQueryTableLink.persist( context=context, - task_instance=self, dataset_id=self._table.dataset_id, project_id=self._table.project, table_id=self._table.table_id, @@ -1957,7 +1952,6 @@ def execute(self, context: Context) -> None: if self._table: BigQueryTableLink.persist( context=context, - task_instance=self, dataset_id=self._table.dataset_id, project_id=self._table.project, table_id=self._table.table_id, @@ -2155,7 +2149,6 @@ def execute(self, context: Context) -> None: ) persist_kwargs = { "context": context, - "task_instance": self, "project_id": dataset["datasetReference"]["projectId"], "dataset_id": dataset["datasetReference"]["datasetId"], } @@ -2167,7 +2160,6 @@ def execute(self, context: Context) -> None: ) persist_kwargs = { "context": context, - "task_instance": self, "project_id": project_id, "dataset_id": dataset_id, } @@ -2239,7 +2231,6 @@ def execute(self, context: Context): dataset_api_repr = dataset.to_api_repr() BigQueryDatasetLink.persist( context=context, - task_instance=self, dataset_id=dataset_api_repr["datasetReference"]["datasetId"], project_id=dataset_api_repr["datasetReference"]["projectId"], ) @@ -2388,7 +2379,6 @@ def execute(self, context: Context): if self._table: BigQueryTableLink.persist( context=context, - task_instance=self, dataset_id=self._table["tableReference"]["datasetId"], project_id=self._table["tableReference"]["projectId"], table_id=self._table["tableReference"]["tableId"], @@ -2491,7 +2481,6 @@ def execute(self, context: Context): dataset_api_repr = dataset.to_api_repr() BigQueryDatasetLink.persist( context=context, - task_instance=self, dataset_id=dataset_api_repr["datasetReference"]["datasetId"], project_id=dataset_api_repr["datasetReference"]["projectId"], ) @@ -2663,7 +2652,6 @@ def execute(self, context: Context) -> None: if self._table: BigQueryTableLink.persist( context=context, - task_instance=self, dataset_id=self._table["tableReference"]["datasetId"], project_id=self._table["tableReference"]["projectId"], table_id=self._table["tableReference"]["tableId"], @@ -2793,7 +2781,6 @@ def execute(self, context: Context): if self._table: BigQueryTableLink.persist( context=context, - task_instance=self, dataset_id=self._table["tableReference"]["datasetId"], project_id=self._table["tableReference"]["projectId"], table_id=self._table["tableReference"]["tableId"], @@ -3039,7 +3026,6 @@ def execute(self, context: Any): table = job_configuration[job_type][table_prop] persist_kwargs = { "context": context, - "task_instance": self, "project_id": self.project_id, "table_id": table, } @@ -3061,7 +3047,6 @@ def execute(self, context: Any): persist_kwargs = { "context": context, - "task_instance": self, "project_id": self.project_id, "location": self.location, "job_id": self.job_id, diff --git a/providers/google/src/airflow/providers/google/cloud/operators/bigquery_dts.py b/providers/google/src/airflow/providers/google/cloud/operators/bigquery_dts.py index bed14add1911f..387a2ab55f32c 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/bigquery_dts.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/bigquery_dts.py @@ -134,7 +134,6 @@ def execute(self, context: Context): transfer_config = _get_transfer_config_details(response.name) BigQueryDataTransferConfigLink.persist( context=context, - task_instance=self, region=transfer_config["region"], config_id=transfer_config["config_id"], project_id=transfer_config["project_id"], @@ -329,7 +328,6 @@ def execute(self, context: Context): transfer_config = _get_transfer_config_details(response.runs[0].name) BigQueryDataTransferConfigLink.persist( context=context, - task_instance=self, region=transfer_config["region"], config_id=transfer_config["config_id"], project_id=transfer_config["project_id"], diff --git a/providers/google/src/airflow/providers/google/cloud/operators/bigtable.py b/providers/google/src/airflow/providers/google/cloud/operators/bigtable.py index d33fa159d7b39..8c202f84b778f 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/bigtable.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/bigtable.py @@ -20,7 +20,7 @@ from __future__ import annotations from collections.abc import Iterable, Sequence -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import google.api_core.exceptions @@ -142,6 +142,13 @@ def __init__( self.impersonation_chain = impersonation_chain super().__init__(**kwargs) + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "instance_id": self.instance_id, + "project_id": self.project_id, + } + def execute(self, context: Context) -> None: hook = BigtableHook( gcp_conn_id=self.gcp_conn_id, @@ -155,7 +162,7 @@ def execute(self, context: Context) -> None: "The instance '%s' already exists in this project. Consider it as created", self.instance_id, ) - BigtableInstanceLink.persist(context=context, task_instance=self) + BigtableInstanceLink.persist(context=context) return try: hook.create_instance( @@ -171,7 +178,7 @@ def execute(self, context: Context) -> None: cluster_storage_type=self.cluster_storage_type, timeout=self.timeout, ) - BigtableInstanceLink.persist(context=context, task_instance=self) + BigtableInstanceLink.persist(context=context) except google.api_core.exceptions.GoogleAPICallError as e: self.log.error("An error occurred. Exiting.") raise e @@ -240,6 +247,13 @@ def __init__( self.impersonation_chain = impersonation_chain super().__init__(**kwargs) + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "instance_id": self.instance_id, + "project_id": self.project_id, + } + def execute(self, context: Context) -> None: hook = BigtableHook( gcp_conn_id=self.gcp_conn_id, @@ -258,7 +272,7 @@ def execute(self, context: Context) -> None: instance_labels=self.instance_labels, timeout=self.timeout, ) - BigtableInstanceLink.persist(context=context, task_instance=self) + BigtableInstanceLink.persist(context=context) except google.api_core.exceptions.GoogleAPICallError as e: self.log.error("An error occurred. Exiting.") raise e @@ -414,6 +428,13 @@ def _compare_column_families(self, hook, instance) -> bool: return False return True + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "instance_id": self.instance_id, + "project_id": self.project_id, + } + def execute(self, context: Context) -> None: hook = BigtableHook( gcp_conn_id=self.gcp_conn_id, @@ -431,7 +452,7 @@ def execute(self, context: Context) -> None: initial_split_keys=self.initial_split_keys, column_families=self.column_families, ) - BigtableTablesLink.persist(context=context, task_instance=self) + BigtableTablesLink.persist(context=context) except google.api_core.exceptions.AlreadyExists: if not self._compare_column_families(hook, instance): raise AirflowException( @@ -575,6 +596,14 @@ def __init__( self.impersonation_chain = impersonation_chain super().__init__(**kwargs) + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "instance_id": self.instance_id, + "cluster_id": self.cluster_id, + "project_id": self.project_id, + } + def execute(self, context: Context) -> None: hook = BigtableHook( gcp_conn_id=self.gcp_conn_id, @@ -586,7 +615,7 @@ def execute(self, context: Context) -> None: try: hook.update_cluster(instance=instance, cluster_id=self.cluster_id, nodes=self.nodes) - BigtableClusterLink.persist(context=context, task_instance=self) + BigtableClusterLink.persist(context=context) except google.api_core.exceptions.NotFound: raise AirflowException( f"Dependency: cluster '{self.cluster_id}' does not exist for instance '{self.instance_id}'." diff --git a/providers/google/src/airflow/providers/google/cloud/operators/cloud_base.py b/providers/google/src/airflow/providers/google/cloud/operators/cloud_base.py index f5c7af50eba2d..fb11a7276fb01 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/cloud_base.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/cloud_base.py @@ -19,6 +19,8 @@ from __future__ import annotations +from typing import Any + from google.api_core.gapic_v1.method import DEFAULT from airflow.models import BaseOperator @@ -36,3 +38,21 @@ def __deepcopy__(self, memo): """ memo[id(DEFAULT)] = DEFAULT return super().__deepcopy__(memo) + + @property + def extra_links_params(self) -> dict[str, Any]: + """ + Override this method to include parameters for link formatting in extra links. + + For example; most of the links on the Google provider require `project_id` and `location` in the Link. + To be not repeat; you can override this function and return something like the following: + + .. code-block:: python + + { + "project_id": self.project_id, + "location": self.location, + } + + """ + return {} diff --git a/providers/google/src/airflow/providers/google/cloud/operators/cloud_build.py b/providers/google/src/airflow/providers/google/cloud/operators/cloud_build.py index 7ae38b7f236f3..78d4bc2d6065c 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/cloud_build.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/cloud_build.py @@ -108,6 +108,12 @@ def __init__( self.impersonation_chain = impersonation_chain self.location = location + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.location, + } + def execute(self, context: Context): hook = CloudBuildHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) result = hook.cancel_build( @@ -124,9 +130,7 @@ def execute(self, context: Context): if project_id: CloudBuildLink.persist( context=context, - task_instance=self, project_id=project_id, - region=self.location, build_id=result.id, ) return Build.to_dict(result) @@ -210,6 +214,12 @@ def prepare_template(self) -> None: if self.build_raw.endswith(".json"): self.build = json.loads(file.read()) + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.location, + } + def execute(self, context: Context): hook = CloudBuildHook( gcp_conn_id=self.gcp_conn_id, @@ -251,9 +261,7 @@ def execute(self, context: Context): if project_id: CloudBuildLink.persist( context=context, - task_instance=self, project_id=project_id, - region=self.location, build_id=cloud_build_instance_result.id, ) return Build.to_dict(cloud_build_instance_result) @@ -269,9 +277,7 @@ def execute_complete(self, context: Context, event: dict): if project_id: CloudBuildLink.persist( context=context, - task_instance=self, project_id=project_id, - region=self.location, build_id=event["id_"], ) return event["instance"] @@ -336,6 +342,12 @@ def __init__( self.impersonation_chain = impersonation_chain self.location = location + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.location, + } + def execute(self, context: Context): hook = CloudBuildHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) result = hook.create_build_trigger( @@ -351,16 +363,12 @@ def execute(self, context: Context): if project_id: CloudBuildTriggerDetailsLink.persist( context=context, - task_instance=self, project_id=project_id, - region=self.location, trigger_id=result.id, ) CloudBuildTriggersListLink.persist( context=context, - task_instance=self, project_id=project_id, - region=self.location, ) return BuildTrigger.to_dict(result) @@ -419,6 +427,12 @@ def __init__( self.impersonation_chain = impersonation_chain self.location = location + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.location, + } + def execute(self, context: Context): hook = CloudBuildHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) hook.delete_build_trigger( @@ -433,9 +447,7 @@ def execute(self, context: Context): if project_id: CloudBuildTriggersListLink.persist( context=context, - task_instance=self, project_id=project_id, - region=self.location, ) @@ -493,6 +505,12 @@ def __init__( self.impersonation_chain = impersonation_chain self.location = location + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.location, + } + def execute(self, context: Context): hook = CloudBuildHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) result = hook.get_build( @@ -507,9 +525,7 @@ def execute(self, context: Context): if project_id: CloudBuildLink.persist( context=context, - task_instance=self, project_id=project_id, - region=self.location, build_id=result.id, ) return Build.to_dict(result) @@ -569,6 +585,12 @@ def __init__( self.impersonation_chain = impersonation_chain self.location = location + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.location, + } + def execute(self, context: Context): hook = CloudBuildHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) result = hook.get_build_trigger( @@ -583,9 +605,7 @@ def execute(self, context: Context): if project_id: CloudBuildTriggerDetailsLink.persist( context=context, - task_instance=self, project_id=project_id, - region=self.location, trigger_id=result.id, ) return BuildTrigger.to_dict(result) @@ -649,6 +669,12 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.location, + } + def execute(self, context: Context): hook = CloudBuildHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) results = hook.list_build_triggers( @@ -664,9 +690,7 @@ def execute(self, context: Context): if project_id: CloudBuildTriggersListLink.persist( context=context, - task_instance=self, project_id=project_id, - region=self.location, ) return [BuildTrigger.to_dict(result) for result in results] @@ -729,6 +753,12 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.location, + } + def execute(self, context: Context): hook = CloudBuildHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) results = hook.list_builds( @@ -743,7 +773,8 @@ def execute(self, context: Context): project_id = self.project_id or hook.project_id if project_id: CloudBuildListLink.persist( - context=context, task_instance=self, project_id=project_id, region=self.location + context=context, + project_id=project_id, ) return [Build.to_dict(result) for result in results] @@ -805,6 +836,12 @@ def __init__( self.impersonation_chain = impersonation_chain self.location = location + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.location, + } + def execute(self, context: Context): hook = CloudBuildHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) result = hook.retry_build( @@ -822,9 +859,7 @@ def execute(self, context: Context): if project_id: CloudBuildLink.persist( context=context, - task_instance=self, project_id=project_id, - region=self.location, build_id=result.id, ) return Build.to_dict(result) @@ -891,6 +926,12 @@ def __init__( self.impersonation_chain = impersonation_chain self.location = location + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.location, + } + def execute(self, context: Context): hook = CloudBuildHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) result = hook.run_build_trigger( @@ -908,9 +949,7 @@ def execute(self, context: Context): if project_id: CloudBuildLink.persist( context=context, - task_instance=self, project_id=project_id, - region=self.location, build_id=result.id, ) return Build.to_dict(result) @@ -974,6 +1013,12 @@ def __init__( self.impersonation_chain = impersonation_chain self.location = location + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.location, + } + def execute(self, context: Context): hook = CloudBuildHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) result = hook.update_build_trigger( @@ -990,9 +1035,7 @@ def execute(self, context: Context): if project_id: CloudBuildTriggerDetailsLink.persist( context=context, - task_instance=self, project_id=project_id, - region=self.location, trigger_id=result.id, ) return BuildTrigger.to_dict(result) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/cloud_composer.py b/providers/google/src/airflow/providers/google/cloud/operators/cloud_composer.py index 77d411cc08e90..59704b3b4cb2e 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/cloud_composer.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/cloud_composer.py @@ -19,7 +19,7 @@ import shlex from collections.abc import Sequence -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from google.api_core.exceptions import AlreadyExists from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault @@ -57,25 +57,6 @@ class CloudComposerEnvironmentLink(BaseGoogleLink): key = "composer_conf" format_str = CLOUD_COMPOSER_DETAILS_LINK - @staticmethod - def persist( - operator_instance: ( - CloudComposerCreateEnvironmentOperator - | CloudComposerUpdateEnvironmentOperator - | CloudComposerGetEnvironmentOperator - ), - context: Context, - ) -> None: - operator_instance.xcom_push( - context, - key=CloudComposerEnvironmentLink.key, - value={ - "project_id": operator_instance.project_id, - "region": operator_instance.region, - "environment_id": operator_instance.environment_id, - }, - ) - class CloudComposerEnvironmentsLink(BaseGoogleLink): """Helper class for constructing Cloud Composer Environment Link.""" @@ -84,16 +65,6 @@ class CloudComposerEnvironmentsLink(BaseGoogleLink): key = "composer_conf" format_str = CLOUD_COMPOSER_ENVIRONMENTS_LINK - @staticmethod - def persist(operator_instance: CloudComposerListEnvironmentsOperator, context: Context) -> None: - operator_instance.xcom_push( - context, - key=CloudComposerEnvironmentsLink.key, - value={ - "project_id": operator_instance.project_id, - }, - ) - class CloudComposerCreateEnvironmentOperator(GoogleCloudBaseOperator): """ @@ -159,6 +130,14 @@ def __init__( self.deferrable = deferrable self.pooling_period_seconds = pooling_period_seconds + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "project_id": self.project_id, + "region": self.region, + "environment_id": self.environment_id, + } + def execute(self, context: Context): hook = CloudComposerHook( gcp_conn_id=self.gcp_conn_id, @@ -171,7 +150,7 @@ def execute(self, context: Context): else: self.environment["name"] = name - CloudComposerEnvironmentLink.persist(operator_instance=self, context=context) + CloudComposerEnvironmentLink.persist(context=context) try: result = hook.create_environment( project_id=self.project_id, @@ -370,6 +349,14 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "project_id": self.project_id, + "region": self.region, + "environment_id": self.environment_id, + } + def execute(self, context: Context): hook = CloudComposerHook( gcp_conn_id=self.gcp_conn_id, @@ -384,8 +371,7 @@ def execute(self, context: Context): timeout=self.timeout, metadata=self.metadata, ) - - CloudComposerEnvironmentLink.persist(operator_instance=self, context=context) + CloudComposerEnvironmentLink.persist(context=context) return Environment.to_dict(result) @@ -445,12 +431,17 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "project_id": self.project_id, + } + def execute(self, context: Context): hook = CloudComposerHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) - CloudComposerEnvironmentsLink.persist(operator_instance=self, context=context) result = hook.list_environments( project_id=self.project_id, region=self.region, @@ -532,6 +523,14 @@ def __init__( self.deferrable = deferrable self.pooling_period_seconds = pooling_period_seconds + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "project_id": self.project_id, + "region": self.region, + "environment_id": self.environment_id, + } + def execute(self, context: Context): hook = CloudComposerHook( gcp_conn_id=self.gcp_conn_id, @@ -549,7 +548,7 @@ def execute(self, context: Context): metadata=self.metadata, ) - CloudComposerEnvironmentLink.persist(operator_instance=self, context=context) + CloudComposerEnvironmentLink.persist(context=context) if not self.deferrable: environment = hook.wait_for_operation(timeout=self.timeout, operation=result) return Environment.to_dict(environment) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/cloud_memorystore.py b/providers/google/src/airflow/providers/google/cloud/operators/cloud_memorystore.py index 5138ed1cf5143..4b0e5d44d0b95 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/cloud_memorystore.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/cloud_memorystore.py @@ -27,7 +27,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.cloud.memcache_v1beta2.types import cloud_memcache @@ -133,6 +133,13 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "instance_id": self.instance_id, + "location_id": self.location, + } + def execute(self, context: Context): hook = CloudMemorystoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain @@ -148,9 +155,6 @@ def execute(self, context: Context): ) RedisInstanceDetailsLink.persist( context=context, - task_instance=self, - instance_id=self.instance_id, - location_id=self.location, project_id=self.project_id or hook.project_id, ) return Instance.to_dict(result) @@ -304,6 +308,13 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "instance_id": self.instance, + "location_id": self.location, + } + def execute(self, context: Context) -> None: hook = CloudMemorystoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain @@ -320,9 +331,6 @@ def execute(self, context: Context) -> None: ) RedisInstanceDetailsLink.persist( context=context, - task_instance=self, - instance_id=self.instance, - location_id=self.location, project_id=self.project_id or hook.project_id, ) @@ -397,6 +405,13 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "instance_id": self.instance, + "location_id": self.location, + } + def execute(self, context: Context) -> None: hook = CloudMemorystoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain @@ -412,9 +427,6 @@ def execute(self, context: Context) -> None: ) RedisInstanceDetailsLink.persist( context=context, - task_instance=self, - instance_id=self.instance, - location_id=self.location, project_id=self.project_id or hook.project_id, ) @@ -482,6 +494,13 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "instance_id": self.instance, + "location_id": self.location, + } + def execute(self, context: Context): hook = CloudMemorystoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain @@ -496,9 +515,6 @@ def execute(self, context: Context): ) RedisInstanceDetailsLink.persist( context=context, - task_instance=self, - instance_id=self.instance, - location_id=self.location, project_id=self.project_id or hook.project_id, ) return Instance.to_dict(result) @@ -577,6 +593,13 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "instance_id": self.instance, + "location_id": self.location, + } + def execute(self, context: Context) -> None: hook = CloudMemorystoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain @@ -592,9 +615,6 @@ def execute(self, context: Context) -> None: ) RedisInstanceDetailsLink.persist( context=context, - task_instance=self, - instance_id=self.instance, - location_id=self.location, project_id=self.project_id or hook.project_id, ) @@ -680,7 +700,6 @@ def execute(self, context: Context): ) RedisInstanceListLink.persist( context=context, - task_instance=self, project_id=self.project_id or hook.project_id, ) instances = [Instance.to_dict(a) for a in result] @@ -789,7 +808,6 @@ def execute(self, context: Context) -> None: location_id, instance_id = res.name.split("/")[-3::2] RedisInstanceDetailsLink.persist( context=context, - task_instance=self, instance_id=self.instance_id or instance_id, location_id=self.location or location_id, project_id=self.project_id or hook.project_id, @@ -882,7 +900,6 @@ def execute(self, context: Context) -> None: location_id, instance_id = res.name.split("/")[-3::2] RedisInstanceDetailsLink.persist( context=context, - task_instance=self, instance_id=self.instance_id or instance_id, location_id=self.location or location_id, project_id=self.project_id or hook.project_id, @@ -1002,7 +1019,6 @@ def execute(self, context: Context) -> None: ) RedisInstanceDetailsLink.persist( context=context, - task_instance=self, instance_id=self.instance_id, location_id=self.location, project_id=self.project_id or hook.project_id, @@ -1171,6 +1187,14 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "instance_id": self.instance_id, + "location_id": self.location, + "project_id": self.project_id, + } + def execute(self, context: Context): hook = CloudMemorystoreMemcachedHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain @@ -1185,13 +1209,7 @@ def execute(self, context: Context): timeout=self.timeout, metadata=self.metadata, ) - MemcachedInstanceDetailsLink.persist( - context=context, - task_instance=self, - instance_id=self.instance_id, - location_id=self.location, - project_id=self.project_id, - ) + MemcachedInstanceDetailsLink.persist(context=context) class CloudMemorystoreMemcachedCreateInstanceOperator(GoogleCloudBaseOperator): @@ -1263,6 +1281,13 @@ def __init__( self.metadata = metadata self.gcp_conn_id = gcp_conn_id + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "instance_id": self.instance_id, + "location_id": self.location, + } + def execute(self, context: Context): hook = CloudMemorystoreMemcachedHook(gcp_conn_id=self.gcp_conn_id) result = hook.create_instance( @@ -1276,9 +1301,6 @@ def execute(self, context: Context): ) MemcachedInstanceDetailsLink.persist( context=context, - task_instance=self, - instance_id=self.instance_id, - location_id=self.location, project_id=self.project_id or hook.project_id, ) return cloud_memcache.Instance.to_dict(result) @@ -1410,6 +1432,13 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "instance_id": self.instance, + "location_id": self.location, + } + def execute(self, context: Context): hook = CloudMemorystoreMemcachedHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain @@ -1424,9 +1453,6 @@ def execute(self, context: Context): ) MemcachedInstanceDetailsLink.persist( context=context, - task_instance=self, - instance_id=self.instance, - location_id=self.location, project_id=self.project_id or hook.project_id, ) return cloud_memcache.Instance.to_dict(result) @@ -1506,7 +1532,6 @@ def execute(self, context: Context): ) MemcachedInstanceListLink.persist( context=context, - task_instance=self, project_id=self.project_id or hook.project_id, ) instances = [cloud_memcache.Instance.to_dict(a) for a in result] @@ -1612,7 +1637,6 @@ def execute(self, context: Context): location_id, instance_id = res.name.split("/")[-3::2] MemcachedInstanceDetailsLink.persist( context=context, - task_instance=self, instance_id=self.instance_id or instance_id, location_id=self.location or location_id, project_id=self.project_id or hook.project_id, @@ -1688,6 +1712,14 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "instance_id": self.instance_id, + "location_id": self.location, + "project_id": self.project_id, + } + def execute(self, context: Context): hook = CloudMemorystoreMemcachedHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain @@ -1702,10 +1734,4 @@ def execute(self, context: Context): timeout=self.timeout, metadata=self.metadata, ) - MemcachedInstanceDetailsLink.persist( - context=context, - task_instance=self, - instance_id=self.instance_id, - location_id=self.location, - project_id=self.project_id, - ) + MemcachedInstanceDetailsLink.persist(context=context) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/cloud_run.py b/providers/google/src/airflow/providers/google/cloud/operators/cloud_run.py index e74724107c5f8..b0da1debc3667 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/cloud_run.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/cloud_run.py @@ -317,7 +317,6 @@ def execute(self, context: Context): if self.operation.metadata.log_uri: CloudRunJobLoggingLink.persist( context=context, - task_instance=self, log_uri=self.operation.metadata.log_uri, ) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/cloud_sql.py b/providers/google/src/airflow/providers/google/cloud/operators/cloud_sql.py index b1ad44b121d22..142649603c5f2 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/cloud_sql.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/cloud_sql.py @@ -285,6 +285,12 @@ def _check_if_db_exists(self, db_name, hook: CloudSQLHook) -> dict | bool: return False raise e + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "instance": self.instance, + } + def execute(self, context: Context): pass @@ -384,8 +390,6 @@ def execute(self, context: Context) -> None: CloudSQLInstanceLink.persist( context=context, - task_instance=self, - cloud_sql_instance=self.instance, project_id=self.project_id or hook.project_id, ) @@ -479,8 +483,6 @@ def execute(self, context: Context): ) CloudSQLInstanceLink.persist( context=context, - task_instance=self, - cloud_sql_instance=self.instance, project_id=self.project_id or hook.project_id, ) @@ -714,8 +716,6 @@ def execute(self, context: Context) -> bool | None: ) CloudSQLInstanceDatabaseLink.persist( context=context, - task_instance=self, - cloud_sql_instance=self.instance, project_id=self.project_id or hook.project_id, ) if self._check_if_db_exists(database, hook): @@ -822,8 +822,6 @@ def execute(self, context: Context) -> None: ) CloudSQLInstanceDatabaseLink.persist( context=context, - task_instance=self, - cloud_sql_instance=self.instance, project_id=self.project_id or hook.project_id, ) return hook.patch_database( @@ -1004,13 +1002,10 @@ def execute(self, context: Context) -> None: ) CloudSQLInstanceLink.persist( context=context, - task_instance=self, - cloud_sql_instance=self.instance, project_id=self.project_id or hook.project_id, ) FileDetailsLink.persist( context=context, - task_instance=self, uri=self.body["exportContext"]["uri"][5:], project_id=self.project_id or hook.project_id, ) @@ -1147,13 +1142,10 @@ def execute(self, context: Context) -> None: ) CloudSQLInstanceLink.persist( context=context, - task_instance=self, - cloud_sql_instance=self.instance, project_id=self.project_id or hook.project_id, ) FileDetailsLink.persist( context=context, - task_instance=self, uri=self.body["importContext"]["uri"][5:], project_id=self.project_id or hook.project_id, ) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py b/providers/google/src/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py index a2321895b492b..93524a33b407a 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py @@ -586,7 +586,6 @@ def execute(self, context: Context) -> dict: if project_id: CloudStorageTransferDetailsLink.persist( context=context, - task_instance=self, project_id=project_id, operation_name=self.operation_name, ) @@ -663,7 +662,6 @@ def execute(self, context: Context) -> list[dict]: if project_id: CloudStorageTransferListLink.persist( context=context, - task_instance=self, project_id=project_id, ) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/compute.py b/providers/google/src/airflow/providers/google/cloud/operators/compute.py index 411e8b666d6e6..2ba8874cc3773 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/compute.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/compute.py @@ -74,6 +74,13 @@ def _validate_inputs(self) -> None: if not self.zone: raise AirflowException("The required parameter 'zone' is missing") + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "location_id": self.zone, + "resource_id": self.resource_id, + } + def execute(self, context: Context): pass @@ -225,9 +232,6 @@ def execute(self, context: Context) -> dict: self.log.info("The %s Instance already exists", self.resource_id) ComputeInstanceDetailsLink.persist( context=context, - task_instance=self, - location_id=self.zone, - resource_id=self.resource_id, project_id=self.project_id or hook.project_id, ) return Instance.to_dict(existing_instance) @@ -247,9 +251,6 @@ def execute(self, context: Context) -> dict: ) ComputeInstanceDetailsLink.persist( context=context, - task_instance=self, - location_id=self.zone, - resource_id=self.resource_id, project_id=self.project_id or hook.project_id, ) return Instance.to_dict(new_instance) @@ -397,9 +398,6 @@ def execute(self, context: Context) -> dict: self.log.info("The %s Instance already exists", self.resource_id) ComputeInstanceDetailsLink.persist( context=context, - task_instance=self, - location_id=self.zone, - resource_id=self.resource_id, project_id=self.project_id or hook.project_id, ) return Instance.to_dict(existing_instance) @@ -420,9 +418,6 @@ def execute(self, context: Context) -> dict: ) ComputeInstanceDetailsLink.persist( context=context, - task_instance=self, - location_id=self.zone, - resource_id=self.resource_id, project_id=self.project_id or hook.project_id, ) return Instance.to_dict(new_instance_from_template) @@ -598,9 +593,6 @@ def execute(self, context: Context) -> None: ) ComputeInstanceDetailsLink.persist( context=context, - task_instance=self, - location_id=self.zone, - resource_id=self.resource_id, project_id=self.project_id or hook.project_id, ) hook.start_instance(zone=self.zone, resource_id=self.resource_id, project_id=self.project_id) @@ -659,9 +651,6 @@ def execute(self, context: Context) -> None: ) ComputeInstanceDetailsLink.persist( context=context, - task_instance=self, - location_id=self.zone, - resource_id=self.resource_id, project_id=self.project_id or hook.project_id, ) hook.stop_instance(zone=self.zone, resource_id=self.resource_id, project_id=self.project_id) @@ -764,9 +753,6 @@ def execute(self, context: Context) -> None: self._validate_all_body_fields() ComputeInstanceDetailsLink.persist( context=context, - task_instance=self, - location_id=self.zone, - resource_id=self.resource_id, project_id=self.project_id or hook.project_id, ) hook.set_machine_type( @@ -972,8 +958,6 @@ def execute(self, context: Context) -> dict: self.log.info("The %s Template already exists.", existing_template) ComputeInstanceTemplateDetailsLink.persist( context=context, - task_instance=self, - resource_id=self.resource_id, project_id=self.project_id or hook.project_id, ) return InstanceTemplate.to_dict(existing_template) @@ -991,8 +975,6 @@ def execute(self, context: Context) -> dict: ) ComputeInstanceTemplateDetailsLink.persist( context=context, - task_instance=self, - resource_id=self.resource_id, project_id=self.project_id or hook.project_id, ) return InstanceTemplate.to_dict(new_template) @@ -1238,7 +1220,6 @@ def execute(self, context: Context) -> dict: ) ComputeInstanceTemplateDetailsLink.persist( context=context, - task_instance=self, resource_id=self.body_patch["name"], project_id=self.project_id or hook.project_id, ) @@ -1259,7 +1240,6 @@ def execute(self, context: Context) -> dict: ) ComputeInstanceTemplateDetailsLink.persist( context=context, - task_instance=self, resource_id=self.body_patch["name"], project_id=self.project_id or hook.project_id, ) @@ -1390,9 +1370,6 @@ def execute(self, context: Context) -> bool | None: self.log.info("Calling patch instance template with updated body: %s", patch_body) ComputeInstanceGroupManagerDetailsLink.persist( context=context, - task_instance=self, - location_id=self.zone, - resource_id=self.resource_id, project_id=self.project_id or hook.project_id, ) return hook.patch_instance_group_manager( @@ -1405,9 +1382,6 @@ def execute(self, context: Context) -> bool | None: # Idempotence achieved ComputeInstanceGroupManagerDetailsLink.persist( context=context, - task_instance=self, - location_id=self.zone, - resource_id=self.resource_id, project_id=self.project_id or hook.project_id, ) return True @@ -1552,10 +1526,7 @@ def execute(self, context: Context) -> dict: self.log.info("The %s Instance Group Manager already exists", existing_instance_group_manager) ComputeInstanceGroupManagerDetailsLink.persist( context=context, - task_instance=self, - resource_id=self.resource_id, project_id=self.project_id or hook.project_id, - location_id=self.zone, ) return InstanceGroupManager.to_dict(existing_instance_group_manager) self._field_sanitizer.sanitize(self.body) @@ -1574,9 +1545,6 @@ def execute(self, context: Context) -> dict: ) ComputeInstanceGroupManagerDetailsLink.persist( context=context, - task_instance=self, - location_id=self.zone, - resource_id=self.resource_id, project_id=self.project_id or hook.project_id, ) return InstanceGroupManager.to_dict(new_instance_group_manager) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/datacatalog.py b/providers/google/src/airflow/providers/google/cloud/operators/datacatalog.py index 08a6654d44ba1..48b444f8b13b7 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/datacatalog.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/datacatalog.py @@ -166,7 +166,6 @@ def execute(self, context: Context): self.xcom_push(context, key="entry_id", value=entry_id) DataCatalogEntryLink.persist( context=context, - task_instance=self, entry_id=self.entry_id, entry_group_id=self.entry_group, location_id=self.location, @@ -287,7 +286,6 @@ def execute(self, context: Context): self.xcom_push(context, key="entry_group_id", value=entry_group_id) DataCatalogEntryGroupLink.persist( context=context, - task_instance=self, entry_group_id=self.entry_group_id, location_id=self.location, project_id=self.project_id or hook.project_id, @@ -430,7 +428,6 @@ def execute(self, context: Context): self.xcom_push(context, key="tag_id", value=tag_id) DataCatalogEntryLink.persist( context=context, - task_instance=self, entry_id=self.entry, entry_group_id=self.entry_group, location_id=self.location, @@ -548,7 +545,6 @@ def execute(self, context: Context): self.xcom_push(context, key="tag_template_id", value=tag_template) DataCatalogTagTemplateLink.persist( context=context, - task_instance=self, tag_template_id=self.tag_template_id, location_id=self.location, project_id=self.project_id or hook.project_id, @@ -675,7 +671,6 @@ def execute(self, context: Context): self.xcom_push(context, key="tag_template_field_id", value=self.tag_template_field_id) DataCatalogTagTemplateLink.persist( context=context, - task_instance=self, tag_template_id=self.tag_template, location_id=self.location, project_id=self.project_id or hook.project_id, @@ -1242,7 +1237,6 @@ def execute(self, context: Context) -> dict: ) DataCatalogEntryLink.persist( context=context, - task_instance=self, entry_id=self.entry, entry_group_id=self.entry_group, location_id=self.location, @@ -1344,7 +1338,6 @@ def execute(self, context: Context) -> dict: ) DataCatalogEntryGroupLink.persist( context=context, - task_instance=self, entry_group_id=self.entry_group, location_id=self.location, project_id=self.project_id or hook.project_id, @@ -1437,7 +1430,6 @@ def execute(self, context: Context) -> dict: ) DataCatalogTagTemplateLink.persist( context=context, - task_instance=self, tag_template_id=self.tag_template, location_id=self.location, project_id=self.project_id or hook.project_id, @@ -1543,7 +1535,6 @@ def execute(self, context: Context) -> list: ) DataCatalogEntryLink.persist( context=context, - task_instance=self, entry_id=self.entry, entry_group_id=self.entry_group, location_id=self.location, @@ -1641,7 +1632,6 @@ def execute(self, context: Context) -> dict: project_id, location_id, entry_group_id, entry_id = result.name.split("/")[1::2] DataCatalogEntryLink.persist( context=context, - task_instance=self, entry_id=entry_id, entry_group_id=entry_group_id, location_id=location_id, @@ -1747,7 +1737,6 @@ def execute(self, context: Context) -> None: ) DataCatalogTagTemplateLink.persist( context=context, - task_instance=self, tag_template_id=self.tag_template, location_id=self.location, project_id=self.project_id or hook.project_id, @@ -1980,7 +1969,6 @@ def execute(self, context: Context) -> None: location_id, entry_group_id, entry_id = result.name.split("/")[3::2] DataCatalogEntryLink.persist( context=context, - task_instance=self, entry_id=self.entry_id or entry_id, entry_group_id=self.entry_group or entry_group_id, location_id=self.location or location_id, @@ -2101,7 +2089,6 @@ def execute(self, context: Context) -> None: location_id, entry_group_id, entry_id = result.name.split("/")[3:8:2] DataCatalogEntryLink.persist( context=context, - task_instance=self, entry_id=self.entry or entry_id, entry_group_id=self.entry_group or entry_group_id, location_id=self.location or location_id, @@ -2218,7 +2205,6 @@ def execute(self, context: Context) -> None: location_id, tag_template_id = result.name.split("/")[3::2] DataCatalogTagTemplateLink.persist( context=context, - task_instance=self, tag_template_id=self.tag_template_id or tag_template_id, location_id=self.location or location_id, project_id=self.project_id or hook.project_id, @@ -2346,7 +2332,6 @@ def execute(self, context: Context) -> None: location_id, tag_template_id = result.name.split("/")[3:6:2] DataCatalogTagTemplateLink.persist( context=context, - task_instance=self, tag_template_id=self.tag_template or tag_template_id, location_id=self.location or location_id, project_id=self.project_id or hook.project_id, diff --git a/providers/google/src/airflow/providers/google/cloud/operators/dataflow.py b/providers/google/src/airflow/providers/google/cloud/operators/dataflow.py index c881853374ead..c61a19b6a3495 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/dataflow.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/dataflow.py @@ -383,7 +383,12 @@ def hook(self) -> DataflowHook: def execute(self, context: Context): def set_current_job(current_job): self.job = current_job - DataflowJobLink.persist(self, context, self.project_id, self.location, self.job.get("id")) + DataflowJobLink.persist( + context=context, + project_id=self.project_id, + region=self.location, + job_id=self.job.get("id"), + ) options = self.dataflow_default_options options.update(self.options) @@ -418,7 +423,9 @@ def set_current_job(current_job): environment=self.environment, ) job_id = self.hook.extract_job_id(self.job) - DataflowJobLink.persist(self, context, self.project_id, self.location, job_id) + DataflowJobLink.persist( + context=context, project_id=self.project_id, region=self.location, job_id=job_id + ) self.defer( trigger=TemplateJobStartTrigger( project_id=self.project_id, @@ -590,7 +597,9 @@ def execute(self, context: Context): def set_current_job(current_job): self.job = current_job - DataflowJobLink.persist(self, context, self.project_id, self.location, self.job.get("id")) + DataflowJobLink.persist( + context=context, project_id=self.project_id, region=self.location, job_id=self.job.get("id") + ) if not self.deferrable: self.job = self.hook.start_flex_template( @@ -609,7 +618,9 @@ def set_current_job(current_job): project_id=self.project_id, ) job_id = self.hook.extract_job_id(self.job) - DataflowJobLink.persist(self, context, self.project_id, self.location, job_id) + DataflowJobLink.persist( + context=context, project_id=self.project_id, region=self.location, job_id=job_id + ) self.defer( trigger=TemplateJobStartTrigger( project_id=self.project_id, @@ -764,7 +775,9 @@ def execute(self, context: Context) -> dict[str, Any]: location=self.region, ) - DataflowJobLink.persist(self, context, self.project_id, self.region, self.job_id) + DataflowJobLink.persist( + context=context, project_id=self.project_id, region=self.region, job_id=self.job_id + ) if self.deferrable: self.defer( @@ -971,6 +984,14 @@ def __init__( self.pipeline_name = self.body["name"].split("/")[-1] if self.body else None + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "project_id": self.project_id, + "location": self.location, + "pipeline_name": self.pipeline_name, + } + def execute(self, context: Context): if self.body is None: raise AirflowException( @@ -1003,7 +1024,7 @@ def execute(self, context: Context): pipeline_name=self.pipeline_name, location=self.location, ) - DataflowPipelineLink.persist(self, context, self.project_id, self.location, self.pipeline_name) + DataflowPipelineLink.persist(context=context) self.xcom_push(context, key="pipeline_name", value=self.pipeline_name) if self.pipeline: if "error" in self.pipeline: @@ -1076,7 +1097,9 @@ def execute(self, context: Context): )["job"] job_id = self.dataflow_hook.extract_job_id(self.job) self.xcom_push(context, key="job_id", value=job_id) - DataflowJobLink.persist(self, context, self.project_id, self.location, job_id) + DataflowJobLink.persist( + context=context, project_id=self.project_id, region=self.location, job_id=job_id + ) except HttpError as e: if e.resp.status == 404: raise AirflowException("Pipeline with given name was not found.") diff --git a/providers/google/src/airflow/providers/google/cloud/operators/dataform.py b/providers/google/src/airflow/providers/google/cloud/operators/dataform.py index 4212cd22a12ec..df21e1b72dd41 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/dataform.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/dataform.py @@ -258,7 +258,6 @@ def execute(self, context: Context): ) workflow_invocation_id = result.name.split("/")[-1] DataformWorkflowInvocationLink.persist( - operator_instance=self, context=context, project_id=self.project_id, region=self.region, @@ -347,6 +346,13 @@ def execute(self, context: Context): timeout=self.timeout, metadata=self.metadata, ) + DataformWorkflowInvocationLink.persist( + context=context, + project_id=self.project_id, + region=self.region, + repository_id=self.repository_id, + workflow_invocation_id=self.workflow_invocation_id, + ) return WorkflowInvocation.to_dict(result) @@ -412,7 +418,6 @@ def execute(self, context: Context): impersonation_chain=self.impersonation_chain, ) DataformWorkflowInvocationLink.persist( - operator_instance=self, context=context, project_id=self.project_id, region=self.region, @@ -494,6 +499,13 @@ def execute(self, context: Context): gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) + DataformWorkflowInvocationLink.persist( + context=context, + project_id=self.project_id, + region=self.region, + repository_id=self.repository_id, + workflow_invocation_id=self.workflow_invocation_id, + ) hook.cancel_workflow_invocation( project_id=self.project_id, region=self.region, @@ -576,7 +588,6 @@ def execute(self, context: Context) -> dict: ) DataformRepositoryLink.persist( - operator_instance=self, context=context, project_id=self.project_id, region=self.region, @@ -735,7 +746,6 @@ def execute(self, context: Context) -> dict: ) DataformWorkspaceLink.persist( - operator_instance=self, context=context, project_id=self.project_id, region=self.region, diff --git a/providers/google/src/airflow/providers/google/cloud/operators/datafusion.py b/providers/google/src/airflow/providers/google/cloud/operators/datafusion.py index 6c6d42fb5cde5..572faa35facf3 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/datafusion.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/datafusion.py @@ -111,10 +111,9 @@ def execute(self, context: Context) -> None: project_id = resource_path_to_dict(resource_name=instance["name"])["projects"] DataFusionInstanceLink.persist( context=context, - task_instance=self, project_id=project_id, instance_name=self.instance_name, - location=self.location, + region=self.location, ) @@ -269,10 +268,9 @@ def execute(self, context: Context) -> dict: project_id = resource_path_to_dict(resource_name=instance["name"])["projects"] DataFusionInstanceLink.persist( context=context, - task_instance=self, project_id=project_id, instance_name=self.instance_name, - location=self.location, + region=self.location, ) return instance @@ -358,10 +356,9 @@ def execute(self, context: Context) -> None: project_id = resource_path_to_dict(resource_name=instance["name"])["projects"] DataFusionInstanceLink.persist( context=context, - task_instance=self, project_id=project_id, instance_name=self.instance_name, - location=self.location, + region=self.location, ) @@ -429,10 +426,9 @@ def execute(self, context: Context) -> dict: project_id = resource_path_to_dict(resource_name=instance["name"])["projects"] DataFusionInstanceLink.persist( context=context, - task_instance=self, project_id=project_id, instance_name=self.instance_name, - location=self.location, + region=self.location, ) return instance @@ -519,7 +515,6 @@ def execute(self, context: Context) -> None: ) DataFusionPipelineLink.persist( context=context, - task_instance=self, uri=instance["serviceEndpoint"], pipeline_name=self.pipeline_name, namespace=self.namespace, @@ -693,7 +688,6 @@ def execute(self, context: Context) -> dict: DataFusionPipelinesLink.persist( context=context, - task_instance=self, uri=service_endpoint, namespace=self.namespace, ) @@ -813,7 +807,6 @@ def execute(self, context: Context) -> str: DataFusionPipelineLink.persist( context=context, - task_instance=self, uri=instance["serviceEndpoint"], pipeline_name=self.pipeline_name, namespace=self.namespace, @@ -943,7 +936,6 @@ def execute(self, context: Context) -> None: DataFusionPipelineLink.persist( context=context, - task_instance=self, uri=instance["serviceEndpoint"], pipeline_name=self.pipeline_name, namespace=self.namespace, diff --git a/providers/google/src/airflow/providers/google/cloud/operators/dataplex.py b/providers/google/src/airflow/providers/google/cloud/operators/dataplex.py index 2c2f7833c9762..ad6a7217e2cc9 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/dataplex.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/dataplex.py @@ -150,6 +150,15 @@ def __init__( self.impersonation_chain = impersonation_chain self.asynchronous = asynchronous + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "lake_id": self.lake_id, + "task_id": self.dataplex_task_id, + "region": self.region, + "project_id": self.project_id, + } + def execute(self, context: Context) -> dict: hook = DataplexHook( gcp_conn_id=self.gcp_conn_id, @@ -157,7 +166,7 @@ def execute(self, context: Context) -> dict: impersonation_chain=self.impersonation_chain, ) self.log.info("Creating Dataplex task %s", self.dataplex_task_id) - DataplexTaskLink.persist(context=context, task_instance=self) + DataplexTaskLink.persist(context=context) try: operation = hook.create_task( @@ -351,6 +360,14 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "project_id": self.project_id, + "lake_id": self.lake_id, + "region": self.region, + } + def execute(self, context: Context) -> list[dict]: hook = DataplexHook( gcp_conn_id=self.gcp_conn_id, @@ -358,7 +375,7 @@ def execute(self, context: Context) -> list[dict]: impersonation_chain=self.impersonation_chain, ) self.log.info("Listing Dataplex tasks from lake %s", self.lake_id) - DataplexTasksLink.persist(context=context, task_instance=self) + DataplexTasksLink.persist(context=context) tasks = hook.list_tasks( project_id=self.project_id, @@ -430,6 +447,15 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "lake_id": self.lake_id, + "task_id": self.dataplex_task_id, + "region": self.region, + "project_id": self.project_id, + } + def execute(self, context: Context) -> dict: hook = DataplexHook( gcp_conn_id=self.gcp_conn_id, @@ -437,7 +463,7 @@ def execute(self, context: Context) -> dict: impersonation_chain=self.impersonation_chain, ) self.log.info("Retrieving Dataplex task %s", self.dataplex_task_id) - DataplexTaskLink.persist(context=context, task_instance=self) + DataplexTaskLink.persist(context=context) task = hook.get_task( project_id=self.project_id, @@ -448,7 +474,7 @@ def execute(self, context: Context) -> dict: timeout=self.timeout, metadata=self.metadata, ) - DataplexTasksLink.persist(context=context, task_instance=self) + DataplexTasksLink.persist(context=context) return Task.to_dict(task) @@ -522,6 +548,14 @@ def __init__( self.impersonation_chain = impersonation_chain self.asynchronous = asynchronous + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "lake_id": self.lake_id, + "region": self.region, + "project_id": self.project_id, + } + def execute(self, context: Context) -> dict: hook = DataplexHook( gcp_conn_id=self.gcp_conn_id, @@ -566,10 +600,7 @@ def execute(self, context: Context) -> dict: if lake["state"] != "CREATING": break time.sleep(time_to_wait) - DataplexLakeLink.persist( - context=context, - task_instance=self, - ) + DataplexLakeLink.persist(context=context) return Lake.to_dict(lake) @@ -625,6 +656,14 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "lake_id": self.lake_id, + "region": self.region, + "project_id": self.project_id, + } + def execute(self, context: Context) -> None: hook = DataplexHook( gcp_conn_id=self.gcp_conn_id, @@ -642,7 +681,7 @@ def execute(self, context: Context) -> None: timeout=self.timeout, metadata=self.metadata, ) - DataplexLakeLink.persist(context=context, task_instance=self) + DataplexLakeLink.persist(context=context) hook.wait_for_operation(timeout=self.timeout, operation=operation) self.log.info("Dataplex lake %s deleted successfully!", self.lake_id) @@ -2179,6 +2218,13 @@ def hook(self) -> DataplexHook: impersonation_chain=self.impersonation_chain, ) + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "location": self.location, + "project_id": self.project_id, + } + class DataplexCatalogCreateEntryGroupOperator(DataplexCatalogBaseOperator): """ @@ -2230,12 +2276,15 @@ def __init__( self.entry_group_configuration = entry_group_configuration self.validate_request = validate_request - def execute(self, context: Context): - DataplexCatalogEntryGroupLink.persist( - context=context, - task_instance=self, - ) + @property + def extra_links_params(self) -> dict[str, Any]: + return { + **super().extra_links_params, + "entry_group_id": self.entry_group_id, + } + def execute(self, context: Context): + DataplexCatalogEntryGroupLink.persist(context=context) if self.validate_request: self.log.info("Validating a Create Dataplex Catalog EntryGroup request.") else: @@ -2316,11 +2365,15 @@ def __init__( super().__init__(*args, **kwargs) self.entry_group_id = entry_group_id + @property + def extra_links_params(self) -> dict[str, Any]: + return { + **super().extra_links_params, + "entry_group_id": self.entry_group_id, + } + def execute(self, context: Context): - DataplexCatalogEntryGroupLink.persist( - context=context, - task_instance=self, - ) + DataplexCatalogEntryGroupLink.persist(context=context) self.log.info( "Retrieving Dataplex Catalog EntryGroup %s.", self.entry_group_id, @@ -2462,10 +2515,7 @@ def __init__( self.order_by = order_by def execute(self, context: Context): - DataplexCatalogEntryGroupsLink.persist( - context=context, - task_instance=self, - ) + DataplexCatalogEntryGroupsLink.persist(context=context) self.log.info( "Listing Dataplex Catalog EntryGroup from location %s.", self.location, @@ -2554,12 +2604,15 @@ def __init__( self.update_mask = update_mask self.validate_request = validate_request - def execute(self, context: Context): - DataplexCatalogEntryGroupLink.persist( - context=context, - task_instance=self, - ) + @property + def extra_links_params(self) -> dict[str, Any]: + return { + **super().extra_links_params, + "entry_group_id": self.entry_group_id, + } + def execute(self, context: Context): + DataplexCatalogEntryGroupLink.persist(context=context) if self.validate_request: self.log.info("Validating an Update Dataplex Catalog EntryGroup request.") else: @@ -2644,12 +2697,15 @@ def __init__( self.entry_type_configuration = entry_type_configuration self.validate_request = validate_request - def execute(self, context: Context): - DataplexCatalogEntryTypeLink.persist( - context=context, - task_instance=self, - ) + @property + def extra_links_params(self) -> dict[str, Any]: + return { + **super().extra_links_params, + "entry_type_id": self.entry_type_id, + } + def execute(self, context: Context): + DataplexCatalogEntryTypeLink.persist(context=context) if self.validate_request: self.log.info("Validating a Create Dataplex Catalog EntryType request.") else: @@ -2730,11 +2786,15 @@ def __init__( super().__init__(*args, **kwargs) self.entry_type_id = entry_type_id + @property + def extra_links_params(self) -> dict[str, Any]: + return { + **super().extra_links_params, + "entry_type_id": self.entry_type_id, + } + def execute(self, context: Context): - DataplexCatalogEntryTypeLink.persist( - context=context, - task_instance=self, - ) + DataplexCatalogEntryTypeLink.persist(context=context) self.log.info( "Retrieving Dataplex Catalog EntryType %s.", self.entry_type_id, @@ -2876,10 +2936,7 @@ def __init__( self.order_by = order_by def execute(self, context: Context): - DataplexCatalogEntryTypesLink.persist( - context=context, - task_instance=self, - ) + DataplexCatalogEntryTypesLink.persist(context=context) self.log.info( "Listing Dataplex Catalog EntryType from location %s.", self.location, @@ -2968,12 +3025,15 @@ def __init__( self.update_mask = update_mask self.validate_request = validate_request - def execute(self, context: Context): - DataplexCatalogEntryTypeLink.persist( - context=context, - task_instance=self, - ) + @property + def extra_links_params(self) -> dict[str, Any]: + return { + **super().extra_links_params, + "entry_type_id": self.entry_type_id, + } + def execute(self, context: Context): + DataplexCatalogEntryTypeLink.persist(context=context) if self.validate_request: self.log.info("Validating an Update Dataplex Catalog EntryType request.") else: @@ -3058,12 +3118,15 @@ def __init__( self.aspect_type_configuration = aspect_type_configuration self.validate_request = validate_request - def execute(self, context: Context): - DataplexCatalogAspectTypeLink.persist( - context=context, - task_instance=self, - ) + @property + def extra_links_params(self) -> dict[str, Any]: + return { + **super().extra_links_params, + "aspect_type_id": self.aspect_type_id, + } + def execute(self, context: Context): + DataplexCatalogAspectTypeLink.persist(context=context) if self.validate_request: self.log.info("Validating a Create Dataplex Catalog AspectType request.") else: @@ -3144,11 +3207,15 @@ def __init__( super().__init__(*args, **kwargs) self.aspect_type_id = aspect_type_id + @property + def extra_links_params(self) -> dict[str, Any]: + return { + **super().extra_links_params, + "aspect_type_id": self.aspect_type_id, + } + def execute(self, context: Context): - DataplexCatalogAspectTypeLink.persist( - context=context, - task_instance=self, - ) + DataplexCatalogAspectTypeLink.persist(context=context) self.log.info( "Retrieving Dataplex Catalog AspectType %s.", self.aspect_type_id, @@ -3223,10 +3290,7 @@ def __init__( self.order_by = order_by def execute(self, context: Context): - DataplexCatalogAspectTypesLink.persist( - context=context, - task_instance=self, - ) + DataplexCatalogAspectTypesLink.persist(context=context) self.log.info( "Listing Dataplex Catalog AspectType from location %s.", self.location, @@ -3315,12 +3379,15 @@ def __init__( self.update_mask = update_mask self.validate_request = validate_request - def execute(self, context: Context): - DataplexCatalogAspectTypeLink.persist( - context=context, - task_instance=self, - ) + @property + def extra_links_params(self) -> dict[str, Any]: + return { + **super().extra_links_params, + "aspect_type_id": self.aspect_type_id, + } + def execute(self, context: Context): + DataplexCatalogAspectTypeLink.persist(context=context) if self.validate_request: self.log.info("Validating an Update Dataplex Catalog AspectType request.") else: @@ -3493,12 +3560,16 @@ def _validate_fields(self, entry_configuration): f"Missing required fields in Entry configuration: {', '.join(missing_fields)}. " ) - def execute(self, context: Context): - DataplexCatalogEntryLink.persist( - context=context, - task_instance=self, - ) + @property + def extra_links_params(self) -> dict[str, Any]: + return { + **super().extra_links_params, + "entry_id": self.entry_id, + "entry_group_id": self.entry_group_id, + } + def execute(self, context: Context): + DataplexCatalogEntryLink.persist(context=context) self._validate_fields(self.entry_configuration) try: entry = self.hook.create_entry( @@ -3599,11 +3670,16 @@ def __init__( self.aspect_types = aspect_types self.paths = paths + @property + def extra_links_params(self) -> dict[str, Any]: + return { + **super().extra_links_params, + "entry_id": self.entry_id, + "entry_group_id": self.entry_group_id, + } + def execute(self, context: Context): - DataplexCatalogEntryLink.persist( - context=context, - task_instance=self, - ) + DataplexCatalogEntryLink.persist(context=context) self.log.info( "Retrieving Dataplex Catalog Entry %s.", self.entry_id, @@ -3701,11 +3777,15 @@ def __init__( self.page_token = page_token self.filter_by = filter_by + @property + def extra_links_params(self) -> dict[str, Any]: + return { + **super().extra_links_params, + "entry_group_id": self.entry_group_id, + } + def execute(self, context: Context): - DataplexCatalogEntryGroupLink.persist( - context=context, - task_instance=self, - ) + DataplexCatalogEntryGroupLink.persist(context=context) self.log.info( "Listing Dataplex Catalog Entry from location %s.", self.location, @@ -3903,11 +3983,16 @@ def __init__( self.aspect_types = aspect_types self.paths = paths + @property + def extra_links_params(self) -> dict[str, Any]: + return { + **super().extra_links_params, + "entry_id": self.entry_id, + "entry_group_id": self.entry_group_id, + } + def execute(self, context: Context): - DataplexCatalogEntryLink.persist( - context=context, - task_instance=self, - ) + DataplexCatalogEntryLink.persist(context=context) self.log.info( "Looking for Dataplex Catalog Entry %s.", self.entry_id, @@ -4022,12 +4107,16 @@ def __init__( self.delete_missing_aspects = delete_missing_aspects self.aspect_keys = aspect_keys - def execute(self, context: Context): - DataplexCatalogEntryLink.persist( - context=context, - task_instance=self, - ) + @property + def extra_links_params(self) -> dict[str, Any]: + return { + **super().extra_links_params, + "entry_id": self.entry_id, + "entry_group_id": self.entry_group_id, + } + def execute(self, context: Context): + DataplexCatalogEntryLink.persist(context=context) try: entry = self.hook.update_entry( location=self.location, diff --git a/providers/google/src/airflow/providers/google/cloud/operators/dataprep.py b/providers/google/src/airflow/providers/google/cloud/operators/dataprep.py index f25a4cc783b3b..1d501a0b6630a 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/dataprep.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/dataprep.py @@ -113,7 +113,6 @@ def execute(self, context: Context) -> dict: if self.project_id: DataprepJobGroupLink.persist( context=context, - task_instance=self, project_id=self.project_id, job_group_id=int(self.job_group_id), ) @@ -170,7 +169,6 @@ def execute(self, context: Context) -> dict: if self.project_id and job_group_id: DataprepJobGroupLink.persist( context=context, - task_instance=self, project_id=self.project_id, job_group_id=int(job_group_id), ) @@ -230,7 +228,6 @@ def execute(self, context: Context) -> dict: if self.project_id and copied_flow_id: DataprepFlowLink.persist( context=context, - task_instance=self, project_id=self.project_id, flow_id=int(copied_flow_id), ) @@ -303,7 +300,6 @@ def execute(self, context: Context) -> dict: job_group_id = response["data"][0]["id"] DataprepJobGroupLink.persist( context=context, - task_instance=self, project_id=self.project_id, job_group_id=int(job_group_id), ) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py b/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py index 4b47aef0e2d62..3246c5bb6c356 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py @@ -807,7 +807,6 @@ def execute(self, context: Context) -> dict: if project_id: DataprocClusterLink.persist( context=context, - operator=self, cluster_id=self.cluster_name, project_id=project_id, region=self.region, @@ -1174,7 +1173,6 @@ def execute(self, context: Context) -> dict | None: cluster = super().execute(context) DataprocClusterLink.persist( context=context, - operator=self, cluster_id=self.cluster_name, project_id=self._get_project_id(), region=self.region, @@ -1459,7 +1457,6 @@ def execute(self, context: Context): if project_id: DataprocWorkflowTemplateLink.persist( context=context, - operator=self, workflow_template_id=self.template["id"], region=self.region, project_id=project_id, @@ -1571,7 +1568,6 @@ def execute(self, context: Context): if project_id: DataprocWorkflowLink.persist( context=context, - operator=self, workflow_id=workflow_id, region=self.region, project_id=project_id, @@ -1727,7 +1723,6 @@ def execute(self, context: Context): if project_id: DataprocWorkflowLink.persist( context=context, - operator=self, workflow_id=workflow_id, region=self.region, project_id=project_id, @@ -1901,7 +1896,6 @@ def execute(self, context: Context): if project_id: DataprocJobLink.persist( context=context, - operator=self, job_id=new_job_id, region=self.region, project_id=project_id, @@ -2074,7 +2068,6 @@ def execute(self, context: Context): if project_id: DataprocClusterLink.persist( context=context, - operator=self, cluster_id=self.cluster_name, project_id=project_id, region=self.region, @@ -2373,7 +2366,6 @@ def execute(self, context: Context): # Persist the link earlier so users can observe the progress DataprocBatchLink.persist( context=context, - operator=self, project_id=self.project_id, region=self.region, batch_id=self.batch_id, @@ -2410,7 +2402,6 @@ def execute(self, context: Context): DataprocBatchLink.persist( context=context, - operator=self, project_id=self.project_id, region=self.region, batch_id=batch_id, @@ -2723,7 +2714,6 @@ def execute(self, context: Context): if project_id: DataprocBatchLink.persist( context=context, - operator=self, project_id=project_id, region=self.region, batch_id=self.batch_id, @@ -2806,7 +2796,7 @@ def execute(self, context: Context): ) project_id = self.project_id or hook.project_id if project_id: - DataprocBatchesListLink.persist(context=context, operator=self, project_id=project_id) + DataprocBatchesListLink.persist(context=context, project_id=project_id) return [Batch.to_dict(result) for result in results] diff --git a/providers/google/src/airflow/providers/google/cloud/operators/dataproc_metastore.py b/providers/google/src/airflow/providers/google/cloud/operators/dataproc_metastore.py index d99ba3d160a4e..01340af99c612 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/dataproc_metastore.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/dataproc_metastore.py @@ -21,7 +21,7 @@ import time from collections.abc import Sequence -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from google.api_core.exceptions import AlreadyExists from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault @@ -32,6 +32,7 @@ from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.dataproc_metastore import DataprocMetastoreHook +from airflow.providers.google.cloud.links.base import BaseGoogleLink from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator from airflow.providers.google.common.links.storage import StorageLink @@ -42,16 +43,6 @@ from airflow.models.taskinstancekey import TaskInstanceKey from airflow.utils.context import Context -from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS - -if AIRFLOW_V_3_0_PLUS: - from airflow.sdk import BaseOperatorLink - from airflow.sdk.execution_time.xcom import XCom -else: - from airflow.models import XCom # type: ignore[no-redef] - from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef] - - BASE_LINK = "https://console.cloud.google.com" METASTORE_BASE_LINK = BASE_LINK + "/dataproc/metastore/services/{region}/{service_id}" METASTORE_BACKUP_LINK = METASTORE_BASE_LINK + "/backups/{resource}?project={project_id}" @@ -61,97 +52,50 @@ METASTORE_SERVICE_LINK = METASTORE_BASE_LINK + "/config?project={project_id}" -class DataprocMetastoreLink(BaseOperatorLink): +class DataprocMetastoreLink(BaseGoogleLink): """Helper class for constructing Dataproc Metastore resource link.""" name = "Dataproc Metastore" key = "conf" - @staticmethod - def persist( - context: Context, - task_instance: ( - DataprocMetastoreCreateServiceOperator - | DataprocMetastoreGetServiceOperator - | DataprocMetastoreRestoreServiceOperator - | DataprocMetastoreUpdateServiceOperator - | DataprocMetastoreListBackupsOperator - | DataprocMetastoreExportMetadataOperator - ), - url: str, - ): - task_instance.xcom_push( - context=context, - key=DataprocMetastoreLink.key, - value={ - "region": task_instance.region, - "service_id": task_instance.service_id, - "project_id": task_instance.project_id, - "url": url, - }, - ) - def get_link( self, operator: BaseOperator, *, ti_key: TaskInstanceKey, ) -> str: - conf = XCom.get_value(key=self.key, ti_key=ti_key) - return ( - conf["url"].format( - region=conf["region"], - service_id=conf["service_id"], - project_id=conf["project_id"], - ) - if conf - else "" + conf = self.get_config(operator, ti_key) + if not conf: + return "" + + return conf["url"].format( + region=conf["region"], + service_id=conf["service_id"], + project_id=conf["project_id"], ) -class DataprocMetastoreDetailedLink(BaseOperatorLink): +class DataprocMetastoreDetailedLink(BaseGoogleLink): """Helper class for constructing Dataproc Metastore detailed resource link.""" name = "Dataproc Metastore resource" key = "config" - @staticmethod - def persist( - context: Context, - task_instance: ( - DataprocMetastoreCreateBackupOperator | DataprocMetastoreCreateMetadataImportOperator - ), - url: str, - resource: str, - ): - task_instance.xcom_push( - context=context, - key=DataprocMetastoreDetailedLink.key, - value={ - "region": task_instance.region, - "service_id": task_instance.service_id, - "project_id": task_instance.project_id, - "url": url, - "resource": resource, - }, - ) - def get_link( self, operator: BaseOperator, *, ti_key: TaskInstanceKey, ) -> str: - conf = XCom.get_value(key=self.key, ti_key=ti_key) - return ( - conf["url"].format( - region=conf["region"], - service_id=conf["service_id"], - project_id=conf["project_id"], - resource=conf["resource"], - ) - if conf - else "" + conf = self.get_config(operator, ti_key) + if not conf: + return "" + + return conf["url"].format( + region=conf["region"], + service_id=conf["service_id"], + project_id=conf["project_id"], + resource=conf["resource"], ) @@ -231,6 +175,14 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.region, + "service_id": self.service_id, + "project_id": self.project_id, + } + def execute(self, context: Context) -> dict: hook = DataprocMetastoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain @@ -263,7 +215,7 @@ def execute(self, context: Context) -> dict: metadata=self.metadata, ) DataprocMetastoreDetailedLink.persist( - context=context, task_instance=self, url=METASTORE_BACKUP_LINK, resource=self.backup_id + context=context, url=METASTORE_BACKUP_LINK, resource=self.backup_id ) return Backup.to_dict(backup) @@ -344,6 +296,14 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.region, + "service_id": self.service_id, + "project_id": self.project_id, + } + def execute(self, context: Context): hook = DataprocMetastoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain @@ -364,7 +324,7 @@ def execute(self, context: Context): self.log.info("Metadata import %s created successfully", self.metadata_import_id) DataprocMetastoreDetailedLink.persist( - context=context, task_instance=self, url=METASTORE_IMPORT_LINK, resource=self.metadata_import_id + context=context, url=METASTORE_IMPORT_LINK, resource=self.metadata_import_id ) return MetadataImport.to_dict(metadata_import) @@ -437,6 +397,14 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.region, + "service_id": self.service_id, + "project_id": self.project_id, + } + def execute(self, context: Context) -> dict: hook = DataprocMetastoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain @@ -465,7 +433,7 @@ def execute(self, context: Context) -> dict: timeout=self.timeout, metadata=self.metadata, ) - DataprocMetastoreLink.persist(context=context, task_instance=self, url=METASTORE_SERVICE_LINK) + DataprocMetastoreLink.persist(context=context, url=METASTORE_SERVICE_LINK) return Service.to_dict(service) @@ -689,6 +657,14 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.region, + "service_id": self.service_id, + "project_id": self.project_id, + } + def execute(self, context: Context): hook = DataprocMetastoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain @@ -708,9 +684,9 @@ def execute(self, context: Context): metadata_export = self._wait_for_export_metadata(hook) self.log.info("Metadata from service %s exported successfully", self.service_id) - DataprocMetastoreLink.persist(context=context, task_instance=self, url=METASTORE_EXPORT_LINK) + DataprocMetastoreLink.persist(context=context, url=METASTORE_EXPORT_LINK) uri = self._get_uri_from_destination(MetadataExport.to_dict(metadata_export)["destination_gcs_uri"]) - StorageLink.persist(context=context, task_instance=self, uri=uri, project_id=self.project_id) + StorageLink.persist(context=context, uri=uri, project_id=self.project_id) return MetadataExport.to_dict(metadata_export) def _get_uri_from_destination(self, destination_uri: str): @@ -799,6 +775,14 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.region, + "service_id": self.service_id, + "project_id": self.project_id, + } + def execute(self, context: Context) -> dict: hook = DataprocMetastoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain @@ -812,7 +796,7 @@ def execute(self, context: Context) -> dict: timeout=self.timeout, metadata=self.metadata, ) - DataprocMetastoreLink.persist(context=context, task_instance=self, url=METASTORE_SERVICE_LINK) + DataprocMetastoreLink.persist(context=context, url=METASTORE_SERVICE_LINK) return Service.to_dict(result) @@ -880,6 +864,14 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.region, + "service_id": self.service_id, + "project_id": self.project_id, + } + def execute(self, context: Context) -> list[dict]: hook = DataprocMetastoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain @@ -897,7 +889,7 @@ def execute(self, context: Context) -> list[dict]: timeout=self.timeout, metadata=self.metadata, ) - DataprocMetastoreLink.persist(context=context, task_instance=self, url=METASTORE_BACKUPS_LINK) + DataprocMetastoreLink.persist(context=context, url=METASTORE_BACKUPS_LINK) return [Backup.to_dict(backup) for backup in backups] @@ -981,6 +973,14 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.region, + "service_id": self.service_id, + "project_id": self.project_id, + } + def execute(self, context: Context): hook = DataprocMetastoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain @@ -1004,7 +1004,7 @@ def execute(self, context: Context): ) self._wait_for_restore_service(hook) self.log.info("Service %s restored from backup %s", self.service_id, self.backup_id) - DataprocMetastoreLink.persist(context=context, task_instance=self, url=METASTORE_SERVICE_LINK) + DataprocMetastoreLink.persist(context=context, url=METASTORE_SERVICE_LINK) def _wait_for_restore_service(self, hook: DataprocMetastoreHook): """ @@ -1107,6 +1107,14 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.region, + "service_id": self.service_id, + "project_id": self.project_id, + } + def execute(self, context: Context): hook = DataprocMetastoreHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain @@ -1126,4 +1134,4 @@ def execute(self, context: Context): ) hook.wait_for_operation(self.timeout, operation) self.log.info("Service %s updated successfully", self.service.get("name")) - DataprocMetastoreLink.persist(context=context, task_instance=self, url=METASTORE_SERVICE_LINK) + DataprocMetastoreLink.persist(context=context, url=METASTORE_SERVICE_LINK) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/datastore.py b/providers/google/src/airflow/providers/google/cloud/operators/datastore.py index 7738910ea69b5..e598b1de2498c 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/datastore.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/datastore.py @@ -136,7 +136,6 @@ def execute(self, context: Context) -> dict: raise AirflowException(f"Operation failed: result={result}") StorageLink.persist( context=context, - task_instance=self, uri=f"{self.bucket}/{result['response']['outputUrl'].split('/')[3]}", project_id=self.project_id or ds_hook.project_id, ) @@ -211,6 +210,12 @@ def __init__( self.project_id = project_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "project_id": self.project_id, + } + def execute(self, context: Context): self.log.info("Importing data from Cloud Storage bucket %s", self.bucket) ds_hook = DatastoreHook( @@ -231,8 +236,7 @@ def execute(self, context: Context): state = result["metadata"]["common"]["state"] if state != "SUCCESSFUL": raise AirflowException(f"Operation failed: result={result}") - - CloudDatastoreImportExportLink.persist(context=context, task_instance=self) + CloudDatastoreImportExportLink.persist(context=context) return result @@ -282,6 +286,12 @@ def __init__( self.project_id = project_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "project_id": self.project_id, + } + def execute(self, context: Context) -> list: hook = DatastoreHook( gcp_conn_id=self.gcp_conn_id, @@ -291,7 +301,7 @@ def execute(self, context: Context) -> list: partial_keys=self.partial_keys, project_id=self.project_id, ) - CloudDatastoreEntitiesLink.persist(context=context, task_instance=self) + CloudDatastoreEntitiesLink.persist(context=context) return keys @@ -398,6 +408,12 @@ def __init__( self.project_id = project_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "project_id": self.project_id, + } + def execute(self, context: Context) -> dict: hook = DatastoreHook( gcp_conn_id=self.gcp_conn_id, @@ -407,7 +423,7 @@ def execute(self, context: Context) -> dict: body=self.body, project_id=self.project_id, ) - CloudDatastoreEntitiesLink.persist(context=context, task_instance=self) + CloudDatastoreEntitiesLink.persist(context=context) return response diff --git a/providers/google/src/airflow/providers/google/cloud/operators/dlp.py b/providers/google/src/airflow/providers/google/cloud/operators/dlp.py index eca84824f0543..c7f689181a41b 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/dlp.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/dlp.py @@ -144,7 +144,6 @@ def execute(self, context: Context) -> None: if project_id: CloudDLPJobDetailsLink.persist( context=context, - task_instance=self, project_id=project_id, job_name=self.dlp_job_id, ) @@ -251,7 +250,6 @@ def execute(self, context: Context): if project_id and template_id: CloudDLPDeidentifyTemplateDetailsLink.persist( context=context, - task_instance=self, project_id=project_id, template_name=template_id, ) @@ -363,7 +361,6 @@ def execute(self, context: Context): if project_id: CloudDLPJobDetailsLink.persist( context=context, - task_instance=self, project_id=project_id, job_name=result["name"].split("/")[-1] if result["name"] else None, ) @@ -473,7 +470,6 @@ def execute(self, context: Context): if project_id and template_id: CloudDLPInspectTemplateDetailsLink.persist( context=context, - task_instance=self, project_id=project_id, template_name=template_id, ) @@ -578,7 +574,6 @@ def execute(self, context: Context): if project_id: CloudDLPJobTriggerDetailsLink.persist( context=context, - task_instance=self, project_id=project_id, trigger_name=trigger_name, ) @@ -692,7 +687,6 @@ def execute(self, context: Context): if project_id and stored_info_type_id: CloudDLPInfoTypeDetailsLink.persist( context=context, - task_instance=self, project_id=project_id, info_type_name=stored_info_type_id, ) @@ -880,7 +874,6 @@ def execute(self, context: Context) -> None: if project_id: CloudDLPDeidentifyTemplatesListLink.persist( context=context, - task_instance=self, project_id=project_id, ) except NotFound: @@ -966,7 +959,6 @@ def execute(self, context: Context) -> None: if project_id: CloudDLPJobsListLink.persist( context=context, - task_instance=self, project_id=project_id, ) @@ -1056,7 +1048,6 @@ def execute(self, context: Context) -> None: if project_id: CloudDLPInspectTemplatesListLink.persist( context=context, - task_instance=self, project_id=project_id, ) @@ -1140,7 +1131,6 @@ def execute(self, context: Context): if project_id: CloudDLPJobTriggersListLink.persist( context=context, - task_instance=self, project_id=project_id, ) @@ -1232,7 +1222,6 @@ def execute(self, context: Context): if project_id: CloudDLPInfoTypesListLink.persist( context=context, - task_instance=self, project_id=project_id, ) @@ -1318,7 +1307,9 @@ def execute(self, context: Context): project_id = self.project_id or hook.project_id if project_id: CloudDLPDeidentifyTemplateDetailsLink.persist( - context=context, task_instance=self, project_id=project_id, template_name=self.template_id + context=context, + project_id=project_id, + template_name=self.template_id, ) return DeidentifyTemplate.to_dict(template) @@ -1400,7 +1391,6 @@ def execute(self, context: Context): if project_id: CloudDLPJobDetailsLink.persist( context=context, - task_instance=self, project_id=project_id, job_name=self.dlp_job_id, ) @@ -1490,7 +1480,6 @@ def execute(self, context: Context): if project_id: CloudDLPInspectTemplateDetailsLink.persist( context=context, - task_instance=self, project_id=project_id, template_name=self.template_id, ) @@ -1574,7 +1563,6 @@ def execute(self, context: Context): if project_id: CloudDLPJobTriggerDetailsLink.persist( context=context, - task_instance=self, project_id=project_id, trigger_name=self.job_trigger_id, ) @@ -1664,7 +1652,6 @@ def execute(self, context: Context): if project_id: CloudDLPInfoTypeDetailsLink.persist( context=context, - task_instance=self, project_id=project_id, info_type_name=self.stored_info_type_id, ) @@ -1844,7 +1831,6 @@ def execute(self, context: Context): if project_id: CloudDLPDeidentifyTemplatesListLink.persist( context=context, - task_instance=self, project_id=project_id, ) @@ -1940,7 +1926,6 @@ def execute(self, context: Context): if project_id: CloudDLPJobsListLink.persist( context=context, - task_instance=self, project_id=project_id, ) @@ -2025,7 +2010,6 @@ def execute(self, context: Context): if project_id: CloudDLPPossibleInfoTypesListLink.persist( context=context, - task_instance=self, project_id=project_id, ) @@ -2119,7 +2103,6 @@ def execute(self, context: Context): if project_id: CloudDLPInspectTemplatesListLink.persist( context=context, - task_instance=self, project_id=project_id, ) @@ -2211,7 +2194,6 @@ def execute(self, context: Context): if project_id: CloudDLPJobTriggersListLink.persist( context=context, - task_instance=self, project_id=project_id, ) @@ -2305,7 +2287,6 @@ def execute(self, context: Context): if project_id: CloudDLPInfoTypesListLink.persist( context=context, - task_instance=self, project_id=project_id, ) @@ -2592,7 +2573,6 @@ def execute(self, context: Context): if project_id: CloudDLPDeidentifyTemplateDetailsLink.persist( context=context, - task_instance=self, project_id=project_id, template_name=self.template_id, ) @@ -2692,7 +2672,6 @@ def execute(self, context: Context): if project_id: CloudDLPInspectTemplateDetailsLink.persist( context=context, - task_instance=self, project_id=project_id, template_name=self.template_id, ) @@ -2786,7 +2765,6 @@ def execute(self, context: Context): if project_id: CloudDLPJobTriggerDetailsLink.persist( context=context, - task_instance=self, project_id=project_id, trigger_name=self.job_trigger_id, ) @@ -2887,7 +2865,6 @@ def execute(self, context: Context): if project_id: CloudDLPInfoTypeDetailsLink.persist( context=context, - task_instance=self, project_id=project_id, info_type_name=self.stored_info_type_id, ) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/functions.py b/providers/google/src/airflow/providers/google/cloud/operators/functions.py index 26782eb3f79aa..77dee5089f2c9 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/functions.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/functions.py @@ -219,6 +219,13 @@ def _set_airflow_version_label(self) -> None: self.body["labels"] = {} self.body["labels"].update({"airflow-version": "v" + version.replace(".", "-").replace("+", "-")}) + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "location": self.location, + "function_name": self.body["name"].split("/")[-1], + } + def execute(self, context: Context): hook = CloudFunctionsHook( gcp_conn_id=self.gcp_conn_id, @@ -237,7 +244,6 @@ def execute(self, context: Context): if project_id: CloudFunctionsDetailsLink.persist( context=context, - task_instance=self, location=self.location, project_id=project_id, function_name=self.body["name"].split("/")[-1], @@ -394,7 +400,6 @@ def execute(self, context: Context): if project_id: CloudFunctionsListLink.persist( context=context, - task_instance=self, project_id=project_id, ) return hook.delete_function(self.name) @@ -462,6 +467,13 @@ def __init__( self.api_version = api_version self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "location": self.location, + "function_name": self.function_id, + } + def execute(self, context: Context): hook = CloudFunctionsHook( api_version=self.api_version, @@ -482,10 +494,7 @@ def execute(self, context: Context): if project_id: CloudFunctionsDetailsLink.persist( context=context, - task_instance=self, - location=self.location, project_id=project_id, - function_name=self.function_id, ) return result diff --git a/providers/google/src/airflow/providers/google/cloud/operators/gcs.py b/providers/google/src/airflow/providers/google/cloud/operators/gcs.py index d2c42c0d7878e..54c01b0d470f0 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/gcs.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/gcs.py @@ -145,7 +145,6 @@ def execute(self, context: Context) -> None: ) StorageLink.persist( context=context, - task_instance=self, uri=self.bucket_name, project_id=self.project_id or hook.project_id, ) @@ -260,7 +259,6 @@ def execute(self, context: Context) -> list: StorageLink.persist( context=context, - task_instance=self, uri=self.bucket, project_id=hook.project_id, ) @@ -439,7 +437,6 @@ def execute(self, context: Context) -> None: ) StorageLink.persist( context=context, - task_instance=self, uri=self.bucket, project_id=hook.project_id, ) @@ -522,7 +519,6 @@ def execute(self, context: Context) -> None: ) FileDetailsLink.persist( context=context, - task_instance=self, uri=f"{self.bucket}/{self.object_name}", project_id=hook.project_id, ) @@ -631,7 +627,6 @@ def execute(self, context: Context) -> None: self.log.info("Uploading file to %s as %s", self.destination_bucket, self.destination_object) FileDetailsLink.persist( context=context, - task_instance=self, uri=f"{self.destination_bucket}/{self.destination_object}", project_id=hook.project_id, ) @@ -829,7 +824,6 @@ def execute(self, context: Context) -> list[str]: ) StorageLink.persist( context=context, - task_instance=self, uri=self.destination_bucket, project_id=destination_hook.project_id, ) @@ -1080,7 +1074,6 @@ def execute(self, context: Context) -> None: ) StorageLink.persist( context=context, - task_instance=self, uri=self._get_uri(self.destination_bucket, self.destination_object), project_id=hook.project_id, ) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/kubernetes_engine.py b/providers/google/src/airflow/providers/google/cloud/operators/kubernetes_engine.py index ef692eba1bc59..bb9b530e34b8c 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/kubernetes_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/kubernetes_engine.py @@ -453,8 +453,15 @@ def _alert_deprecated_body_fields(self) -> None: stacklevel=2, ) + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "project_id": self.project_id, + "location": self.location, + } + def execute(self, context: Context) -> str: - KubernetesEngineClusterLink.persist(context=context, task_instance=self, cluster=self.body) + KubernetesEngineClusterLink.persist(context=context, cluster=self.body) try: operation = self.cluster_hook.create_cluster( @@ -553,9 +560,16 @@ def __init__( self.use_dns_endpoint = use_dns_endpoint self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "project_id": self.project_id, + "location": self.location, + } + def execute(self, context: Context): cluster = self.cluster_hook.get_cluster(name=self.cluster_name, project_id=self.project_id) - KubernetesEngineClusterLink.persist(context=context, task_instance=self, cluster=cluster) + KubernetesEngineClusterLink.persist(context=context, cluster=cluster) if self.cluster_hook.check_cluster_autoscaling_ability(cluster=cluster): super().execute(context) @@ -854,7 +868,14 @@ def execute(self, context: Context) -> None: self.cluster_name, self.job, ) - KubernetesEngineJobLink.persist(context=context, task_instance=self) + KubernetesEngineJobLink.persist( + context=context, + location=self.location, + cluster_name=self.cluster_name, + namespace=self.job.metadata.namespace, + job_name=self.job.metadata.name, + project_id=self.project_id, + ) return None @@ -918,6 +939,15 @@ def __init__( self.namespace = namespace self.do_xcom_push = do_xcom_push + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "location": self.location, + "cluster_name": self.cluster_name, + "namespace": self.namespace, + "project_id": self.project_id, + } + def execute(self, context: Context) -> dict: if self.namespace: jobs = self.hook.list_jobs_from_namespace(namespace=self.namespace) @@ -928,7 +958,7 @@ def execute(self, context: Context) -> dict: if self.do_xcom_push: ti = context["ti"] ti.xcom_push(key="jobs_list", value=V1JobList.to_dict(jobs)) - KubernetesEngineWorkloadsLink.persist(context=context, task_instance=self) + KubernetesEngineWorkloadsLink.persist(context=context) return V1JobList.to_dict(jobs) @@ -1270,8 +1300,14 @@ def execute(self, context: Context) -> None: self.name, self.cluster_name, ) - KubernetesEngineJobLink.persist(context=context, task_instance=self) - + KubernetesEngineJobLink.persist( + context=context, + location=self.location, + cluster_name=self.cluster_name, + namespace=self.job.metadata.namespace, + job_name=self.job.metadata.name, + project_id=self.project_id, + ) return k8s.V1Job.to_dict(self.job) @@ -1344,6 +1380,13 @@ def execute(self, context: Context) -> None: self.name, self.cluster_name, ) - KubernetesEngineJobLink.persist(context=context, task_instance=self) + KubernetesEngineJobLink.persist( + context=context, + location=self.location, + cluster_name=self.cluster_name, + namespace=self.job.metadata.namespace, + job_name=self.job.metadata.name, + project_id=self.project_id, + ) return k8s.V1Job.to_dict(self.job) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/life_sciences.py b/providers/google/src/airflow/providers/google/cloud/operators/life_sciences.py index 3389c1bdb8bb7..3897c00253cf2 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/life_sciences.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/life_sciences.py @@ -113,7 +113,6 @@ def execute(self, context: Context) -> dict: if project_id: LifeSciencesLink.persist( context=context, - task_instance=self, project_id=project_id, ) return hook.run_pipeline(body=self.body, location=self.location, project_id=self.project_id) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/managed_kafka.py b/providers/google/src/airflow/providers/google/cloud/operators/managed_kafka.py index b649149ccc0da..6c42eb8be1d01 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/managed_kafka.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/managed_kafka.py @@ -21,7 +21,7 @@ from collections.abc import Sequence from functools import cached_property -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from google.api_core.exceptions import AlreadyExists, NotFound from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault @@ -150,9 +150,17 @@ def __init__( self.cluster_id = cluster_id self.request_id = request_id + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "location": self.location, + "cluster_id": self.cluster_id, + "project_id": self.project_id, + } + def execute(self, context: Context): self.log.info("Creating an Apache Kafka cluster.") - ApacheKafkaClusterLink.persist(context=context, task_instance=self, cluster_id=self.cluster_id) + ApacheKafkaClusterLink.persist(context=context) try: operation = self.hook.create_cluster( project_id=self.project_id, @@ -227,8 +235,14 @@ def __init__( self.filter = filter self.order_by = order_by + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "project_id": self.project_id, + } + def execute(self, context: Context): - ApacheKafkaClusterListLink.persist(context=context, task_instance=self) + ApacheKafkaClusterListLink.persist(context=context) self.log.info("Listing Clusters from location %s.", self.location) try: cluster_list_pager = self.hook.list_clusters( @@ -285,12 +299,16 @@ def __init__( super().__init__(*args, **kwargs) self.cluster_id = cluster_id + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "location": self.location, + "cluster_id": self.cluster_id, + "project_id": self.project_id, + } + def execute(self, context: Context): - ApacheKafkaClusterLink.persist( - context=context, - task_instance=self, - cluster_id=self.cluster_id, - ) + ApacheKafkaClusterLink.persist(context=context) self.log.info("Getting Cluster: %s", self.cluster_id) try: cluster = self.hook.get_cluster( @@ -362,12 +380,16 @@ def __init__( self.update_mask = update_mask self.request_id = request_id + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "location": self.location, + "cluster_id": self.cluster_id, + "project_id": self.project_id, + } + def execute(self, context: Context): - ApacheKafkaClusterLink.persist( - context=context, - task_instance=self, - cluster_id=self.cluster_id, - ) + ApacheKafkaClusterLink.persist(context=context) self.log.info("Updating an Apache Kafka cluster.") try: operation = self.hook.update_cluster( @@ -497,14 +519,18 @@ def __init__( self.topic_id = topic_id self.topic = topic + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "location": self.location, + "cluster_id": self.cluster_id, + "topic_id": self.topic_id, + "project_id": self.project_id, + } + def execute(self, context: Context): self.log.info("Creating an Apache Kafka topic.") - ApacheKafkaTopicLink.persist( - context=context, - task_instance=self, - cluster_id=self.cluster_id, - topic_id=self.topic_id, - ) + ApacheKafkaTopicLink.persist(context=context) try: topic_obj = self.hook.create_topic( project_id=self.project_id, @@ -574,8 +600,16 @@ def __init__( self.page_size = page_size self.page_token = page_token + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "location": self.location, + "cluster_id": self.cluster_id, + "project_id": self.project_id, + } + def execute(self, context: Context): - ApacheKafkaClusterLink.persist(context=context, task_instance=self, cluster_id=self.cluster_id) + ApacheKafkaClusterLink.persist(context=context) self.log.info("Listing Topics for cluster %s.", self.cluster_id) try: topic_list_pager = self.hook.list_topics( @@ -636,13 +670,17 @@ def __init__( self.cluster_id = cluster_id self.topic_id = topic_id + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "location": self.location, + "cluster_id": self.cluster_id, + "topic_id": self.topic_id, + "project_id": self.project_id, + } + def execute(self, context: Context): - ApacheKafkaTopicLink.persist( - context=context, - task_instance=self, - cluster_id=self.cluster_id, - topic_id=self.topic_id, - ) + ApacheKafkaTopicLink.persist(context=context) self.log.info("Getting Topic: %s", self.topic_id) try: topic = self.hook.get_topic( @@ -707,13 +745,17 @@ def __init__( self.topic = topic self.update_mask = update_mask + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "location": self.location, + "cluster_id": self.cluster_id, + "topic_id": self.topic_id, + "project_id": self.project_id, + } + def execute(self, context: Context): - ApacheKafkaTopicLink.persist( - context=context, - task_instance=self, - cluster_id=self.cluster_id, - topic_id=self.topic_id, - ) + ApacheKafkaTopicLink.persist(context=context) self.log.info("Updating an Apache Kafka topic.") try: topic_obj = self.hook.update_topic( @@ -833,8 +875,16 @@ def __init__( self.page_size = page_size self.page_token = page_token + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "location": self.location, + "cluster_id": self.cluster_id, + "project_id": self.project_id, + } + def execute(self, context: Context): - ApacheKafkaClusterLink.persist(context=context, task_instance=self, cluster_id=self.cluster_id) + ApacheKafkaClusterLink.persist(context=context) self.log.info("Listing Consumer Groups for cluster %s.", self.cluster_id) try: consumer_group_list_pager = self.hook.list_consumer_groups( @@ -895,13 +945,17 @@ def __init__( self.cluster_id = cluster_id self.consumer_group_id = consumer_group_id + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "location": self.location, + "cluster_id": self.cluster_id, + "consumer_group_id": self.consumer_group_id, + "project_id": self.project_id, + } + def execute(self, context: Context): - ApacheKafkaConsumerGroupLink.persist( - context=context, - task_instance=self, - cluster_id=self.cluster_id, - consumer_group_id=self.consumer_group_id, - ) + ApacheKafkaConsumerGroupLink.persist(context=context) self.log.info("Getting Consumer Group: %s", self.consumer_group_id) try: consumer_group = self.hook.get_consumer_group( @@ -971,13 +1025,17 @@ def __init__( self.consumer_group = consumer_group self.update_mask = update_mask + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "location": self.location, + "cluster_id": self.cluster_id, + "consumer_group_id": self.consumer_group_id, + "project_id": self.project_id, + } + def execute(self, context: Context): - ApacheKafkaConsumerGroupLink.persist( - context=context, - task_instance=self, - cluster_id=self.cluster_id, - consumer_group_id=self.consumer_group_id, - ) + ApacheKafkaConsumerGroupLink.persist(context=context) self.log.info("Updating an Apache Kafka consumer group.") try: consumer_group_obj = self.hook.update_consumer_group( diff --git a/providers/google/src/airflow/providers/google/cloud/operators/mlengine.py b/providers/google/src/airflow/providers/google/cloud/operators/mlengine.py index cf432491395df..96c15be645873 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/mlengine.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/mlengine.py @@ -104,7 +104,6 @@ def execute(self, context: Context): if project_id: MLEngineModelLink.persist( context=context, - task_instance=self, project_id=project_id, model_id=self.model["name"], ) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/pubsub.py b/providers/google/src/airflow/providers/google/cloud/operators/pubsub.py index ba2703dce4803..4df7586c5e915 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/pubsub.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/pubsub.py @@ -183,7 +183,6 @@ def execute(self, context: Context) -> None: self.log.info("Created topic %s", self.topic) PubSubTopicLink.persist( context=context, - task_instance=self, topic_id=self.topic, project_id=self.project_id or hook.project_id, ) @@ -392,7 +391,6 @@ def execute(self, context: Context) -> str: self.log.info("Created subscription for topic %s", self.topic) PubSubSubscriptionLink.persist( context=context, - task_instance=self, subscription_id=self.subscription or result, # result returns subscription name project_id=self.project_id or hook.project_id, ) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/spanner.py b/providers/google/src/airflow/providers/google/cloud/operators/spanner.py index 5dced11d3b4f4..51c4f61f20841 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/spanner.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/spanner.py @@ -122,7 +122,6 @@ def execute(self, context: Context) -> None: ) SpannerInstanceLink.persist( context=context, - task_instance=self, instance_id=self.instance_id, project_id=self.project_id or hook.project_id, ) @@ -290,7 +289,6 @@ def execute(self, context: Context): ) SpannerDatabaseLink.persist( context=context, - task_instance=self, instance_id=self.instance_id, database_id=self.database_id, project_id=self.project_id or hook.project_id, @@ -380,7 +378,6 @@ def execute(self, context: Context) -> bool | None: ) SpannerDatabaseLink.persist( context=context, - task_instance=self, instance_id=self.instance_id, database_id=self.database_id, project_id=self.project_id or hook.project_id, @@ -496,7 +493,6 @@ def execute(self, context: Context) -> None: ) SpannerDatabaseLink.persist( context=context, - task_instance=self, instance_id=self.instance_id, database_id=self.database_id, project_id=self.project_id or hook.project_id, diff --git a/providers/google/src/airflow/providers/google/cloud/operators/speech_to_text.py b/providers/google/src/airflow/providers/google/cloud/operators/speech_to_text.py index 752e7d3454be0..5f21db6806b64 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/speech_to_text.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/speech_to_text.py @@ -117,7 +117,6 @@ def execute(self, context: Context): if self.audio.uri: FileDetailsLink.persist( context=context, - task_instance=self, # Slice from: "gs://{BUCKET_NAME}/{FILE_NAME}" to: "{BUCKET_NAME}/{FILE_NAME}" uri=self.audio.uri[5:], project_id=self.project_id or hook.project_id, diff --git a/providers/google/src/airflow/providers/google/cloud/operators/stackdriver.py b/providers/google/src/airflow/providers/google/cloud/operators/stackdriver.py index 2edc126479e12..d61d8d40f6d6c 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/stackdriver.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/stackdriver.py @@ -145,7 +145,6 @@ def execute(self, context: Context): ) StackdriverPoliciesLink.persist( context=context, - operator_instance=self, project_id=self.project_id or self.hook.project_id, ) return [AlertPolicy.to_dict(policy) for policy in result] @@ -228,7 +227,6 @@ def execute(self, context: Context): ) StackdriverPoliciesLink.persist( context=context, - operator_instance=self, project_id=self.project_id or self.hook.project_id, ) @@ -311,7 +309,6 @@ def execute(self, context: Context): ) StackdriverPoliciesLink.persist( context=context, - operator_instance=self, project_id=self.project_id or self.hook.project_id, ) @@ -394,7 +391,6 @@ def execute(self, context: Context): ) StackdriverPoliciesLink.persist( context=context, - operator_instance=self, project_id=self.project_id or self.hook.project_id, ) @@ -580,7 +576,6 @@ def execute(self, context: Context): ) StackdriverNotificationsLink.persist( context=context, - operator_instance=self, project_id=self.project_id or self.hook.project_id, ) return [NotificationChannel.to_dict(channel) for channel in channels] @@ -666,7 +661,6 @@ def execute(self, context: Context): ) StackdriverNotificationsLink.persist( context=context, - operator_instance=self, project_id=self.project_id or self.hook.project_id, ) @@ -751,7 +745,6 @@ def execute(self, context: Context): ) StackdriverNotificationsLink.persist( context=context, - operator_instance=self, project_id=self.project_id or self.hook.project_id, ) @@ -838,7 +831,6 @@ def execute(self, context: Context): ) StackdriverNotificationsLink.persist( context=context, - operator_instance=self, project_id=self.project_id or self.hook.project_id, ) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/tasks.py b/providers/google/src/airflow/providers/google/cloud/operators/tasks.py index 735f8fbde9fa9..4cb1f9ff05c0a 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/tasks.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/tasks.py @@ -137,7 +137,6 @@ def execute(self, context: Context): metadata=self.metadata, ) CloudTasksQueueLink.persist( - operator_instance=self, context=context, queue_name=queue.name, ) @@ -236,7 +235,6 @@ def execute(self, context: Context): metadata=self.metadata, ) CloudTasksQueueLink.persist( - operator_instance=self, context=context, queue_name=queue.name, ) @@ -319,7 +317,6 @@ def execute(self, context: Context): metadata=self.metadata, ) CloudTasksQueueLink.persist( - operator_instance=self, context=context, queue_name=queue.name, ) @@ -406,7 +403,6 @@ def execute(self, context: Context): metadata=self.metadata, ) CloudTasksLink.persist( - operator_instance=self, context=context, project_id=self.project_id or hook.project_id, ) @@ -564,7 +560,6 @@ def execute(self, context: Context): metadata=self.metadata, ) CloudTasksQueueLink.persist( - operator_instance=self, context=context, queue_name=queue.name, ) @@ -647,7 +642,6 @@ def execute(self, context: Context): metadata=self.metadata, ) CloudTasksQueueLink.persist( - operator_instance=self, context=context, queue_name=queue.name, ) @@ -730,7 +724,6 @@ def execute(self, context: Context): metadata=self.metadata, ) CloudTasksQueueLink.persist( - operator_instance=self, context=context, queue_name=queue.name, ) @@ -830,7 +823,6 @@ def execute(self, context: Context): metadata=self.metadata, ) CloudTasksQueueLink.persist( - operator_instance=self, context=context, queue_name=task.name, ) @@ -923,7 +915,6 @@ def execute(self, context: Context): metadata=self.metadata, ) CloudTasksQueueLink.persist( - operator_instance=self, context=context, queue_name=task.name, ) @@ -1016,7 +1007,6 @@ def execute(self, context: Context): metadata=self.metadata, ) CloudTasksQueueLink.persist( - operator_instance=self, context=context, queue_name=f"projects/{self.project_id or hook.project_id}/" f"locations/{self.location}/queues/{self.queue_name}", @@ -1190,7 +1180,6 @@ def execute(self, context: Context): metadata=self.metadata, ) CloudTasksQueueLink.persist( - operator_instance=self, context=context, queue_name=task.name, ) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/text_to_speech.py b/providers/google/src/airflow/providers/google/cloud/operators/text_to_speech.py index adb4bb6a9d376..35ed1f5d34f76 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/text_to_speech.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/text_to_speech.py @@ -150,7 +150,6 @@ def execute(self, context: Context) -> None: ) FileDetailsLink.persist( context=context, - task_instance=self, uri=f"{self.target_bucket_name}/{self.target_filename}", project_id=cloud_storage_hook.project_id, ) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/translate.py b/providers/google/src/airflow/providers/google/cloud/operators/translate.py index 6f6100076ea6d..dd30b4536ae5e 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/translate.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/translate.py @@ -394,7 +394,6 @@ def execute(self, context: Context) -> dict: self.log.info("Translate text batch job started.") TranslateTextBatchLink.persist( context=context, - task_instance=self, project_id=self.project_id or hook.project_id, output_config=self.output_config, ) @@ -486,9 +485,9 @@ def execute(self, context: Context) -> str: project_id = self.project_id or hook.project_id TranslationNativeDatasetLink.persist( context=context, - task_instance=self, dataset_id=dataset_id, project_id=project_id, + location=self.location, ) return result @@ -556,7 +555,6 @@ def execute(self, context: Context): project_id = self.project_id or hook.project_id TranslationDatasetsListLink.persist( context=context, - task_instance=self, project_id=project_id, ) self.log.info("Requesting datasets list") @@ -657,9 +655,9 @@ def execute(self, context: Context): project_id = self.project_id or hook.project_id TranslationNativeDatasetLink.persist( context=context, - task_instance=self, dataset_id=self.dataset_id, project_id=project_id, + location=self.location, ) hook.wait_for_operation_done(operation=operation, timeout=self.timeout) self.log.info("Importing data finished!") @@ -827,10 +825,10 @@ def execute(self, context: Context) -> str: project_id = self.project_id or hook.project_id TranslationModelLink.persist( context=context, - task_instance=self, dataset_id=self.dataset_id, model_id=model_id, project_id=project_id, + location=self.location, ) return result @@ -898,7 +896,6 @@ def execute(self, context: Context): project_id = self.project_id or hook.project_id TranslationModelsListLink.persist( context=context, - task_instance=self, project_id=project_id, ) self.log.info("Requesting models list") @@ -1141,7 +1138,6 @@ def execute(self, context: Context) -> dict: if self.document_output_config: TranslateResultByOutputConfigLink.persist( context=context, - task_instance=self, project_id=self.project_id or hook.project_id, output_config=self.document_output_config, ) @@ -1304,7 +1300,6 @@ def execute(self, context: Context) -> dict: self.log.info("Batch document translation job started.") TranslateResultByOutputConfigLink.persist( context=context, - task_instance=self, project_id=self.project_id or hook.project_id, output_config=self.output_config, ) @@ -1610,7 +1605,6 @@ def execute(self, context: Context) -> Sequence[str]: project_id = self.project_id or hook.project_id TranslationGlossariesListLink.persist( context=context, - task_instance=self, project_id=project_id, ) self.log.info("Requesting glossaries list") diff --git a/providers/google/src/airflow/providers/google/cloud/operators/translate_speech.py b/providers/google/src/airflow/providers/google/cloud/operators/translate_speech.py index e92865beda17f..01003e070b9e4 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/translate_speech.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/translate_speech.py @@ -173,7 +173,6 @@ def execute(self, context: Context) -> dict: if self.audio.uri: FileDetailsLink.persist( context=context, - task_instance=self, # Slice from: "gs://{BUCKET_NAME}/{FILE_NAME}" to: "{BUCKET_NAME}/{FILE_NAME}" uri=self.audio.uri[5:], project_id=self.project_id or translate_hook.project_id, diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py index 363b04d5d9b80..e1a1f50575ae6 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py @@ -21,7 +21,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from google.api_core.exceptions import NotFound from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault @@ -91,6 +91,13 @@ def __init__( self.impersonation_chain = impersonation_chain self.hook: AutoMLHook | None = None + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.region, + "project_id": self.project_id, + } + def on_kill(self) -> None: """Act as a callback called when the operator is killed; cancel any running job.""" if self.hook: @@ -243,11 +250,11 @@ def execute(self, context: Context): result = Model.to_dict(model) model_id = self.hook.extract_model_id(result) self.xcom_push(context, key="model_id", value=model_id) - VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) + VertexAIModelLink.persist(context=context, model_id=model_id) else: result = model # type: ignore self.xcom_push(context, key="training_id", value=training_id) - VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id) + VertexAITrainingLink.persist(context=context, training_id=training_id) return result @@ -335,11 +342,11 @@ def execute(self, context: Context): result = Model.to_dict(model) model_id = self.hook.extract_model_id(result) self.xcom_push(context, key="model_id", value=model_id) - VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) + VertexAIModelLink.persist(context=context, model_id=model_id) else: result = model # type: ignore self.xcom_push(context, key="training_id", value=training_id) - VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id) + VertexAITrainingLink.persist(context=context, training_id=training_id) return result @@ -458,11 +465,11 @@ def execute(self, context: Context): result = Model.to_dict(model) model_id = self.hook.extract_model_id(result) self.xcom_push(context, key="model_id", value=model_id) - VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) + VertexAIModelLink.persist(context=context, model_id=model_id) else: result = model # type: ignore self.xcom_push(context, key="training_id", value=training_id) - VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id) + VertexAITrainingLink.persist(context=context, training_id=training_id) return result @@ -532,11 +539,11 @@ def execute(self, context: Context): result = Model.to_dict(model) model_id = self.hook.extract_model_id(result) self.xcom_push(context, key="model_id", value=model_id) - VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) + VertexAIModelLink.persist(context=context, model_id=model_id) else: result = model # type: ignore self.xcom_push(context, key="training_id", value=training_id) - VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id) + VertexAITrainingLink.persist(context=context, training_id=training_id) return result @@ -640,6 +647,12 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "project_id": self.project_id, + } + def execute(self, context: Context): hook = AutoMLHook( gcp_conn_id=self.gcp_conn_id, @@ -656,5 +669,5 @@ def execute(self, context: Context): timeout=self.timeout, metadata=self.metadata, ) - VertexAITrainingPipelinesLink.persist(context=context, task_instance=self) + VertexAITrainingPipelinesLink.persist(context=context) return [TrainingPipeline.to_dict(result) for result in results] diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py index 705f6e31154fd..156aefdc6fa6b 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py @@ -231,6 +231,13 @@ def hook(self) -> BatchPredictionJobHook: impersonation_chain=self.impersonation_chain, ) + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.region, + "project_id": self.project_id, + } + def execute(self, context: Context): self.log.info("Creating Batch prediction job") batch_prediction_job: BatchPredictionJobObject = self.hook.submit_batch_prediction_job( @@ -264,7 +271,8 @@ def execute(self, context: Context): self.xcom_push(context, key="batch_prediction_job_id", value=batch_prediction_job_id) VertexAIBatchPredictionJobLink.persist( - context=context, task_instance=self, batch_prediction_job_id=batch_prediction_job_id + context=context, + batch_prediction_job_id=batch_prediction_job_id, ) if self.deferrable: @@ -427,6 +435,13 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.region, + "project_id": self.project_id, + } + def execute(self, context: Context): hook = BatchPredictionJobHook( gcp_conn_id=self.gcp_conn_id, @@ -445,7 +460,8 @@ def execute(self, context: Context): ) self.log.info("Batch prediction job was gotten.") VertexAIBatchPredictionJobLink.persist( - context=context, task_instance=self, batch_prediction_job_id=self.batch_prediction_job + context=context, + batch_prediction_job_id=self.batch_prediction_job, ) return BatchPredictionJob.to_dict(result) except NotFound: @@ -517,6 +533,12 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "project_id": self.project_id, + } + def execute(self, context: Context): hook = BatchPredictionJobHook( gcp_conn_id=self.gcp_conn_id, @@ -533,5 +555,5 @@ def execute(self, context: Context): timeout=self.timeout, metadata=self.metadata, ) - VertexAIBatchPredictionJobListLink.persist(context=context, task_instance=self) + VertexAIBatchPredictionJobListLink.persist(context=context) return [BatchPredictionJob.to_dict(result) for result in results] diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py index 4fdac797dca8e..9543f27c71924 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py @@ -170,6 +170,13 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.region, + "project_id": self.project_id, + } + def execute_complete(self, context: Context, event: dict[str, Any]) -> dict[str, Any] | None: if event["status"] == "error": raise AirflowException(event["message"]) @@ -180,7 +187,7 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> dict[str, model = training_pipeline["model_to_upload"] model_id = self.hook.extract_model_id(model) self.xcom_push(context, key="model_id", value=model_id) - VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) + VertexAIModelLink.persist(context=context, model_id=model_id) return model except KeyError: self.log.warning( @@ -585,12 +592,12 @@ def execute(self, context: Context): result = Model.to_dict(model) model_id = self.hook.extract_model_id(result) self.xcom_push(context, key="model_id", value=model_id) - VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) + VertexAIModelLink.persist(context=context, model_id=model_id) else: result = model # type: ignore self.xcom_push(context, key="training_id", value=training_id) self.xcom_push(context, key="custom_job_id", value=custom_job_id) - VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id) + VertexAITrainingLink.persist(context=context, training_id=training_id) return result def invoke_defer(self, context: Context) -> None: @@ -649,7 +656,7 @@ def invoke_defer(self, context: Context) -> None: custom_container_training_job_obj.wait_for_resource_creation() training_pipeline_id: str = custom_container_training_job_obj.name self.xcom_push(context, key="training_id", value=training_pipeline_id) - VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_pipeline_id) + VertexAITrainingLink.persist(context=context, training_id=training_pipeline_id) self.defer( trigger=CustomContainerTrainingJobTrigger( conn_id=self.gcp_conn_id, @@ -1042,12 +1049,12 @@ def execute(self, context: Context): result = Model.to_dict(model) model_id = self.hook.extract_model_id(result) self.xcom_push(context, key="model_id", value=model_id) - VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) + VertexAIModelLink.persist(context=context, model_id=model_id) else: result = model # type: ignore self.xcom_push(context, key="training_id", value=training_id) self.xcom_push(context, key="custom_job_id", value=custom_job_id) - VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id) + VertexAITrainingLink.persist(context=context, training_id=training_id) return result def invoke_defer(self, context: Context) -> None: @@ -1107,7 +1114,7 @@ def invoke_defer(self, context: Context) -> None: custom_python_training_job_obj.wait_for_resource_creation() training_pipeline_id: str = custom_python_training_job_obj.name self.xcom_push(context, key="training_id", value=training_pipeline_id) - VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_pipeline_id) + VertexAITrainingLink.persist(context=context, training_id=training_pipeline_id) self.defer( trigger=CustomPythonPackageTrainingJobTrigger( conn_id=self.gcp_conn_id, @@ -1505,12 +1512,12 @@ def execute(self, context: Context): result = Model.to_dict(model) model_id = self.hook.extract_model_id(result) self.xcom_push(context, key="model_id", value=model_id) - VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) + VertexAIModelLink.persist(context=context, model_id=model_id) else: result = model # type: ignore self.xcom_push(context, key="training_id", value=training_id) self.xcom_push(context, key="custom_job_id", value=custom_job_id) - VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id) + VertexAITrainingLink.persist(context=context, training_id=training_id) return result def invoke_defer(self, context: Context) -> None: @@ -1570,7 +1577,7 @@ def invoke_defer(self, context: Context) -> None: custom_training_job_obj.wait_for_resource_creation() training_pipeline_id: str = custom_training_job_obj.name self.xcom_push(context, key="training_id", value=training_pipeline_id) - VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_pipeline_id) + VertexAITrainingLink.persist(context=context, training_id=training_pipeline_id) self.defer( trigger=CustomTrainingJobTrigger( conn_id=self.gcp_conn_id, @@ -1748,6 +1755,12 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "project_id": self.project_id, + } + def execute(self, context: Context): hook = CustomJobHook( gcp_conn_id=self.gcp_conn_id, @@ -1764,5 +1777,5 @@ def execute(self, context: Context): timeout=self.timeout, metadata=self.metadata, ) - VertexAITrainingPipelinesLink.persist(context=context, task_instance=self) + VertexAITrainingPipelinesLink.persist(context=context) return [TrainingPipeline.to_dict(result) for result in results] diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/dataset.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/dataset.py index 0d5754fc3f9e0..594b8e4fcdab8 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/dataset.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/dataset.py @@ -20,7 +20,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from google.api_core.exceptions import NotFound from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault @@ -85,6 +85,13 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.region, + "project_id": self.project_id, + } + def execute(self, context: Context): hook = DatasetHook( gcp_conn_id=self.gcp_conn_id, @@ -107,7 +114,7 @@ def execute(self, context: Context): self.log.info("Dataset was created. Dataset id: %s", dataset_id) self.xcom_push(context, key="dataset_id", value=dataset_id) - VertexAIDatasetLink.persist(context=context, task_instance=self, dataset_id=dataset_id) + VertexAIDatasetLink.persist(context=context, dataset_id=dataset_id) return dataset @@ -160,6 +167,13 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.region, + "project_id": self.project_id, + } + def execute(self, context: Context): hook = DatasetHook( gcp_conn_id=self.gcp_conn_id, @@ -177,7 +191,7 @@ def execute(self, context: Context): timeout=self.timeout, metadata=self.metadata, ) - VertexAIDatasetLink.persist(context=context, task_instance=self, dataset_id=self.dataset_id) + VertexAIDatasetLink.persist(context=context, dataset_id=self.dataset_id) self.log.info("Dataset was gotten.") return Dataset.to_dict(dataset_obj) except NotFound: @@ -451,6 +465,12 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "project_id": self.project_id, + } + def execute(self, context: Context): hook = DatasetHook( gcp_conn_id=self.gcp_conn_id, @@ -468,7 +488,7 @@ def execute(self, context: Context): timeout=self.timeout, metadata=self.metadata, ) - VertexAIDatasetListLink.persist(context=context, task_instance=self) + VertexAIDatasetListLink.persist(context=context) return [Dataset.to_dict(result) for result in results] diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py index db485d5cfcc77..32a64e17bcabf 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py @@ -21,7 +21,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from google.api_core.exceptions import NotFound from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault @@ -93,6 +93,13 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.region, + "project_id": self.project_id, + } + def execute(self, context: Context): hook = EndpointServiceHook( gcp_conn_id=self.gcp_conn_id, @@ -116,7 +123,7 @@ def execute(self, context: Context): self.log.info("Endpoint was created. Endpoint ID: %s", endpoint_id) self.xcom_push(context, key="endpoint_id", value=endpoint_id) - VertexAIEndpointLink.persist(context=context, task_instance=self, endpoint_id=endpoint_id) + VertexAIEndpointLink.persist(context=context, endpoint_id=endpoint_id) return endpoint @@ -255,6 +262,13 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.region, + "project_id": self.project_id, + } + def execute(self, context: Context): hook = EndpointServiceHook( gcp_conn_id=self.gcp_conn_id, @@ -279,7 +293,7 @@ def execute(self, context: Context): self.log.info("Model was deployed. Deployed Model ID: %s", deployed_model_id) self.xcom_push(context, key="deployed_model_id", value=deployed_model_id) - VertexAIModelLink.persist(context=context, task_instance=self, model_id=deployed_model_id) + VertexAIModelLink.persist(context=context, model_id=deployed_model_id) return deploy_model @@ -330,6 +344,13 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.region, + "project_id": self.project_id, + } + def execute(self, context: Context): hook = EndpointServiceHook( gcp_conn_id=self.gcp_conn_id, @@ -346,7 +367,7 @@ def execute(self, context: Context): timeout=self.timeout, metadata=self.metadata, ) - VertexAIEndpointLink.persist(context=context, task_instance=self, endpoint_id=self.endpoint_id) + VertexAIEndpointLink.persist(context=context, endpoint_id=self.endpoint_id) self.log.info("Endpoint was gotten.") return Endpoint.to_dict(endpoint_obj) except NotFound: @@ -429,6 +450,12 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "project_id": self.project_id, + } + def execute(self, context: Context): hook = EndpointServiceHook( gcp_conn_id=self.gcp_conn_id, @@ -446,7 +473,7 @@ def execute(self, context: Context): timeout=self.timeout, metadata=self.metadata, ) - VertexAIEndpointListLink.persist(context=context, task_instance=self) + VertexAIEndpointListLink.persist(context=context) return [Endpoint.to_dict(result) for result in results] @@ -582,6 +609,13 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.region, + "project_id": self.project_id, + } + def execute(self, context: Context): hook = EndpointServiceHook( gcp_conn_id=self.gcp_conn_id, @@ -599,5 +633,5 @@ def execute(self, context: Context): metadata=self.metadata, ) self.log.info("Endpoint was updated") - VertexAIEndpointLink.persist(context=context, task_instance=self, endpoint_id=self.endpoint_id) + VertexAIEndpointLink.persist(context=context, endpoint_id=self.endpoint_id) return Endpoint.to_dict(result) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py index b0928f67ca9b0..86d278c9ab1ba 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py @@ -258,9 +258,7 @@ def execute(self, context: Context): self.log.info("Hyperparameter Tuning job was created. Job id: %s", hyperparameter_tuning_job_id) self.xcom_push(context, key="hyperparameter_tuning_job_id", value=hyperparameter_tuning_job_id) - VertexAITrainingLink.persist( - context=context, task_instance=self, training_id=hyperparameter_tuning_job_id - ) + VertexAITrainingLink.persist(context=context, training_id=hyperparameter_tuning_job_id) if self.deferrable: self.defer( @@ -355,9 +353,7 @@ def execute(self, context: Context): timeout=self.timeout, metadata=self.metadata, ) - VertexAITrainingLink.persist( - context=context, task_instance=self, training_id=self.hyperparameter_tuning_job_id - ) + VertexAITrainingLink.persist(context=context, training_id=self.hyperparameter_tuning_job_id) self.log.info("Hyperparameter tuning job was gotten.") return types.HyperparameterTuningJob.to_dict(result) except NotFound: @@ -487,6 +483,12 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "project_id": self.project_id, + } + def execute(self, context: Context): hook = HyperparameterTuningJobHook( gcp_conn_id=self.gcp_conn_id, @@ -503,5 +505,5 @@ def execute(self, context: Context): timeout=self.timeout, metadata=self.metadata, ) - VertexAIHyperparameterTuningJobListLink.persist(context=context, task_instance=self) + VertexAIHyperparameterTuningJobListLink.persist(context=context) return [types.HyperparameterTuningJob.to_dict(result) for result in results] diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/model_service.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/model_service.py index 79980685c0b14..d5f9c26e5a078 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/model_service.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/model_service.py @@ -20,7 +20,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from google.api_core.exceptions import NotFound from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault @@ -161,6 +161,13 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.region, + "project_id": self.project_id, + } + def execute(self, context: Context): hook = ModelServiceHook( gcp_conn_id=self.gcp_conn_id, @@ -180,7 +187,7 @@ def execute(self, context: Context): self.log.info("Model found. Model ID: %s", self.model_id) self.xcom_push(context, key="model_id", value=self.model_id) - VertexAIModelLink.persist(context=context, task_instance=self, model_id=self.model_id) + VertexAIModelLink.persist(context=context, model_id=self.model_id) return Model.to_dict(model) except NotFound: self.log.info("The Model ID %s does not exist.", self.model_id) @@ -257,7 +264,12 @@ def execute(self, context: Context): metadata=self.metadata, ) hook.wait_for_operation(timeout=self.timeout, operation=operation) - VertexAIModelExportLink.persist(context=context, task_instance=self) + VertexAIModelExportLink.persist( + context=context, + output_config=self.output_config, + model_id=self.model_id, + project_id=self.project_id, + ) self.log.info("Model was exported.") except NotFound: self.log.info("The Model ID %s does not exist.", self.model_id) @@ -335,6 +347,12 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "project_id": self.project_id, + } + def execute(self, context: Context): hook = ModelServiceHook( gcp_conn_id=self.gcp_conn_id, @@ -352,7 +370,7 @@ def execute(self, context: Context): timeout=self.timeout, metadata=self.metadata, ) - VertexAIModelListLink.persist(context=context, task_instance=self) + VertexAIModelListLink.persist(context=context) return [Model.to_dict(result) for result in results] @@ -407,6 +425,13 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.region, + "project_id": self.project_id, + } + def execute(self, context: Context): hook = ModelServiceHook( gcp_conn_id=self.gcp_conn_id, @@ -429,7 +454,7 @@ def execute(self, context: Context): self.log.info("Model was uploaded. Model ID: %s", model_id) self.xcom_push(context, key="model_id", value=model_id) - VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) + VertexAIModelLink.persist(context=context, model_id=model_id) return model_resp @@ -553,6 +578,13 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.region, + "project_id": self.project_id, + } + def execute(self, context: Context): hook = ModelServiceHook( gcp_conn_id=self.gcp_conn_id, @@ -571,7 +603,7 @@ def execute(self, context: Context): timeout=self.timeout, metadata=self.metadata, ) - VertexAIModelLink.persist(context=context, task_instance=self, model_id=self.model_id) + VertexAIModelLink.persist(context=context, model_id=self.model_id) return Model.to_dict(updated_model) @@ -627,6 +659,13 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.region, + "project_id": self.project_id, + } + def execute(self, context: Context): hook = ModelServiceHook( gcp_conn_id=self.gcp_conn_id, @@ -645,7 +684,7 @@ def execute(self, context: Context): timeout=self.timeout, metadata=self.metadata, ) - VertexAIModelLink.persist(context=context, task_instance=self, model_id=self.model_id) + VertexAIModelLink.persist(context=context, model_id=self.model_id) return Model.to_dict(updated_model) @@ -701,6 +740,13 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.region, + "project_id": self.project_id, + } + def execute(self, context: Context): hook = ModelServiceHook( gcp_conn_id=self.gcp_conn_id, @@ -721,7 +767,7 @@ def execute(self, context: Context): timeout=self.timeout, metadata=self.metadata, ) - VertexAIModelLink.persist(context=context, task_instance=self, model_id=self.model_id) + VertexAIModelLink.persist(context=context, model_id=self.model_id) return Model.to_dict(updated_model) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py index 30352c541712c..d12adaaf27222 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py @@ -166,6 +166,13 @@ def __init__( self.deferrable = deferrable self.poll_interval = poll_interval + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.region, + "project_id": self.project_id, + } + def execute(self, context: Context): self.log.info("Running Pipeline job") pipeline_job_obj: PipelineJob = self.hook.submit_pipeline_job( @@ -189,7 +196,7 @@ def execute(self, context: Context): pipeline_job_id = pipeline_job_obj.job_id self.log.info("Pipeline job was created. Job id: %s", pipeline_job_id) self.xcom_push(context, key="pipeline_job_id", value=pipeline_job_id) - VertexAIPipelineJobLink.persist(context=context, task_instance=self, pipeline_id=pipeline_job_id) + VertexAIPipelineJobLink.persist(context=context, pipeline_id=pipeline_job_id) if self.deferrable: pipeline_job_obj.wait_for_resource_creation() @@ -280,6 +287,13 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.region, + "project_id": self.project_id, + } + def execute(self, context: Context): hook = PipelineJobHook( gcp_conn_id=self.gcp_conn_id, @@ -296,9 +310,7 @@ def execute(self, context: Context): timeout=self.timeout, metadata=self.metadata, ) - VertexAIPipelineJobLink.persist( - context=context, task_instance=self, pipeline_id=self.pipeline_job_id - ) + VertexAIPipelineJobLink.persist(context=context, pipeline_id=self.pipeline_job_id) self.log.info("Pipeline job was gotten.") return types.PipelineJob.to_dict(result) except NotFound: @@ -412,6 +424,13 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "region": self.region, + "project_id": self.project_id, + } + def execute(self, context: Context): hook = PipelineJobHook( gcp_conn_id=self.gcp_conn_id, @@ -428,7 +447,7 @@ def execute(self, context: Context): timeout=self.timeout, metadata=self.metadata, ) - VertexAIPipelineJobListLink.persist(context=context, task_instance=self) + VertexAIPipelineJobListLink.persist(context=context) return [types.PipelineJob.to_dict(result) for result in results] diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/ray.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/ray.py index 8dfdc03e7c9eb..6368e81eb0170 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/ray.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/ray.py @@ -193,7 +193,9 @@ def execute(self, context: Context): key="cluster_id", value=cluster_id, ) - VertexAIRayClusterLink.persist(context=context, task_instance=self, cluster_id=cluster_id) + VertexAIRayClusterLink.persist( + context=context, location=self.location, cluster_id=cluster_id, project_id=self.project_id + ) self.log.info("Ray cluster was created.") except Exception as error: raise AirflowException(error) @@ -220,7 +222,7 @@ class ListRayClustersOperator(RayBaseOperator): operator_extra_links = (VertexAIRayClusterListLink(),) def execute(self, context: Context): - VertexAIRayClusterListLink.persist(context=context, task_instance=self) + VertexAIRayClusterListLink.persist(context=context, project_id=self.project_id) self.log.info("Listing Clusters from location %s.", self.location) try: ray_cluster_list = self.hook.list_ray_clusters( @@ -268,8 +270,9 @@ def __init__( def execute(self, context: Context): VertexAIRayClusterLink.persist( context=context, - task_instance=self, + location=self.location, cluster_id=self.cluster_id, + project_id=self.project_id, ) self.log.info("Getting Cluster: %s", self.cluster_id) try: @@ -325,8 +328,9 @@ def __init__( def execute(self, context: Context): VertexAIRayClusterLink.persist( context=context, - task_instance=self, + location=self.location, cluster_id=self.cluster_id, + project_id=self.project_id, ) self.log.info("Updating a Ray cluster.") try: diff --git a/providers/google/src/airflow/providers/google/cloud/operators/workflows.py b/providers/google/src/airflow/providers/google/cloud/operators/workflows.py index 70b8eaef3950c..81f97d18a6d97 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/workflows.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/workflows.py @@ -147,7 +147,6 @@ def execute(self, context: Context): WorkflowsWorkflowDetailsLink.persist( context=context, - task_instance=self, location_id=self.location, workflow_id=self.workflow_id, project_id=self.project_id or hook.project_id, @@ -235,7 +234,6 @@ def execute(self, context: Context): WorkflowsWorkflowDetailsLink.persist( context=context, - task_instance=self, location_id=self.location, workflow_id=self.workflow_id, project_id=self.project_id or hook.project_id, @@ -368,7 +366,6 @@ def execute(self, context: Context): WorkflowsListOfWorkflowsLink.persist( context=context, - task_instance=self, project_id=self.project_id or hook.project_id, ) @@ -434,7 +431,6 @@ def execute(self, context: Context): WorkflowsWorkflowDetailsLink.persist( context=context, - task_instance=self, location_id=self.location, workflow_id=self.workflow_id, project_id=self.project_id or hook.project_id, @@ -509,7 +505,6 @@ def execute(self, context: Context): WorkflowsExecutionLink.persist( context=context, - task_instance=self, location_id=self.location, workflow_id=self.workflow_id, execution_id=execution_id, @@ -582,7 +577,6 @@ def execute(self, context: Context): WorkflowsExecutionLink.persist( context=context, - task_instance=self, location_id=self.location, workflow_id=self.workflow_id, execution_id=self.execution_id, @@ -661,7 +655,6 @@ def execute(self, context: Context): WorkflowsWorkflowDetailsLink.persist( context=context, - task_instance=self, location_id=self.location, workflow_id=self.workflow_id, project_id=self.project_id or hook.project_id, @@ -737,7 +730,6 @@ def execute(self, context: Context): WorkflowsExecutionLink.persist( context=context, - task_instance=self, location_id=self.location, workflow_id=self.workflow_id, execution_id=self.execution_id, diff --git a/providers/google/src/airflow/providers/google/cloud/sensors/bigtable.py b/providers/google/src/airflow/providers/google/cloud/sensors/bigtable.py index f67ebfc59a68c..5196b77b34f89 100644 --- a/providers/google/src/airflow/providers/google/cloud/sensors/bigtable.py +++ b/providers/google/src/airflow/providers/google/cloud/sensors/bigtable.py @@ -20,7 +20,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import google.api_core.exceptions from google.cloud.bigtable import enums @@ -89,6 +89,13 @@ def __init__( self.impersonation_chain = impersonation_chain super().__init__(**kwargs) + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "instance_id": self.instance_id, + "project_id": self.project_id, + } + def poke(self, context: Context) -> bool: hook = BigtableHook( gcp_conn_id=self.gcp_conn_id, @@ -119,5 +126,5 @@ def poke(self, context: Context) -> bool: return False self.log.info("Table '%s' is replicated.", self.table_id) - BigtableTablesLink.persist(context=context, task_instance=self) + BigtableTablesLink.persist(context=context) return True diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py index abe87bf75e3e1..603fe3929dae3 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py @@ -183,7 +183,6 @@ def execute(self, context: Context) -> None: dest_table_info = self._job_conf["configuration"]["copy"]["destinationTable"] BigQueryTableLink.persist( context=context, - task_instance=self, dataset_id=dest_table_info["datasetId"], project_id=dest_table_info["projectId"], table_id=dest_table_info["tableId"], diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py index 046cfc7a9dcee..1531807884673 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py @@ -254,7 +254,6 @@ def execute(self, context: Context): dataset_id, project_id, table_id = conf["datasetId"], conf["projectId"], conf["tableId"] BigQueryTableLink.persist( context=context, - task_instance=self, dataset_id=dataset_id, project_id=project_id, table_id=table_id, diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py index 56b58523eac2b..ec63aeee5f6cd 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py @@ -101,7 +101,6 @@ def persist_links(self, context: Context) -> None: project_id, dataset_id, table_id = self.source_project_dataset_table.split(".") BigQueryTableLink.persist( context=context, - task_instance=self, dataset_id=dataset_id, project_id=project_id, table_id=table_id, diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py b/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py index 46c0cc92b628e..08a119a4d3a53 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py @@ -373,7 +373,6 @@ def execute(self, context: Context): BigQueryTableLink.persist( context=context, - task_instance=self, dataset_id=table_obj_api_repr["tableReference"]["datasetId"], project_id=table_obj_api_repr["tableReference"]["projectId"], table_id=table_obj_api_repr["tableReference"]["tableId"], diff --git a/providers/google/src/airflow/providers/google/common/links/storage.py b/providers/google/src/airflow/providers/google/common/links/storage.py index 562a79355c7bd..8041ae3a265e3 100644 --- a/providers/google/src/airflow/providers/google/common/links/storage.py +++ b/providers/google/src/airflow/providers/google/common/links/storage.py @@ -18,18 +18,12 @@ from __future__ import annotations -from typing import TYPE_CHECKING - from airflow.providers.google.cloud.links.base import BaseGoogleLink BASE_LINK = "https://console.cloud.google.com" GCS_STORAGE_LINK = BASE_LINK + "/storage/browser/{uri};tab=objects?project={project_id}" GCS_FILE_DETAILS_LINK = BASE_LINK + "/storage/browser/_details/{uri};tab=live_object?project={project_id}" -if TYPE_CHECKING: - from airflow.models import BaseOperator - from airflow.utils.context import Context - class StorageLink(BaseGoogleLink): """Helper class for constructing GCS Storage link.""" @@ -38,14 +32,6 @@ class StorageLink(BaseGoogleLink): key = "storage_conf" format_str = GCS_STORAGE_LINK - @staticmethod - def persist(context: Context, task_instance, uri: str, project_id: str | None): - task_instance.xcom_push( - context=context, - key=StorageLink.key, - value={"uri": uri, "project_id": project_id}, - ) - class FileDetailsLink(BaseGoogleLink): """Helper class for constructing GCS file details link.""" @@ -53,11 +39,3 @@ class FileDetailsLink(BaseGoogleLink): name = "GCS File Details" key = "file_details" format_str = GCS_FILE_DETAILS_LINK - - @staticmethod - def persist(context: Context, task_instance: BaseOperator, uri: str, project_id: str | None): - task_instance.xcom_push( - context=context, - key=FileDetailsLink.key, - value={"uri": uri, "project_id": project_id}, - ) diff --git a/providers/google/tests/unit/google/cloud/links/test_alloy_db.py b/providers/google/tests/unit/google/cloud/links/test_alloy_db.py index be4c6d447163a..41c7f22d86d7a 100644 --- a/providers/google/tests/unit/google/cloud/links/test_alloy_db.py +++ b/providers/google/tests/unit/google/cloud/links/test_alloy_db.py @@ -17,9 +17,8 @@ # under the License. from __future__ import annotations -from unittest import mock - from airflow.providers.google.cloud.links.alloy_db import ( + AlloyDBBackupsLink, AlloyDBClusterLink, AlloyDBUsersLink, ) @@ -37,6 +36,9 @@ EXPECTED_ALLOY_DB_USERS_LINK_FORMAT_STR = ( "/alloydb/locations/{location_id}/clusters/{cluster_id}/users?project={project_id}" ) +EXPECTED_ALLOY_DB_BACKUP_LINK_NAME = "AlloyDB Backups" +EXPECTED_ALLOY_DB_BACKUP_LINK_KEY = "alloy_db_backups" +EXPECTED_ALLOY_DB_BACKUP_LINK_FORMAT_STR = "/alloydb/backups?project={project_id}" class TestAlloyDBClusterLink: @@ -45,27 +47,6 @@ def test_class_attributes(self): assert AlloyDBClusterLink.name == EXPECTED_ALLOY_DB_CLUSTER_LINK_NAME assert AlloyDBClusterLink.format_str == EXPECTED_ALLOY_DB_CLUSTER_LINK_FORMAT_STR - def test_persist(self): - mock_context, mock_task_instance = mock.MagicMock(), mock.MagicMock() - - AlloyDBClusterLink.persist( - context=mock_context, - task_instance=mock_task_instance, - location_id=TEST_LOCATION, - cluster_id=TEST_CLUSTER_ID, - project_id=TEST_PROJECT_ID, - ) - - mock_task_instance.xcom_push.assert_called_once_with( - mock_context, - key=EXPECTED_ALLOY_DB_CLUSTER_LINK_KEY, - value={ - "location_id": TEST_LOCATION, - "cluster_id": TEST_CLUSTER_ID, - "project_id": TEST_PROJECT_ID, - }, - ) - class TestAlloyDBUsersLink: def test_class_attributes(self): @@ -73,23 +54,9 @@ def test_class_attributes(self): assert AlloyDBUsersLink.name == EXPECTED_ALLOY_DB_USERS_LINK_NAME assert AlloyDBUsersLink.format_str == EXPECTED_ALLOY_DB_USERS_LINK_FORMAT_STR - def test_persist(self): - mock_context, mock_task_instance = mock.MagicMock(), mock.MagicMock() - - AlloyDBUsersLink.persist( - context=mock_context, - task_instance=mock_task_instance, - location_id=TEST_LOCATION, - cluster_id=TEST_CLUSTER_ID, - project_id=TEST_PROJECT_ID, - ) - mock_task_instance.xcom_push.assert_called_once_with( - mock_context, - key=EXPECTED_ALLOY_DB_USERS_LINK_KEY, - value={ - "location_id": TEST_LOCATION, - "cluster_id": TEST_CLUSTER_ID, - "project_id": TEST_PROJECT_ID, - }, - ) +class TestAlloyDBBackupsLink: + def test_class_attributes(self): + assert AlloyDBBackupsLink.key == EXPECTED_ALLOY_DB_BACKUP_LINK_KEY + assert AlloyDBBackupsLink.name == EXPECTED_ALLOY_DB_BACKUP_LINK_NAME + assert AlloyDBBackupsLink.format_str == EXPECTED_ALLOY_DB_BACKUP_LINK_FORMAT_STR diff --git a/providers/google/tests/unit/google/cloud/links/test_base_link.py b/providers/google/tests/unit/google/cloud/links/test_base_link.py new file mode 100644 index 0000000000000..f0af84dc2ae2a --- /dev/null +++ b/providers/google/tests/unit/google/cloud/links/test_base_link.py @@ -0,0 +1,129 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Any +from unittest import mock + +import pytest + +from airflow.providers.google.cloud.links.base import BaseGoogleLink +from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator +from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS + +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.execution_time.comms import XComResult + +TEST_LOCATION = "test-location" +TEST_CLUSTER_ID = "test-cluster-id" +TEST_PROJECT_ID = "test-project-id" +EXPECTED_GOOGLE_LINK_KEY = "google_link_for_test" +EXPECTED_GOOGLE_LINK_NAME = "Google Link for Test" +EXPECTED_GOOGLE_LINK_FORMAT = "/services/locations/{location}/clusters/{cluster_id}?project={project_id}" +EXPECTED_GOOGLE_LINK = "https://console.cloud.google.com" + EXPECTED_GOOGLE_LINK_FORMAT.format( + location=TEST_LOCATION, cluster_id=TEST_CLUSTER_ID, project_id=TEST_PROJECT_ID +) + + +class GoogleLink(BaseGoogleLink): + key = EXPECTED_GOOGLE_LINK_KEY + name = EXPECTED_GOOGLE_LINK_NAME + format_str = EXPECTED_GOOGLE_LINK_FORMAT + + +class TestBaseGoogleLink: + def test_class_attributes(self): + assert GoogleLink.key == EXPECTED_GOOGLE_LINK_KEY + assert GoogleLink.name == EXPECTED_GOOGLE_LINK_NAME + assert GoogleLink.format_str == EXPECTED_GOOGLE_LINK_FORMAT + + def test_persist(self): + mock_context = mock.MagicMock() + + if AIRFLOW_V_3_0_PLUS: + GoogleLink.persist( + context=mock_context, + location=TEST_LOCATION, + cluster_id=TEST_CLUSTER_ID, + project_id=TEST_PROJECT_ID, + ) + mock_context["ti"].xcom_push.assert_called_once_with( + key=EXPECTED_GOOGLE_LINK_KEY, + value={ + "location": TEST_LOCATION, + "cluster_id": TEST_CLUSTER_ID, + "project_id": TEST_PROJECT_ID, + }, + ) + else: + GoogleLink.persist( + context=mock_context, + location=TEST_LOCATION, + cluster_id=TEST_CLUSTER_ID, + project_id=TEST_PROJECT_ID, + ) + + +class MyOperator(GoogleCloudBaseOperator): + operator_extra_links = (GoogleLink(),) + + def __init__(self, project_id: str, location: str, cluster_id: str, **kwargs): + super().__init__(**kwargs) + self.project_id = project_id + self.location = location + self.cluster_id = cluster_id + + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "project_id": self.project_id, + "cluster_id": self.cluster_id, + "location": self.location, + } + + def execute(self, context) -> Any: + GoogleLink.persist(context=context) + + +class TestOperatorWithBaseGoogleLink: + @pytest.mark.db_test + def test_get_link(self, create_task_instance_of_operator, session, mock_supervisor_comms): + expected_url = EXPECTED_GOOGLE_LINK + link = GoogleLink() + ti = create_task_instance_of_operator( + MyOperator, + dag_id="test_link_dag", + task_id="test_link_task", + location=TEST_LOCATION, + cluster_id=TEST_CLUSTER_ID, + project_id=TEST_PROJECT_ID, + ) + session.add(ti) + session.commit() + + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + mock_supervisor_comms.send.return_value = XComResult( + key="key", + value={ + "cluster_id": ti.task.cluster_id, + "location": ti.task.location, + "project_id": ti.task.project_id, + }, + ) + actual_url = link.get_link(operator=ti.task, ti_key=ti.key) + assert actual_url == expected_url diff --git a/providers/google/tests/unit/google/cloud/links/test_cloud_run.py b/providers/google/tests/unit/google/cloud/links/test_cloud_run.py index 5f3c348698c19..5e156af8d8515 100644 --- a/providers/google/tests/unit/google/cloud/links/test_cloud_run.py +++ b/providers/google/tests/unit/google/cloud/links/test_cloud_run.py @@ -37,20 +37,21 @@ class TestCloudRunJobLoggingLink: def test_class_attributes(self): assert CloudRunJobLoggingLink.key == "log_uri" assert CloudRunJobLoggingLink.name == "Cloud Run Job Logging" + assert CloudRunJobLoggingLink.format_str == "{log_uri}" def test_persist(self): - mock_context, mock_task_instance = mock.MagicMock(), mock.MagicMock() + mock_context = mock.MagicMock() + mock_context["ti"] = mock.MagicMock() + mock_context["task"] = mock.MagicMock() CloudRunJobLoggingLink.persist( context=mock_context, - task_instance=mock_task_instance, log_uri=TEST_LOG_URI, ) - mock_task_instance.xcom_push.assert_called_once_with( - mock_context, + mock_context["ti"].xcom_push.assert_called_once_with( key=CloudRunJobLoggingLink.key, - value=TEST_LOG_URI, + value={"log_uri": TEST_LOG_URI}, ) @pytest.mark.db_test @@ -66,11 +67,13 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis ) session.add(ti) session.commit() - link.persist(context={"ti": ti}, task_instance=ti.task, log_uri=TEST_LOG_URI) - if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + + link.persist(context={"ti": ti, "task": ti.task}, log_uri=TEST_LOG_URI) + + if mock_supervisor_comms: mock_supervisor_comms.send.return_value = XComResult( key="key", - value=TEST_LOG_URI, + value={"log_uri": TEST_LOG_URI}, ) actual_url = link.get_link(operator=ti.task, ti_key=ti.key) assert actual_url == TEST_LOG_URI diff --git a/providers/google/tests/unit/google/cloud/links/test_dataplex.py b/providers/google/tests/unit/google/cloud/links/test_dataplex.py index babdf9127c8e1..3b4f750dabb8d 100644 --- a/providers/google/tests/unit/google/cloud/links/test_dataplex.py +++ b/providers/google/tests/unit/google/cloud/links/test_dataplex.py @@ -120,7 +120,6 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis ) session.add(ti) session.commit() - link.persist(context={"ti": ti}, task_instance=ti.task) if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: mock_supervisor_comms.send.return_value = XComResult( @@ -151,7 +150,9 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis ) session.add(ti) session.commit() - link.persist(context={"ti": ti}, task_instance=ti.task) + + link.persist(context={"ti": ti, "task": ti.task}) + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: mock_supervisor_comms.send.return_value = XComResult( key="key", @@ -181,7 +182,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis ) session.add(ti) session.commit() - link.persist(context={"ti": ti}, task_instance=ti.task) + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: mock_supervisor_comms.send.return_value = XComResult( key="key", @@ -210,7 +211,9 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis ) session.add(ti) session.commit() - link.persist(context={"ti": ti}, task_instance=ti.task) + + link.persist(context={"ti": ti, "task": ti.task}) + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: mock_supervisor_comms.send.return_value = XComResult( key="key", @@ -240,7 +243,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis ) session.add(ti) session.commit() - link.persist(context={"ti": ti}, task_instance=ti.task) + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: mock_supervisor_comms.send.return_value = XComResult( key="key", @@ -268,7 +271,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis ) session.add(ti) session.commit() - link.persist(context={"ti": ti}, task_instance=ti.task) + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: mock_supervisor_comms.send.return_value = XComResult( key="key", @@ -298,7 +301,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis ) session.add(ti) session.commit() - link.persist(context={"ti": ti}, task_instance=ti.task) + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: mock_supervisor_comms.send.return_value = XComResult( key="key", @@ -326,7 +329,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis ) session.add(ti) session.commit() - link.persist(context={"ti": ti}, task_instance=ti.task) + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: mock_supervisor_comms.send.return_value = XComResult( key="key", @@ -356,7 +359,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis ) session.add(ti) session.commit() - link.persist(context={"ti": ti}, task_instance=ti.task) + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: mock_supervisor_comms.send.return_value = XComResult( key="key", @@ -385,7 +388,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis ) session.add(ti) session.commit() - link.persist(context={"ti": ti}, task_instance=ti.task) + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: mock_supervisor_comms.send.return_value = XComResult( key="key", diff --git a/providers/google/tests/unit/google/cloud/links/test_managed_kafka.py b/providers/google/tests/unit/google/cloud/links/test_managed_kafka.py index 8867b8c0616d2..307be52adfa4b 100644 --- a/providers/google/tests/unit/google/cloud/links/test_managed_kafka.py +++ b/providers/google/tests/unit/google/cloud/links/test_managed_kafka.py @@ -17,8 +17,6 @@ # under the License. from __future__ import annotations -from unittest import mock - from airflow.providers.google.cloud.links.managed_kafka import ( ApacheKafkaClusterLink, ApacheKafkaClusterListLink, @@ -57,28 +55,6 @@ def test_class_attributes(self): assert ApacheKafkaClusterLink.name == EXPECTED_MANAGED_KAFKA_CLUSTER_LINK_NAME assert ApacheKafkaClusterLink.format_str == EXPECTED_MANAGED_KAFKA_CLUSTER_LINK_FORMAT_STR - def test_persist(self): - mock_context, mock_task_instance = ( - mock.MagicMock(), - mock.MagicMock(location=TEST_LOCATION, project_id=TEST_PROJECT_ID), - ) - - ApacheKafkaClusterLink.persist( - context=mock_context, - task_instance=mock_task_instance, - cluster_id=TEST_CLUSTER_ID, - ) - - mock_task_instance.xcom_push.assert_called_once_with( - context=mock_context, - key=EXPECTED_MANAGED_KAFKA_CLUSTER_LINK_KEY, - value={ - "location": TEST_LOCATION, - "cluster_id": TEST_CLUSTER_ID, - "project_id": TEST_PROJECT_ID, - }, - ) - class TestApacheKafkaClusterListLink: def test_class_attributes(self): @@ -86,22 +62,6 @@ def test_class_attributes(self): assert ApacheKafkaClusterListLink.name == EXPECTED_MANAGED_KAFKA_CLUSTER_LIST_LINK_NAME assert ApacheKafkaClusterListLink.format_str == EXPECTED_MANAGED_KAFKA_CLUSTER_LIST_LINK_FORMAT_STR - def test_persist(self): - mock_context, mock_task_instance = mock.MagicMock(), mock.MagicMock(project_id=TEST_PROJECT_ID) - - ApacheKafkaClusterListLink.persist( - context=mock_context, - task_instance=mock_task_instance, - ) - - mock_task_instance.xcom_push.assert_called_once_with( - context=mock_context, - key=EXPECTED_MANAGED_KAFKA_CLUSTER_LIST_LINK_KEY, - value={ - "project_id": TEST_PROJECT_ID, - }, - ) - class TestApacheKafkaTopicLink: def test_class_attributes(self): @@ -109,30 +69,6 @@ def test_class_attributes(self): assert ApacheKafkaTopicLink.name == EXPECTED_MANAGED_KAFKA_TOPIC_LINK_NAME assert ApacheKafkaTopicLink.format_str == EXPECTED_MANAGED_KAFKA_TOPIC_LINK_FORMAT_STR - def test_persist(self): - mock_context, mock_task_instance = ( - mock.MagicMock(), - mock.MagicMock(location=TEST_LOCATION, project_id=TEST_PROJECT_ID), - ) - - ApacheKafkaTopicLink.persist( - context=mock_context, - task_instance=mock_task_instance, - cluster_id=TEST_CLUSTER_ID, - topic_id=TEST_TOPIC_ID, - ) - - mock_task_instance.xcom_push.assert_called_once_with( - context=mock_context, - key=EXPECTED_MANAGED_KAFKA_TOPIC_LINK_KEY, - value={ - "location": TEST_LOCATION, - "cluster_id": TEST_CLUSTER_ID, - "topic_id": TEST_TOPIC_ID, - "project_id": TEST_PROJECT_ID, - }, - ) - class TestApacheKafkaConsumerGroupLink: def test_class_attributes(self): @@ -141,27 +77,3 @@ def test_class_attributes(self): assert ( ApacheKafkaConsumerGroupLink.format_str == EXPECTED_MANAGED_KAFKA_CONSUMER_GROUP_LINK_FORMAT_STR ) - - def test_persist(self): - mock_context, mock_task_instance = ( - mock.MagicMock(), - mock.MagicMock(location=TEST_LOCATION, project_id=TEST_PROJECT_ID), - ) - - ApacheKafkaConsumerGroupLink.persist( - context=mock_context, - task_instance=mock_task_instance, - cluster_id=TEST_CLUSTER_ID, - consumer_group_id=TEST_CONSUMER_GROUP_ID, - ) - - mock_task_instance.xcom_push.assert_called_once_with( - context=mock_context, - key=EXPECTED_MANAGED_KAFKA_CONSUMER_GROUP_LINK_KEY, - value={ - "location": TEST_LOCATION, - "cluster_id": TEST_CLUSTER_ID, - "consumer_group_id": TEST_CONSUMER_GROUP_ID, - "project_id": TEST_PROJECT_ID, - }, - ) diff --git a/providers/google/tests/unit/google/cloud/links/test_translate.py b/providers/google/tests/unit/google/cloud/links/test_translate.py index 640b958a3841e..2f127f9440dc9 100644 --- a/providers/google/tests/unit/google/cloud/links/test_translate.py +++ b/providers/google/tests/unit/google/cloud/links/test_translate.py @@ -17,6 +17,8 @@ # under the License. from __future__ import annotations +from unittest import mock + import pytest from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS @@ -60,11 +62,22 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis ) session.add(ti) session.commit() - link.persist(context={"ti": ti}, task_instance=ti.task, dataset_id=DATASET, project_id=GCP_PROJECT_ID) + + link.persist( + context={"ti": ti, "task": ti.task}, + dataset_id=DATASET, + project_id=GCP_PROJECT_ID, + location=GCP_LOCATION, + ) + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: mock_supervisor_comms.send.return_value = XComResult( key="key", - value={"location": ti.task.location, "dataset_id": DATASET, "project_id": GCP_PROJECT_ID}, + value={ + "location": ti.task.location, + "dataset_id": DATASET, + "project_id": GCP_PROJECT_ID, + }, ) actual_url = link.get_link(operator=ti.task, ti_key=ti.key) assert actual_url == expected_url @@ -83,7 +96,8 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis ) session.add(ti) session.commit() - link.persist(context={"ti": ti}, task_instance=ti.task, project_id=GCP_PROJECT_ID) + link.persist(context={"ti": ti, "task": ti.task}, project_id=GCP_PROJECT_ID) + if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: mock_supervisor_comms.send.return_value = XComResult( key="key", @@ -113,14 +127,20 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis ) session.add(ti) session.commit() + task = mock.MagicMock() + task.extra_links_params = { + "dataset_id": DATASET, + "model_id": MODEL, + "project_id": GCP_PROJECT_ID, + "location": GCP_LOCATION, + } link.persist( - context={"ti": ti}, - task_instance=ti.task, + context={"ti": ti, "task": task}, dataset_id=DATASET, model_id=MODEL, project_id=GCP_PROJECT_ID, ) - if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + if mock_supervisor_comms: mock_supervisor_comms.send.return_value = XComResult( key="key", value={ @@ -130,7 +150,7 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis "project_id": GCP_PROJECT_ID, }, ) - actual_url = link.get_link(operator=ti.task, ti_key=ti.key) + actual_url = link.get_link(operator=task, ti_key=ti.key) assert actual_url == expected_url @@ -152,12 +172,20 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis ) session.add(ti) session.commit() + task = mock.MagicMock() + task.extra_links_params = { + "dataset_id": DATASET, + "project_id": GCP_PROJECT_ID, + "location": GCP_LOCATION, + } link.persist( - context={"ti": ti}, - task_instance=ti.task, + context={"ti": ti, "task": task}, + dataset_id=DATASET, project_id=GCP_PROJECT_ID, + location=GCP_LOCATION, ) - if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: + + if mock_supervisor_comms: mock_supervisor_comms.send.return_value = XComResult( key="key", value={ @@ -166,5 +194,5 @@ def test_get_link(self, create_task_instance_of_operator, session, mock_supervis "project_id": GCP_PROJECT_ID, }, ) - actual_url = link.get_link(operator=ti.task, ti_key=ti.key) + actual_url = link.get_link(operator=task, ti_key=ti.key) assert actual_url == expected_url diff --git a/providers/google/tests/unit/google/cloud/links/test_vertex_ai.py b/providers/google/tests/unit/google/cloud/links/test_vertex_ai.py index 0551a1cd2a4b1..87a0becc0afa6 100644 --- a/providers/google/tests/unit/google/cloud/links/test_vertex_ai.py +++ b/providers/google/tests/unit/google/cloud/links/test_vertex_ai.py @@ -44,19 +44,18 @@ def test_class_attributes(self): assert VertexAIRayClusterLink.format_str == EXPECTED_VERTEX_AI_RAY_CLUSTER_LINK_FORMAT_STR def test_persist(self): - mock_context, mock_task_instance = ( - mock.MagicMock(), - mock.MagicMock(location=TEST_LOCATION, project_id=TEST_PROJECT_ID), - ) + mock_context = mock.MagicMock() + mock_context["ti"] = mock.MagicMock(location=TEST_LOCATION, project_id=TEST_PROJECT_ID) + mock_context["task"] = mock.MagicMock() VertexAIRayClusterLink.persist( context=mock_context, - task_instance=mock_task_instance, + location=TEST_LOCATION, cluster_id=TEST_CLUSTER_ID, + project_id=TEST_PROJECT_ID, ) - mock_task_instance.xcom_push.assert_called_once_with( - context=mock_context, + mock_context["ti"].xcom_push.assert_called_once_with( key=EXPECTED_VERTEX_AI_RAY_CLUSTER_LINK_KEY, value={ "location": TEST_LOCATION, @@ -73,15 +72,16 @@ def test_class_attributes(self): assert VertexAIRayClusterListLink.format_str == EXPECTED_VERTEX_AI_RAY_CLUSTER_LIST_LINK_FORMAT_STR def test_persist(self): - mock_context, mock_task_instance = mock.MagicMock(), mock.MagicMock(project_id=TEST_PROJECT_ID) + mock_context = mock.MagicMock() + mock_context["ti"] = mock.MagicMock(project_id=TEST_PROJECT_ID) + mock_context["task"] = mock.MagicMock() VertexAIRayClusterListLink.persist( context=mock_context, - task_instance=mock_task_instance, + project_id=TEST_PROJECT_ID, ) - mock_task_instance.xcom_push.assert_called_once_with( - context=mock_context, + mock_context["ti"].xcom_push.assert_called_once_with( key=EXPECTED_VERTEX_AI_RAY_CLUSTER_LIST_LINK_KEY, value={ "project_id": TEST_PROJECT_ID, diff --git a/providers/google/tests/unit/google/cloud/operators/test_alloy_db.py b/providers/google/tests/unit/google/cloud/operators/test_alloy_db.py index 9cc9e5e04ace7..ad21fb5addad1 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_alloy_db.py +++ b/providers/google/tests/unit/google/cloud/operators/test_alloy_db.py @@ -275,15 +275,12 @@ def test_get_cluster(self, mock_hook, mock_log, mock_to_dict): mock_to_dict.assert_called_once_with(mock_cluster) assert result == expected_result - @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBClusterLink")) @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.Cluster.to_dict")) @mock.patch(CREATE_CLUSTER_OPERATOR_PATH.format("_get_cluster")) @mock.patch(CREATE_CLUSTER_OPERATOR_PATH.format("get_operation_result")) @mock.patch(CREATE_CLUSTER_OPERATOR_PATH.format("log")) @mock.patch(ALLOY_DB_HOOK_PATH, new_callable=mock.PropertyMock) - def test_execute( - self, mock_hook, mock_log, mock_get_operation_result, mock_get_cluster, mock_to_dict, mock_link - ): + def test_execute(self, mock_hook, mock_log, mock_get_operation_result, mock_get_cluster, mock_to_dict): mock_get_cluster.return_value = None mock_create_cluster = mock_hook.return_value.create_cluster mock_create_secondary_cluster = mock_hook.return_value.create_secondary_cluster @@ -295,14 +292,6 @@ def test_execute( result = self.operator.execute(context=mock_context) - mock_link.persist.assert_called_once_with( - context=mock_context, - task_instance=self.operator, - location_id=TEST_GCP_REGION, - cluster_id=TEST_CLUSTER_ID, - project_id=TEST_GCP_PROJECT, - ) - mock_log.info.assert_called_once_with("Creating an AlloyDB cluster.") mock_get_cluster.assert_called_once() mock_create_cluster.assert_called_once_with( @@ -322,14 +311,18 @@ def test_execute( assert result == expected_result - @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBClusterLink")) @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.Cluster.to_dict")) @mock.patch(CREATE_CLUSTER_OPERATOR_PATH.format("_get_cluster")) @mock.patch(CREATE_CLUSTER_OPERATOR_PATH.format("get_operation_result")) @mock.patch(CREATE_CLUSTER_OPERATOR_PATH.format("log")) @mock.patch(ALLOY_DB_HOOK_PATH, new_callable=mock.PropertyMock) def test_execute_is_secondary( - self, mock_hook, mock_log, mock_get_operation_result, mock_get_cluster, mock_to_dict, mock_link + self, + mock_hook, + mock_log, + mock_get_operation_result, + mock_get_cluster, + mock_to_dict, ): mock_get_cluster.return_value = None mock_create_cluster = mock_hook.return_value.create_cluster @@ -343,14 +336,6 @@ def test_execute_is_secondary( result = self.operator.execute(context=mock_context) - mock_link.persist.assert_called_once_with( - context=mock_context, - task_instance=self.operator, - location_id=TEST_GCP_REGION, - cluster_id=TEST_CLUSTER_ID, - project_id=TEST_GCP_PROJECT, - ) - mock_log.info.assert_called_once_with("Creating an AlloyDB cluster.") mock_get_cluster.assert_called_once() assert not mock_create_cluster.called @@ -370,14 +355,18 @@ def test_execute_is_secondary( assert result == expected_result - @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBClusterLink")) @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.Cluster.to_dict")) @mock.patch(CREATE_CLUSTER_OPERATOR_PATH.format("_get_cluster")) @mock.patch(CREATE_CLUSTER_OPERATOR_PATH.format("get_operation_result")) @mock.patch(CREATE_CLUSTER_OPERATOR_PATH.format("log")) @mock.patch(ALLOY_DB_HOOK_PATH, new_callable=mock.PropertyMock) def test_execute_validate_request( - self, mock_hook, mock_log, mock_get_operation_result, mock_get_cluster, mock_to_dict, mock_link + self, + mock_hook, + mock_log, + mock_get_operation_result, + mock_get_cluster, + mock_to_dict, ): mock_get_cluster.return_value = None mock_create_cluster = mock_hook.return_value.create_cluster @@ -390,14 +379,6 @@ def test_execute_validate_request( result = self.operator.execute(context=mock_context) - mock_link.persist.assert_called_once_with( - context=mock_context, - task_instance=self.operator, - location_id=TEST_GCP_REGION, - cluster_id=TEST_CLUSTER_ID, - project_id=TEST_GCP_PROJECT, - ) - mock_log.info.assert_called_once_with("Validating a Create AlloyDB cluster request.") mock_get_cluster.assert_called_once() mock_create_cluster.assert_called_once_with( @@ -416,14 +397,18 @@ def test_execute_validate_request( mock_get_operation_result.assert_called_once_with(mock_operation) assert result is None - @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBClusterLink")) @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.Cluster.to_dict")) @mock.patch(CREATE_CLUSTER_OPERATOR_PATH.format("_get_cluster")) @mock.patch(CREATE_CLUSTER_OPERATOR_PATH.format("get_operation_result")) @mock.patch(CREATE_CLUSTER_OPERATOR_PATH.format("log")) @mock.patch(ALLOY_DB_HOOK_PATH, new_callable=mock.PropertyMock) def test_execute_validate_request_is_secondary( - self, mock_hook, mock_log, mock_get_operation_result, mock_get_cluster, mock_to_dict, mock_link + self, + mock_hook, + mock_log, + mock_get_operation_result, + mock_get_cluster, + mock_to_dict, ): mock_get_cluster.return_value = None mock_create_cluster = mock_hook.return_value.create_cluster @@ -437,14 +422,6 @@ def test_execute_validate_request_is_secondary( result = self.operator.execute(context=mock_context) - mock_link.persist.assert_called_once_with( - context=mock_context, - task_instance=self.operator, - location_id=TEST_GCP_REGION, - cluster_id=TEST_CLUSTER_ID, - project_id=TEST_GCP_PROJECT, - ) - mock_log.info.assert_called_once_with("Validating a Create AlloyDB cluster request.") mock_get_cluster.assert_called_once() mock_create_secondary_cluster.assert_called_once_with( @@ -463,13 +440,16 @@ def test_execute_validate_request_is_secondary( mock_get_operation_result.assert_called_once_with(mock_operation) assert result is None - @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBClusterLink")) @mock.patch(CREATE_CLUSTER_OPERATOR_PATH.format("_get_cluster")) @mock.patch(CREATE_CLUSTER_OPERATOR_PATH.format("get_operation_result")) @mock.patch(CREATE_CLUSTER_OPERATOR_PATH.format("log")) @mock.patch(ALLOY_DB_HOOK_PATH, new_callable=mock.PropertyMock) def test_execute_already_exists( - self, mock_hook, mock_log, mock_get_operation_result, mock_get_cluster, mock_link + self, + mock_hook, + mock_log, + mock_get_operation_result, + mock_get_cluster, ): expected_result = mock_get_cluster.return_value mock_create_cluster = mock_hook.return_value.create_cluster @@ -479,14 +459,6 @@ def test_execute_already_exists( result = self.operator.execute(context=mock_context) - mock_link.persist.assert_called_once_with( - context=mock_context, - task_instance=self.operator, - location_id=TEST_GCP_REGION, - cluster_id=TEST_CLUSTER_ID, - project_id=TEST_GCP_PROJECT, - ) - assert not mock_log.info.called mock_get_cluster.assert_called_once() assert not mock_create_cluster.called @@ -494,14 +466,18 @@ def test_execute_already_exists( assert not mock_get_operation_result.called assert result == expected_result - @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBClusterLink")) @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.Cluster.to_dict")) @mock.patch(CREATE_CLUSTER_OPERATOR_PATH.format("_get_cluster")) @mock.patch(CREATE_CLUSTER_OPERATOR_PATH.format("get_operation_result")) @mock.patch(CREATE_CLUSTER_OPERATOR_PATH.format("log")) @mock.patch(ALLOY_DB_HOOK_PATH, new_callable=mock.PropertyMock) def test_execute_exception( - self, mock_hook, mock_log, mock_get_operation_result, mock_get_cluster, mock_to_dict, mock_link + self, + mock_hook, + mock_log, + mock_get_operation_result, + mock_get_cluster, + mock_to_dict, ): mock_get_cluster.return_value = None mock_create_cluster = mock_hook.return_value.create_cluster @@ -512,14 +488,6 @@ def test_execute_exception( with pytest.raises(AirflowException): self.operator.execute(context=mock_context) - mock_link.persist.assert_called_once_with( - context=mock_context, - task_instance=self.operator, - location_id=TEST_GCP_REGION, - cluster_id=TEST_CLUSTER_ID, - project_id=TEST_GCP_PROJECT, - ) - mock_log.info.assert_called_once_with("Creating an AlloyDB cluster.") mock_get_cluster.assert_called_once() mock_create_cluster.assert_called_once_with( @@ -569,12 +537,11 @@ def test_template_fields(self): ) assert set(AlloyDBUpdateClusterOperator.template_fields) == expected_template_fields - @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBClusterLink")) @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.Cluster.to_dict")) @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBUpdateClusterOperator.get_operation_result")) @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBUpdateClusterOperator.log")) @mock.patch(ALLOY_DB_HOOK_PATH, new_callable=mock.PropertyMock) - def test_execute(self, mock_hook, mock_log, mock_get_operation_result, mock_to_dict, mock_link): + def test_execute(self, mock_hook, mock_log, mock_get_operation_result, mock_to_dict): mock_update_cluster = mock_hook.return_value.update_cluster mock_operation = mock_update_cluster.return_value mock_operation_result = mock_get_operation_result.return_value @@ -584,13 +551,6 @@ def test_execute(self, mock_hook, mock_log, mock_get_operation_result, mock_to_d result = self.operator.execute(context=mock_context) - mock_link.persist.assert_called_once_with( - context=mock_context, - task_instance=self.operator, - location_id=TEST_GCP_REGION, - cluster_id=TEST_CLUSTER_ID, - project_id=TEST_GCP_PROJECT, - ) mock_update_cluster.assert_called_once_with( cluster_id=TEST_CLUSTER_ID, project_id=TEST_GCP_PROJECT, @@ -614,13 +574,16 @@ def test_execute(self, mock_hook, mock_log, mock_get_operation_result, mock_to_d ] ) - @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBClusterLink")) @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.Cluster.to_dict")) @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBUpdateClusterOperator.get_operation_result")) @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBUpdateClusterOperator.log")) @mock.patch(ALLOY_DB_HOOK_PATH, new_callable=mock.PropertyMock) def test_execute_validate_request( - self, mock_hook, mock_log, mock_get_operation_result, mock_to_dict, mock_link + self, + mock_hook, + mock_log, + mock_get_operation_result, + mock_to_dict, ): mock_update_cluster = mock_hook.return_value.update_cluster mock_operation = mock_update_cluster.return_value @@ -648,21 +611,13 @@ def test_execute_validate_request( ) mock_get_operation_result.assert_called_once_with(mock_operation) assert not mock_to_dict.called - mock_link.persist.assert_called_once_with( - context=mock_context, - task_instance=self.operator, - location_id=TEST_GCP_REGION, - cluster_id=TEST_CLUSTER_ID, - project_id=TEST_GCP_PROJECT, - ) assert result is None - @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBClusterLink")) @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.Cluster.to_dict")) @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBUpdateClusterOperator.get_operation_result")) @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBUpdateClusterOperator.log")) @mock.patch(ALLOY_DB_HOOK_PATH, new_callable=mock.PropertyMock) - def test_execute_exception(self, mock_hook, mock_log, mock_get_operation_result, mock_to_dict, mock_link): + def test_execute_exception(self, mock_hook, mock_log, mock_get_operation_result, mock_to_dict): mock_update_cluster = mock_hook.return_value.update_cluster mock_update_cluster.side_effect = Exception @@ -671,13 +626,6 @@ def test_execute_exception(self, mock_hook, mock_log, mock_get_operation_result, with pytest.raises(AirflowException): self.operator.execute(context=mock_context) - mock_link.persist.assert_called_once_with( - context=mock_context, - task_instance=self.operator, - location_id=TEST_GCP_REGION, - cluster_id=TEST_CLUSTER_ID, - project_id=TEST_GCP_PROJECT, - ) mock_update_cluster.assert_called_once_with( cluster_id=TEST_CLUSTER_ID, project_id=TEST_GCP_PROJECT, @@ -912,14 +860,18 @@ def test_get_instance(self, mock_hook, mock_log, mock_to_dict): mock_to_dict.assert_called_once_with(mock_instance) assert result == expected_result - @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBClusterLink")) @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.Instance.to_dict")) @mock.patch(CREATE_INSTANCE_OPERATOR_PATH.format("_get_instance")) @mock.patch(CREATE_INSTANCE_OPERATOR_PATH.format("get_operation_result")) @mock.patch(CREATE_INSTANCE_OPERATOR_PATH.format("log")) @mock.patch(ALLOY_DB_HOOK_PATH, new_callable=mock.PropertyMock) def test_execute( - self, mock_hook, mock_log, mock_get_operation_result, mock_get_instance, mock_to_dict, mock_link + self, + mock_hook, + mock_log, + mock_get_operation_result, + mock_get_instance, + mock_to_dict, ): mock_get_instance.return_value = None mock_create_instance = mock_hook.return_value.create_instance @@ -932,14 +884,6 @@ def test_execute( result = self.operator.execute(context=mock_context) - mock_link.persist.assert_called_once_with( - context=mock_context, - task_instance=self.operator, - location_id=TEST_GCP_REGION, - cluster_id=TEST_CLUSTER_ID, - project_id=TEST_GCP_PROJECT, - ) - mock_log.info.assert_called_once_with("Creating an AlloyDB instance.") mock_get_instance.assert_called_once() mock_create_instance.assert_called_once_with( @@ -960,14 +904,18 @@ def test_execute( assert result == expected_result - @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBClusterLink")) @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.Instance.to_dict")) @mock.patch(CREATE_INSTANCE_OPERATOR_PATH.format("_get_instance")) @mock.patch(CREATE_INSTANCE_OPERATOR_PATH.format("get_operation_result")) @mock.patch(CREATE_INSTANCE_OPERATOR_PATH.format("log")) @mock.patch(ALLOY_DB_HOOK_PATH, new_callable=mock.PropertyMock) def test_execute_is_secondary( - self, mock_hook, mock_log, mock_get_operation_result, mock_get_instance, mock_to_dict, mock_link + self, + mock_hook, + mock_log, + mock_get_operation_result, + mock_get_instance, + mock_to_dict, ): mock_get_instance.return_value = None mock_create_instance = mock_hook.return_value.create_instance @@ -981,14 +929,6 @@ def test_execute_is_secondary( result = self.operator.execute(context=mock_context) - mock_link.persist.assert_called_once_with( - context=mock_context, - task_instance=self.operator, - location_id=TEST_GCP_REGION, - cluster_id=TEST_CLUSTER_ID, - project_id=TEST_GCP_PROJECT, - ) - mock_log.info.assert_called_once_with("Creating an AlloyDB instance.") mock_get_instance.assert_called_once() assert not mock_create_instance.called @@ -1009,14 +949,18 @@ def test_execute_is_secondary( assert result == expected_result - @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBClusterLink")) @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.Instance.to_dict")) @mock.patch(CREATE_INSTANCE_OPERATOR_PATH.format("_get_instance")) @mock.patch(CREATE_INSTANCE_OPERATOR_PATH.format("get_operation_result")) @mock.patch(CREATE_INSTANCE_OPERATOR_PATH.format("log")) @mock.patch(ALLOY_DB_HOOK_PATH, new_callable=mock.PropertyMock) def test_execute_validate_request( - self, mock_hook, mock_log, mock_get_operation_result, mock_get_instance, mock_to_dict, mock_link + self, + mock_hook, + mock_log, + mock_get_operation_result, + mock_get_instance, + mock_to_dict, ): mock_get_instance.return_value = None mock_create_instance = mock_hook.return_value.create_instance @@ -1029,14 +973,6 @@ def test_execute_validate_request( result = self.operator.execute(context=mock_context) - mock_link.persist.assert_called_once_with( - context=mock_context, - task_instance=self.operator, - location_id=TEST_GCP_REGION, - cluster_id=TEST_CLUSTER_ID, - project_id=TEST_GCP_PROJECT, - ) - mock_log.info.assert_called_once_with("Validating a Create AlloyDB instance request.") mock_get_instance.assert_called_once() mock_create_instance.assert_called_once_with( @@ -1056,14 +992,18 @@ def test_execute_validate_request( mock_get_operation_result.assert_called_once_with(mock_operation) assert result is None - @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBClusterLink")) @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.Instance.to_dict")) @mock.patch(CREATE_INSTANCE_OPERATOR_PATH.format("_get_instance")) @mock.patch(CREATE_INSTANCE_OPERATOR_PATH.format("get_operation_result")) @mock.patch(CREATE_INSTANCE_OPERATOR_PATH.format("log")) @mock.patch(ALLOY_DB_HOOK_PATH, new_callable=mock.PropertyMock) def test_execute_validate_request_is_secondary( - self, mock_hook, mock_log, mock_get_operation_result, mock_get_instance, mock_to_dict, mock_link + self, + mock_hook, + mock_log, + mock_get_operation_result, + mock_get_instance, + mock_to_dict, ): mock_get_instance.return_value = None mock_create_instance = mock_hook.return_value.create_instance @@ -1077,14 +1017,6 @@ def test_execute_validate_request_is_secondary( result = self.operator.execute(context=mock_context) - mock_link.persist.assert_called_once_with( - context=mock_context, - task_instance=self.operator, - location_id=TEST_GCP_REGION, - cluster_id=TEST_CLUSTER_ID, - project_id=TEST_GCP_PROJECT, - ) - mock_log.info.assert_called_once_with("Validating a Create AlloyDB instance request.") mock_get_instance.assert_called_once() mock_create_secondary_instance.assert_called_once_with( @@ -1104,13 +1036,16 @@ def test_execute_validate_request_is_secondary( mock_get_operation_result.assert_called_once_with(mock_operation) assert result is None - @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBClusterLink")) @mock.patch(CREATE_INSTANCE_OPERATOR_PATH.format("_get_instance")) @mock.patch(CREATE_INSTANCE_OPERATOR_PATH.format("get_operation_result")) @mock.patch(CREATE_INSTANCE_OPERATOR_PATH.format("log")) @mock.patch(ALLOY_DB_HOOK_PATH, new_callable=mock.PropertyMock) def test_execute_already_exists( - self, mock_hook, mock_log, mock_get_operation_result, mock_get_instance, mock_link + self, + mock_hook, + mock_log, + mock_get_operation_result, + mock_get_instance, ): expected_result = mock_get_instance.return_value mock_create_instance = mock_hook.return_value.create_instance @@ -1120,14 +1055,6 @@ def test_execute_already_exists( result = self.operator.execute(context=mock_context) - mock_link.persist.assert_called_once_with( - context=mock_context, - task_instance=self.operator, - location_id=TEST_GCP_REGION, - cluster_id=TEST_CLUSTER_ID, - project_id=TEST_GCP_PROJECT, - ) - assert not mock_log.info.called mock_get_instance.assert_called_once() assert not mock_create_instance.called @@ -1135,14 +1062,18 @@ def test_execute_already_exists( assert not mock_get_operation_result.called assert result == expected_result - @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBClusterLink")) @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.Instance.to_dict")) @mock.patch(CREATE_INSTANCE_OPERATOR_PATH.format("_get_instance")) @mock.patch(CREATE_INSTANCE_OPERATOR_PATH.format("get_operation_result")) @mock.patch(CREATE_INSTANCE_OPERATOR_PATH.format("log")) @mock.patch(ALLOY_DB_HOOK_PATH, new_callable=mock.PropertyMock) def test_execute_exception( - self, mock_hook, mock_log, mock_get_operation_result, mock_get_instance, mock_to_dict, mock_link + self, + mock_hook, + mock_log, + mock_get_operation_result, + mock_get_instance, + mock_to_dict, ): mock_get_instance.return_value = None mock_create_instance = mock_hook.return_value.create_instance @@ -1153,14 +1084,6 @@ def test_execute_exception( with pytest.raises(AirflowException): self.operator.execute(context=mock_context) - mock_link.persist.assert_called_once_with( - context=mock_context, - task_instance=self.operator, - location_id=TEST_GCP_REGION, - cluster_id=TEST_CLUSTER_ID, - project_id=TEST_GCP_PROJECT, - ) - mock_log.info.assert_called_once_with("Creating an AlloyDB instance.") mock_get_instance.assert_called_once() mock_create_instance.assert_called_once_with( @@ -1217,12 +1140,11 @@ def test_template_fields(self): } | set(AlloyDBWriteBaseOperator.template_fields) assert set(AlloyDBUpdateInstanceOperator.template_fields) == expected_template_fields - @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBClusterLink")) @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.Instance.to_dict")) @mock.patch(UPDATE_INSTANCE_OPERATOR_PATH.format("get_operation_result")) @mock.patch(UPDATE_INSTANCE_OPERATOR_PATH.format("log")) @mock.patch(ALLOY_DB_HOOK_PATH, new_callable=mock.PropertyMock) - def test_execute(self, mock_hook, mock_log, mock_get_operation_result, mock_to_dict, mock_link): + def test_execute(self, mock_hook, mock_log, mock_get_operation_result, mock_to_dict): mock_update_instance = mock_hook.return_value.update_instance mock_operation = mock_update_instance.return_value mock_operation_result = mock_get_operation_result.return_value @@ -1232,13 +1154,6 @@ def test_execute(self, mock_hook, mock_log, mock_get_operation_result, mock_to_d result = self.operator.execute(context=mock_context) - mock_link.persist.assert_called_once_with( - context=mock_context, - task_instance=self.operator, - location_id=TEST_GCP_REGION, - cluster_id=TEST_CLUSTER_ID, - project_id=TEST_GCP_PROJECT, - ) mock_update_instance.assert_called_once_with( cluster_id=TEST_CLUSTER_ID, instance_id=TEST_INSTANCE_ID, @@ -1263,13 +1178,16 @@ def test_execute(self, mock_hook, mock_log, mock_get_operation_result, mock_to_d ] ) - @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBClusterLink")) @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.Cluster.to_dict")) @mock.patch(UPDATE_INSTANCE_OPERATOR_PATH.format("get_operation_result")) @mock.patch(UPDATE_INSTANCE_OPERATOR_PATH.format("log")) @mock.patch(ALLOY_DB_HOOK_PATH, new_callable=mock.PropertyMock) def test_execute_validate_request( - self, mock_hook, mock_log, mock_get_operation_result, mock_to_dict, mock_link + self, + mock_hook, + mock_log, + mock_get_operation_result, + mock_to_dict, ): mock_update_instance = mock_hook.return_value.update_instance mock_operation = mock_update_instance.return_value @@ -1298,21 +1216,13 @@ def test_execute_validate_request( ) mock_get_operation_result.assert_called_once_with(mock_operation) assert not mock_to_dict.called - mock_link.persist.assert_called_once_with( - context=mock_context, - task_instance=self.operator, - location_id=TEST_GCP_REGION, - cluster_id=TEST_CLUSTER_ID, - project_id=TEST_GCP_PROJECT, - ) assert result is None - @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBClusterLink")) @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.Cluster.to_dict")) @mock.patch(UPDATE_INSTANCE_OPERATOR_PATH.format("get_operation_result")) @mock.patch(UPDATE_INSTANCE_OPERATOR_PATH.format("log")) @mock.patch(ALLOY_DB_HOOK_PATH, new_callable=mock.PropertyMock) - def test_execute_exception(self, mock_hook, mock_log, mock_get_operation_result, mock_to_dict, mock_link): + def test_execute_exception(self, mock_hook, mock_log, mock_get_operation_result, mock_to_dict): mock_update_instance = mock_hook.return_value.update_instance mock_update_instance.side_effect = Exception @@ -1321,13 +1231,6 @@ def test_execute_exception(self, mock_hook, mock_log, mock_get_operation_result, with pytest.raises(AirflowException): self.operator.execute(context=mock_context) - mock_link.persist.assert_called_once_with( - context=mock_context, - task_instance=self.operator, - location_id=TEST_GCP_REGION, - cluster_id=TEST_CLUSTER_ID, - project_id=TEST_GCP_PROJECT, - ) mock_update_instance.assert_called_once_with( cluster_id=TEST_CLUSTER_ID, instance_id=TEST_INSTANCE_ID, @@ -1556,12 +1459,11 @@ def test_get_user(self, mock_hook, mock_log, mock_to_dict): mock_to_dict.assert_called_once_with(mock_user) assert result == expected_result - @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBUsersLink")) @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.User.to_dict")) @mock.patch(CREATE_USER_OPERATOR_PATH.format("_get_user")) @mock.patch(CREATE_USER_OPERATOR_PATH.format("log")) @mock.patch(ALLOY_DB_HOOK_PATH, new_callable=mock.PropertyMock) - def test_execute(self, mock_hook, mock_log, mock_get_user, mock_to_dict, mock_link): + def test_execute(self, mock_hook, mock_log, mock_get_user, mock_to_dict): mock_get_user.return_value = None mock_create_user = mock_hook.return_value.create_user mock_user = mock_create_user.return_value @@ -1571,13 +1473,6 @@ def test_execute(self, mock_hook, mock_log, mock_get_user, mock_to_dict, mock_li result = self.operator.execute(context=mock_context) - mock_link.persist.assert_called_once_with( - context=mock_context, - task_instance=self.operator, - location_id=TEST_GCP_REGION, - cluster_id=TEST_CLUSTER_ID, - project_id=TEST_GCP_PROJECT, - ) mock_log.info.assert_has_calls( [ call("Creating an AlloyDB user."), @@ -1600,12 +1495,11 @@ def test_execute(self, mock_hook, mock_log, mock_get_user, mock_to_dict, mock_li mock_to_dict.assert_called_once_with(mock_user) assert result == expected_result - @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBUsersLink")) @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.User.to_dict")) @mock.patch(CREATE_USER_OPERATOR_PATH.format("_get_user")) @mock.patch(CREATE_USER_OPERATOR_PATH.format("log")) @mock.patch(ALLOY_DB_HOOK_PATH, new_callable=mock.PropertyMock) - def test_execute_validate_request(self, mock_hook, mock_log, mock_get_user, mock_to_dict, mock_link): + def test_execute_validate_request(self, mock_hook, mock_log, mock_get_user, mock_to_dict): mock_get_user.return_value = None mock_create_user = mock_hook.return_value.create_user @@ -1614,14 +1508,6 @@ def test_execute_validate_request(self, mock_hook, mock_log, mock_get_user, mock result = self.operator.execute(context=mock_context) - mock_link.persist.assert_called_once_with( - context=mock_context, - task_instance=self.operator, - location_id=TEST_GCP_REGION, - cluster_id=TEST_CLUSTER_ID, - project_id=TEST_GCP_PROJECT, - ) - mock_log.info.assert_called_once_with("Validating a Create AlloyDB user request.") mock_get_user.assert_called_once() mock_create_user.assert_called_once_with( @@ -1639,36 +1525,26 @@ def test_execute_validate_request(self, mock_hook, mock_log, mock_get_user, mock assert not mock_to_dict.called assert result is None - @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBUsersLink")) @mock.patch(CREATE_USER_OPERATOR_PATH.format("_get_user")) @mock.patch(CREATE_USER_OPERATOR_PATH.format("log")) @mock.patch(ALLOY_DB_HOOK_PATH, new_callable=mock.PropertyMock) - def test_execute_already_exists(self, mock_hook, mock_log, mock_get_user, mock_link): + def test_execute_already_exists(self, mock_hook, mock_log, mock_get_user): expected_result = mock_get_user.return_value mock_create_user = mock_hook.return_value.create_user mock_context = mock.MagicMock() result = self.operator.execute(context=mock_context) - mock_link.persist.assert_called_once_with( - context=mock_context, - task_instance=self.operator, - location_id=TEST_GCP_REGION, - cluster_id=TEST_CLUSTER_ID, - project_id=TEST_GCP_PROJECT, - ) - assert not mock_log.info.called mock_get_user.assert_called_once() assert not mock_create_user.called assert result == expected_result - @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBUsersLink")) @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.User.to_dict")) @mock.patch(CREATE_USER_OPERATOR_PATH.format("_get_user")) @mock.patch(CREATE_USER_OPERATOR_PATH.format("log")) @mock.patch(ALLOY_DB_HOOK_PATH, new_callable=mock.PropertyMock) - def test_execute_exception(self, mock_hook, mock_log, mock_get_user, mock_to_dict, mock_link): + def test_execute_exception(self, mock_hook, mock_log, mock_get_user, mock_to_dict): mock_get_user.return_value = None mock_create_user = mock_hook.return_value.create_user mock_create_user.side_effect = Exception() @@ -1677,14 +1553,6 @@ def test_execute_exception(self, mock_hook, mock_log, mock_get_user, mock_to_dic with pytest.raises(AirflowException): self.operator.execute(context=mock_context) - mock_link.persist.assert_called_once_with( - context=mock_context, - task_instance=self.operator, - location_id=TEST_GCP_REGION, - cluster_id=TEST_CLUSTER_ID, - project_id=TEST_GCP_PROJECT, - ) - mock_log.info.assert_called_once_with("Creating an AlloyDB user.") mock_get_user.assert_called_once() mock_create_user.assert_called_once_with( @@ -1739,11 +1607,10 @@ def test_template_fields(self): } | set(AlloyDBWriteBaseOperator.template_fields) assert set(AlloyDBUpdateUserOperator.template_fields) == expected_template_fields - @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBUsersLink")) @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.User.to_dict")) @mock.patch(UPDATE_USER_OPERATOR_PATH.format("log")) @mock.patch(ALLOY_DB_HOOK_PATH, new_callable=mock.PropertyMock) - def test_execute(self, mock_hook, mock_log, mock_to_dict, mock_link): + def test_execute(self, mock_hook, mock_log, mock_to_dict): mock_update_user = mock_hook.return_value.update_user mock_user = mock_update_user.return_value expected_result = mock_to_dict.return_value @@ -1751,13 +1618,6 @@ def test_execute(self, mock_hook, mock_log, mock_to_dict, mock_link): result = self.operator.execute(context=mock_context) - mock_link.persist.assert_called_once_with( - context=mock_context, - task_instance=self.operator, - location_id=TEST_GCP_REGION, - cluster_id=TEST_CLUSTER_ID, - project_id=TEST_GCP_PROJECT, - ) mock_update_user.assert_called_once_with( cluster_id=TEST_CLUSTER_ID, user_id=TEST_USER_ID, @@ -1781,11 +1641,10 @@ def test_execute(self, mock_hook, mock_log, mock_to_dict, mock_link): ] ) - @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBUsersLink")) @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.User.to_dict")) @mock.patch(UPDATE_USER_OPERATOR_PATH.format("log")) @mock.patch(ALLOY_DB_HOOK_PATH, new_callable=mock.PropertyMock) - def test_execute_validate_request(self, mock_hook, mock_log, mock_to_dict, mock_link): + def test_execute_validate_request(self, mock_hook, mock_log, mock_to_dict): mock_update_user = mock_hook.return_value.update_user expected_message = "Validating an Update AlloyDB user request." @@ -1810,20 +1669,12 @@ def test_execute_validate_request(self, mock_hook, mock_log, mock_to_dict, mock_ metadata=TEST_METADATA, ) assert not mock_to_dict.called - mock_link.persist.assert_called_once_with( - context=mock_context, - task_instance=self.operator, - location_id=TEST_GCP_REGION, - cluster_id=TEST_CLUSTER_ID, - project_id=TEST_GCP_PROJECT, - ) assert result is None - @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBUsersLink")) @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.User.to_dict")) @mock.patch(UPDATE_USER_OPERATOR_PATH.format("log")) @mock.patch(ALLOY_DB_HOOK_PATH, new_callable=mock.PropertyMock) - def test_execute_exception(self, mock_hook, mock_log, mock_to_dict, mock_link): + def test_execute_exception(self, mock_hook, mock_log, mock_to_dict): mock_update_user = mock_hook.return_value.update_user mock_update_user.side_effect = Exception @@ -1832,13 +1683,6 @@ def test_execute_exception(self, mock_hook, mock_log, mock_to_dict, mock_link): with pytest.raises(AirflowException): self.operator.execute(context=mock_context) - mock_link.persist.assert_called_once_with( - context=mock_context, - task_instance=self.operator, - location_id=TEST_GCP_REGION, - cluster_id=TEST_CLUSTER_ID, - project_id=TEST_GCP_PROJECT, - ) mock_update_user.assert_called_once_with( cluster_id=TEST_CLUSTER_ID, user_id=TEST_USER_ID, @@ -2045,14 +1889,18 @@ def test_get_backup(self, mock_hook, mock_log, mock_to_dict): mock_to_dict.assert_called_once_with(mock_instance) assert result == expected_result - @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBBackupsLink")) @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.Backup.to_dict")) @mock.patch(CREATE_BACKUP_OPERATOR_PATH.format("_get_backup")) @mock.patch(CREATE_BACKUP_OPERATOR_PATH.format("get_operation_result")) @mock.patch(CREATE_BACKUP_OPERATOR_PATH.format("log")) @mock.patch(ALLOY_DB_HOOK_PATH, new_callable=mock.PropertyMock) def test_execute( - self, mock_hook, mock_log, mock_get_operation_result, mock_get_backup, mock_to_dict, mock_link + self, + mock_hook, + mock_log, + mock_get_operation_result, + mock_get_backup, + mock_to_dict, ): mock_get_backup.return_value = None mock_create_backup = mock_hook.return_value.create_backup @@ -2064,12 +1912,6 @@ def test_execute( result = self.operator.execute(context=mock_context) - mock_link.persist.assert_called_once_with( - context=mock_context, - task_instance=self.operator, - project_id=TEST_GCP_PROJECT, - ) - mock_log.info.assert_called_once_with("Creating an AlloyDB backup.") mock_get_backup.assert_called_once() mock_create_backup.assert_called_once_with( @@ -2088,14 +1930,18 @@ def test_execute( assert result == expected_result - @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBBackupsLink")) @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.Backup.to_dict")) @mock.patch(CREATE_BACKUP_OPERATOR_PATH.format("_get_backup")) @mock.patch(CREATE_BACKUP_OPERATOR_PATH.format("get_operation_result")) @mock.patch(CREATE_BACKUP_OPERATOR_PATH.format("log")) @mock.patch(ALLOY_DB_HOOK_PATH, new_callable=mock.PropertyMock) def test_execute_validate_request( - self, mock_hook, mock_log, mock_get_operation_result, mock_get_backup, mock_to_dict, mock_link + self, + mock_hook, + mock_log, + mock_get_operation_result, + mock_get_backup, + mock_to_dict, ): mock_get_backup.return_value = None mock_create_backup = mock_hook.return_value.create_backup @@ -2107,12 +1953,6 @@ def test_execute_validate_request( result = self.operator.execute(context=mock_context) - mock_link.persist.assert_called_once_with( - context=mock_context, - task_instance=self.operator, - project_id=TEST_GCP_PROJECT, - ) - mock_log.info.assert_called_once_with("Validating a Create AlloyDB backup request.") mock_get_backup.assert_called_once() mock_create_backup.assert_called_once_with( @@ -2130,13 +1970,16 @@ def test_execute_validate_request( mock_get_operation_result.assert_called_once_with(mock_operation) assert result is None - @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBBackupsLink")) @mock.patch(CREATE_BACKUP_OPERATOR_PATH.format("_get_backup")) @mock.patch(CREATE_BACKUP_OPERATOR_PATH.format("get_operation_result")) @mock.patch(CREATE_BACKUP_OPERATOR_PATH.format("log")) @mock.patch(ALLOY_DB_HOOK_PATH, new_callable=mock.PropertyMock) def test_execute_already_exists( - self, mock_hook, mock_log, mock_get_operation_result, mock_get_backup, mock_link + self, + mock_hook, + mock_log, + mock_get_operation_result, + mock_get_backup, ): expected_result = mock_get_backup.return_value mock_create_instance = mock_hook.return_value.create_instance @@ -2145,26 +1988,24 @@ def test_execute_already_exists( result = self.operator.execute(context=mock_context) - mock_link.persist.assert_called_once_with( - context=mock_context, - task_instance=self.operator, - project_id=TEST_GCP_PROJECT, - ) - assert not mock_log.info.called mock_get_backup.assert_called_once() assert not mock_create_instance.called assert not mock_get_operation_result.called assert result == expected_result - @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBBackupsLink")) @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.Backup.to_dict")) @mock.patch(CREATE_BACKUP_OPERATOR_PATH.format("_get_backup")) @mock.patch(CREATE_BACKUP_OPERATOR_PATH.format("get_operation_result")) @mock.patch(CREATE_BACKUP_OPERATOR_PATH.format("log")) @mock.patch(ALLOY_DB_HOOK_PATH, new_callable=mock.PropertyMock) def test_execute_exception( - self, mock_hook, mock_log, mock_get_operation_result, mock_get_backup, mock_to_dict, mock_link + self, + mock_hook, + mock_log, + mock_get_operation_result, + mock_get_backup, + mock_to_dict, ): mock_get_backup.return_value = None mock_create_backup = mock_hook.return_value.create_backup @@ -2174,12 +2015,6 @@ def test_execute_exception( with pytest.raises(AirflowException): self.operator.execute(context=mock_context) - mock_link.persist.assert_called_once_with( - context=mock_context, - task_instance=self.operator, - project_id=TEST_GCP_PROJECT, - ) - mock_log.info.assert_called_once_with("Creating an AlloyDB backup.") mock_get_backup.assert_called_once() mock_create_backup.assert_called_once_with( @@ -2231,12 +2066,11 @@ def test_template_fields(self): } | set(AlloyDBWriteBaseOperator.template_fields) assert set(AlloyDBUpdateBackupOperator.template_fields) == expected_template_fields - @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBBackupsLink")) @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.Backup.to_dict")) @mock.patch(UPDATE_BACKUP_OPERATOR_PATH.format("get_operation_result")) @mock.patch(UPDATE_BACKUP_OPERATOR_PATH.format("log")) @mock.patch(ALLOY_DB_HOOK_PATH, new_callable=mock.PropertyMock) - def test_execute(self, mock_hook, mock_log, mock_get_operation_result, mock_to_dict, mock_link): + def test_execute(self, mock_hook, mock_log, mock_get_operation_result, mock_to_dict): mock_update_backup = mock_hook.return_value.update_backup mock_operation = mock_update_backup.return_value mock_operation_result = mock_get_operation_result.return_value @@ -2246,11 +2080,6 @@ def test_execute(self, mock_hook, mock_log, mock_get_operation_result, mock_to_d result = self.operator.execute(context=mock_context) - mock_link.persist.assert_called_once_with( - context=mock_context, - task_instance=self.operator, - project_id=TEST_GCP_PROJECT, - ) mock_update_backup.assert_called_once_with( backup_id=TEST_BACKUP_ID, project_id=TEST_GCP_PROJECT, @@ -2274,14 +2103,11 @@ def test_execute(self, mock_hook, mock_log, mock_get_operation_result, mock_to_d ] ) - @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBBackupsLink")) @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.Backup.to_dict")) @mock.patch(UPDATE_BACKUP_OPERATOR_PATH.format("get_operation_result")) @mock.patch(UPDATE_BACKUP_OPERATOR_PATH.format("log")) @mock.patch(ALLOY_DB_HOOK_PATH, new_callable=mock.PropertyMock) - def test_execute_validate_request( - self, mock_hook, mock_log, mock_get_operation_result, mock_to_dict, mock_link - ): + def test_execute_validate_request(self, mock_hook, mock_log, mock_get_operation_result, mock_to_dict): mock_update_ackup = mock_hook.return_value.update_backup mock_operation = mock_update_ackup.return_value mock_get_operation_result.return_value = None @@ -2308,19 +2134,13 @@ def test_execute_validate_request( ) mock_get_operation_result.assert_called_once_with(mock_operation) assert not mock_to_dict.called - mock_link.persist.assert_called_once_with( - context=mock_context, - task_instance=self.operator, - project_id=TEST_GCP_PROJECT, - ) assert result is None - @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBBackupsLink")) @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.Backup.to_dict")) @mock.patch(UPDATE_BACKUP_OPERATOR_PATH.format("get_operation_result")) @mock.patch(UPDATE_BACKUP_OPERATOR_PATH.format("log")) @mock.patch(ALLOY_DB_HOOK_PATH, new_callable=mock.PropertyMock) - def test_execute_exception(self, mock_hook, mock_log, mock_get_operation_result, mock_to_dict, mock_link): + def test_execute_exception(self, mock_hook, mock_log, mock_get_operation_result, mock_to_dict): mock_update_backup = mock_hook.return_value.update_backup mock_update_backup.side_effect = Exception @@ -2329,11 +2149,6 @@ def test_execute_exception(self, mock_hook, mock_log, mock_get_operation_result, with pytest.raises(AirflowException): self.operator.execute(context=mock_context) - mock_link.persist.assert_called_once_with( - context=mock_context, - task_instance=self.operator, - project_id=TEST_GCP_PROJECT, - ) mock_update_backup.assert_called_once_with( backup_id=TEST_BACKUP_ID, backup=TEST_BACKUP, diff --git a/providers/google/tests/unit/google/cloud/operators/test_automl.py b/providers/google/tests/unit/google/cloud/operators/test_automl.py index 7ae70c83c9ed3..34c0397a43906 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_automl.py +++ b/providers/google/tests/unit/google/cloud/operators/test_automl.py @@ -91,8 +91,8 @@ def test_execute(self, mock_hook): project_id=GCP_PROJECT_ID, task_id=TASK_ID, ) + op.execute(context=mock.MagicMock()) - op.execute(context=mock.MagicMock()) mock_hook.return_value.create_model.assert_called_once_with( model=MODEL, location=GCP_LOCATION, @@ -153,9 +153,9 @@ def test_execute(self, mock_hook, mock_link_persist): ) mock_link_persist.assert_called_once_with( context=mock_context, - task_instance=op, model_id=MODEL_ID, project_id=GCP_PROJECT_ID, + location=GCP_LOCATION, dataset_id=DATASET_ID, ) @@ -218,7 +218,8 @@ def test_execute(self, mock_hook): project_id=GCP_PROJECT_ID, task_id=TASK_ID, ) - op.execute(context=mock.MagicMock()) + op.execute(context=mock.MagicMock()) + mock_hook.return_value.create_dataset.assert_called_once_with( dataset=DATASET, location=GCP_LOCATION, @@ -335,7 +336,9 @@ def test_execute(self, mock_hook): project_id=GCP_PROJECT_ID, task_id=TASK_ID, ) + op.execute(context=mock.MagicMock()) + mock_hook.return_value.get_model.assert_called_once_with( location=GCP_LOCATION, metadata=(), @@ -456,7 +459,9 @@ def test_execute(self, mock_hook): input_config=INPUT_CONFIG, task_id=TASK_ID, ) + op.execute(context=mock.MagicMock()) + mock_hook.return_value.import_data.assert_called_once_with( input_config=INPUT_CONFIG, location=GCP_LOCATION, @@ -532,7 +537,9 @@ class TestAutoMLDatasetListOperator: def test_execute(self, mock_hook): with pytest.warns(AirflowProviderDeprecationWarning): op = AutoMLListDatasetOperator(location=GCP_LOCATION, project_id=GCP_PROJECT_ID, task_id=TASK_ID) + op.execute(context=mock.MagicMock()) + mock_hook.return_value.list_datasets.assert_called_once_with( location=GCP_LOCATION, metadata=(), diff --git a/providers/google/tests/unit/google/cloud/operators/test_bigquery.py b/providers/google/tests/unit/google/cloud/operators/test_bigquery.py index 9fbf7536837b4..6d1007fb79115 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_bigquery.py +++ b/providers/google/tests/unit/google/cloud/operators/test_bigquery.py @@ -141,7 +141,9 @@ def test_execute(self, mock_hook): table_id=TEST_TABLE_ID, table_resource={}, ) + operator.execute(context=MagicMock()) + mock_hook.return_value.create_table.assert_called_once_with( dataset_id=TEST_DATASET, project_id=TEST_GCP_PROJECT_ID, @@ -171,16 +173,6 @@ def test_create_view(self, mock_hook): table_resource=body, ) operator.execute(context=MagicMock()) - mock_hook.return_value.create_table.assert_called_once_with( - dataset_id=TEST_DATASET, - project_id=TEST_GCP_PROJECT_ID, - table_id=TEST_TABLE_ID, - schema_fields=None, - table_resource=body, - exists_ok=False, - location=None, - timeout=None, - ) @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") def test_create_materialized_view(self, mock_hook): @@ -201,6 +193,7 @@ def test_create_materialized_view(self, mock_hook): ) operator.execute(context=MagicMock()) + mock_hook.return_value.create_table.assert_called_once_with( dataset_id=TEST_DATASET, project_id=TEST_GCP_PROJECT_ID, @@ -240,6 +233,7 @@ def test_create_clustered_table(self, mock_hook): ) operator.execute(context=MagicMock()) + mock_hook.return_value.create_table.assert_called_once_with( dataset_id=TEST_DATASET, project_id=TEST_GCP_PROJECT_ID, @@ -285,13 +279,13 @@ def test_create_existing_table(self, mock_hook, caplog, if_exists, is_conflict, mock_hook.return_value.create_table.side_effect = Conflict("any") else: mock_hook.return_value.create_table.side_effect = None - if expected_error is not None: - with pytest.raises(expected_error): + if expected_error is not None: + with pytest.raises(expected_error): + operator.execute(context=MagicMock()) + else: operator.execute(context=MagicMock()) - else: - operator.execute(context=MagicMock()) - if log_msg is not None: - assert log_msg in caplog.text + if log_msg is not None: + assert log_msg in caplog.text @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") def test_get_openlineage_facets_on_complete(self, mock_hook): @@ -316,6 +310,7 @@ def test_get_openlineage_facets_on_complete(self, mock_hook): table_id=TEST_TABLE_ID, table_resource=table_resource, ) + operator.execute(context=MagicMock()) mock_hook.return_value.create_table.assert_called_once_with( @@ -367,7 +362,9 @@ def test_execute(self, mock_hook): project_id=TEST_GCP_PROJECT_ID, table_id=TEST_TABLE_ID, ) + operator.execute(context=MagicMock()) + mock_hook.return_value.create_empty_table.assert_called_once_with( dataset_id=TEST_DATASET, project_id=TEST_GCP_PROJECT_ID, @@ -393,7 +390,9 @@ def test_create_view(self, mock_hook): table_id=TEST_TABLE_ID, view=VIEW_DEFINITION, ) + operator.execute(context=MagicMock()) + mock_hook.return_value.create_empty_table.assert_called_once_with( dataset_id=TEST_DATASET, project_id=TEST_GCP_PROJECT_ID, @@ -421,6 +420,7 @@ def test_create_materialized_view(self, mock_hook): ) operator.execute(context=MagicMock()) + mock_hook.return_value.create_empty_table.assert_called_once_with( dataset_id=TEST_DATASET, project_id=TEST_GCP_PROJECT_ID, diff --git a/providers/google/tests/unit/google/cloud/operators/test_bigquery_dts.py b/providers/google/tests/unit/google/cloud/operators/test_bigquery_dts.py index f9188dadd7acb..a4df65dc3c416 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_bigquery_dts.py +++ b/providers/google/tests/unit/google/cloud/operators/test_bigquery_dts.py @@ -63,7 +63,7 @@ def test_execute(self, mock_hook): ) ti = mock.MagicMock() - return_value = op.execute({"ti": ti}) + return_value = op.execute({"ti": ti, "task": ti.task}) mock_hook.return_value.create_transfer_config.assert_called_once_with( authorization_code=None, @@ -120,7 +120,7 @@ def test_execute(self, mock_hook): ) ti = mock.MagicMock() - op.execute({"ti": ti}) + op.execute({"ti": ti, "task": ti.task}) mock_hook.return_value.start_manual_transfer_runs.assert_called_once_with( transfer_config_id=TRANSFER_CONFIG_ID, @@ -154,7 +154,7 @@ def test_defer_mode(self, _, defer_method): ) ti = mock.MagicMock() - op.execute({"ti": ti}) + op.execute({"ti": ti, "task": ti.task}) defer_method.assert_called_once() @@ -213,7 +213,7 @@ def test_get_openlineage_facets_on_complete_with_blob_storage_sources( op = BigQueryDataTransferServiceStartTransferRunsOperator( transfer_config_id=TRANSFER_CONFIG_ID, task_id="id", project_id=PROJECT_ID ) - op.execute({"ti": mock.MagicMock()}) + op.execute({"ti": mock.MagicMock(), "task": mock.MagicMock()}) result = op.get_openlineage_facets_on_complete(None) assert not result.run_facets assert not result.job_facets @@ -289,7 +289,7 @@ def test_get_openlineage_facets_on_complete_with_sql_sources( op = BigQueryDataTransferServiceStartTransferRunsOperator( transfer_config_id=TRANSFER_CONFIG_ID, task_id="id", project_id=PROJECT_ID ) - op.execute({"ti": mock.MagicMock()}) + op.execute({"ti": mock.MagicMock(), "task": mock.MagicMock()}) result = op.get_openlineage_facets_on_complete(None) assert not result.run_facets assert not result.job_facets @@ -323,7 +323,7 @@ def test_get_openlineage_facets_on_complete_with_scheduled_query(self, mock_hook op = BigQueryDataTransferServiceStartTransferRunsOperator( transfer_config_id=TRANSFER_CONFIG_ID, task_id="id", project_id=PROJECT_ID ) - op.execute({"ti": mock.MagicMock()}) + op.execute({"ti": mock.MagicMock(), "task": mock.MagicMock()}) result = op.get_openlineage_facets_on_complete(None) assert len(result.job_facets) == 1 assert result.job_facets["sql"].query == "SELECT a,b,c from x.y.z" @@ -361,7 +361,7 @@ def test_get_openlineage_facets_on_complete_with_error(self, mock_hook, mock_wai op = BigQueryDataTransferServiceStartTransferRunsOperator( transfer_config_id=TRANSFER_CONFIG_ID, task_id="id", project_id=PROJECT_ID ) - op.execute({"ti": mock.MagicMock()}) + op.execute({"ti": mock.MagicMock(), "task": mock.MagicMock()}) result = op.get_openlineage_facets_on_complete(None) assert not result.job_facets assert len(result.run_facets) == 1 @@ -398,7 +398,7 @@ def test_get_openlineage_facets_on_complete_deferred(self, mock_defer, mock_hook op = BigQueryDataTransferServiceStartTransferRunsOperator( transfer_config_id=TRANSFER_CONFIG_ID, task_id="id", project_id=PROJECT_ID, deferrable=True ) - op.execute({"ti": mock.MagicMock()}) + op.execute({"ti": mock.MagicMock(), "task": mock.MagicMock()}) # `defer` is mocked so it will not call the `execute_completed`, so we do it manually. op.execute_completed( mock.MagicMock(), {"status": "done", "run_id": 123, "config_id": 321, "message": "msg"} diff --git a/providers/google/tests/unit/google/cloud/operators/test_bigtable.py b/providers/google/tests/unit/google/cloud/operators/test_bigtable.py index 64a8c74c4cba0..7bb54da1f60b5 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_bigtable.py +++ b/providers/google/tests/unit/google/cloud/operators/test_bigtable.py @@ -100,7 +100,7 @@ def test_create_instance_that_exists(self, mock_hook): gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, @@ -120,7 +120,7 @@ def test_create_instance_that_exists_empty_project_id(self, mock_hook): gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, @@ -178,7 +178,7 @@ def test_create_instance_that_doesnt_exists(self, mock_hook): gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, @@ -210,7 +210,7 @@ def test_create_instance_with_replicas_that_doesnt_exists(self, mock_hook): gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, @@ -243,7 +243,7 @@ def test_delete_execute(self, mock_hook): gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, @@ -268,7 +268,7 @@ def test_update_execute_empty_project_id(self, mock_hook): gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, @@ -797,7 +797,7 @@ def test_create_execute(self, mock_hook): impersonation_chain=IMPERSONATION_CHAIN, ) instance = mock_hook.return_value.get_instance.return_value = mock.Mock(Instance) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, diff --git a/providers/google/tests/unit/google/cloud/operators/test_cloud_build.py b/providers/google/tests/unit/google/cloud/operators/test_cloud_build.py index 36631f923a8b2..b7ca33cae727b 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_cloud_build.py +++ b/providers/google/tests/unit/google/cloud/operators/test_cloud_build.py @@ -472,7 +472,7 @@ def test_async_create_build_correct_logging_should_execute_successfully( with mock.patch.object(ti.task.log, "info") as mock_log_info: ti.task.execute_complete( - context={"ti": ti}, + context={"ti": ti, "task": ti.task}, event={ "instance": TEST_BUILD_INSTANCE, "status": "success", diff --git a/providers/google/tests/unit/google/cloud/operators/test_datacatalog.py b/providers/google/tests/unit/google/cloud/operators/test_datacatalog.py index 07d9a610ca40f..0026dd1f50111 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_datacatalog.py +++ b/providers/google/tests/unit/google/cloud/operators/test_datacatalog.py @@ -179,13 +179,8 @@ def test_assert_valid_hook_call(self, mock_xcom, mock_hook) -> None: ) mock_xcom.assert_called_with( context, - key="data_catalog_entry", - value={ - "entry_id": TEST_ENTRY_ID, - "entry_group_id": TEST_ENTRY_GROUP_ID, - "location_id": TEST_LOCATION, - "project_id": TEST_PROJECT_ID, - }, + key="entry_id", + value=TEST_ENTRY_ID, ) assert result == TEST_ENTRY_DICT @@ -236,13 +231,8 @@ def test_assert_valid_hook_call_when_exists(self, mock_xcom, mock_hook) -> None: ) mock_xcom.assert_called_with( context, - key="data_catalog_entry", - value={ - "entry_id": TEST_ENTRY_ID, - "entry_group_id": TEST_ENTRY_GROUP_ID, - "location_id": TEST_LOCATION, - "project_id": TEST_PROJECT_ID, - }, + key="entry_id", + value=TEST_ENTRY_ID, ) assert result == TEST_ENTRY_DICT @@ -284,12 +274,8 @@ def test_assert_valid_hook_call(self, mock_xcom, mock_hook) -> None: ) mock_xcom.assert_called_with( context, - key="data_catalog_entry_group", - value={ - "entry_group_id": TEST_ENTRY_GROUP_ID, - "location_id": TEST_LOCATION, - "project_id": TEST_PROJECT_ID, - }, + key="entry_group_id", + value=TEST_ENTRY_GROUP_ID, ) assert result == TEST_ENTRY_GROUP_DICT @@ -335,13 +321,8 @@ def test_assert_valid_hook_call(self, mock_xcom, mock_hook) -> None: ) mock_xcom.assert_called_with( context, - key="data_catalog_entry", - value={ - "entry_id": TEST_ENTRY_ID, - "entry_group_id": TEST_ENTRY_GROUP_ID, - "location_id": TEST_LOCATION, - "project_id": TEST_PROJECT_ID, - }, + key="tag_id", + value=TEST_TAG_ID, ) assert result == TEST_TAG_DICT @@ -383,12 +364,8 @@ def test_assert_valid_hook_call(self, mock_xcom, mock_hook) -> None: ) mock_xcom.assert_called_with( context, - key="data_catalog_tag_template", - value={ - "tag_template_id": TEST_TAG_TEMPLATE_ID, - "location_id": TEST_LOCATION, - "project_id": TEST_PROJECT_ID, - }, + key="tag_template_id", + value=TEST_TAG_TEMPLATE_ID, ) assert result == {**result, **TEST_TAG_TEMPLATE_DICT} @@ -432,12 +409,8 @@ def test_assert_valid_hook_call(self, mock_xcom, mock_hook) -> None: ) mock_xcom.assert_called_with( context, - key="data_catalog_tag_template", - value={ - "tag_template_id": TEST_TAG_TEMPLATE_ID, - "location_id": TEST_LOCATION, - "project_id": TEST_PROJECT_ID, - }, + key="tag_template_field_id", + value=TEST_TAG_TEMPLATE_FIELD_ID, ) assert result == {**result, **TEST_TAG_TEMPLATE_FIELD_DICT} diff --git a/providers/google/tests/unit/google/cloud/operators/test_dataprep.py b/providers/google/tests/unit/google/cloud/operators/test_dataprep.py index 2377184f6acf2..2cdbd8bed490d 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_dataprep.py +++ b/providers/google/tests/unit/google/cloud/operators/test_dataprep.py @@ -120,7 +120,6 @@ def test_execute_with_project_id_will_persist_link_to_job_group( if provide_project_id: link_mock.persist.assert_called_with( context=context, - task_instance=op, project_id=project_id, job_group_id=JOB_ID, ) @@ -228,7 +227,6 @@ def test_execute_with_project_id_will_persist_link_to_flow( if provide_project_id: link_mock.persist.assert_called_with( context=context, - task_instance=op, project_id=project_id, flow_id=NEW_FLOW_ID, ) diff --git a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py index 986cfba4113f5..00923cb590a1f 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py +++ b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py @@ -406,7 +406,8 @@ map_index=1, logical_date=dt.datetime(2024, 11, 11), dag_run=MagicMock(logical_date=dt.datetime(2024, 11, 11), clear_number=0), - ) + ), + "task": mock.MagicMock(), } OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG = { "url": "https://some-custom.url", @@ -459,13 +460,13 @@ def setup_class(cls): def setup_method(self): self.mock_ti = MagicMock() - self.mock_context = {"ti": self.mock_ti} + self.mock_context = {"ti": self.mock_ti, "task": self.mock_ti.task} self.extra_links_manager_mock = Mock() self.extra_links_manager_mock.attach_mock(self.mock_ti, "ti") def tearDown(self): self.mock_ti = MagicMock() - self.mock_context = {"ti": self.mock_ti} + self.mock_context = {"ti": self.mock_ti, "task": self.mock_ti.task} self.extra_links_manager_mock = Mock() self.extra_links_manager_mock.attach_mock(self.mock_ti, "ti") @@ -496,13 +497,14 @@ def setup_class(cls): super().setup_class() if AIRFLOW_V_3_0_PLUS: cls.extra_links_expected_calls_base = [ - call.ti.xcom_push(key="dataproc_cluster", value=DATAPROC_CLUSTER_EXPECTED) + call.ti.xcom_push(key="dataproc_cluster", value=DATAPROC_CLUSTER_EXPECTED), ] else: cls.extra_links_expected_calls_base = [ - call.ti.xcom_push( - key="dataproc_cluster", value=DATAPROC_CLUSTER_EXPECTED, execution_date=None - ) + call.ti.task.extra_links_params.__bool__(), + call.ti.task.extra_links_params.keys(), + call.ti.task.extra_links_params.keys().__iter__(), + call.ti.xcom_push(key="dataproc_cluster", value=DATAPROC_CLUSTER_EXPECTED), ] @@ -818,7 +820,6 @@ def test_execute(self, mock_hook, to_dict_mock): self.mock_ti.xcom_push.assert_called_once_with( key="dataproc_cluster", value=DATAPROC_CLUSTER_EXPECTED, - execution_date=None, ) @mock.patch(DATAPROC_PATH.format("Cluster.to_dict")) @@ -874,7 +875,6 @@ def test_execute_in_gke(self, mock_hook, to_dict_mock): self.mock_ti.xcom_push.assert_called_once_with( key="dataproc_cluster", value=DATAPROC_CLUSTER_EXPECTED, - execution_date=None, ) @mock.patch(DATAPROC_PATH.format("Cluster.to_dict")) @@ -1271,12 +1271,7 @@ def test_create_execute_call_finished_before_defer(self, mock_trigger_hook, mock class TestDataprocSubmitJobOperator(DataprocJobTestBase): @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute(self, mock_hook): - if AIRFLOW_V_3_0_PLUS: - xcom_push_call = call.ti.xcom_push(key="dataproc_job", value=DATAPROC_JOB_EXPECTED) - else: - xcom_push_call = call.ti.xcom_push( - key="dataproc_job", value=DATAPROC_JOB_EXPECTED, execution_date=None - ) + xcom_push_call = call.ti.xcom_push(key="dataproc_job", value=DATAPROC_JOB_EXPECTED) wait_for_job_call = call.hook().wait_for_job( job_id=TEST_JOB_ID, region=GCP_REGION, project_id=GCP_PROJECT, timeout=None ) @@ -1322,12 +1317,7 @@ def test_execute(self, mock_hook): job_id=TEST_JOB_ID, project_id=GCP_PROJECT, region=GCP_REGION, timeout=None ) - if AIRFLOW_V_3_0_PLUS: - self.mock_ti.xcom_push.assert_called_once_with(key="dataproc_job", value=DATAPROC_JOB_EXPECTED) - else: - self.mock_ti.xcom_push.assert_called_once_with( - key="dataproc_job", value=DATAPROC_JOB_EXPECTED, execution_date=None - ) + self.mock_ti.xcom_push.assert_called_once_with(key="dataproc_job", value=DATAPROC_JOB_EXPECTED) @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute_async(self, mock_hook): @@ -1365,12 +1355,7 @@ def test_execute_async(self, mock_hook): ) mock_hook.return_value.wait_for_job.assert_not_called() - if AIRFLOW_V_3_0_PLUS: - self.mock_ti.xcom_push.assert_called_once_with(key="dataproc_job", value=DATAPROC_JOB_EXPECTED) - else: - self.mock_ti.xcom_push.assert_called_once_with( - key="dataproc_job", value=DATAPROC_JOB_EXPECTED, execution_date=None - ) + self.mock_ti.xcom_push.assert_called_once_with(key="dataproc_job", value=DATAPROC_JOB_EXPECTED) @mock.patch(DATAPROC_PATH.format("DataprocHook")) @mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook")) @@ -1481,7 +1466,8 @@ def test_execute_openlineage_parent_job_info_injection( try_number=1, map_index=1, logical_date=dt.datetime(2024, 11, 11), - ) + ), + "task": MagicMock(), } mock_ol_accessible.return_value = True @@ -2089,17 +2075,10 @@ def test_execute(self, mock_hook): # Test whether the xcom push happens before updating the cluster self.extra_links_manager_mock.assert_has_calls(expected_calls, any_order=False) - if AIRFLOW_V_3_0_PLUS: - self.mock_ti.xcom_push.assert_called_once_with( - key="dataproc_cluster", - value=DATAPROC_CLUSTER_EXPECTED, - ) - else: - self.mock_ti.xcom_push.assert_called_once_with( - key="dataproc_cluster", - value=DATAPROC_CLUSTER_EXPECTED, - execution_date=None, - ) + self.mock_ti.xcom_push.assert_called_once_with( + key="dataproc_cluster", + value=DATAPROC_CLUSTER_EXPECTED, + ) def test_missing_region_parameter(self): with pytest.raises((TypeError, AirflowException), match="missing keyword argument 'region'"): diff --git a/providers/google/tests/unit/google/cloud/operators/test_datastore.py b/providers/google/tests/unit/google/cloud/operators/test_datastore.py index c4197a2a69cb8..67f0e55a64a5b 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_datastore.py +++ b/providers/google/tests/unit/google/cloud/operators/test_datastore.py @@ -58,7 +58,7 @@ def test_execute(self, mock_hook): project_id=PROJECT_ID, bucket=BUCKET, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=CONN_ID, impersonation_chain=None) mock_hook.return_value.export_to_storage_bucket.assert_called_once_with( @@ -87,7 +87,7 @@ def test_execute(self, mock_hook): bucket=BUCKET, file=FILE, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(CONN_ID, impersonation_chain=None) mock_hook.return_value.import_from_storage_bucket.assert_called_once_with( @@ -112,7 +112,7 @@ def test_execute(self, mock_hook): project_id=PROJECT_ID, partial_keys=partial_keys, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=CONN_ID, impersonation_chain=None) mock_hook.return_value.allocate_ids.assert_called_once_with( @@ -143,7 +143,7 @@ def test_execute(self, mock_hook): op = CloudDatastoreCommitOperator( task_id="test_task", gcp_conn_id=CONN_ID, project_id=PROJECT_ID, body=BODY ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=CONN_ID, impersonation_chain=None) mock_hook.return_value.commit.assert_called_once_with(project_id=PROJECT_ID, body=BODY) diff --git a/providers/google/tests/unit/google/cloud/operators/test_functions.py b/providers/google/tests/unit/google/cloud/operators/test_functions.py index 679f69b1c92be..fa859894b746e 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_functions.py +++ b/providers/google/tests/unit/google/cloud/operators/test_functions.py @@ -730,11 +730,7 @@ def test_execute(self, mock_gcf_hook, mock_xcom): ) mock_xcom.assert_called_with( - context, - key="cloud_functions_details", - value={ - "location": GCP_LOCATION, - "function_name": function_id, - "project_id": GCP_PROJECT_ID, - }, + context=context, + key="execution_id", + value=exec_id, ) diff --git a/providers/google/tests/unit/google/cloud/operators/test_gcs.py b/providers/google/tests/unit/google/cloud/operators/test_gcs.py index a5eb7ddafe4fe..48e9561ac537f 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_gcs.py +++ b/providers/google/tests/unit/google/cloud/operators/test_gcs.py @@ -401,6 +401,7 @@ def test_execute(self, mock_hook, mock_subprocess, mock_tempdir): data_interval_start=timespan_start, data_interval_end=timespan_end, ti=mock_ti, + task=mock.MagicMock(), ) mock_tempdir.return_value.__enter__.side_effect = [source, destination] @@ -581,6 +582,7 @@ def test_get_openlineage_facets_on_complete( data_interval_start=timespan_start, data_interval_end=timespan_end, ti=mock.Mock(), + task=mock.MagicMock(), ) mock_tempdir.return_value.__enter__.side_effect = ["source", destination] diff --git a/providers/google/tests/unit/google/cloud/operators/test_kubernetes_engine.py b/providers/google/tests/unit/google/cloud/operators/test_kubernetes_engine.py index 510db752b3d60..a8fafd6af5dc1 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_kubernetes_engine.py +++ b/providers/google/tests/unit/google/cloud/operators/test_kubernetes_engine.py @@ -483,9 +483,7 @@ def test_execute(self, mock_cluster_hook, mock_link): result = self.operator.execute(context=mock_context) - mock_link.persist.assert_called_once_with( - context=mock_context, task_instance=self.operator, cluster=GKE_CLUSTER_CREATE_BODY_DICT - ) + mock_link.persist.assert_called_once_with(context=mock_context, cluster=GKE_CLUSTER_CREATE_BODY_DICT) mock_create_cluster.assert_called_once_with( cluster=GKE_CLUSTER_CREATE_BODY_DICT, project_id=TEST_PROJECT_ID, @@ -506,9 +504,7 @@ def test_execute_error(self, mock_cluster_hook, mock_link, mock_log): result = self.operator.execute(context=mock_context) - mock_link.persist.assert_called_once_with( - context=mock_context, task_instance=self.operator, cluster=GKE_CLUSTER_CREATE_BODY_DICT - ) + mock_link.persist.assert_called_once_with(context=mock_context, cluster=GKE_CLUSTER_CREATE_BODY_DICT) mock_create_cluster.assert_called_once_with( cluster=GKE_CLUSTER_CREATE_BODY_DICT, project_id=TEST_PROJECT_ID, @@ -535,9 +531,7 @@ def test_deferrable(self, mock_cluster_hook, mock_defer, mock_link, mock_trigger self.operator.execute(context=mock_context) - mock_link.persist.assert_called_once_with( - context=mock_context, task_instance=self.operator, cluster=GKE_CLUSTER_CREATE_BODY_DICT - ) + mock_link.persist.assert_called_once_with(context=mock_context, cluster=GKE_CLUSTER_CREATE_BODY_DICT) mock_create_cluster.assert_called_once_with( cluster=GKE_CLUSTER_CREATE_BODY_DICT, project_id=TEST_PROJECT_ID, @@ -621,7 +615,6 @@ def test_execute(self, mock_hook, mock_link, mock_super): ) mock_link.persist.assert_called_once_with( context=mock_context, - task_instance=self.operator, cluster=mock_cluster, ) mock_check_cluster_autoscaling_ability.assert_called_once_with(cluster=mock_cluster) @@ -647,7 +640,6 @@ def test_execute_not_scalable(self, mock_hook, mock_link, mock_super, mock_log): ) mock_link.persist.assert_called_once_with( context=mock_context, - task_instance=self.operator, cluster=mock_cluster, ) mock_check_cluster_autoscaling_ability.assert_called_once_with(cluster=mock_cluster) @@ -933,7 +925,14 @@ def test_execute(self, mock_cluster_hook, mock_hook, mock_fetch_cluster_info, mo GKE_CLUSTER_NAME, mock_job, ) - mock_link.persist.assert_called_once_with(context=mock_context, task_instance=self.operator) + mock_link.persist.assert_called_once_with( + context=mock_context, + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_name=GKE_CLUSTER_NAME, + namespace=mock_job.metadata.namespace, + job_name=mock_job.metadata.name, + ) class TestGKEListJobsOperator: @@ -949,12 +948,11 @@ def test_template_fields(self): expected_template_fields = {"namespace"} | set(GKEOperatorMixin.template_fields) assert set(GKEListJobsOperator.template_fields) == expected_template_fields - @mock.patch(GKE_OPERATORS_PATH.format("KubernetesEngineWorkloadsLink")) @mock.patch(GKE_OPERATORS_PATH.format("V1JobList.to_dict")) @mock.patch(GKE_OPERATORS_PATH.format("GKEListJobsOperator.log")) @mock.patch(GKE_OPERATORS_PATH.format("GKEHook")) @mock.patch(GKE_OPERATORS_PATH.format("GKEKubernetesHook")) - def test_execute(self, mock_hook, cluster_hook, mock_log, mock_to_dict, mock_link): + def test_execute(self, mock_hook, cluster_hook, mock_log, mock_to_dict): mock_list_jobs_from_namespace = mock_hook.return_value.list_jobs_from_namespace mock_list_jobs_all_namespaces = mock_hook.return_value.list_jobs_all_namespaces mock_job_1, mock_job_2 = mock.MagicMock(), mock.MagicMock() @@ -963,7 +961,7 @@ def test_execute(self, mock_hook, cluster_hook, mock_log, mock_to_dict, mock_lin mock_to_dict_value = mock_to_dict.return_value mock_ti = mock.MagicMock() - context = {"ti": mock_ti} + context = {"ti": mock_ti, "task": mock.MagicMock()} result = self.operator.execute(context=context) @@ -976,8 +974,9 @@ def test_execute(self, mock_hook, cluster_hook, mock_log, mock_to_dict, mock_lin ] ) mock_to_dict.assert_has_calls([call(mock_jobs), call(mock_jobs)]) - mock_ti.xcom_push.assert_called_once_with(key="jobs_list", value=mock_to_dict_value) - mock_link.persist.assert_called_once_with(context=context, task_instance=self.operator) + mock_ti.xcom_push.assert_has_calls( + [call(key="jobs_list", value=mock_to_dict_value), call(key="kubernetes_workloads_conf", value={})] + ) assert result == mock_to_dict_value @mock.patch(GKE_OPERATORS_PATH.format("KubernetesEngineWorkloadsLink")) @@ -994,7 +993,7 @@ def test_execute_namespaced(self, mock_hook, cluster_hook, mock_log, mock_to_dic mock_to_dict_value = mock_to_dict.return_value mock_ti = mock.MagicMock() - context = {"ti": mock_ti} + context = {"ti": mock_ti, "task": mock.MagicMock()} self.operator.namespace = K8S_NAMESPACE result = self.operator.execute(context=context) @@ -1009,7 +1008,7 @@ def test_execute_namespaced(self, mock_hook, cluster_hook, mock_log, mock_to_dic ) mock_to_dict.assert_has_calls([call(mock_jobs), call(mock_jobs)]) mock_ti.xcom_push.assert_called_once_with(key="jobs_list", value=mock_to_dict_value) - mock_link.persist.assert_called_once_with(context=context, task_instance=self.operator) + mock_link.persist.assert_called_once_with(context=context) assert result == mock_to_dict_value @mock.patch(GKE_OPERATORS_PATH.format("KubernetesEngineWorkloadsLink")) @@ -1026,7 +1025,7 @@ def test_execute_not_do_xcom_push(self, mock_hook, cluster_hook, mock_log, mock_ mock_to_dict_value = mock_to_dict.return_value mock_ti = mock.MagicMock() - context = {"ti": mock_ti} + context = {"ti": mock_ti, "task": mock.MagicMock()} self.operator.do_xcom_push = False result = self.operator.execute(context=context) @@ -1040,8 +1039,7 @@ def test_execute_not_do_xcom_push(self, mock_hook, cluster_hook, mock_log, mock_ ] ) mock_to_dict.assert_called_once_with(mock_jobs) - mock_ti.xcom_push.assert_not_called() - mock_link.persist.assert_called_once_with(context=context, task_instance=self.operator) + mock_link.persist.assert_called_once_with(context=context) assert result == mock_to_dict_value @@ -1188,7 +1186,14 @@ def test_execute(self, mock_cluster_hook, mock_hook, mock_log, mock_link, mock_t K8S_JOB_NAME, GKE_CLUSTER_NAME, ) - mock_link.persist.assert_called_once_with(context=mock_context, task_instance=self.operator) + mock_link.persist.assert_called_once_with( + context=mock_context, + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_name=GKE_CLUSTER_NAME, + namespace=mock_job.metadata.namespace, + job_name=mock_job.metadata.name, + ) mock_to_dict.assert_called_once_with(mock_job) assert result == expected_result @@ -1231,6 +1236,13 @@ def test_execute(self, mock_cluster_hook, mock_hook, mock_log, mock_link, mock_t K8S_JOB_NAME, GKE_CLUSTER_NAME, ) - mock_link.persist.assert_called_once_with(context=mock_context, task_instance=self.operator) + mock_link.persist.assert_called_once_with( + context=mock_context, + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_name=GKE_CLUSTER_NAME, + namespace=mock_job.metadata.namespace, + job_name=mock_job.metadata.name, + ) mock_to_dict.assert_called_once_with(mock_job) assert result == expected_result diff --git a/providers/google/tests/unit/google/cloud/operators/test_managed_kafka.py b/providers/google/tests/unit/google/cloud/operators/test_managed_kafka.py index fd41068201439..5e99b0f0633fc 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_managed_kafka.py +++ b/providers/google/tests/unit/google/cloud/operators/test_managed_kafka.py @@ -104,7 +104,7 @@ def test_execute(self, mock_hook, to_dict_mock): timeout=TIMEOUT, metadata=METADATA, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.create_cluster.assert_called_once_with( location=GCP_LOCATION, @@ -142,7 +142,7 @@ def test_execute(self, mock_hook, to_cluster_dict_mock, to_clusters_dict_mock): timeout=TIMEOUT, metadata=METADATA, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.list_clusters.assert_called_once_with( location=GCP_LOCATION, @@ -172,7 +172,7 @@ def test_execute(self, mock_hook, to_dict_mock): timeout=TIMEOUT, metadata=METADATA, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.get_cluster.assert_called_once_with( location=GCP_LOCATION, @@ -202,7 +202,7 @@ def test_execute(self, mock_hook, to_dict_mock): timeout=TIMEOUT, metadata=METADATA, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.update_cluster.assert_called_once_with( project_id=GCP_PROJECT, @@ -262,7 +262,7 @@ def test_execute(self, mock_hook, to_dict_mock): timeout=TIMEOUT, metadata=METADATA, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.create_topic.assert_called_once_with( location=GCP_LOCATION, @@ -297,7 +297,7 @@ def test_execute(self, mock_hook, to_cluster_dict_mock, to_clusters_dict_mock): timeout=TIMEOUT, metadata=METADATA, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.list_topics.assert_called_once_with( location=GCP_LOCATION, @@ -327,7 +327,7 @@ def test_execute(self, mock_hook, to_dict_mock): timeout=TIMEOUT, metadata=METADATA, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.get_topic.assert_called_once_with( location=GCP_LOCATION, @@ -358,7 +358,7 @@ def test_execute(self, mock_hook, to_dict_mock): timeout=TIMEOUT, metadata=METADATA, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.update_topic.assert_called_once_with( project_id=GCP_PROJECT, @@ -422,7 +422,7 @@ def test_execute(self, mock_hook, to_cluster_dict_mock, to_clusters_dict_mock): timeout=TIMEOUT, metadata=METADATA, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.list_consumer_groups.assert_called_once_with( location=GCP_LOCATION, @@ -452,7 +452,7 @@ def test_execute(self, mock_hook, to_dict_mock): timeout=TIMEOUT, metadata=METADATA, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.get_consumer_group.assert_called_once_with( location=GCP_LOCATION, @@ -483,7 +483,7 @@ def test_execute(self, mock_hook, to_dict_mock): timeout=TIMEOUT, metadata=METADATA, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.update_consumer_group.assert_called_once_with( project_id=GCP_PROJECT, diff --git a/providers/google/tests/unit/google/cloud/operators/test_translate.py b/providers/google/tests/unit/google/cloud/operators/test_translate.py index 1d5ef328ae360..957bd3ddd1b5b 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_translate.py +++ b/providers/google/tests/unit/google/cloud/operators/test_translate.py @@ -207,7 +207,6 @@ def test_minimal_green_path(self, mock_hook, mock_link_persist): mock_link_persist.assert_called_once_with( context=context, - task_instance=op, project_id=PROJECT_ID, output_config=OUTPUT_CONFIG, ) @@ -268,7 +267,7 @@ def test_minimal_green_path(self, mock_hook, mock_xcom_push, mock_link_persist): mock_link_persist.assert_called_once_with( context=context, dataset_id=DATASET_ID, - task_instance=op, + location=LOCATION, project_id=PROJECT_ID, ) assert result == DS_CREATION_RESULT_SAMPLE @@ -319,7 +318,6 @@ def test_minimal_green_path(self, mock_hook, mock_link_persist): ) mock_link_persist.assert_called_once_with( context=context, - task_instance=op, project_id=PROJECT_ID, ) assert result == [DS_ID_1, DS_ID_2] @@ -362,7 +360,7 @@ def test_minimal_green_path(self, mock_hook, mock_link_persist): mock_link_persist.assert_called_once_with( context=context, dataset_id=DATASET_ID, - task_instance=op, + location=LOCATION, project_id=PROJECT_ID, ) @@ -455,10 +453,10 @@ def test_minimal_green_path(self, mock_hook, mock_xcom_push, mock_link_persist): mock_xcom_push.assert_called_once_with(context, key="model_id", value=MODEL_ID) mock_link_persist.assert_called_once_with( context=context, - task_instance=op, model_id=MODEL_ID, project_id=PROJECT_ID, dataset_id=DATASET_ID, + location=LOCATION, ) assert result == MODEL_CREATION_RESULT_SAMPLE @@ -515,7 +513,6 @@ def test_minimal_green_path(self, mock_hook, mock_link_persist): assert result == [MODEL_ID_1, MODEL_ID_2] mock_link_persist.assert_called_once_with( context=context, - task_instance=op, project_id=PROJECT_ID, ) @@ -634,7 +631,6 @@ def test_minimal_green_path(self, mock_hook, mock_link_persist): assert result == BATCH_DOC_TRANSLATION_RESULT mock_link_persist.assert_called_once_with( context=context, - task_instance=op, project_id=PROJECT_ID, output_config=OUTPUT_CONFIG, ) @@ -709,7 +705,6 @@ def test_minimal_green_path(self, mock_hook, mock_link_persist): assert result == DOC_TRANSLATION_RESULT mock_link_persist.assert_called_once_with( context=context, - task_instance=op, project_id=PROJECT_ID, output_config=OUTPUT_CONFIG, ) @@ -884,7 +879,6 @@ def test_minimal_green_path(self, mock_hook, mock_link_persist): assert result == [GLOSSARY_ID_1, GLOSSARY_ID_2] mock_link_persist.assert_called_once_with( context=context, - task_instance=op, project_id=PROJECT_ID, ) diff --git a/providers/google/tests/unit/google/cloud/operators/test_vertex_ai.py b/providers/google/tests/unit/google/cloud/operators/test_vertex_ai.py index 80f488a334087..67d11428cb76b 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_vertex_ai.py +++ b/providers/google/tests/unit/google/cloud/operators/test_vertex_ai.py @@ -264,7 +264,7 @@ def test_execute(self, mock_hook, mock_dataset): dataset_id=TEST_DATASET_ID, parent_model=TEST_PARENT_MODEL, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_dataset.assert_called_once_with(name=TEST_DATASET_ID) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.create_custom_container_training_job.assert_called_once_with( @@ -352,7 +352,7 @@ def test_execute__parent_model_version_index_is_removed(self, mock_hook, mock_da dataset_id=TEST_DATASET_ID, parent_model=VERSIONED_TEST_PARENT_MODEL, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.return_value.create_custom_container_training_job.assert_called_once_with( staging_bucket=STAGING_BUCKET, display_name=DISPLAY_NAME, @@ -433,7 +433,7 @@ def test_execute_enters_deferred_state(self, mock_hook): ) mock_hook.return_value.exists.return_value = False with pytest.raises(TaskDeferred) as exc: - task.execute(context={"ti": mock.MagicMock()}) + task.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) assert isinstance(exc.value.trigger, CustomContainerTrainingJobTrigger), ( "Trigger is not a CustomContainerTrainingJobTrigger" ) @@ -480,7 +480,7 @@ def test_execute_complete_success( }, ) mock_xcom_push.assert_called_with(None, key="model_id", value="test-model") - mock_link_persist.assert_called_once_with(context=None, task_instance=task, model_id="test-model") + mock_link_persist.assert_called_once_with(context=None, model_id="test-model") assert actual_result == expected_result def test_execute_complete_error_status_raises_exception(self): @@ -586,7 +586,7 @@ def test_execute(self, mock_hook, mock_dataset): dataset_id=TEST_DATASET_ID, parent_model=TEST_PARENT_MODEL, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_dataset.assert_called_once_with(name=TEST_DATASET_ID) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.create_custom_python_package_training_job.assert_called_once_with( @@ -676,7 +676,7 @@ def test_execute__parent_model_version_index_is_removed(self, mock_hook, mock_da dataset_id=TEST_DATASET_ID, parent_model=VERSIONED_TEST_PARENT_MODEL, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.return_value.create_custom_python_package_training_job.assert_called_once_with( staging_bucket=STAGING_BUCKET, display_name=DISPLAY_NAME, @@ -759,7 +759,7 @@ def test_execute_enters_deferred_state(self, mock_hook): ) mock_hook.return_value.exists.return_value = False with pytest.raises(TaskDeferred) as exc: - task.execute(context={"ti": mock.MagicMock()}) + task.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) assert isinstance(exc.value.trigger, CustomPythonPackageTrainingJobTrigger), ( "Trigger is not a CustomPythonPackageTrainingJobTrigger" ) @@ -811,7 +811,7 @@ def test_execute_complete_success( }, ) mock_xcom_push.assert_called_with(None, key="model_id", value="test-model") - mock_link_persist.assert_called_once_with(context=None, task_instance=task, model_id="test-model") + mock_link_persist.assert_called_once_with(context=None, model_id="test-model") assert actual_result == expected_result def test_execute_complete_error_status_raises_exception(self): @@ -911,7 +911,7 @@ def test_execute(self, mock_hook, mock_dataset): dataset_id=TEST_DATASET_ID, parent_model=TEST_PARENT_MODEL, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_dataset.assert_called_once_with(name=TEST_DATASET_ID) mock_hook.return_value.create_custom_training_job.assert_called_once_with( @@ -994,7 +994,7 @@ def test_execute__parent_model_version_index_is_removed(self, mock_hook, mock_da dataset_id=TEST_DATASET_ID, parent_model=VERSIONED_TEST_PARENT_MODEL, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.return_value.create_custom_training_job.assert_called_once_with( staging_bucket=STAGING_BUCKET, display_name=DISPLAY_NAME, @@ -1070,7 +1070,7 @@ def test_execute_enters_deferred_state(self, mock_hook): ) mock_hook.return_value.exists.return_value = False with pytest.raises(TaskDeferred) as exc: - task.execute(context={"ti": mock.MagicMock()}) + task.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) assert isinstance(exc.value.trigger, CustomTrainingJobTrigger), ( "Trigger is not a CustomTrainingJobTrigger" ) @@ -1113,7 +1113,7 @@ def test_execute_complete_success( }, ) mock_xcom_push.assert_called_with(None, key="model_id", value="test-model") - mock_link_persist.assert_called_once_with(context=None, task_instance=task, model_id="test-model") + mock_link_persist.assert_called_once_with(context=None, model_id="test-model") assert actual_result == expected_result def test_execute_complete_error_status_raises_exception(self): @@ -1255,7 +1255,7 @@ def test_execute(self, mock_hook): timeout=TIMEOUT, metadata=METADATA, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.list_training_pipelines.assert_called_once_with( region=GCP_LOCATION, @@ -1285,7 +1285,7 @@ def test_execute(self, mock_hook, to_dict_mock): timeout=TIMEOUT, metadata=METADATA, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.create_dataset.assert_called_once_with( region=GCP_LOCATION, @@ -1407,7 +1407,7 @@ def test_execute(self, mock_hook, to_dict_mock): timeout=TIMEOUT, metadata=METADATA, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.list_datasets.assert_called_once_with( region=GCP_LOCATION, @@ -1479,7 +1479,7 @@ def test_execute(self, mock_hook, mock_dataset): parent_model=TEST_PARENT_MODEL, holiday_regions=TEST_TRAINING_DATA_HOLIDAY_REGIONS, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_dataset.assert_called_once_with(dataset_name=TEST_DATASET_ID) mock_hook.return_value.create_auto_ml_forecasting_training_job.assert_called_once_with( @@ -1550,7 +1550,7 @@ def test_execute__parent_model_version_index_is_removed(self, mock_hook, mock_da parent_model=VERSIONED_TEST_PARENT_MODEL, holiday_regions=TEST_TRAINING_DATA_HOLIDAY_REGIONS, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.return_value.create_auto_ml_forecasting_training_job.assert_called_once_with( project_id=GCP_PROJECT, region=GCP_LOCATION, @@ -1615,7 +1615,7 @@ def test_execute(self, mock_hook, mock_dataset): project_id=GCP_PROJECT, parent_model=TEST_PARENT_MODEL, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_dataset.assert_called_once_with(dataset_name=TEST_DATASET_ID) mock_hook.return_value.create_auto_ml_image_training_job.assert_called_once_with( @@ -1665,7 +1665,7 @@ def test_execute__parent_model_version_index_is_removed(self, mock_hook, mock_da project_id=GCP_PROJECT, parent_model=VERSIONED_TEST_PARENT_MODEL, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.return_value.create_auto_ml_image_training_job.assert_called_once_with( project_id=GCP_PROJECT, region=GCP_LOCATION, @@ -1719,7 +1719,7 @@ def test_execute(self, mock_hook, mock_dataset): project_id=GCP_PROJECT, parent_model=TEST_PARENT_MODEL, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_dataset.assert_called_once_with( dataset_name=TEST_DATASET_ID, project=GCP_PROJECT, credentials="creds" @@ -1781,7 +1781,7 @@ def test_execute__parent_model_version_index_is_removed(self, mock_hook, mock_da project_id=GCP_PROJECT, parent_model=VERSIONED_TEST_PARENT_MODEL, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.return_value.create_auto_ml_tabular_training_job.assert_called_once_with( project_id=GCP_PROJECT, region=GCP_LOCATION, @@ -1836,7 +1836,7 @@ def test_execute(self, mock_hook, mock_dataset): project_id=GCP_PROJECT, parent_model=TEST_PARENT_MODEL, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_dataset.assert_called_once_with(dataset_name=TEST_DATASET_ID) mock_hook.return_value.create_auto_ml_video_training_job.assert_called_once_with( @@ -1879,7 +1879,7 @@ def test_execute__parent_model_version_index_is_removed(self, mock_hook, mock_da project_id=GCP_PROJECT, parent_model=VERSIONED_TEST_PARENT_MODEL, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.return_value.create_auto_ml_video_training_job.assert_called_once_with( project_id=GCP_PROJECT, region=GCP_LOCATION, @@ -1975,7 +1975,7 @@ def test_execute(self, mock_hook): timeout=TIMEOUT, metadata=METADATA, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.list_training_pipelines.assert_called_once_with( region=GCP_LOCATION, @@ -2010,7 +2010,7 @@ def test_execute(self, mock_hook, mock_link_persist): create_request_timeout=TEST_CREATE_REQUEST_TIMEOUT, batch_size=TEST_BATCH_SIZE, ) - context = {"ti": mock.MagicMock()} + context = {"ti": mock.MagicMock(), "task": mock.MagicMock()} op.execute(context=context) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) @@ -2042,7 +2042,8 @@ def test_execute(self, mock_hook, mock_link_persist): mock_job.wait_for_completion.assert_called_once() mock_job.to_dict.assert_called_once() mock_link_persist.assert_called_once_with( - context=context, task_instance=op, batch_prediction_job_id=TEST_BATCH_PREDICTION_JOB_ID + context=context, + batch_prediction_job_id=TEST_BATCH_PREDICTION_JOB_ID, ) @mock.patch(VERTEX_AI_LINKS_PATH.format("VertexAIBatchPredictionJobLink.persist")) @@ -2064,7 +2065,7 @@ def test_execute_deferrable(self, mock_hook, mock_link_persist): batch_size=TEST_BATCH_SIZE, deferrable=True, ) - context = {"ti": mock.MagicMock()} + context = {"ti": mock.MagicMock(), "task": mock.MagicMock()} with ( pytest.raises(TaskDeferred) as exception_info, pytest.warns( @@ -2105,7 +2106,6 @@ def test_execute_deferrable(self, mock_hook, mock_link_persist): mock_link_persist.assert_called_once_with( batch_prediction_job_id=TEST_BATCH_PREDICTION_JOB_ID, context=context, - task_instance=op, ) assert hasattr(exception_info.value, "trigger") assert exception_info.value.trigger.conn_id == GCP_CONN_ID @@ -2207,7 +2207,7 @@ def test_execute(self, mock_hook): timeout=TIMEOUT, metadata=METADATA, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.list_batch_prediction_jobs.assert_called_once_with( region=GCP_LOCATION, @@ -2238,7 +2238,7 @@ def test_execute(self, mock_hook, to_dict_mock): timeout=TIMEOUT, metadata=METADATA, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.create_endpoint.assert_called_once_with( region=GCP_LOCATION, @@ -2294,7 +2294,7 @@ def test_execute(self, mock_hook, to_dict_mock): timeout=TIMEOUT, metadata=METADATA, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.deploy_model.assert_called_once_with( region=GCP_LOCATION, @@ -2332,7 +2332,7 @@ def test_execute(self, mock_hook): timeout=TIMEOUT, metadata=METADATA, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.list_endpoints.assert_called_once_with( region=GCP_LOCATION, @@ -2396,7 +2396,7 @@ def test_execute(self, mock_hook, to_dict_mock): max_trial_count=15, parallel_trial_count=3, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.create_hyperparameter_tuning_job.assert_called_once_with( project_id=GCP_PROJECT, @@ -2446,7 +2446,7 @@ def test_deferrable(self, mock_hook, mock_defer): parallel_trial_count=3, deferrable=True, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_defer.assert_called_once() @mock.patch(VERTEX_AI_PATH.format("hyperparameter_tuning_job.HyperparameterTuningJobHook")) @@ -2517,7 +2517,7 @@ def test_execute(self, mock_hook, to_dict_mock): project_id=GCP_PROJECT, hyperparameter_tuning_job_id=TEST_HYPERPARAMETER_TUNING_JOB_ID, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.get_hyperparameter_tuning_job.assert_called_once_with( project_id=GCP_PROJECT, @@ -2574,7 +2574,7 @@ def test_execute(self, mock_hook): timeout=TIMEOUT, metadata=METADATA, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.list_hyperparameter_tuning_jobs.assert_called_once_with( region=GCP_LOCATION, @@ -2604,7 +2604,7 @@ def test_execute(self, mock_hook): timeout=TIMEOUT, metadata=METADATA, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.export_model.assert_called_once_with( region=GCP_LOCATION, @@ -2668,7 +2668,7 @@ def test_execute(self, mock_hook, to_dict_mock): timeout=TIMEOUT, metadata=METADATA, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.list_models.assert_called_once_with( region=GCP_LOCATION, @@ -2699,7 +2699,7 @@ def test_execute(self, mock_hook, to_dict_mock): timeout=TIMEOUT, metadata=METADATA, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.upload_model.assert_called_once_with( region=GCP_LOCATION, @@ -2726,7 +2726,7 @@ def test_execute_with_parent_model(self, mock_hook, to_dict_mock): timeout=TIMEOUT, metadata=METADATA, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.upload_model.assert_called_once_with( region=GCP_LOCATION, @@ -2754,7 +2754,7 @@ def test_execute(self, mock_hook, to_dict_mock): timeout=TIMEOUT, metadata=METADATA, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.get_model.assert_called_once_with( region=GCP_LOCATION, @@ -2781,7 +2781,7 @@ def test_execute(self, mock_hook, to_dict_mock): timeout=TIMEOUT, metadata=METADATA, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.list_model_versions.assert_called_once_with( region=GCP_LOCATION, @@ -2808,7 +2808,7 @@ def test_execute(self, mock_hook, to_dict_mock): timeout=TIMEOUT, metadata=METADATA, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.set_version_as_default.assert_called_once_with( region=GCP_LOCATION, @@ -2836,7 +2836,7 @@ def test_execute(self, mock_hook, to_dict_mock): timeout=TIMEOUT, metadata=METADATA, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.add_version_aliases.assert_called_once_with( region=GCP_LOCATION, @@ -2865,7 +2865,7 @@ def test_execute(self, mock_hook, to_dict_mock): timeout=TIMEOUT, metadata=METADATA, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.delete_version_aliases.assert_called_once_with( region=GCP_LOCATION, @@ -2893,7 +2893,7 @@ def test_execute(self, mock_hook, to_dict_mock): timeout=TIMEOUT, metadata=METADATA, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.delete_model_version.assert_called_once_with( region=GCP_LOCATION, @@ -2930,7 +2930,7 @@ def test_execute(self, to_dict_mock, mock_hook): create_request_timeout=None, experiment=None, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.submit_pipeline_job.assert_called_once_with( project_id=GCP_PROJECT, @@ -2966,7 +2966,7 @@ def test_execute_enters_deferred_state(self, mock_hook): ) mock_hook.return_value.exists.return_value = False with pytest.raises(TaskDeferred) as exc: - task.execute(context={"ti": mock.MagicMock()}) + task.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) assert isinstance(exc.value.trigger, RunPipelineJobTrigger), "Trigger is not a RunPipelineJobTrigger" @mock.patch(VERTEX_AI_PATH.format("pipeline_job.RunPipelineJobOperator.xcom_push")) @@ -3023,7 +3023,7 @@ def test_execute(self, mock_hook, to_dict_mock): project_id=GCP_PROJECT, pipeline_job_id=TEST_PIPELINE_JOB_ID, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.get_pipeline_job.assert_called_once_with( project_id=GCP_PROJECT, @@ -3080,7 +3080,7 @@ def test_execute(self, mock_hook): timeout=TIMEOUT, metadata=METADATA, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.list_pipeline_jobs.assert_called_once_with( region=GCP_LOCATION, @@ -3118,7 +3118,7 @@ def test_execute(self, mock_hook): reserved_ip_ranges=None, labels=None, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.create_ray_cluster.assert_called_once_with( location=GCP_LOCATION, @@ -3149,7 +3149,7 @@ def test_execute(self, mock_hook): location=GCP_LOCATION, project_id=GCP_PROJECT, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.list_ray_clusters.assert_called_once_with( location=GCP_LOCATION, @@ -3168,7 +3168,7 @@ def test_execute(self, mock_hook): location=GCP_LOCATION, project_id=GCP_PROJECT, ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.get_ray_cluster.assert_called_once_with( location=GCP_LOCATION, @@ -3189,7 +3189,7 @@ def test_execute(self, mock_hook): cluster_id=TEST_CLUSTER_ID, worker_node_types=[TEST_NODE_RESOURCES], ) - op.execute(context={"ti": mock.MagicMock()}) + op.execute(context={"ti": mock.MagicMock(), "task": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.update_ray_cluster.assert_called_once_with( project_id=GCP_PROJECT, diff --git a/providers/google/tests/unit/google/cloud/operators/test_workflows.py b/providers/google/tests/unit/google/cloud/operators/test_workflows.py index c0d854c91c759..7a48e2a8acbfd 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_workflows.py +++ b/providers/google/tests/unit/google/cloud/operators/test_workflows.py @@ -248,8 +248,8 @@ def test_execute(self, mock_hook, mock_object): class TestWorkflowExecutionsCreateExecutionOperator: @mock.patch(BASE_PATH.format("Execution")) @mock.patch(BASE_PATH.format("WorkflowsHook")) - @mock.patch(BASE_PATH.format("WorkflowsCreateExecutionOperator.xcom_push")) - def test_execute(self, mock_xcom, mock_hook, mock_object): + @mock.patch(BASE_PATH.format("WorkflowsExecutionLink.persist")) + def test_execute(self, mock_link_persist, mock_hook, mock_object): mock_hook.return_value.create_execution.return_value.name = "name/execution_id" op = WorkflowsCreateExecutionOperator( task_id="test_task", @@ -280,15 +280,12 @@ def test_execute(self, mock_xcom, mock_hook, mock_object): timeout=TIMEOUT, metadata=METADATA, ) - mock_xcom.assert_called_with( - context, - key="workflow_execution", - value={ - "location_id": LOCATION, - "workflow_id": WORKFLOW_ID, - "execution_id": EXECUTION_ID, - "project_id": PROJECT_ID, - }, + mock_link_persist.assert_called_with( + context=context, + location_id=LOCATION, + workflow_id=WORKFLOW_ID, + execution_id=EXECUTION_ID, + project_id=PROJECT_ID, ) assert result == mock_object.to_dict.return_value diff --git a/providers/google/tests/unit/google/cloud/sensors/test_cloud_storage_transfer_service.py b/providers/google/tests/unit/google/cloud/sensors/test_cloud_storage_transfer_service.py index 1bc464ee23b75..4312b43df57bb 100644 --- a/providers/google/tests/unit/google/cloud/sensors/test_cloud_storage_transfer_service.py +++ b/providers/google/tests/unit/google/cloud/sensors/test_cloud_storage_transfer_service.py @@ -62,7 +62,7 @@ def test_wait_for_status_success(self, mock_tool): expected_statuses=GcpTransferOperationStatus.SUCCESS, ) - context = {"ti": (mock.Mock(**{"xcom_push.return_value": None}))} + context = {"ti": (mock.Mock(**{"xcom_push.return_value": None})), "task": mock.MagicMock()} result = op.poke(context) mock_tool.return_value.list_transfer_operations.assert_called_once_with( @@ -96,7 +96,7 @@ def test_wait_for_status_success_without_project_id(self, mock_tool): expected_statuses=GcpTransferOperationStatus.SUCCESS, ) - context = {"ti": (mock.Mock(**{"xcom_push.return_value": None}))} + context = {"ti": (mock.Mock(**{"xcom_push.return_value": None})), "task": mock.MagicMock()} result = op.poke(context) mock_tool.return_value.list_transfer_operations.assert_called_once_with( @@ -118,7 +118,7 @@ def test_wait_for_status_success_default_expected_status(self, mock_tool): expected_statuses=GcpTransferOperationStatus.SUCCESS, ) - context = {"ti": (mock.Mock(**{"xcom_push.return_value": None}))} + context = {"ti": (mock.Mock(**{"xcom_push.return_value": None})), "task": mock.MagicMock()} result = op.poke(context) @@ -162,7 +162,7 @@ def test_wait_for_status_after_retry(self, mock_tool): expected_statuses=GcpTransferOperationStatus.SUCCESS, ) - context = {"ti": (mock.Mock(**{"xcom_push.return_value": None}))} + context = {"ti": (mock.Mock(**{"xcom_push.return_value": None})), "task": mock.MagicMock()} result = op.poke(context) assert not result @@ -214,7 +214,7 @@ def test_wait_for_status_normalize_status(self, mock_tool, expected_status, rece expected_statuses=expected_status, ) - context = {"ti": (mock.Mock(**{"xcom_push.return_value": None}))} + context = {"ti": (mock.Mock(**{"xcom_push.return_value": None})), "task": mock.MagicMock()} result = op.poke(context) assert not result @@ -240,7 +240,7 @@ def test_job_status_sensor_finish_before_deferred(self, mock_defer, mock_hook): ) mock_hook.operations_contain_expected_statuses.return_value = True - context = {"ti": (mock.Mock(**{"xcom_push.return_value": None}))} + context = {"ti": (mock.Mock(**{"xcom_push.return_value": None})), "task": mock.MagicMock()} op.execute(context) assert not mock_defer.called @@ -258,7 +258,7 @@ def test_execute_deferred(self, mock_hook): ) mock_hook.operations_contain_expected_statuses.return_value = False - context = {"ti": (mock.Mock(**{"xcom_push.return_value": None}))} + context = {"ti": (mock.Mock(**{"xcom_push.return_value": None})), "task": mock.MagicMock()} with pytest.raises(TaskDeferred) as exc: op.execute(context) @@ -273,7 +273,7 @@ def test_execute_deferred_failure(self): deferrable=True, ) - context = {"ti": (mock.Mock(**{"xcom_push.return_value": None}))} + context = {"ti": (mock.Mock(**{"xcom_push.return_value": None})), "task": mock.MagicMock()} with pytest.raises(AirflowException): op.execute_complete(context=context, event={"status": "error", "message": "test failure message"}) @@ -287,6 +287,6 @@ def test_execute_complete(self): deferrable=True, ) - context = {"ti": (mock.Mock(**{"xcom_push.return_value": None}))} + context = {"ti": (mock.Mock(**{"xcom_push.return_value": None})), "task": mock.MagicMock()} op.execute_complete(context=context, event={"status": "success", "operations": []}) diff --git a/providers/google/tests/unit/google/cloud/transfers/test_bigquery_to_mssql.py b/providers/google/tests/unit/google/cloud/transfers/test_bigquery_to_mssql.py index e4fd89732467c..fb6023e61916b 100644 --- a/providers/google/tests/unit/google/cloud/transfers/test_bigquery_to_mssql.py +++ b/providers/google/tests/unit/google/cloud/transfers/test_bigquery_to_mssql.py @@ -81,7 +81,6 @@ def test_persist_links(self, mock_link): mock_link.persist.assert_called_once_with( context=mock_context, - task_instance=operator, dataset_id=TEST_DATASET, project_id=TEST_PROJECT, table_id=TEST_TABLE_ID,