Skip to content

Commit

Permalink
Fix dag (un)pausing won't work on environment where dag files are mis…
Browse files Browse the repository at this point in the history
…sing (#40345)

closes: #38834

#38265 added bulk pause and resume of DAGs. However, this PR seems to reuse the cli util method that collects DAGs from the default dag folder but not from the metadata DB. Hence, this would cause the unpause command to fail on environments where the dag folder is missing.
  • Loading branch information
boushphong authored Jul 1, 2024
1 parent dc03889 commit e3d62c3
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 18 deletions.
32 changes: 16 additions & 16 deletions airflow/cli/commands/dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,18 +225,24 @@ def dag_unpause(args) -> None:
def set_is_paused(is_paused: bool, args) -> None:
"""Set is_paused for DAG by a given dag_id."""
should_apply = True
dags = [
dag
for dag in get_dags(args.subdir, dag_id=args.dag_id, use_regex=args.treat_dag_id_as_regex)
if is_paused != dag.get_is_paused()
]
with create_session() as session:
query = select(DagModel)

if args.treat_dag_id_as_regex:
query = query.where(DagModel.dag_id.regexp_match(args.dag_id))
else:
query = query.where(DagModel.dag_id == args.dag_id)

query = query.where(DagModel.is_paused != is_paused)

matched_dags = session.scalars(query).all()

if not dags:
if not matched_dags:
print(f"No {'un' if is_paused else ''}paused DAGs were found")
return

if not args.yes and args.treat_dag_id_as_regex:
dags_ids = [dag.dag_id for dag in dags]
dags_ids = [dag.dag_id for dag in matched_dags]
question = (
f"You are about to {'un' if not is_paused else ''}pause {len(dags_ids)} DAGs:\n"
f"{','.join(dags_ids)}"
Expand All @@ -245,17 +251,11 @@ def set_is_paused(is_paused: bool, args) -> None:
should_apply = ask_yesno(question)

if should_apply:
dags_models = [DagModel.get_dagmodel(dag.dag_id) for dag in dags]
for dag_model in dags_models:
if dag_model is not None:
dag_model.set_is_paused(is_paused=is_paused)
for dag_model in matched_dags:
dag_model.set_is_paused(is_paused=is_paused)

AirflowConsole().print_as(
data=[
{"dag_id": dag.dag_id, "is_paused": dag.get_is_paused()}
for dag in dags_models
if dag is not None
],
data=[{"dag_id": dag.dag_id, "is_paused": not dag.get_is_paused()} for dag in matched_dags],
output=args.output,
)
else:
Expand Down
13 changes: 11 additions & 2 deletions tests/cli/commands/test_dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,10 +717,19 @@ def test_pause_regex_yes(self, mock_yesno):
mock_yesno.assert_not_called()
dag_command.dag_unpause(args)

def test_pause_non_existing_dag_error(self):
def test_pause_non_existing_dag_do_not_error(self):
args = self.parser.parse_args(["dags", "pause", "non_existing_dag"])
with pytest.raises(AirflowException):
with contextlib.redirect_stdout(StringIO()) as temp_stdout:
dag_command.dag_pause(args)
out = temp_stdout.getvalue().strip().splitlines()[-1]
assert out == "No unpaused DAGs were found"

def test_unpause_non_existing_dag_do_not_error(self):
args = self.parser.parse_args(["dags", "unpause", "non_existing_dag"])
with contextlib.redirect_stdout(StringIO()) as temp_stdout:
dag_command.dag_unpause(args)
out = temp_stdout.getvalue().strip().splitlines()[-1]
assert out == "No paused DAGs were found"

def test_unpause_already_unpaused_dag_do_not_error(self):
args = self.parser.parse_args(["dags", "unpause", "example_bash_operator", "--yes"])
Expand Down

0 comments on commit e3d62c3

Please sign in to comment.