Skip to content

Commit

Permalink
fixbug: transformers callback yaml (#704)
Browse files Browse the repository at this point in the history
* Update config.py

* change check logic

* fix:isinstance
  • Loading branch information
SAKURA-CAT authored Sep 20, 2024
1 parent 7ad24aa commit 64cd09e
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 10 deletions.
30 changes: 20 additions & 10 deletions swanlab/data/run/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,32 @@
import json
from dataclasses import is_dataclass, asdict

BASE_TYPE = (int, float, str, bool)


def json_serializable(obj):
"""
将传入的字典转换为JSON可序列化格式。
:raises TypeError: 对象不是JSON可序列化的
"""
# 如果对象是基本类型,则直接返回
if isinstance(obj, (int, float, str, bool, type(None))):
if isinstance(obj, float) and math.isnan(obj):
return Line.nan
if isinstance(obj, float) and math.isinf(obj):
return Line.inf
return obj
# 不可以直接使用isinstance,详见issue: https://github.com/SwanHubX/SwanLab/issues/702
if obj is None:
return None
if type(obj) is float and math.isnan(obj):
return Line.nan
if type(obj) is float and math.isinf(obj):
return Line.inf

for t in BASE_TYPE:
if type(obj) is t:
return obj
# 继承的子类需要转译
if isinstance(obj, t):
return t(obj)

# 将日期和时间转换为字符串
elif isinstance(obj, (datetime.date, datetime.datetime)):
if isinstance(obj, (datetime.date, datetime.datetime)):
return obj.isoformat()

# 对于列表和元组,递归调用此函数
Expand Down Expand Up @@ -121,9 +131,9 @@ class SwanLabConfig(MutableMapping):
"""

def __init__(
self,
config: Union[MutableMapping, argparse.Namespace] = None,
on_setter: Optional[Callable[[RuntimeInfo], Any]] = None,
self,
config: Union[MutableMapping, argparse.Namespace] = None,
on_setter: Optional[Callable[[RuntimeInfo], Any]] = None,
):
"""
实例化配置类,如果settings不为None,说明是通过swanlab.init调用的,否则是通过swanlab.config调用的
Expand Down
21 changes: 21 additions & 0 deletions test/unit/data/run/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,27 @@ class MyData:
assert yaml.dump(config) == yaml.dump(vars(config_data))


def test_parse_base_class():
"""
继承自基础的类不能绕过parse函数
"""

class StrChild(str):
pass

value = parse(StrChild("abc"))
assert value == "abc"
assert yaml.safe_dump({"value": value})

class IntChild(int):
pass

value = parse(IntChild(1))
assert value == 1
assert value != "1"
assert yaml.safe_dump({"value": value})


class TestSwanLabConfigOperation:
"""
单独测试TestSwanLabRunConfig这个类
Expand Down

0 comments on commit 64cd09e

Please sign in to comment.