Skip to content

Commit

Permalink
Remove XCom pickling
Browse files Browse the repository at this point in the history
  • Loading branch information
kaxil committed Nov 12, 2024
1 parent 0360b99 commit fd62961
Show file tree
Hide file tree
Showing 8 changed files with 16 additions and 193 deletions.
2 changes: 1 addition & 1 deletion airflow/api_connexion/endpoints/xcom_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def get_xcom_entry(
stub.value = XCom.deserialize_value(stub)
item = stub

if stringify or conf.getboolean("core", "enable_xcom_pickling"):
if stringify:
return xcom_schema_string.dump(item)

return xcom_schema_native.dump(item)
9 changes: 0 additions & 9 deletions airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -249,15 +249,6 @@ core:
type: string
example: ~
default: "False"
enable_xcom_pickling:
description: |
Whether to enable pickling for xcom (note that this is insecure and allows for
RCE exploits).
version_added: ~
type: string
example: ~
default: "False"
see_also: "https://docs.python.org/3/library/pickle.html#comparison-with-json"
allowed_deserialization_classes:
description: |
What classes can be imported during deserialization. This is a multi line value.
Expand Down
4 changes: 1 addition & 3 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3447,9 +3447,7 @@ def xcom_push(
Make an XCom available for tasks to pull.
:param key: Key to store the value under.
:param value: Value to store. What types are possible depends on whether
``enable_xcom_pickling`` is true or not. If so, this can be any
picklable object; only be JSON-serializable may be used otherwise.
:param value: Value to store. Only be JSON-serializable may be used otherwise.
"""
XCom.set(
key=key,
Expand Down
26 changes: 3 additions & 23 deletions airflow/models/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import inspect
import json
import logging
import pickle
from typing import TYPE_CHECKING, Any, Iterable, cast

from sqlalchemy import (
Expand Down Expand Up @@ -456,20 +455,7 @@ def serialize_value(
map_index: int | None = None,
) -> Any:
"""Serialize XCom value to str or pickled object."""
if conf.getboolean("core", "enable_xcom_pickling"):
return pickle.dumps(value)
try:
return json.dumps(value, cls=XComEncoder).encode("UTF-8")
except (ValueError, TypeError) as ex:
log.error(
"%s."
" If you are using pickle instead of JSON for XCom,"
" then you need to enable pickle support for XCom"
" in your airflow config or make sure to decorate your"
" object with attr.",
ex,
)
raise
return json.dumps(value, cls=XComEncoder).encode("UTF-8")

@staticmethod
def _deserialize_value(result: XCom, orm: bool) -> Any:
Expand All @@ -479,14 +465,8 @@ def _deserialize_value(result: XCom, orm: bool) -> Any:

if result.value is None:
return None
if conf.getboolean("core", "enable_xcom_pickling"):
try:
return pickle.loads(result.value)
except pickle.UnpicklingError:
return json.loads(result.value.decode("UTF-8"), cls=XComDecoder, object_hook=object_hook)
else:
# Since xcom_pickling is disabled, we should only try to deserialize with JSON
return json.loads(result.value.decode("UTF-8"), cls=XComDecoder, object_hook=object_hook)

return json.loads(result.value.decode("UTF-8"), cls=XComDecoder, object_hook=object_hook)

@staticmethod
def deserialize_value(result: XCom) -> Any:
Expand Down
10 changes: 10 additions & 0 deletions newsfragments/aip-72.significant.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,13 @@ As part of this change the following breaking changes have occurred:
- Shipping DAGs via pickle is no longer supported

This was a feature that was not widely used and was a security risk. It has been removed.

- Pickling is no longer supported for XCom serialization.

XCom data will no longer support pickling. This change is intended to improve security and simplify data
handling by supporting JSON-only serialization. DAGs that depend on XCom pickling must update to use JSON-serializable data.

As part of that change, ``[core] enable_xcom_pickling`` configuration option has been removed.

If you still need to use pickling, you can use a custom XCom backend that stores references in the metadata DB and
the pickled data can be stored in a separate storage like S3.
30 changes: 0 additions & 30 deletions tests/api_connexion/endpoints/test_xcom_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,36 +158,6 @@ def test_should_respond_200_native(self):
"value": {"key": "value"},
}

@conf_vars({("core", "enable_xcom_pickling"): "True"})
def test_should_respond_200_native_for_pickled(self):
dag_id = "test-dag-id"
task_id = "test-task-id"
execution_date = "2005-04-02T00:00:00+00:00"
xcom_key = "test-xcom-key"
execution_date_parsed = timezone.parse(execution_date)
run_id = DagRun.generate_run_id(DagRunType.MANUAL, execution_date_parsed)
value_non_serializable_key = {("201009_NB502104_0421_AHJY23BGXG (SEQ_WF: 138898)", None): 82359}
self._create_xcom_entry(
dag_id, run_id, execution_date_parsed, task_id, xcom_key, {"key": value_non_serializable_key}
)
response = self.client.get(
f"/api/v1/dags/{dag_id}/dagRuns/{run_id}/taskInstances/{task_id}/xcomEntries/{xcom_key}",
environ_overrides={"REMOTE_USER": "test"},
)
assert 200 == response.status_code

current_data = response.json
current_data["timestamp"] = "TIMESTAMP"
assert current_data == {
"dag_id": dag_id,
"execution_date": execution_date,
"key": xcom_key,
"task_id": task_id,
"map_index": -1,
"timestamp": "TIMESTAMP",
"value": f"{{'key': {str(value_non_serializable_key)}}}",
}

def test_should_raise_404_for_non_existent_xcom(self):
dag_id = "test-dag-id"
task_id = "test-task-id"
Expand Down
51 changes: 0 additions & 51 deletions tests/api_connexion/schemas/test_xcom_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,18 @@
# under the License.
from __future__ import annotations

import pickle

import pytest
from sqlalchemy import or_, select

from airflow.api_connexion.schemas.xcom_schema import (
XComCollection,
xcom_collection_item_schema,
xcom_collection_schema,
xcom_schema_string,
)
from airflow.models import DagRun, XCom
from airflow.utils import timezone
from airflow.utils.session import create_session

from tests_common.test_utils.config import conf_vars

pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode]


Expand Down Expand Up @@ -184,49 +179,3 @@ def test_serialize(self, create_xcom, session):
"total_entries": 2,
},
)


class TestXComSchema:
default_time = "2016-04-02T21:00:00+00:00"
default_time_parsed = timezone.parse(default_time)

@conf_vars({("core", "enable_xcom_pickling"): "True"})
def test_serialize(self, create_xcom, session):
create_xcom(
dag_id="test_dag",
task_id="test_task_id",
execution_date=self.default_time_parsed,
key="test_key",
value=pickle.dumps(b"test_binary"),
)
xcom_model = session.query(XCom).first()
deserialized_xcom = xcom_schema_string.dump(xcom_model)
assert deserialized_xcom == {
"key": "test_key",
"timestamp": self.default_time,
"execution_date": self.default_time,
"task_id": "test_task_id",
"dag_id": "test_dag",
"value": "test_binary",
"map_index": -1,
}

@conf_vars({("core", "enable_xcom_pickling"): "True"})
def test_deserialize(self):
xcom_dump = {
"key": "test_key",
"timestamp": self.default_time,
"execution_date": self.default_time,
"task_id": "test_task_id",
"dag_id": "test_dag",
"value": b"test_binary",
}
result = xcom_schema_string.load(xcom_dump)
assert result == {
"key": "test_key",
"timestamp": self.default_time_parsed,
"execution_date": self.default_time_parsed,
"task_id": "test_task_id",
"dag_id": "test_dag",
"value": "test_binary",
}
77 changes: 1 addition & 76 deletions tests/models/test_xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import datetime
import operator
import os
from typing import TYPE_CHECKING
from unittest import mock
from unittest.mock import MagicMock
Expand All @@ -28,7 +27,6 @@
from airflow.configuration import conf
from airflow.models.dagrun import DagRun, DagRunType
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.models.xcom import BaseXCom, XCom, resolve_xcom_backend
from airflow.operators.empty import EmptyOperator
from airflow.settings import json
Expand Down Expand Up @@ -109,77 +107,19 @@ def test_resolve_xcom_class(self):
cls = resolve_xcom_backend()
assert issubclass(cls, CustomXCom)

@conf_vars({("core", "xcom_backend"): "", ("core", "enable_xcom_pickling"): "False"})
@conf_vars({("core", "xcom_backend"): ""})
def test_resolve_xcom_class_fallback_to_basexcom(self):
cls = resolve_xcom_backend()
assert issubclass(cls, BaseXCom)
assert cls.serialize_value([1]) == b"[1]"

@conf_vars({("core", "enable_xcom_pickling"): "False"})
@conf_vars({("core", "xcom_backend"): "to be removed"})
def test_resolve_xcom_class_fallback_to_basexcom_no_config(self):
conf.remove_option("core", "xcom_backend")
cls = resolve_xcom_backend()
assert issubclass(cls, BaseXCom)
assert cls.serialize_value([1]) == b"[1]"

def test_xcom_deserialize_with_json_to_pickle_switch(self, task_instance, session):
ti_key = TaskInstanceKey(
dag_id=task_instance.dag_id,
task_id=task_instance.task_id,
run_id=task_instance.run_id,
)
with conf_vars({("core", "enable_xcom_pickling"): "False"}):
XCom.set(
key="xcom_test3",
value={"key": "value"},
dag_id=task_instance.dag_id,
task_id=task_instance.task_id,
run_id=task_instance.run_id,
session=session,
)
with conf_vars({("core", "enable_xcom_pickling"): "True"}):
ret_value = XCom.get_value(key="xcom_test3", ti_key=ti_key, session=session)
assert ret_value == {"key": "value"}

@pytest.mark.skip_if_database_isolation_mode
def test_xcom_deserialize_pickle_when_xcom_pickling_is_disabled(self, task_instance, session):
with conf_vars({("core", "enable_xcom_pickling"): "True"}):
XCom.set(
key="xcom_test3",
value={"key": "value"},
dag_id=task_instance.dag_id,
task_id=task_instance.task_id,
run_id=task_instance.run_id,
session=session,
)
with conf_vars({("core", "enable_xcom_pickling"): "False"}):
with pytest.raises(UnicodeDecodeError):
XCom.get_one(
key="xcom_test3",
dag_id=task_instance.dag_id,
task_id=task_instance.task_id,
run_id=task_instance.run_id,
session=session,
)

@pytest.mark.skip_if_database_isolation_mode
@conf_vars({("core", "xcom_enable_pickling"): "False"})
def test_xcom_disable_pickle_type_fail_on_non_json(self, task_instance, session):
class PickleRce:
def __reduce__(self):
return os.system, ("ls -alt",)

with pytest.raises(TypeError):
XCom.set(
key="xcom_test3",
value=PickleRce(),
dag_id=task_instance.dag_id,
task_id=task_instance.task_id,
run_id=task_instance.run_id,
session=session,
)

@mock.patch("airflow.models.xcom.XCom.orm_deserialize_value")
def test_xcom_init_on_load_uses_orm_deserialize_value(self, mock_orm_deserialize):
instance = BaseXCom(
Expand Down Expand Up @@ -216,7 +156,6 @@ def test_get_one_custom_backend_no_use_orm_deserialize_value(self, task_instance
XCom.orm_deserialize_value.assert_not_called()

@pytest.mark.skip_if_database_isolation_mode
@conf_vars({("core", "enable_xcom_pickling"): "False"})
@mock.patch("airflow.models.xcom.conf.getimport")
def test_set_serialize_call_current_signature(self, get_import, task_instance):
"""
Expand Down Expand Up @@ -266,17 +205,6 @@ def serialize_value(
)


@pytest.fixture(
params=[
pytest.param("true", id="enable_xcom_pickling=true"),
pytest.param("false", id="enable_xcom_pickling=false"),
],
)
def setup_xcom_pickling(request):
with conf_vars({("core", "enable_xcom_pickling"): str(request.param)}):
yield


@pytest.fixture
def push_simple_json_xcom(session):
def func(*, ti: TaskInstance, key: str, value):
Expand All @@ -292,7 +220,6 @@ def func(*, ti: TaskInstance, key: str, value):
return func


@pytest.mark.usefixtures("setup_xcom_pickling")
class TestXComGet:
@pytest.fixture
def setup_for_xcom_get_one(self, task_instance, push_simple_json_xcom):
Expand Down Expand Up @@ -403,7 +330,6 @@ def test_xcom_get_many_from_prior_dates(self, session, tis_for_xcom_get_many_fro
assert [x.execution_date for x in stored_xcoms] == [ti2.execution_date, ti1.execution_date]


@pytest.mark.usefixtures("setup_xcom_pickling")
class TestXComSet:
def test_xcom_set(self, session, task_instance):
XCom.set(
Expand Down Expand Up @@ -439,7 +365,6 @@ def test_xcom_set_again_replace(self, session, task_instance):
assert session.query(XCom).one().value == {"key2": "value2"}


@pytest.mark.usefixtures("setup_xcom_pickling")
class TestXComClear:
@pytest.fixture
def setup_for_xcom_clear(self, task_instance, push_simple_json_xcom):
Expand Down

0 comments on commit fd62961

Please sign in to comment.