Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@ def __init__(

def choose_branch(self, context: Context) -> str | Iterable[str]:
if self.use_task_logical_date:
now = context["logical_date"]
now = context.get("logical_date")
if not now:
dag_run = context.get("dag_run")
now = dag_run.run_after # type: ignore[union-attr]
else:
now = timezone.coerce_datetime(timezone.utcnow())
lower, upper = target_times_as_dates(now, self.target_lower, self.target_upper)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,13 @@ def __init__(

def choose_branch(self, context: Context) -> str | Iterable[str]:
if self.use_task_logical_date:
now = context["logical_date"]
now = context.get("logical_date")
if not now:
dag_run = context.get("dag_run")
now = dag_run.run_after # type: ignore[union-attr]
else:
now = timezone.make_naive(timezone.utcnow(), self.dag.timezone)

if now.isoweekday() in self._week_day_num:
if now.isoweekday() in self._week_day_num: # type: ignore[union-attr]
return self.follow_task_ids_if_true
return self.follow_task_ids_if_false
31 changes: 31 additions & 0 deletions providers/standard/tests/unit/standard/operators/test_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,34 @@ def test_branch_datetime_operator_use_task_logical_date(self, dag_maker, target_
"branch_2": State.SKIPPED,
}
)

@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Skip on Airflow < 3.0")
@time_machine.travel("2020-12-01 09:00:00")
def test_choose_branch_should_use_run_after_when_logical_date_none(self, dag_maker):
with dag_maker(
"branch_datetime_operator_uses_run_after",
default_args={"owner": "airflow", "start_date": DEFAULT_DATE},
schedule=INTERVAL,
serialized=True,
):
branch_1 = EmptyOperator(task_id="branch_1")
branch_2 = EmptyOperator(task_id="branch_2")

branch_op = BranchDateTimeOperator(
task_id="datetime_branch",
follow_task_ids_if_true="branch_1",
follow_task_ids_if_false="branch_2",
target_upper=datetime.datetime(2020, 9, 7, 11, 0, 0),
target_lower=datetime.datetime(2020, 6, 7, 10, 0, 0),
use_task_logical_date=True,
)
branch_1.set_upstream(branch_op)
branch_2.set_upstream(branch_op)

dr = dag_maker.create_dagrun(
run_id="manual__run_after",
start_date=DEFAULT_DATE,
state=State.RUNNING,
**{"run_after": timezone.datetime(2020, 8, 7)},
)
assert branch_op.choose_branch(context={"dag_run": dr}) == "branch_1"
26 changes: 26 additions & 0 deletions providers/standard/tests/unit/standard/operators/test_weekday.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,32 @@ def test_branch_follow_true_with_logical_date(self, dag_maker):
},
)

@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Skip on Airflow < 3.0")
@time_machine.travel("2021-01-25") # Monday
def test_choose_branch_should_use_run_after_when_logical_date_none(self, dag_maker):
with dag_maker(
"branch_day_of_week_operator_test", start_date=DEFAULT_DATE, schedule=INTERVAL, serialized=True
):
branch_op = BranchDayOfWeekOperator(
task_id="make_choice",
follow_task_ids_if_true="branch_1",
follow_task_ids_if_false="branch_2",
week_day="Wednesday",
use_task_logical_date=True, # We compare to DEFAULT_DATE which is Wednesday
)
branch_1 = EmptyOperator(task_id="branch_1")
branch_2 = EmptyOperator(task_id="branch_2")
branch_1.set_upstream(branch_op)
branch_2.set_upstream(branch_op)

dr = dag_maker.create_dagrun(
run_id="manual__",
start_date=timezone.utcnow(),
state=State.RUNNING,
**{"run_after": DEFAULT_DATE},
)
assert branch_op.choose_branch(context={"dag_run": dr}) == "branch_1"

@time_machine.travel("2021-01-25") # Monday
def test_branch_follow_false(self, dag_maker):
"""Checks if BranchDayOfWeekOperator follow false branch"""
Expand Down
Loading