diff --git a/airflow/migrations/versions/0140_2_9_0_update_trigger_kwargs_type.py b/airflow/migrations/versions/0140_2_9_0_update_trigger_kwargs_type.py index dbde1201e4cd0..2d57686e43f46 100644 --- a/airflow/migrations/versions/0140_2_9_0_update_trigger_kwargs_type.py +++ b/airflow/migrations/versions/0140_2_9_0_update_trigger_kwargs_type.py @@ -16,18 +16,22 @@ # specific language governing permissions and limitations # under the License. -"""update trigger kwargs type +"""update trigger kwargs type and encrypt Revision ID: 1949afb29106 Revises: ee1467d4aa35 Create Date: 2024-03-17 22:09:09.406395 """ +import json +from textwrap import dedent + +from alembic import context, op import sqlalchemy as sa +from sqlalchemy.orm import lazyload +from airflow.serialization.serialized_objects import BaseSerialization from airflow.models.trigger import Trigger -from alembic import op - from airflow.utils.sqlalchemy import ExtendedJSON # revision identifiers, used by Alembic. @@ -38,13 +42,43 @@ airflow_version = "2.9.0" +def get_session() -> sa.orm.Session: + conn = op.get_bind() + sessionmaker = sa.orm.sessionmaker() + return sessionmaker(bind=conn) + def upgrade(): - """Update trigger kwargs type to string""" + """Update trigger kwargs type to string and encrypt""" with op.batch_alter_table("trigger") as batch_op: batch_op.alter_column("kwargs", type_=sa.Text(), ) + if not context.is_offline_mode(): + session = get_session() + try: + for trigger in session.query(Trigger).options(lazyload(Trigger.task_instance)): + trigger.kwargs = trigger.kwargs + session.commit() + finally: + session.close() + def downgrade(): - """Unapply update trigger kwargs type to string""" + """Unapply update trigger kwargs type to string and encrypt""" + if context.is_offline_mode(): + print(dedent(""" + ------------ + -- WARNING: Unable to decrypt trigger kwargs automatically in offline mode! + -- If any trigger rows exist when you do an offline downgrade, the migration will fail. + ------------ + """)) + else: + session = get_session() + try: + for trigger in session.query(Trigger).options(lazyload(Trigger.task_instance)): + trigger.encrypted_kwargs = json.dumps(BaseSerialization.serialize(trigger.kwargs)) + session.commit() + finally: + session.close() + with op.batch_alter_table("trigger") as batch_op: - batch_op.alter_column("kwargs", type_=ExtendedJSON(), postgresql_using="kwargs::json") + batch_op.alter_column("kwargs", type_=ExtendedJSON(), postgresql_using='kwargs::json') diff --git a/airflow/models/trigger.py b/airflow/models/trigger.py index d3509b377d532..670d88d6142bb 100644 --- a/airflow/models/trigger.py +++ b/airflow/models/trigger.py @@ -116,7 +116,15 @@ def _decrypt_kwargs(encrypted_kwargs: str) -> dict[str, Any]: from airflow.models.crypto import get_fernet from airflow.serialization.serialized_objects import BaseSerialization - decrypted_kwargs = json.loads(get_fernet().decrypt(encrypted_kwargs.encode("utf-8")).decode("utf-8")) + # We weren't able to encrypt the kwargs in all migration paths, + # so we need to handle the case where they are not encrypted. + # Triggers aren't long lasting, so we can skip encrypting them now. + if encrypted_kwargs.startswith("{"): + decrypted_kwargs = json.loads(encrypted_kwargs) + else: + decrypted_kwargs = json.loads( + get_fernet().decrypt(encrypted_kwargs.encode("utf-8")).decode("utf-8") + ) return BaseSerialization.deserialize(decrypted_kwargs) diff --git a/airflow/utils/db.py b/airflow/utils/db.py index c0d282a587343..b7997498bb47d 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -972,33 +972,6 @@ def synchronize_log_template(*, session: Session = NEW_SESSION) -> None: session.add(LogTemplate(filename=filename, elasticsearch_id=elasticsearch_id)) -def encrypt_trigger_kwargs(*, session: Session) -> None: - """Encrypt trigger kwargs.""" - from airflow.models.trigger import Trigger - from airflow.serialization.serialized_objects import BaseSerialization - - for trigger in session.query(Trigger): - # convert serialized dict to string and encrypt it - trigger.kwargs = BaseSerialization.deserialize(json.loads(trigger.encrypted_kwargs)) - session.commit() - - -def decrypt_trigger_kwargs(*, session: Session) -> None: - """Decrypt trigger kwargs.""" - from airflow.models.trigger import Trigger - from airflow.serialization.serialized_objects import BaseSerialization - - if not inspect(session.bind).has_table(Trigger.__tablename__): - # table does not exist, nothing to do - # this can happen when we downgrade to an old version before the Trigger table was added - return - - for trigger in session.scalars(select(Trigger.encrypted_kwargs)): - # decrypt the string and convert it to serialized dict - trigger.encrypted_kwargs = json.dumps(BaseSerialization.serialize(trigger.kwargs)) - session.commit() - - def check_conn_id_duplicates(session: Session) -> Iterable[str]: """ Check unique conn_id in connection table. @@ -1666,12 +1639,6 @@ def upgradedb( _reserialize_dags(session=session) add_default_pool_if_not_exists(session=session) synchronize_log_template(session=session) - if _revision_greater( - config, - _REVISION_HEADS_MAP["2.9.0"], - _get_current_revision(session=session), - ): - encrypt_trigger_kwargs(session=session) @provide_session @@ -1744,12 +1711,6 @@ def downgrade(*, to_revision, from_revision=None, show_sql_only=False, session: else: log.info("Applying downgrade migrations.") command.downgrade(config, revision=to_revision, sql=show_sql_only) - if _revision_greater( - config, - _REVISION_HEADS_MAP["2.9.0"], - to_revision, - ): - decrypt_trigger_kwargs(session=session) def drop_airflow_models(connection): diff --git a/docs/apache-airflow/img/airflow_erd.sha256 b/docs/apache-airflow/img/airflow_erd.sha256 index 8947b7e631598..cdcf039446dd5 100644 --- a/docs/apache-airflow/img/airflow_erd.sha256 +++ b/docs/apache-airflow/img/airflow_erd.sha256 @@ -1 +1 @@ -072fb4b43a86ccb57765ec3f163350519773be83ab38b7ac747d25e1197233e8 \ No newline at end of file +77757e21aee500cb7fe7fd75e0f158633a0037d4d74e6f45eb14238f901ebacd \ No newline at end of file diff --git a/docs/apache-airflow/img/airflow_erd.svg b/docs/apache-airflow/img/airflow_erd.svg index bf4c6c94906a0..fb280ee0ea7fc 100644 --- a/docs/apache-airflow/img/airflow_erd.svg +++ b/docs/apache-airflow/img/airflow_erd.svg @@ -1421,28 +1421,28 @@ task_instance--xcom -0..N +1 1 task_instance--xcom -1 +0..N 1 task_instance--xcom -0..N +1 1 task_instance--xcom -1 +0..N 1 diff --git a/docs/apache-airflow/migrations-ref.rst b/docs/apache-airflow/migrations-ref.rst index d989564d91670..d858879d545fc 100644 --- a/docs/apache-airflow/migrations-ref.rst +++ b/docs/apache-airflow/migrations-ref.rst @@ -41,7 +41,7 @@ Here's the list of all the Database Migrations that are executed via when you ru +=================================+===================+===================+==============================================================+ | ``677fdbb7fc54`` (head) | ``1949afb29106`` | ``2.10.0`` | add new executor field to db | +---------------------------------+-------------------+-------------------+--------------------------------------------------------------+ -| ``1949afb29106`` | ``ee1467d4aa35`` | ``2.9.0`` | update trigger kwargs type | +| ``1949afb29106`` | ``ee1467d4aa35`` | ``2.9.0`` | update trigger kwargs type and encrypt | +---------------------------------+-------------------+-------------------+--------------------------------------------------------------+ | ``ee1467d4aa35`` | ``b4078ac230a1`` | ``2.9.0`` | add display name for dag and task instance | +---------------------------------+-------------------+-------------------+--------------------------------------------------------------+ diff --git a/tests/models/test_trigger.py b/tests/models/test_trigger.py index a3dd6ce35afbf..6be2086f34112 100644 --- a/tests/models/test_trigger.py +++ b/tests/models/test_trigger.py @@ -17,6 +17,7 @@ from __future__ import annotations import datetime +import json from typing import Any, AsyncIterator import pytest @@ -27,6 +28,7 @@ from airflow.jobs.triggerer_job_runner import TriggererJobRunner from airflow.models import TaskInstance, Trigger from airflow.operators.empty import EmptyOperator +from airflow.serialization.serialized_objects import BaseSerialization from airflow.triggers.base import BaseTrigger, TriggerEvent from airflow.utils import timezone from airflow.utils.session import create_session @@ -378,3 +380,18 @@ def test_serialize_sensitive_kwargs(): assert isinstance(trigger_row.encrypted_kwargs, str) assert "value1" not in trigger_row.encrypted_kwargs assert "value2" not in trigger_row.encrypted_kwargs + + +def test_kwargs_not_encrypted(): + """ + Tests that we don't decrypt kwargs if they aren't encrypted. + We weren't able to encrypt the kwargs in all migration paths. + """ + trigger = Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={}) + # force the `encrypted_kwargs` to be unencrypted, like they would be after an offline upgrade + trigger.encrypted_kwargs = json.dumps( + BaseSerialization.serialize({"param1": "value1", "param2": "value2"}) + ) + + assert trigger.kwargs["param1"] == "value1" + assert trigger.kwargs["param2"] == "value2"