Skip to content

Commit

Permalink
feat: support nest log (#808)
Browse files Browse the repository at this point in the history
* add flatten

* update test
  • Loading branch information
Zeyi-Lin authored Feb 3, 2025
1 parent 5d66953 commit 2a6ba54
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 7 deletions.
21 changes: 18 additions & 3 deletions swanlab/data/run/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,17 @@ def config(self) -> SwanLabConfig:
"""
return self.__config

def __flatten_dict(self, d: dict, parent_key='', sep='.') -> dict:
"""Helper method to flatten nested dictionaries with dot notation"""
items = []
for k, v in d.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
if isinstance(v, dict):
items.extend(self.__flatten_dict(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)

def log(self, data: dict, step: int = None):
"""
Log a row of data to the current run. Unlike `swanlab.log`, this api will be called directly based on the
Expand All @@ -320,7 +331,8 @@ def log(self, data: dict, step: int = None):
data : Dict[str, DataType]
Data must be a dict.
The key must be a string with 0-9, a-z, A-Z, " ", "_", "-", "/".
The value must be a `float`, `float convertible object`, `int` or `swanlab.data.BaseType`.
The value must be a `float`, `float convertible object`, `int`, `swanlab.data.BaseType`, or a nested dict.
For nested dicts, keys will be joined with dots (e.g., {'a': {'b': 1}} becomes {'a.b': 1}).
step : int, optional
The step number of the current data, if not provided, it will be automatically incremented.
If step is duplicated, the data will be ignored.
Expand All @@ -347,15 +359,18 @@ def log(self, data: dict, step: int = None):
)
step = None

# 展平嵌套字典
flattened_data = self.__flatten_dict(data)

log_return = {}
# 遍历data,记录data
for k, v in data.items():
for k, v in flattened_data.items():
_k = k
k = check_key_format(k, auto_cut=True)
if k != _k:
# 超过255字符,截断
swanlog.warning(f"Key {_k} is too long, cut to 255 characters.")
if k in data.keys():
if k in flattened_data.keys():
raise ValueError(f'tag: Not supported too long Key "{_k}" and auto cut failed')
# ---------------------------------- 包装数据 ----------------------------------
# 输入为可转换为float的数据类型
Expand Down
9 changes: 5 additions & 4 deletions test/unit/data/run/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,14 @@ def test_log_column_with_id(self):
# ---------------------------------- 解析log数字/Line ----------------------------------
def test_log_number_ok(self):
run = SwanLabRun()
data = {"a": 1, "b": 0.1, "math.nan": math.nan, "math.inf": math.inf}
data = {"a": 1, "b": 0.1, "c": {"d": 2}, "math.nan": math.nan, "math.inf": math.inf}
ll = run.log(data)
assert len(ll) == 4
assert len(ll) == 5
# 都没有错误
assert all([ll[k].is_error is False for k in ll])
assert ll["a"].data == 1
assert ll["b"].data == 0.1
assert ll["c.d"].data == 2
assert ll["math.nan"].data == Line.nan
assert ll["math.inf"].data == Line.inf
assert all([ll[k].column_info.chart_type == ll[k].column_info.chart_type.LINE for k in ll])
Expand All @@ -146,9 +147,9 @@ def test_log_number_use_line(self):
使用Line类型log,本质上应该与数字类型一样,数字类型是Line类型的语法糖
"""
run = SwanLabRun()
data = {"a": Line(1), "b": Line(0.1), "math.nan": Line(math.nan), "math.inf": Line(math.inf)}
data = {"a": Line(1), "b": Line(0.1), "c": {"d": Line(2)}, "math.nan": Line(math.nan), "math.inf": Line(math.inf)}
ll = run.log(data)
assert len(ll) == 4
assert len(ll) == 5
# line(1)和[line(1)]是一样的
ll2 = run.log({"a": [Line(1)]})
assert ll2["a"].data == ll["a"].data
Expand Down

0 comments on commit 2a6ba54

Please sign in to comment.