Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def deserialize_value(result) -> Any:
try:
with path.open(mode="rb", compression="infer") as f:
return json.load(f, cls=XComDecoder)
except (TypeError, ValueError):
except (FileNotFoundError, TypeError, ValueError):
return data

@staticmethod
Expand Down
37 changes: 28 additions & 9 deletions providers/common/io/tests/unit/common/io/xcom/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.
from __future__ import annotations

from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import pytest

Expand All @@ -35,9 +35,11 @@

if AIRFLOW_V_3_0_PLUS:
from airflow.models.xcom import XComModel
from airflow.sdk import ObjectStoragePath
from airflow.sdk.execution_time.comms import XComResult
from airflow.sdk.execution_time.xcom import resolve_xcom_backend
else:
from airflow.io.path import ObjectStoragePath
from airflow.models.xcom import BaseXCom, resolve_xcom_backend # type: ignore[no-redef]


Expand Down Expand Up @@ -373,6 +375,11 @@ def test_compression(self, task_instance, session, mock_supervisor_comms):
@pytest.mark.parametrize(
"value, expected_value",
[
pytest.param(
"file://airflow/xcoms/non_existing_file.json",
"file://airflow/xcoms/non_existing_file.json",
id="str",
),
pytest.param(1, 1, id="int"),
pytest.param(1.0, 1.0, id="float"),
pytest.param("string", "string", id="str"),
Expand All @@ -385,11 +392,23 @@ def test_compression(self, task_instance, session, mock_supervisor_comms):
],
)
def test_serialization_deserialization_basic(self, value, expected_value):
XCom = resolve_xcom_backend()
airflow.models.xcom.XCom = XCom

serialized_data = XCom.serialize_value(value)
mock_xcom_ser = MagicMock(value=serialized_data)
deserialized_data = XCom.deserialize_value(mock_xcom_ser)

assert deserialized_data == expected_value
def conditional_side_effect(data) -> ObjectStoragePath:
if isinstance(data, str) and data.startswith("file://"):
return ObjectStoragePath(data)
return original_get_full_path(data)

original_get_full_path = XComObjectStorageBackend._get_full_path

with patch.object(
XComObjectStorageBackend,
"_get_full_path",
side_effect=conditional_side_effect,
):
XCom = resolve_xcom_backend()
airflow.models.xcom.XCom = XCom

serialized_data = XCom.serialize_value(value)
mock_xcom_ser = MagicMock(value=serialized_data)
deserialized_data = XCom.deserialize_value(mock_xcom_ser)

assert deserialized_data == expected_value