Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def execute(self, context: Context) -> None:
# set the mapred_job_name if it's not set with dag, task, execution time info
if not self.mapred_job_name:
ti = context["ti"]
logical_date = context["logical_date"]
logical_date = context.get("logical_date", None)
if logical_date is None:
raise RuntimeError("logical_date is None")
hostname = ti.hostname or ""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@
import uuid
from collections.abc import Iterable, Mapping, Sequence
from copy import deepcopy
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Literal, NoReturn, cast, overload

import pendulum
from aiohttp import ClientSession as ClientSession
from gcloud.aio.bigquery import Job, Table as Table_async
from google.cloud.bigquery import (
Expand Down Expand Up @@ -75,6 +76,7 @@
GoogleBaseHook,
get_field,
)
from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.utils.hashlib_wrapper import md5
from airflow.utils.helpers import convert_camel_to_snake
from airflow.utils.log.logging_mixin import LoggingMixin
Expand All @@ -86,6 +88,11 @@
from google.api_core.retry import Retry
from requests import Session

if AIRFLOW_V_3_0_PLUS:
from airflow.sdk.definitions.context import Context
else:
from airflow.utils.context import Context

log = logging.getLogger(__name__)

BigQueryJob = CopyJob | QueryJob | LoadJob | ExtractJob
Expand Down Expand Up @@ -1274,7 +1281,7 @@ def insert_job(
job_api_repr.result(timeout=timeout, retry=retry)
return job_api_repr

def generate_job_id(self, job_id, dag_id, task_id, logical_date, configuration, force_rerun=False) -> str:
def generate_job_id(self, job_id, dag_id, task_id, date, configuration, force_rerun=False) -> str:
if force_rerun:
hash_base = str(uuid.uuid4())
else:
Expand All @@ -1285,10 +1292,18 @@ def generate_job_id(self, job_id, dag_id, task_id, logical_date, configuration,
if job_id:
return f"{job_id}_{uniqueness_suffix}"

exec_date = logical_date.isoformat()
exec_date = date.isoformat()
job_id = f"airflow_{dag_id}_{task_id}_{exec_date}_{uniqueness_suffix}"
return re.sub(r"[:\-+.]", "_", job_id)

def get_exec_date(self, context: Context) -> datetime:
date = context.get("logical_date", None)
if AIRFLOW_V_3_0_PLUS and date is None:
if dr := context.get("dag_run"):
if dr.run_after:
date = pendulum.instance(dr.run_after)
return date if date is not None else datetime.now(tz=timezone.utc)

def split_tablename(
self, table_input: str, default_project_id: str, var_name: str | None = None
) -> tuple[str, str, str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2374,7 +2374,7 @@ def execute(self, context: Any):
job_id=self.job_id,
dag_id=self.dag_id,
task_id=self.task_id,
logical_date=context["logical_date"],
date=hook.get_exec_date(context),
configuration=self.configuration,
force_rerun=self.force_rerun,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
import json
import re
import uuid
from collections.abc import Sequence
from collections.abc import Collection, Sequence
from typing import TYPE_CHECKING

import pendulum
from google.api_core.exceptions import AlreadyExists
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.cloud.workflows.executions_v1beta import Execution
Expand All @@ -36,12 +37,16 @@
)
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS

if TYPE_CHECKING:
from google.api_core.retry import Retry
from google.protobuf.field_mask_pb2 import FieldMask

from airflow.utils.context import Context
if AIRFLOW_V_3_0_PLUS:
from airflow.sdk.definitions.context import Context
else:
from airflow.utils.context import Context

from airflow.utils.hashlib_wrapper import md5

Expand Down Expand Up @@ -69,7 +74,7 @@ class WorkflowsCreateWorkflowOperator(GoogleCloudBaseOperator):
:param metadata: Additional metadata that is provided to the method.
"""

template_fields: Sequence[str] = ("location", "workflow", "workflow_id")
template_fields: Collection[str] = ("location", "workflow", "workflow_id")
template_fields_renderers = {"workflow": "json"}
operator_extra_links = (WorkflowsWorkflowDetailsLink(),)

Expand Down Expand Up @@ -101,7 +106,7 @@ def __init__(
self.impersonation_chain = impersonation_chain
self.force_rerun = force_rerun

def _workflow_id(self, context):
def _workflow_id(self, context: Context):
if self.workflow_id and not self.force_rerun:
# If users provide workflow id then assuring the idempotency
# is on their side
Expand All @@ -114,8 +119,14 @@ def _workflow_id(self, context):

# We are limited by allowed length of workflow_id so
# we use hash of whole information
exec_date = context["logical_date"].isoformat()
base = f"airflow_{self.dag_id}_{self.task_id}_{exec_date}_{hash_base}"
date = context.get("logical_date", None)
if AIRFLOW_V_3_0_PLUS and date is None:
if dr := context.get("dag_run"):
if dr.run_after:
date = pendulum.instance(dr.run_after)
exec_date = date if date is not None else datetime.datetime.now(tz=datetime.timezone.utc)
exec_date_iso = exec_date.isoformat()
base = f"airflow_{self.dag_id}_{self.task_id}_{exec_date_iso}_{hash_base}"
workflow_id = md5(base.encode()).hexdigest()
return re.sub(r"[:\-+.]", "_", workflow_id)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def execute(self, context: Context):
job_id=self.job_id,
dag_id=self.dag_id,
task_id=self.task_id,
logical_date=context["logical_date"],
date=hook.get_exec_date(context),
configuration=configuration,
force_rerun=self.force_rerun,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def execute(self, context: Context):
job_id=self.job_id,
dag_id=self.dag_id,
task_id=self.task_id,
logical_date=context["logical_date"],
date=hook.get_exec_date(context),
configuration=self.configuration,
force_rerun=self.force_rerun,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
_validate_value,
)

from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS

pytestmark = pytest.mark.filterwarnings("error::airflow.exceptions.AirflowProviderDeprecationWarning")

PROJECT_ID = "bq-project"
Expand Down Expand Up @@ -673,11 +675,28 @@ def test_job_id_validity(self, mock_md5, test_dag_id, expected_job_id):
job_id=None,
dag_id=test_dag_id,
task_id="test_job_id",
logical_date=datetime(2020, 1, 23),
date=datetime(2020, 1, 23),
configuration=configuration,
)
assert job_id == expected_job_id

def test_get_exec_date(self):
import pendulum

if AIRFLOW_V_3_0_PLUS:
from airflow.models import DagRun
from airflow.sdk.definitions.context import Context

ctx = Context(logical_date=pendulum.datetime(2025, 1, 1))
assert self.hook.get_exec_date(ctx) == pendulum.datetime(2025, 1, 1)
ctx = Context(dag_run=DagRun(run_after=pendulum.datetime(2025, 1, 1)))
assert self.hook.get_exec_date(ctx) == pendulum.datetime(2025, 1, 1)
else:
from airflow.utils.context import Context

ctx = Context(logical_date=pendulum.datetime(2025, 1, 1))
assert self.hook.get_exec_date(ctx) == pendulum.datetime(2025, 1, 1)

@mock.patch(
"airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_job",
return_value=mock.MagicMock(spec=CopyJob),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1595,7 +1595,7 @@ def test_bigquery_insert_job_operator_with_job_id_generate(
job_id=job_id,
dag_id="adhoc_airflow",
task_id="insert_query_job",
logical_date=ANY,
date=ANY,
configuration=configuration,
force_rerun=True,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from __future__ import annotations

import datetime
import json
import re
from unittest import mock

from google.protobuf.timestamp_pb2 import Timestamp
Expand All @@ -32,6 +34,7 @@
WorkflowsListWorkflowsOperator,
WorkflowsUpdateWorkflowOperator,
)
from airflow.utils.hashlib_wrapper import md5

BASE_PATH = "airflow.providers.google.cloud.operators.workflows.{}"
LOCATION = "europe-west1"
Expand Down Expand Up @@ -86,6 +89,39 @@ def test_execute(self, mock_hook, mock_object):

assert result == mock_object.to_dict.return_value

def test_execute_wihout_workflow_id(self):
import pendulum

from airflow.models.dagrun import DagRun

from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS

if AIRFLOW_V_3_0_PLUS:
from airflow.sdk.definitions.context import Context
else:
from airflow.utils.context import Context
op = WorkflowsCreateWorkflowOperator(
task_id="test_task",
workflow=WORKFLOW,
workflow_id="",
location=LOCATION,
project_id=PROJECT_ID,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
hash_base = json.dumps(WORKFLOW, sort_keys=True)
date = pendulum.datetime(2025, 1, 1)
ctx = Context(logical_date=date)
expected = md5(f"airflow_{op.dag_id}_test_task_{date.isoformat()}_{hash_base}".encode()).hexdigest()
assert op._workflow_id(ctx) == re.sub(r"[:\-+.]", "_", expected)

if AIRFLOW_V_3_0_PLUS:
ctx = Context(dag_run=DagRun(run_after=date))
assert op._workflow_id(ctx) == re.sub(r"[:\-+.]", "_", expected)


class TestWorkflowsUpdateWorkflowOperator:
@mock.patch(BASE_PATH.format("Workflow"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
)
from airflow.providers.google.cloud.transfers.gcs_to_bigquery import GCSToBigQueryOperator
from airflow.utils.state import TaskInstanceState
from airflow.utils.timezone import datetime

TASK_ID = "test-gcs-to-bq-operator"
TEST_EXPLICIT_DEST = "test-project.dataset.table"
Expand Down Expand Up @@ -1746,7 +1745,7 @@ def test_execute_without_external_table_generate_job_id_async_should_execute_suc
job_id=None,
dag_id="adhoc_airflow",
task_id=TASK_ID,
logical_date=datetime(2016, 1, 1, 0, 0),
date=hook.return_value.get_exec_date(),
configuration={},
force_rerun=True,
)
Expand Down