Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
PythonFZ committed Feb 15, 2022
1 parent 4d12599 commit 0a91c7a
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 9 deletions.
2 changes: 2 additions & 0 deletions zntrack/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
cwd_temp_dir,
decode_dict,
deprecated,
encode_dict,
get_python_interpreter,
module_handler,
module_to_path,
Expand All @@ -27,6 +28,7 @@
"config",
"cwd_temp_dir",
"decode_dict",
"encode_dict",
"module_handler",
"update_nb_name",
"module_to_path",
Expand Down
5 changes: 5 additions & 0 deletions zntrack/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ def decode_dict(value):
return json.loads(json.dumps(value), cls=znjson.ZnDecoder)


def encode_dict(value) -> dict:
"""Encode value into a dict serialized with ZnJson"""
return json.loads(json.dumps(value, cls=znjson.ZnEncoder))


def get_auto_init(fields: typing.List[str]):
"""Automatically create a __init__ based on fields
Parameters
Expand Down
28 changes: 19 additions & 9 deletions zntrack/zn/split_option.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,39 @@
a parameter in another (zntrack.json)
"""
import json
import logging

import znjson
import typing

from zntrack import utils
from zntrack.core.parameter import ZnTrackOption

log = logging.getLogger(__name__)


def split_value(input_val):
"""Split input_val into data for params.yaml and zntrack.json"""
def split_value(input_val) -> (typing.Union[dict, list], typing.Union[dict, list]):
"""Split input_val into data for params.yaml and zntrack.json
Parameters
----------
input_val: dict
A dictionary of shape {_type: str, value: any} from ZnJSON
Returns
-------
params_data: dict|list
A dictionary containing the data considered a parameter
input_val: dict|list
A dictionary containing the constant data which is not considered a parameter
"""
if isinstance(input_val, (list, tuple)):
data = [split_value(x) for x in input_val]
params_data, zntrack_data = zip(*data)
else:
if input_val["_type"] in ["zn.method"]:
# zn.Method
params_data = input_val["value"].pop("kwargs")
params_data["_cls"] = input_val["value"].pop("cls")

# _ = input_val.pop("value")
else:
# things that are not zn.method and do not have kwargs, such as pathlib, ...
params_data = input_val.pop("value")
Expand Down Expand Up @@ -84,7 +94,7 @@ def save(self, instance):
where the path as string / the dataclass as dict is stored in params.yaml
"""
value = self.__get__(instance, self.owner)
serialized_value = json.loads(json.dumps(value, cls=znjson.ZnEncoder))
serialized_value = utils.encode_dict(value)

try:
# if znjson was used to serialize the data, it will have a _type key
Expand Down

0 comments on commit 0a91c7a

Please sign in to comment.