diff --git a/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py b/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py index 24b9a59b49a45..2057ee178fc2b 100644 --- a/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py +++ b/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py @@ -43,7 +43,7 @@ from airflow.utils import timezone from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.state import DagRunState -from airflow.utils.types import DagRunType +from airflow.utils.types import NOTSET, ArgNotSet, DagRunType XCOM_LOGICAL_DATE_ISO = "trigger_logical_date_iso" XCOM_RUN_ID = "trigger_run_id" @@ -153,7 +153,7 @@ def __init__( trigger_dag_id: str, trigger_run_id: str | None = None, conf: dict | None = None, - logical_date: str | datetime.datetime | None = None, + logical_date: str | datetime.datetime | None | ArgNotSet = NOTSET, reset_dag_run: bool = False, wait_for_completion: bool = False, poke_interval: int = 60, @@ -180,19 +180,23 @@ def __init__( self.failed_states = [DagRunState.FAILED] self.skip_when_already_exists = skip_when_already_exists self._defer = deferrable - - if logical_date is not None and not isinstance(logical_date, (str, datetime.datetime)): - type_name = type(logical_date).__name__ + self.logical_date = logical_date + if logical_date is NOTSET: + self.logical_date = NOTSET + elif logical_date is None or isinstance(logical_date, (str, datetime.datetime)): + self.logical_date = logical_date + else: raise TypeError( - f"Expected str or datetime.datetime type for parameter 'logical_date'. Got {type_name}" + f"Expected str, datetime.datetime, or None for parameter 'logical_date'. Got {type(logical_date).__name__}" ) - self.logical_date = logical_date - def execute(self, context: Context): - if self.logical_date is None or isinstance(self.logical_date, datetime.datetime): - parsed_logical_date = self.logical_date - else: + if self.logical_date is NOTSET: + # If no logical_date is provided we will set utcnow() + parsed_logical_date = timezone.utcnow() + elif self.logical_date is None or isinstance(self.logical_date, datetime.datetime): + parsed_logical_date = self.logical_date # type: ignore + elif isinstance(self.logical_date, str): parsed_logical_date = timezone.parse(self.logical_date) try: diff --git a/providers/standard/tests/unit/standard/operators/test_trigger_dagrun.py b/providers/standard/tests/unit/standard/operators/test_trigger_dagrun.py index 55dc7eccc4dc0..b82b9e0929ee2 100644 --- a/providers/standard/tests/unit/standard/operators/test_trigger_dagrun.py +++ b/providers/standard/tests/unit/standard/operators/test_trigger_dagrun.py @@ -108,7 +108,7 @@ def test_trigger_dagrun(self): assert exc_info.value.trigger_dag_id == TRIGGERED_DAG_ID assert exc_info.value.conf == {"foo": "bar"} - assert exc_info.value.logical_date is None + assert exc_info.value.logical_date is not None assert exc_info.value.reset_dag_run is False assert exc_info.value.skip_when_already_exists is False assert exc_info.value.wait_for_completion is False @@ -119,7 +119,7 @@ def test_trigger_dagrun(self): run_type=DagRunType.MANUAL, run_after=timezone.utcnow() ).rsplit("_", 1)[0] # rsplit because last few characters are random. - assert exc_info.value.dag_run_id.rsplit("_", 1)[0] == expected_run_id + assert exc_info.value.dag_run_id == expected_run_id @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3") @mock.patch("airflow.providers.standard.operators.trigger_dagrun.XCom.get_one") diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 6b25bb81bf20f..1db6efa1b14cf 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -117,6 +117,7 @@ if TYPE_CHECKING: from kgb import SpyAgency +import time_machine def get_inline_dag(dag_id: str, task: BaseOperator) -> DAG: @@ -2225,6 +2226,7 @@ class CustomOperator(BaseOperator): class TestTriggerDagRunOperator: """Tests to verify various aspects of TriggerDagRunOperator""" + @time_machine.travel("2025-01-01 00:00:00", tick=False) def test_handle_trigger_dag_run(self, create_runtime_ti, mock_supervisor_comms): """Test that TriggerDagRunOperator (with default args) sends the correct message to the Supervisor""" from airflow.providers.standard.operators.trigger_dagrun import TriggerDagRunOperator @@ -2249,6 +2251,7 @@ def test_handle_trigger_dag_run(self, create_runtime_ti, mock_supervisor_comms): dag_id="test_dag", run_id="test_run_id", reset_dag_run=False, + logical_date=datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc), ), log=mock.ANY, ), @@ -2274,6 +2277,7 @@ def test_handle_trigger_dag_run(self, create_runtime_ti, mock_supervisor_comms): (False, TaskInstanceState.FAILED), ], ) + @time_machine.travel("2025-01-01 00:00:00", tick=False) def test_handle_trigger_dag_run_conflict( self, skip_when_already_exists, expected_state, create_runtime_ti, mock_supervisor_comms ): @@ -2299,6 +2303,7 @@ def test_handle_trigger_dag_run_conflict( mock.call.send_request( msg=TriggerDagRun( dag_id="test_dag", + logical_date=datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc), run_id="test_run_id", reset_dag_run=False, ), @@ -2318,6 +2323,7 @@ def test_handle_trigger_dag_run_conflict( ([DagRunState.SUCCESS], None, DagRunState.FAILED, DagRunState.FAILED), ], ) + @time_machine.travel("2025-01-01 00:00:00", tick=False) def test_handle_trigger_dag_run_wait_for_completion( self, allowed_states, @@ -2367,6 +2373,7 @@ def test_handle_trigger_dag_run_wait_for_completion( msg=TriggerDagRun( dag_id="test_dag", run_id="test_run_id", + logical_date=datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc), ), log=mock.ANY, ),