Skip to content

Commit

Permalink
Fernet-key-rotation-optimisation (#40786)
Browse files Browse the repository at this point in the history
* Current implementation of Fernet key rotation implicitly executes `all()` method on the processed tables leading to loading all rows to memory.
It's been observed that some users store additional data in `variable` table which is leading to memory issues during the operation.
This change introduces batch processing of fernet key rotation to avoid it. To be consistent across the tables (`variable`, `connection`, `trigger`) the batching operation was added for all of them.

---------

Co-authored-by: bjankiewicz <bjankiewicz@google.com>
  • Loading branch information
bjankie1 and bjankiewicz authored Jul 17, 2024
1 parent bc93a94 commit 79722e3
Showing 1 changed file with 46 additions and 8 deletions.
54 changes: 46 additions & 8 deletions airflow/cli/commands/rotate_fernet_key_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,55 @@
from airflow.utils import cli as cli_utils
from airflow.utils.providers_configuration_loader import providers_configuration_loaded
from airflow.utils.session import create_session
from airflow.utils.sqlalchemy import is_sqlalchemy_v1


@cli_utils.action_cli
@providers_configuration_loaded
def rotate_fernet_key(args):
"""Rotates all encrypted connection credentials and variables."""
"""Rotates all encrypted connection credentials, triggers and variables."""
batch_size = 100
rotate_method = rotate_items_in_batches_v1 if is_sqlalchemy_v1() else rotate_items_in_batches_v2
with create_session() as session:
conns_query = select(Connection).where(Connection.is_encrypted | Connection.is_extra_encrypted)
for conn in session.scalars(conns_query):
conn.rotate_fernet_key()
for var in session.scalars(select(Variable).where(Variable.is_encrypted)):
var.rotate_fernet_key()
for trigger in session.scalars(select(Trigger)):
trigger.rotate_fernet_key()
with session.begin(): # Start a single transaction
rotate_method(
session,
Connection,
filter_condition=Connection.is_encrypted | Connection.is_extra_encrypted,
batch_size=batch_size,
)
rotate_method(session, Variable, filter_condition=Variable.is_encrypted, batch_size=batch_size)
rotate_method(session, Trigger, filter_condition=None, batch_size=batch_size)


def rotate_items_in_batches_v1(session, model_class, filter_condition=None, batch_size=100):
"""Rotates Fernet keys for items of a given model in batches to avoid excessive memory usage.
This function is a replacement for yield_per, which is not available in SQLAlchemy 1.x.
"""
offset = 0
while True:
query = select(model_class)
if filter_condition is not None:
query = query.where(filter_condition)
query = query.offset(offset).limit(batch_size)
items = session.scalars(query).all()
if not items:
break # No more items to process
for item in items:
item.rotate_fernet_key()
offset += batch_size


def rotate_items_in_batches_v2(session, model_class, filter_condition=None, batch_size=100):
"""Rotates Fernet keys for items of a given model in batches to avoid excessive memory usage.
This function is taking advantage of yield_per available in SQLAlchemy 2.x.
"""
while True:
query = select(model_class)
if filter_condition is not None:
query = query.where(filter_condition)
items = session.scalars(query).yield_per(batch_size)
for item in items:
item.rotate_fernet_key()

0 comments on commit 79722e3

Please sign in to comment.