Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix trigger kwarg encryption migration #39246

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,22 @@
# specific language governing permissions and limitations
# under the License.

"""update trigger kwargs type
"""update trigger kwargs type and encrypt

Revision ID: 1949afb29106
Revises: ee1467d4aa35
Create Date: 2024-03-17 22:09:09.406395

"""
import json
from textwrap import dedent

from alembic import context, op
import sqlalchemy as sa
from sqlalchemy.orm import lazyload

from airflow.serialization.serialized_objects import BaseSerialization
from airflow.models.trigger import Trigger
from alembic import op

from airflow.utils.sqlalchemy import ExtendedJSON

# revision identifiers, used by Alembic.
Expand All @@ -38,13 +42,43 @@
airflow_version = "2.9.0"


def get_session() -> sa.orm.Session:
conn = op.get_bind()
sessionmaker = sa.orm.sessionmaker()
return sessionmaker(bind=conn)

def upgrade():
"""Update trigger kwargs type to string"""
"""Update trigger kwargs type to string and encrypt"""
with op.batch_alter_table("trigger") as batch_op:
batch_op.alter_column("kwargs", type_=sa.Text(), )

if not context.is_offline_mode():
dstandish marked this conversation as resolved.
Show resolved Hide resolved
session = get_session()
try:
for trigger in session.query(Trigger).options(lazyload(Trigger.task_instance)):
trigger.kwargs = trigger.kwargs
session.commit()
finally:
session.close()
dstandish marked this conversation as resolved.
Show resolved Hide resolved


def downgrade():
"""Unapply update trigger kwargs type to string"""
"""Unapply update trigger kwargs type to string and encrypt"""
if context.is_offline_mode():
print(dedent("""
------------
-- WARNING: Unable to decrypt trigger kwargs automatically in offline mode!
uranusjr marked this conversation as resolved.
Show resolved Hide resolved
-- If any trigger rows exist when you do an offline downgrade, the migration will fail.
------------
"""))
else:
session = get_session()
try:
for trigger in session.query(Trigger).options(lazyload(Trigger.task_instance)):
trigger.encrypted_kwargs = json.dumps(BaseSerialization.serialize(trigger.kwargs))
session.commit()
finally:
session.close()

with op.batch_alter_table("trigger") as batch_op:
batch_op.alter_column("kwargs", type_=ExtendedJSON(), postgresql_using="kwargs::json")
batch_op.alter_column("kwargs", type_=ExtendedJSON(), postgresql_using='kwargs::json')
10 changes: 9 additions & 1 deletion airflow/models/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,15 @@ def _decrypt_kwargs(encrypted_kwargs: str) -> dict[str, Any]:
from airflow.models.crypto import get_fernet
from airflow.serialization.serialized_objects import BaseSerialization

decrypted_kwargs = json.loads(get_fernet().decrypt(encrypted_kwargs.encode("utf-8")).decode("utf-8"))
# We weren't able to encrypt the kwargs in all migration paths,
# so we need to handle the case where they are not encrypted.
# Triggers aren't long lasting, so we can skip encrypting them now.
jedcunningham marked this conversation as resolved.
Show resolved Hide resolved
if encrypted_kwargs.startswith("{"):
decrypted_kwargs = json.loads(encrypted_kwargs)
jedcunningham marked this conversation as resolved.
Show resolved Hide resolved
Lee-W marked this conversation as resolved.
Show resolved Hide resolved
else:
decrypted_kwargs = json.loads(
get_fernet().decrypt(encrypted_kwargs.encode("utf-8")).decode("utf-8")
)

return BaseSerialization.deserialize(decrypted_kwargs)

Expand Down
39 changes: 0 additions & 39 deletions airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,33 +972,6 @@ def synchronize_log_template(*, session: Session = NEW_SESSION) -> None:
session.add(LogTemplate(filename=filename, elasticsearch_id=elasticsearch_id))


def encrypt_trigger_kwargs(*, session: Session) -> None:
"""Encrypt trigger kwargs."""
from airflow.models.trigger import Trigger
from airflow.serialization.serialized_objects import BaseSerialization

for trigger in session.query(Trigger):
# convert serialized dict to string and encrypt it
trigger.kwargs = BaseSerialization.deserialize(json.loads(trigger.encrypted_kwargs))
session.commit()


def decrypt_trigger_kwargs(*, session: Session) -> None:
"""Decrypt trigger kwargs."""
from airflow.models.trigger import Trigger
from airflow.serialization.serialized_objects import BaseSerialization

if not inspect(session.bind).has_table(Trigger.__tablename__):
# table does not exist, nothing to do
# this can happen when we downgrade to an old version before the Trigger table was added
return

for trigger in session.scalars(select(Trigger.encrypted_kwargs)):
# decrypt the string and convert it to serialized dict
trigger.encrypted_kwargs = json.dumps(BaseSerialization.serialize(trigger.kwargs))
session.commit()


def check_conn_id_duplicates(session: Session) -> Iterable[str]:
"""
Check unique conn_id in connection table.
Expand Down Expand Up @@ -1666,12 +1639,6 @@ def upgradedb(
_reserialize_dags(session=session)
add_default_pool_if_not_exists(session=session)
synchronize_log_template(session=session)
if _revision_greater(
config,
_REVISION_HEADS_MAP["2.9.0"],
_get_current_revision(session=session),
):
encrypt_trigger_kwargs(session=session)


@provide_session
Expand Down Expand Up @@ -1744,12 +1711,6 @@ def downgrade(*, to_revision, from_revision=None, show_sql_only=False, session:
else:
log.info("Applying downgrade migrations.")
command.downgrade(config, revision=to_revision, sql=show_sql_only)
if _revision_greater(
config,
_REVISION_HEADS_MAP["2.9.0"],
to_revision,
):
decrypt_trigger_kwargs(session=session)


def drop_airflow_models(connection):
Expand Down
2 changes: 1 addition & 1 deletion docs/apache-airflow/img/airflow_erd.sha256
Original file line number Diff line number Diff line change
@@ -1 +1 @@
072fb4b43a86ccb57765ec3f163350519773be83ab38b7ac747d25e1197233e8
77757e21aee500cb7fe7fd75e0f158633a0037d4d74e6f45eb14238f901ebacd
8 changes: 4 additions & 4 deletions docs/apache-airflow/img/airflow_erd.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/apache-airflow/migrations-ref.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Here's the list of all the Database Migrations that are executed via when you ru
+=================================+===================+===================+==============================================================+
| ``677fdbb7fc54`` (head) | ``1949afb29106`` | ``2.10.0`` | add new executor field to db |
+---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
| ``1949afb29106`` | ``ee1467d4aa35`` | ``2.9.0`` | update trigger kwargs type |
| ``1949afb29106`` | ``ee1467d4aa35`` | ``2.9.0`` | update trigger kwargs type and encrypt |
+---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
| ``ee1467d4aa35`` | ``b4078ac230a1`` | ``2.9.0`` | add display name for dag and task instance |
+---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
Expand Down
17 changes: 17 additions & 0 deletions tests/models/test_trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import datetime
import json
from typing import Any, AsyncIterator

import pytest
Expand All @@ -27,6 +28,7 @@
from airflow.jobs.triggerer_job_runner import TriggererJobRunner
from airflow.models import TaskInstance, Trigger
from airflow.operators.empty import EmptyOperator
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils import timezone
from airflow.utils.session import create_session
Expand Down Expand Up @@ -378,3 +380,18 @@ def test_serialize_sensitive_kwargs():
assert isinstance(trigger_row.encrypted_kwargs, str)
assert "value1" not in trigger_row.encrypted_kwargs
assert "value2" not in trigger_row.encrypted_kwargs


def test_kwargs_not_encrypted():
"""
Tests that we don't decrypt kwargs if they aren't encrypted.
We weren't able to encrypt the kwargs in all migration paths.
"""
trigger = Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={})
# force the `encrypted_kwargs` to be unencrypted, like they would be after an offline upgrade
trigger.encrypted_kwargs = json.dumps(
BaseSerialization.serialize({"param1": "value1", "param2": "value2"})
)

assert trigger.kwargs["param1"] == "value1"
assert trigger.kwargs["param2"] == "value2"