Skip to content

Commit

Permalink
Fix trigger kwarg encryption migration (#39246)
Browse files Browse the repository at this point in the history
Do the encryption in the migration itself, and fix support for offline
migrations as well.

The offline up migration won't actually encrypt the trigger kwargs as there
isn't a safe way to accomplish that, so the decryption processes checks
and short circuits if it isn't encrypted.

The offline down migration will now print out a warning that the offline
migration will fail if there are any running triggers. I think this is
the best we can do for that scenario (and folks willing to do offline
migrations will hopefully be able to understand the situation).

This also solves the "encrypting the already encrypted kwargs" bug in
2.9.0.

(cherry picked from commit adeb7f7)
  • Loading branch information
jedcunningham committed Apr 26, 2024
1 parent d5d8b58 commit c9cc726
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,22 @@
# specific language governing permissions and limitations
# under the License.

"""update trigger kwargs type
"""update trigger kwargs type and encrypt
Revision ID: 1949afb29106
Revises: ee1467d4aa35
Create Date: 2024-03-17 22:09:09.406395
"""
import json
from textwrap import dedent

from alembic import context, op
import sqlalchemy as sa
from sqlalchemy.orm import lazyload

from airflow.serialization.serialized_objects import BaseSerialization
from airflow.models.trigger import Trigger
from alembic import op

from airflow.utils.sqlalchemy import ExtendedJSON

# revision identifiers, used by Alembic.
Expand All @@ -38,13 +42,43 @@
airflow_version = "2.9.0"


def get_session() -> sa.orm.Session:
conn = op.get_bind()
sessionmaker = sa.orm.sessionmaker()
return sessionmaker(bind=conn)

def upgrade():
"""Update trigger kwargs type to string"""
"""Update trigger kwargs type to string and encrypt"""
with op.batch_alter_table("trigger") as batch_op:
batch_op.alter_column("kwargs", type_=sa.Text(), )

if not context.is_offline_mode():
session = get_session()
try:
for trigger in session.query(Trigger).options(lazyload(Trigger.task_instance)):
trigger.kwargs = trigger.kwargs
session.commit()
finally:
session.close()


def downgrade():
"""Unapply update trigger kwargs type to string"""
"""Unapply update trigger kwargs type to string and encrypt"""
if context.is_offline_mode():
print(dedent("""
------------
-- WARNING: Unable to decrypt trigger kwargs automatically in offline mode!
-- If any trigger rows exist when you do an offline downgrade, the migration will fail.
------------
"""))
else:
session = get_session()
try:
for trigger in session.query(Trigger).options(lazyload(Trigger.task_instance)):
trigger.encrypted_kwargs = json.dumps(BaseSerialization.serialize(trigger.kwargs))
session.commit()
finally:
session.close()

with op.batch_alter_table("trigger") as batch_op:
batch_op.alter_column("kwargs", type_=ExtendedJSON(), postgresql_using="kwargs::json")
batch_op.alter_column("kwargs", type_=ExtendedJSON(), postgresql_using='kwargs::json')
10 changes: 9 additions & 1 deletion airflow/models/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,15 @@ def _decrypt_kwargs(encrypted_kwargs: str) -> dict[str, Any]:
from airflow.models.crypto import get_fernet
from airflow.serialization.serialized_objects import BaseSerialization

decrypted_kwargs = json.loads(get_fernet().decrypt(encrypted_kwargs.encode("utf-8")).decode("utf-8"))
# We weren't able to encrypt the kwargs in all migration paths,
# so we need to handle the case where they are not encrypted.
# Triggers aren't long lasting, so we can skip encrypting them now.
if encrypted_kwargs.startswith("{"):
decrypted_kwargs = json.loads(encrypted_kwargs)
else:
decrypted_kwargs = json.loads(
get_fernet().decrypt(encrypted_kwargs.encode("utf-8")).decode("utf-8")
)

return BaseSerialization.deserialize(decrypted_kwargs)

Expand Down
39 changes: 0 additions & 39 deletions airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,33 +972,6 @@ def synchronize_log_template(*, session: Session = NEW_SESSION) -> None:
session.add(LogTemplate(filename=filename, elasticsearch_id=elasticsearch_id))


def encrypt_trigger_kwargs(*, session: Session) -> None:
"""Encrypt trigger kwargs."""
from airflow.models.trigger import Trigger
from airflow.serialization.serialized_objects import BaseSerialization

for trigger in session.query(Trigger):
# convert serialized dict to string and encrypt it
trigger.kwargs = BaseSerialization.deserialize(json.loads(trigger.encrypted_kwargs))
session.commit()


def decrypt_trigger_kwargs(*, session: Session) -> None:
"""Decrypt trigger kwargs."""
from airflow.models.trigger import Trigger
from airflow.serialization.serialized_objects import BaseSerialization

if not inspect(session.bind).has_table(Trigger.__tablename__):
# table does not exist, nothing to do
# this can happen when we downgrade to an old version before the Trigger table was added
return

for trigger in session.scalars(select(Trigger.encrypted_kwargs)):
# decrypt the string and convert it to serialized dict
trigger.encrypted_kwargs = json.dumps(BaseSerialization.serialize(trigger.kwargs))
session.commit()


def check_conn_id_duplicates(session: Session) -> Iterable[str]:
"""
Check unique conn_id in connection table.
Expand Down Expand Up @@ -1666,12 +1639,6 @@ def upgradedb(
_reserialize_dags(session=session)
add_default_pool_if_not_exists(session=session)
synchronize_log_template(session=session)
if _revision_greater(
config,
_REVISION_HEADS_MAP["2.9.0"],
_get_current_revision(session=session),
):
encrypt_trigger_kwargs(session=session)


@provide_session
Expand Down Expand Up @@ -1744,12 +1711,6 @@ def downgrade(*, to_revision, from_revision=None, show_sql_only=False, session:
else:
log.info("Applying downgrade migrations.")
command.downgrade(config, revision=to_revision, sql=show_sql_only)
if _revision_greater(
config,
_REVISION_HEADS_MAP["2.9.0"],
to_revision,
):
decrypt_trigger_kwargs(session=session)


def drop_airflow_models(connection):
Expand Down
2 changes: 1 addition & 1 deletion docs/apache-airflow/img/airflow_erd.sha256
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2a24225537326f38be5df14e0b7a8dca867122093e0fa932f1a11ac12d1fb11c
3eb263f117248f914f64bf7cf44757526ecc00f222677629a11602f3bae7cdf0
4 changes: 2 additions & 2 deletions docs/apache-airflow/img/airflow_erd.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/apache-airflow/migrations-ref.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Here's the list of all the Database Migrations that are executed via when you ru
+---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
| Revision ID | Revises ID | Airflow Version | Description |
+=================================+===================+===================+==============================================================+
| ``1949afb29106`` (head) | ``ee1467d4aa35`` | ``2.9.0`` | update trigger kwargs type |
| ``1949afb29106`` (head) | ``ee1467d4aa35`` | ``2.9.0`` | update trigger kwargs type and encrypt |
+---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
| ``ee1467d4aa35`` | ``b4078ac230a1`` | ``2.9.0`` | add display name for dag and task instance |
+---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
Expand Down
17 changes: 17 additions & 0 deletions tests/models/test_trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import datetime
import json
from typing import Any, AsyncIterator

import pytest
Expand All @@ -27,6 +28,7 @@
from airflow.jobs.triggerer_job_runner import TriggererJobRunner
from airflow.models import TaskInstance, Trigger
from airflow.operators.empty import EmptyOperator
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils import timezone
from airflow.utils.session import create_session
Expand Down Expand Up @@ -378,3 +380,18 @@ def test_serialize_sensitive_kwargs():
assert isinstance(trigger_row.encrypted_kwargs, str)
assert "value1" not in trigger_row.encrypted_kwargs
assert "value2" not in trigger_row.encrypted_kwargs


def test_kwargs_not_encrypted():
"""
Tests that we don't decrypt kwargs if they aren't encrypted.
We weren't able to encrypt the kwargs in all migration paths.
"""
trigger = Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={})
# force the `encrypted_kwargs` to be unencrypted, like they would be after an offline upgrade
trigger.encrypted_kwargs = json.dumps(
BaseSerialization.serialize({"param1": "value1", "param2": "value2"})
)

assert trigger.kwargs["param1"] == "value1"
assert trigger.kwargs["param2"] == "value2"

0 comments on commit c9cc726

Please sign in to comment.