diff --git a/airflow-core/src/airflow/dag_processing/collection.py b/airflow-core/src/airflow/dag_processing/collection.py index 31a664b7e02b4..6634754a58df1 100644 --- a/airflow-core/src/airflow/dag_processing/collection.py +++ b/airflow-core/src/airflow/dag_processing/collection.py @@ -159,10 +159,32 @@ def calculate(cls, dags: dict[str, LazyDeserializedDAG], *, session: Session) -> def _update_dag_tags(tag_names: set[str], dm: DagModel, *, session: Session) -> None: orm_tags = {t.name: t for t in dm.tags} + tags_to_delete = [] for name, orm_tag in orm_tags.items(): if name not in tag_names: session.delete(orm_tag) - dm.tags.extend(DagTag(name=name, dag_id=dm.dag_id) for name in tag_names.difference(orm_tags)) + tags_to_delete.append(orm_tag) + + tags_to_add = tag_names.difference(orm_tags) + if tags_to_delete: + # Remove deleted tags from the collection to keep it in sync + for tag in tags_to_delete: + dm.tags.remove(tag) + + # Check if there's a potential case-only rename on MySQL (e.g., 'tag' -> 'TAG'). + # MySQL uses case-insensitive collation for the (name, dag_id) primary key by default, + # which can cause duplicate key errors when renaming tags with only case changes. + if get_dialect_name(session) == "mysql": + orm_tags_lower = {name.lower(): name for name in orm_tags} + has_case_only_change = any(tag.lower() in orm_tags_lower for tag in tags_to_add) + + if has_case_only_change: + # Force DELETE operations to execute before INSERT operations. + session.flush() + # Refresh the tags relationship from the database to reflect the deletions. + session.expire(dm, ["tags"]) + + dm.tags.extend(DagTag(name=name, dag_id=dm.dag_id) for name in tags_to_add) def _update_dag_owner_links(dag_owner_links: dict[str, str], dm: DagModel, *, session: Session) -> None: diff --git a/airflow-core/tests/unit/dag_processing/test_collection.py b/airflow-core/tests/unit/dag_processing/test_collection.py index e7d12f62c32be..92c03e97c52b8 100644 --- a/airflow-core/tests/unit/dag_processing/test_collection.py +++ b/airflow-core/tests/unit/dag_processing/test_collection.py @@ -37,6 +37,7 @@ AssetModelOperation, DagModelOperation, _get_latest_runs_stmt, + _update_dag_tags, update_dag_parsing_results_in_db, ) from airflow.exceptions import SerializationError @@ -48,6 +49,7 @@ DagScheduleAssetNameReference, DagScheduleAssetUriReference, ) +from airflow.models.dag import DagTag from airflow.models.errors import ParseImportError from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.standard.operators.empty import EmptyOperator @@ -941,3 +943,33 @@ def test_max_consecutive_failed_dag_runs_defaults_from_conf_when_none( update_dag_parsing_results_in_db("testing", None, [dag], {}, 0.1, set(), session) orm_dag = session.get(DagModel, "dag_max_failed_runs_default") assert orm_dag.max_consecutive_failed_dag_runs == 6 + + +@pytest.mark.db_test +class TestUpdateDagTags: + @pytest.fixture(autouse=True) + def setup_teardown(self, session): + yield + session.query(DagModel).filter(DagModel.dag_id == "test_dag").delete() + session.commit() + + @pytest.mark.parametrize( + ["initial_tags", "new_tags", "expected_tags"], + [ + (["dangerous"], {"DANGEROUS"}, {"DANGEROUS"}), + (["existing"], {"existing", "new"}, {"existing", "new"}), + (["tag1", "tag2"], {"tag1"}, {"tag1"}), + (["keep", "remove", "lowercase"], {"keep", "LOWERCASE", "new"}, {"keep", "LOWERCASE", "new"}), + (["tag1", "tag2"], set(), set()), + ], + ) + def test_update_dag_tags(self, testing_dag_bundle, session, initial_tags, new_tags, expected_tags): + dag_model = DagModel(dag_id="test_dag", bundle_name="testing") + dag_model.tags = [DagTag(name=tag, dag_id="test_dag") for tag in initial_tags] + session.add(dag_model) + session.commit() + + _update_dag_tags(new_tags, dag_model, session=session) + session.commit() + + assert {t.name for t in dag_model.tags} == expected_tags