Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from airflow.utils import timezone
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunType
from airflow.utils.types import NOTSET, ArgNotSet, DagRunType

XCOM_LOGICAL_DATE_ISO = "trigger_logical_date_iso"
XCOM_RUN_ID = "trigger_run_id"
Expand Down Expand Up @@ -153,7 +153,7 @@ def __init__(
trigger_dag_id: str,
trigger_run_id: str | None = None,
conf: dict | None = None,
logical_date: str | datetime.datetime | None = None,
logical_date: str | datetime.datetime | None | ArgNotSet = NOTSET,
reset_dag_run: bool = False,
wait_for_completion: bool = False,
poke_interval: int = 60,
Expand All @@ -180,19 +180,23 @@ def __init__(
self.failed_states = [DagRunState.FAILED]
self.skip_when_already_exists = skip_when_already_exists
self._defer = deferrable

if logical_date is not None and not isinstance(logical_date, (str, datetime.datetime)):
type_name = type(logical_date).__name__
self.logical_date = logical_date
if logical_date is NOTSET:
self.logical_date = NOTSET
elif logical_date is None or isinstance(logical_date, (str, datetime.datetime)):
self.logical_date = logical_date
else:
raise TypeError(
f"Expected str or datetime.datetime type for parameter 'logical_date'. Got {type_name}"
f"Expected str, datetime.datetime, or None for parameter 'logical_date'. Got {type(logical_date).__name__}"
)

self.logical_date = logical_date

def execute(self, context: Context):
if self.logical_date is None or isinstance(self.logical_date, datetime.datetime):
parsed_logical_date = self.logical_date
else:
if self.logical_date is NOTSET:
# If no logical_date is provided we will set utcnow()
parsed_logical_date = timezone.utcnow()
elif self.logical_date is None or isinstance(self.logical_date, datetime.datetime):
parsed_logical_date = self.logical_date # type: ignore
elif isinstance(self.logical_date, str):
parsed_logical_date = timezone.parse(self.logical_date)

try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_trigger_dagrun(self):

assert exc_info.value.trigger_dag_id == TRIGGERED_DAG_ID
assert exc_info.value.conf == {"foo": "bar"}
assert exc_info.value.logical_date is None
assert exc_info.value.logical_date is not None
assert exc_info.value.reset_dag_run is False
assert exc_info.value.skip_when_already_exists is False
assert exc_info.value.wait_for_completion is False
Expand All @@ -119,7 +119,7 @@ def test_trigger_dagrun(self):
run_type=DagRunType.MANUAL, run_after=timezone.utcnow()
).rsplit("_", 1)[0]
# rsplit because last few characters are random.
assert exc_info.value.dag_run_id.rsplit("_", 1)[0] == expected_run_id
assert exc_info.value.dag_run_id == expected_run_id

@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3")
@mock.patch("airflow.providers.standard.operators.trigger_dagrun.XCom.get_one")
Expand Down
7 changes: 7 additions & 0 deletions task-sdk/tests/task_sdk/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@

if TYPE_CHECKING:
from kgb import SpyAgency
import time_machine


def get_inline_dag(dag_id: str, task: BaseOperator) -> DAG:
Expand Down Expand Up @@ -2225,6 +2226,7 @@ class CustomOperator(BaseOperator):
class TestTriggerDagRunOperator:
"""Tests to verify various aspects of TriggerDagRunOperator"""

@time_machine.travel("2025-01-01 00:00:00", tick=False)
def test_handle_trigger_dag_run(self, create_runtime_ti, mock_supervisor_comms):
"""Test that TriggerDagRunOperator (with default args) sends the correct message to the Supervisor"""
from airflow.providers.standard.operators.trigger_dagrun import TriggerDagRunOperator
Expand All @@ -2249,6 +2251,7 @@ def test_handle_trigger_dag_run(self, create_runtime_ti, mock_supervisor_comms):
dag_id="test_dag",
run_id="test_run_id",
reset_dag_run=False,
logical_date=datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc),
),
log=mock.ANY,
),
Expand All @@ -2274,6 +2277,7 @@ def test_handle_trigger_dag_run(self, create_runtime_ti, mock_supervisor_comms):
(False, TaskInstanceState.FAILED),
],
)
@time_machine.travel("2025-01-01 00:00:00", tick=False)
def test_handle_trigger_dag_run_conflict(
self, skip_when_already_exists, expected_state, create_runtime_ti, mock_supervisor_comms
):
Expand All @@ -2299,6 +2303,7 @@ def test_handle_trigger_dag_run_conflict(
mock.call.send_request(
msg=TriggerDagRun(
dag_id="test_dag",
logical_date=datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc),
run_id="test_run_id",
reset_dag_run=False,
),
Expand All @@ -2318,6 +2323,7 @@ def test_handle_trigger_dag_run_conflict(
([DagRunState.SUCCESS], None, DagRunState.FAILED, DagRunState.FAILED),
],
)
@time_machine.travel("2025-01-01 00:00:00", tick=False)
def test_handle_trigger_dag_run_wait_for_completion(
self,
allowed_states,
Expand Down Expand Up @@ -2367,6 +2373,7 @@ def test_handle_trigger_dag_run_wait_for_completion(
msg=TriggerDagRun(
dag_id="test_dag",
run_id="test_run_id",
logical_date=datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc),
),
log=mock.ANY,
),
Expand Down