Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade old DAG/task param format when deserializing from the DB #18986

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 29 additions & 37 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,32 @@ def _value_is_hardcoded_default(cls, attrname: str, value: Any, instance: Any) -
return True
return False

@classmethod
def _serialize_params_dict(cls, params: ParamsDict):
"""Serialize Params dict for a DAG/Task"""
serialized_params = {}
for k, v in params.items():
# TODO: As of now, we would allow serialization of params which are of type Param only
if f'{v.__module__}.{v.__class__.__name__}' == 'airflow.models.param.Param':
serialized_params[k] = v.dump()
else:
raise ValueError('Params to a DAG or a Task can be only of type airflow.models.param.Param')
return serialized_params

@classmethod
def _deserialize_params_dict(cls, encoded_params: Dict) -> ParamsDict:
"""Deserialize a DAGs Params dict"""
op_params = {}
for k, v in encoded_params.items():
if isinstance(v, dict) and "__class" in v:
param_class = import_string(v['__class'])
op_params[k] = param_class(**v)
else:
# Old style params, upgrade it
op_params[k] = Param(v)

return ParamsDict(op_params)


class DependencyDetector:
"""Detects dependencies between DAGs."""
Expand Down Expand Up @@ -584,7 +610,7 @@ def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> BaseOperator:
elif k == "deps":
v = cls._deserialize_deps(v)
elif k == "params":
v = cls._deserialize_operator_params(v)
v = cls._deserialize_params_dict(v)
elif k in cls._decorated_fields or k not in op.get_serialized_fields():
v = cls._deserialize(v)
# else use v as it is
Expand Down Expand Up @@ -721,17 +747,6 @@ def _serialize_operator_extra_links(cls, operator_extra_links: Iterable[BaseOper

return serialize_operator_extra_links

@classmethod
def _deserialize_operator_params(cls, encoded_op_params: Dict) -> Dict[str, Param]:
"""Deserialize Params dict of a operator"""
op_params = {}
for k, v in encoded_op_params.items():
param_class = import_string(v['__class'])
del v['__class']
op_params[k] = param_class(**v)

return ParamsDict(op_params)

@classmethod
def _serialize_operator_params(cls, op_params: ParamsDict):
"""Serialize Params dict of a operator"""
Expand Down Expand Up @@ -802,7 +817,7 @@ def serialize_dag(cls, dag: DAG) -> dict:

# Edge info in the JSON exactly matches our internal structure
serialize_dag["edge_info"] = dag.edge_info
serialize_dag["params"] = cls._serialize_dag_params(dag.params)
serialize_dag["params"] = cls._serialize_params_dict(dag.params)

# has_on_*_callback are only stored if the value is True, as the default is False
if dag.has_on_success_callback:
Expand Down Expand Up @@ -843,7 +858,7 @@ def deserialize_dag(cls, encoded_dag: Dict[str, Any]) -> 'SerializedDAG':
elif k in cls._decorated_fields:
v = cls._deserialize(v)
elif k == "params":
v = cls._deserialize_dag_params(v)
v = cls._deserialize_params_dict(v)
# else use v as it is

setattr(dag, k, v)
Expand Down Expand Up @@ -915,29 +930,6 @@ def from_dict(cls, serialized_obj: dict) -> 'SerializedDAG':
raise ValueError(f"Unsure how to deserialize version {ver!r}")
return cls.deserialize_dag(serialized_obj['dag'])

@classmethod
def _serialize_dag_params(cls, dag_params: ParamsDict):
"""Serialize Params dict for a DAG"""
serialized_params = {}
for k, v in dag_params.items():
# TODO: As of now, we would allow serialization of params which are of type Param only
if f'{v.__module__}.{v.__class__.__name__}' == 'airflow.models.param.Param':
serialized_params[k] = v.dump()
else:
raise ValueError('Params to a DAG can be only of type airflow.models.param.Param')
return serialized_params

@classmethod
def _deserialize_dag_params(cls, encoded_dag_params: Dict) -> ParamsDict:
"""Deserialize a DAGs Params dict"""
op_params = {}
for k, v in encoded_dag_params.items():
param_class = import_string(v['__class'])
del v['__class']
op_params[k] = param_class(**v)

return ParamsDict(op_params)


class SerializedTaskGroup(TaskGroup, BaseSerialization):
"""A JSON serializable representation of TaskGroup."""
Expand Down
18 changes: 18 additions & 0 deletions tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1432,6 +1432,24 @@ def test_serialized_objects_are_sorted(self, object_to_serialized, expected_outp
serialized_obj = serialized_obj["__var"]
assert serialized_obj == expected_output

def test_params_upgrade(self):
serialized = {
"__version": 1,
"dag": {
"_dag_id": "simple_dag",
"fileloc": __file__,
"tasks": [],
"timezone": "UTC",
"params": {"none": None, "str": "str", "dict": {"a": "b"}},
},
}
SerializedDAG.validate_schema(serialized)
dag = SerializedDAG.from_dict(serialized)

assert dag.params["none"] is None
assert isinstance(dict.__getitem__(dag.params, "none"), Param)
assert dag.params["str"] == "str"


def test_kubernetes_optional():
"""Serialisation / deserialisation continues to work without kubernetes installed"""
Expand Down