Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING

from airflow.models import BaseOperator
from airflow.providers.google.ads.hooks.ads import GoogleAdsHook
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.providers.google.version_compat import BaseOperator

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING

from airflow.models import BaseOperator
from airflow.providers.google.ads.hooks.ads import GoogleAdsHook
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.providers.google.version_compat import BaseOperator

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down
21 changes: 11 additions & 10 deletions providers/google/src/airflow/providers/google/cloud/links/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,23 @@

from typing import TYPE_CHECKING, ClassVar

from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS

if TYPE_CHECKING:
from airflow.models import BaseOperator
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
from airflow.sdk import BaseSensorOperator
from airflow.utils.context import Context
from airflow.providers.google.version_compat import (
AIRFLOW_V_3_0_PLUS,
BaseOperator,
BaseOperatorLink,
BaseSensorOperator,
)

if AIRFLOW_V_3_0_PLUS:
from airflow.sdk import BaseOperatorLink
from airflow.sdk.execution_time.xcom import XCom
else:
from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef]
from airflow.models.xcom import XCom # type: ignore[no-redef]

if TYPE_CHECKING:
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
from airflow.utils.context import Context

BASE_LINK = "https://console.cloud.google.com"


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,20 @@

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.google.cloud.links.base import BASE_LINK, BaseGoogleLink
from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.providers.google.version_compat import (
AIRFLOW_V_3_0_PLUS,
BaseOperator,
BaseOperatorLink,
)

if TYPE_CHECKING:
from airflow.models import BaseOperator
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.utils.context import Context

if AIRFLOW_V_3_0_PLUS:
from airflow.sdk import BaseOperatorLink
from airflow.sdk.execution_time.xcom import XCom
else:
from airflow.models import XCom # type: ignore[no-redef]
from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef]
from airflow.models.xcom import XCom # type: ignore[no-redef]


def __getattr__(name: str) -> Any:
Expand Down Expand Up @@ -94,16 +95,16 @@ class DataprocLink(BaseOperatorLink):
@staticmethod
def persist(
context: Context,
task_instance,
url: str,
resource: str,
region: str,
project_id: str,
):
task_instance.xcom_push(
context=context,
context["task_instance"].xcom_push(
key=DataprocLink.key,
value={
"region": task_instance.region,
"project_id": task_instance.project_id,
"region": region,
"project_id": project_id,
"url": url,
"resource": resource,
},
Expand Down Expand Up @@ -147,14 +148,13 @@ class DataprocListLink(BaseOperatorLink):
@staticmethod
def persist(
context: Context,
task_instance,
url: str,
project_id: str,
):
task_instance.xcom_push(
context=context,
context["task_instance"].xcom_push(
key=DataprocListLink.key,
value={
"project_id": task_instance.project_id,
"project_id": project_id,
"url": url,
},
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def execute(self, context: Context):
model_id = hook.extract_object_id(result)
self.log.info("Model is created, model_id: %s", model_id)

self.xcom_push(context, key="model_id", value=model_id)
context["task_instance"].xcom_push(key="model_id", value=model_id)
if project_id:
TranslationLegacyModelLink.persist(
context=context,
Expand Down Expand Up @@ -415,7 +415,7 @@ def execute(self, context: Context):
dataset_id = hook.extract_object_id(result)
self.log.info("Creating completed. Dataset id: %s", dataset_id)

self.xcom_push(context, key="dataset_id", value=dataset_id)
context["task_instance"].xcom_push(key="dataset_id", value=dataset_id)
project_id = self.project_id or hook.project_id
if project_id:
TranslationLegacyDatasetLink.persist(
Expand Down Expand Up @@ -1248,8 +1248,7 @@ def execute(self, context: Context):
result.append(Dataset.to_dict(dataset))
self.log.info("Datasets obtained.")

self.xcom_push(
context,
context["task_instance"].xcom_push(
key="dataset_id_list",
value=[hook.extract_object_id(d) for d in result],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def execute(self, context: Context):

result = TransferConfig.to_dict(response)
self.log.info("Created DTS transfer config %s", get_object_id(result))
self.xcom_push(context, key="transfer_config_id", value=get_object_id(result))
context["ti"].xcom_push(key="transfer_config_id", value=get_object_id(result))
# don't push AWS secret in XCOM
result.get("params", {}).pop("secret_access_key", None)
result.get("params", {}).pop("access_key_id", None)
Expand Down Expand Up @@ -335,7 +335,7 @@ def execute(self, context: Context):

result = StartManualTransferRunsResponse.to_dict(response)
run_id = get_object_id(result["runs"][0])
self.xcom_push(context, key="run_id", value=run_id)
context["ti"].xcom_push(key="run_id", value=run_id)

if not self.deferrable:
# Save as attribute for further use by OpenLineage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from google.api_core.gapic_v1.method import DEFAULT

from airflow.models import BaseOperator
from airflow.providers.google.version_compat import BaseOperator


class GoogleCloudBaseOperator(BaseOperator):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def execute(self, context: Context):
location=self.location,
)

self.xcom_push(context, key="id", value=result.id)
context["task_instance"].xcom_push(key="id", value=result.id)
project_id = self.project_id or hook.project_id
if project_id:
CloudBuildLink.persist(
Expand Down Expand Up @@ -235,7 +235,7 @@ def execute(self, context: Context):
metadata=self.metadata,
location=self.location,
)
self.xcom_push(context, key="id", value=self.id_)
context["task_instance"].xcom_push(key="id", value=self.id_)
if not self.wait:
return Build.to_dict(
hook.get_build(id_=self.id_, project_id=self.project_id, location=self.location)
Expand Down Expand Up @@ -358,7 +358,7 @@ def execute(self, context: Context):
metadata=self.metadata,
location=self.location,
)
self.xcom_push(context, key="id", value=result.id)
context["task_instance"].xcom_push(key="id", value=result.id)
project_id = self.project_id or hook.project_id
if project_id:
CloudBuildTriggerDetailsLink.persist(
Expand Down Expand Up @@ -854,7 +854,7 @@ def execute(self, context: Context):
location=self.location,
)

self.xcom_push(context, key="id", value=result.id)
context["task_instance"].xcom_push(key="id", value=result.id)
project_id = self.project_id or hook.project_id
if project_id:
CloudBuildLink.persist(
Expand Down Expand Up @@ -944,7 +944,7 @@ def execute(self, context: Context):
metadata=self.metadata,
location=self.location,
)
self.xcom_push(context, key="id", value=result.id)
context["task_instance"].xcom_push(key="id", value=result.id)
project_id = self.project_id or hook.project_id
if project_id:
CloudBuildLink.persist(
Expand Down Expand Up @@ -1030,7 +1030,7 @@ def execute(self, context: Context):
metadata=self.metadata,
location=self.location,
)
self.xcom_push(context, key="id", value=result.id)
context["task_instance"].xcom_push(key="id", value=result.id)
project_id = self.project_id or hook.project_id
if project_id:
CloudBuildTriggerDetailsLink.persist(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def execute(self, context: Context):
)
_, _, entry_id = result.name.rpartition("/")
self.log.info("Current entry_id ID: %s", entry_id)
self.xcom_push(context, key="entry_id", value=entry_id)
context["ti"].xcom_push(key="entry_id", value=entry_id)
DataCatalogEntryLink.persist(
context=context,
entry_id=self.entry_id,
Expand Down Expand Up @@ -283,7 +283,7 @@ def execute(self, context: Context):

_, _, entry_group_id = result.name.rpartition("/")
self.log.info("Current entry group ID: %s", entry_group_id)
self.xcom_push(context, key="entry_group_id", value=entry_group_id)
context["ti"].xcom_push(key="entry_group_id", value=entry_group_id)
DataCatalogEntryGroupLink.persist(
context=context,
entry_group_id=self.entry_group_id,
Expand Down Expand Up @@ -425,7 +425,7 @@ def execute(self, context: Context):

_, _, tag_id = tag.name.rpartition("/")
self.log.info("Current Tag ID: %s", tag_id)
self.xcom_push(context, key="tag_id", value=tag_id)
context["ti"].xcom_push(key="tag_id", value=tag_id)
DataCatalogEntryLink.persist(
context=context,
entry_id=self.entry,
Expand Down Expand Up @@ -542,7 +542,7 @@ def execute(self, context: Context):
)
_, _, tag_template = result.name.rpartition("/")
self.log.info("Current Tag ID: %s", tag_template)
self.xcom_push(context, key="tag_template_id", value=tag_template)
context["ti"].xcom_push(key="tag_template_id", value=tag_template)
DataCatalogTagTemplateLink.persist(
context=context,
tag_template_id=self.tag_template_id,
Expand Down Expand Up @@ -668,7 +668,7 @@ def execute(self, context: Context):
result = tag_template.fields[self.tag_template_field_id]

self.log.info("Current Tag ID: %s", self.tag_template_field_id)
self.xcom_push(context, key="tag_template_field_id", value=self.tag_template_field_id)
context["ti"].xcom_push(key="tag_template_field_id", value=self.tag_template_field_id)
DataCatalogTagTemplateLink.persist(
context=context,
tag_template_id=self.tag_template,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def set_current_job(current_job):
append_job_name=self.append_job_name,
)
job_id = self.hook.extract_job_id(self.job)
self.xcom_push(context, key="job_id", value=job_id)
context["task_instance"].xcom_push(key="job_id", value=job_id)
return job_id

self.job = self.hook.launch_job_with_template(
Expand Down Expand Up @@ -446,7 +446,7 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> str:
raise AirflowException(event["message"])

job_id = event["job_id"]
self.xcom_push(context, key="job_id", value=job_id)
context["task_instance"].xcom_push(key="job_id", value=job_id)
self.log.info("Task %s completed with response %s", self.task_id, event["message"])
return job_id

Expand Down Expand Up @@ -609,7 +609,7 @@ def set_current_job(current_job):
on_new_job_callback=set_current_job,
)
job_id = self.hook.extract_job_id(self.job)
self.xcom_push(context, key="job_id", value=job_id)
context["task_instance"].xcom_push(key="job_id", value=job_id)
return self.job

self.job = self.hook.launch_job_with_flex_template(
Expand Down Expand Up @@ -650,7 +650,7 @@ def execute_complete(self, context: Context, event: dict) -> dict[str, str]:

job_id = event["job_id"]
self.log.info("Task %s completed with response %s", job_id, event["message"])
self.xcom_push(context, key="job_id", value=job_id)
context["task_instance"].xcom_push(key="job_id", value=job_id)
job = self.hook.get_job(job_id=job_id, project_id=self.project_id, location=self.location)
return job

Expand Down Expand Up @@ -807,7 +807,7 @@ def execute_complete(self, context: Context, event: dict) -> dict[str, Any]:
raise AirflowException(event["message"])
job = event["job"]
self.log.info("Job %s completed with response %s", job["id"], event["message"])
self.xcom_push(context, key="job_id", value=job["id"])
context["task_instance"].xcom_push(key="job_id", value=job["id"])

return job

Expand Down Expand Up @@ -1025,7 +1025,7 @@ def execute(self, context: Context):
location=self.location,
)
DataflowPipelineLink.persist(context=context)
self.xcom_push(context, key="pipeline_name", value=self.pipeline_name)
context["task_instance"].xcom_push(key="pipeline_name", value=self.pipeline_name)
if self.pipeline:
if "error" in self.pipeline:
raise AirflowException(self.pipeline.get("error").get("message"))
Expand Down Expand Up @@ -1096,7 +1096,7 @@ def execute(self, context: Context):
location=self.location,
)["job"]
job_id = self.dataflow_hook.extract_job_id(self.job)
self.xcom_push(context, key="job_id", value=job_id)
context["task_instance"].xcom_push(key="job_id", value=job_id)
DataflowJobLink.persist(
context=context, project_id=self.project_id, region=self.location, job_id=job_id
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2533,8 +2533,7 @@ def execute(self, context: Context):
metadata=self.metadata,
)
self.log.info("EntryGroup on page: %s", entry_group_on_page)
self.xcom_push(
context=context,
context["ti"].xcom_push(
key="entry_group_page",
value=ListEntryGroupsResponse.to_dict(entry_group_on_page._response),
)
Expand Down Expand Up @@ -2954,8 +2953,7 @@ def execute(self, context: Context):
metadata=self.metadata,
)
self.log.info("EntryType on page: %s", entry_type_on_page)
self.xcom_push(
context=context,
context["ti"].xcom_push(
key="entry_type_page",
value=ListEntryTypesResponse.to_dict(entry_type_on_page._response),
)
Expand Down Expand Up @@ -3308,8 +3306,7 @@ def execute(self, context: Context):
metadata=self.metadata,
)
self.log.info("AspectType on page: %s", aspect_type_on_page)
self.xcom_push(
context=context,
context["ti"].xcom_push(
key="aspect_type_page",
value=ListAspectTypesResponse.to_dict(aspect_type_on_page._response),
)
Expand Down Expand Up @@ -3803,8 +3800,7 @@ def execute(self, context: Context):
metadata=self.metadata,
)
self.log.info("Entries on page: %s", entries_on_page)
self.xcom_push(
context=context,
context["ti"].xcom_push(
key="entry_page",
value=ListEntriesResponse.to_dict(entries_on_page._response),
)
Expand Down Expand Up @@ -3901,8 +3897,7 @@ def execute(self, context: Context):
metadata=self.metadata,
)
self.log.info("Entries on page: %s", entries_on_page)
self.xcom_push(
context=context,
context["ti"].xcom_push(
key="entry_page",
value=SearchEntriesResponse.to_dict(entries_on_page._response),
)
Expand Down
Loading