Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion airflow-core/src/airflow/dag_processing/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,32 @@ def calculate(cls, dags: dict[str, LazyDeserializedDAG], *, session: Session) ->

def _update_dag_tags(tag_names: set[str], dm: DagModel, *, session: Session) -> None:
orm_tags = {t.name: t for t in dm.tags}
tags_to_delete = []
for name, orm_tag in orm_tags.items():
if name not in tag_names:
session.delete(orm_tag)
dm.tags.extend(DagTag(name=name, dag_id=dm.dag_id) for name in tag_names.difference(orm_tags))
tags_to_delete.append(orm_tag)

tags_to_add = tag_names.difference(orm_tags)
if tags_to_delete:
# Remove deleted tags from the collection to keep it in sync
for tag in tags_to_delete:
dm.tags.remove(tag)

# Check if there's a potential case-only rename on MySQL (e.g., 'tag' -> 'TAG').
# MySQL uses case-insensitive collation for the (name, dag_id) primary key by default,
# which can cause duplicate key errors when renaming tags with only case changes.
if get_dialect_name(session) == "mysql":
orm_tags_lower = {name.lower(): name for name in orm_tags}
has_case_only_change = any(tag.lower() in orm_tags_lower for tag in tags_to_add)

if has_case_only_change:
# Force DELETE operations to execute before INSERT operations.
session.flush()
# Refresh the tags relationship from the database to reflect the deletions.
session.expire(dm, ["tags"])

dm.tags.extend(DagTag(name=name, dag_id=dm.dag_id) for name in tags_to_add)


def _update_dag_owner_links(dag_owner_links: dict[str, str], dm: DagModel, *, session: Session) -> None:
Expand Down
32 changes: 32 additions & 0 deletions airflow-core/tests/unit/dag_processing/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
AssetModelOperation,
DagModelOperation,
_get_latest_runs_stmt,
_update_dag_tags,
update_dag_parsing_results_in_db,
)
from airflow.exceptions import SerializationError
Expand All @@ -48,6 +49,7 @@
DagScheduleAssetNameReference,
DagScheduleAssetUriReference,
)
from airflow.models.dag import DagTag
from airflow.models.errors import ParseImportError
from airflow.models.serialized_dag import SerializedDagModel
from airflow.providers.standard.operators.empty import EmptyOperator
Expand Down Expand Up @@ -941,3 +943,33 @@ def test_max_consecutive_failed_dag_runs_defaults_from_conf_when_none(
update_dag_parsing_results_in_db("testing", None, [dag], {}, 0.1, set(), session)
orm_dag = session.get(DagModel, "dag_max_failed_runs_default")
assert orm_dag.max_consecutive_failed_dag_runs == 6


@pytest.mark.db_test
class TestUpdateDagTags:
@pytest.fixture(autouse=True)
def setup_teardown(self, session):
yield
session.query(DagModel).filter(DagModel.dag_id == "test_dag").delete()
session.commit()

@pytest.mark.parametrize(
["initial_tags", "new_tags", "expected_tags"],
[
(["dangerous"], {"DANGEROUS"}, {"DANGEROUS"}),
(["existing"], {"existing", "new"}, {"existing", "new"}),
(["tag1", "tag2"], {"tag1"}, {"tag1"}),
(["keep", "remove", "lowercase"], {"keep", "LOWERCASE", "new"}, {"keep", "LOWERCASE", "new"}),
(["tag1", "tag2"], set(), set()),
],
)
def test_update_dag_tags(self, testing_dag_bundle, session, initial_tags, new_tags, expected_tags):
dag_model = DagModel(dag_id="test_dag", bundle_name="testing")
dag_model.tags = [DagTag(name=tag, dag_id="test_dag") for tag in initial_tags]
session.add(dag_model)
session.commit()

_update_dag_tags(new_tags, dag_model, session=session)
session.commit()

assert {t.name for t in dag_model.tags} == expected_tags