diff --git a/providers/standard/src/airflow/providers/standard/operators/datetime.py b/providers/standard/src/airflow/providers/standard/operators/datetime.py index dd788dd95fab2..18c6b0a7a9aa5 100644 --- a/providers/standard/src/airflow/providers/standard/operators/datetime.py +++ b/providers/standard/src/airflow/providers/standard/operators/datetime.py @@ -77,7 +77,10 @@ def __init__( def choose_branch(self, context: Context) -> str | Iterable[str]: if self.use_task_logical_date: - now = context["logical_date"] + now = context.get("logical_date") + if not now: + dag_run = context.get("dag_run") + now = dag_run.run_after # type: ignore[union-attr] else: now = timezone.coerce_datetime(timezone.utcnow()) lower, upper = target_times_as_dates(now, self.target_lower, self.target_upper) diff --git a/providers/standard/src/airflow/providers/standard/operators/weekday.py b/providers/standard/src/airflow/providers/standard/operators/weekday.py index 89a361385e957..bcae0b746c524 100644 --- a/providers/standard/src/airflow/providers/standard/operators/weekday.py +++ b/providers/standard/src/airflow/providers/standard/operators/weekday.py @@ -116,10 +116,13 @@ def __init__( def choose_branch(self, context: Context) -> str | Iterable[str]: if self.use_task_logical_date: - now = context["logical_date"] + now = context.get("logical_date") + if not now: + dag_run = context.get("dag_run") + now = dag_run.run_after # type: ignore[union-attr] else: now = timezone.make_naive(timezone.utcnow(), self.dag.timezone) - if now.isoweekday() in self._week_day_num: + if now.isoweekday() in self._week_day_num: # type: ignore[union-attr] return self.follow_task_ids_if_true return self.follow_task_ids_if_false diff --git a/providers/standard/tests/unit/standard/operators/test_datetime.py b/providers/standard/tests/unit/standard/operators/test_datetime.py index f23c0f9b757b8..790a4d8c5493d 100644 --- a/providers/standard/tests/unit/standard/operators/test_datetime.py +++ b/providers/standard/tests/unit/standard/operators/test_datetime.py @@ -312,3 +312,34 @@ def test_branch_datetime_operator_use_task_logical_date(self, dag_maker, target_ "branch_2": State.SKIPPED, } ) + + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Skip on Airflow < 3.0") + @time_machine.travel("2020-12-01 09:00:00") + def test_choose_branch_should_use_run_after_when_logical_date_none(self, dag_maker): + with dag_maker( + "branch_datetime_operator_uses_run_after", + default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, + schedule=INTERVAL, + serialized=True, + ): + branch_1 = EmptyOperator(task_id="branch_1") + branch_2 = EmptyOperator(task_id="branch_2") + + branch_op = BranchDateTimeOperator( + task_id="datetime_branch", + follow_task_ids_if_true="branch_1", + follow_task_ids_if_false="branch_2", + target_upper=datetime.datetime(2020, 9, 7, 11, 0, 0), + target_lower=datetime.datetime(2020, 6, 7, 10, 0, 0), + use_task_logical_date=True, + ) + branch_1.set_upstream(branch_op) + branch_2.set_upstream(branch_op) + + dr = dag_maker.create_dagrun( + run_id="manual__run_after", + start_date=DEFAULT_DATE, + state=State.RUNNING, + **{"run_after": timezone.datetime(2020, 8, 7)}, + ) + assert branch_op.choose_branch(context={"dag_run": dr}) == "branch_1" diff --git a/providers/standard/tests/unit/standard/operators/test_weekday.py b/providers/standard/tests/unit/standard/operators/test_weekday.py index ddcacddfc4352..583f20fd663be 100644 --- a/providers/standard/tests/unit/standard/operators/test_weekday.py +++ b/providers/standard/tests/unit/standard/operators/test_weekday.py @@ -179,6 +179,32 @@ def test_branch_follow_true_with_logical_date(self, dag_maker): }, ) + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Skip on Airflow < 3.0") + @time_machine.travel("2021-01-25") # Monday + def test_choose_branch_should_use_run_after_when_logical_date_none(self, dag_maker): + with dag_maker( + "branch_day_of_week_operator_test", start_date=DEFAULT_DATE, schedule=INTERVAL, serialized=True + ): + branch_op = BranchDayOfWeekOperator( + task_id="make_choice", + follow_task_ids_if_true="branch_1", + follow_task_ids_if_false="branch_2", + week_day="Wednesday", + use_task_logical_date=True, # We compare to DEFAULT_DATE which is Wednesday + ) + branch_1 = EmptyOperator(task_id="branch_1") + branch_2 = EmptyOperator(task_id="branch_2") + branch_1.set_upstream(branch_op) + branch_2.set_upstream(branch_op) + + dr = dag_maker.create_dagrun( + run_id="manual__", + start_date=timezone.utcnow(), + state=State.RUNNING, + **{"run_after": DEFAULT_DATE}, + ) + assert branch_op.choose_branch(context={"dag_run": dr}) == "branch_1" + @time_machine.travel("2021-01-25") # Monday def test_branch_follow_false(self, dag_maker): """Checks if BranchDayOfWeekOperator follow false branch"""