Skip to content

Commit

Permalink
Fixbug/tag sort (#327)
Browse files Browse the repository at this point in the history
  • Loading branch information
Feudalman authored Feb 19, 2024
1 parent 6ab9206 commit 7b964e9
Show file tree
Hide file tree
Showing 9 changed files with 112 additions and 10 deletions.
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
"run": "core",
"system": "core",
"data": "guard",
"unit": "test"
"unit": "test",
"migrate": "Husky"
},
"material-icon-theme.files.associations": {
".env.mock": "Tune"
Expand Down
12 changes: 12 additions & 0 deletions swanlab/data/run/exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,18 @@ class SwanLabTag:
__slice_size = 1000

def __init__(self, experiment_id, tag: str, log_dir: str) -> None:
"""
初始化tag对象
Parameters
----------
experiment_id : int
实验id
tag : str
tag名称
log_dir : str
log文件夹路径
"""
self.experiment_id = experiment_id
self.tag = tag
self.__steps = set()
Expand Down
9 changes: 7 additions & 2 deletions swanlab/db/db_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from ..env import get_db_path
import os
from peewee import SqliteDatabase
from .table_config import tables
from .table_config import tables, Tag
from .migrate import *

# 判断是否已经binded了
binded = False
Expand Down Expand Up @@ -48,13 +49,17 @@ def connect(autocreate=False) -> SqliteDatabase:
path_exists = os.path.exists(os.path.dirname(path))
if not path_exists or (not db_exists and not autocreate):
raise FileNotFoundError(f"DB file {path} not found")

# 启用外键约束
swandb = SqliteDatabase(path, pragmas={"foreign_keys": 1})
if not binded:
# 动态绑定数据库
swandb.connect()
swandb.bind(tables)
swandb.create_tables(tables)
swandb.close()
# 完成数据迁移,如果tag表中没有sort字段,则添加
if not Tag.field_exists("sort"):
# 不启用外键约束
add_sort(SqliteDatabase(path))
binded = True
return swandb
10 changes: 10 additions & 0 deletions swanlab/db/migrate/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
r"""
@DATE: 2024-02-19 15:41:25
@File: swanlab/db/migrate/__init__.py
@IDE: vscode
@Description:
数据库迁移模块
"""
from .tag_sort import add_sort
20 changes: 20 additions & 0 deletions swanlab/db/migrate/tag_sort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
r"""
@DATE: 2024-02-19 15:40:58
@File: swanlab/db/migrate/tag_sort.py
@IDE: vscode
@Description:
为数据库的Tag表添加sort字段
"""
from peewee import SqliteDatabase, IntegerField
from playhouse.migrate import migrate, SqliteMigrator


def add_sort(db: SqliteDatabase):
"""
为数据库的Tag表添加sort字段
"""
migrator = SqliteMigrator(db)
sort = IntegerField(default=0)
migrate(migrator.add_column("tag", "sort", sort))
20 changes: 19 additions & 1 deletion swanlab/db/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
@Description:
在此处定义基础模型类
"""
from peewee import Model
from peewee import Model, OperationalError
from playhouse.shortcuts import model_to_dict
from ..utils import create_time
import json
Expand Down Expand Up @@ -116,3 +116,21 @@ def create(cls, **query):
query["create_time"] = current_time
query["update_time"] = current_time
return super().create(**query)

@classmethod
def field_exists(cls, field: str) -> bool:
"""
判断某个字段是否存在于表中
Parameters
----------
field : str
字段名
"""
# 进行一次查询
try:
a = cls.select().where(cls.__getattribute__(cls, field) is not None)
cls.search2dict(a)
except OperationalError:
return False
return True
9 changes: 7 additions & 2 deletions swanlab/db/models/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
实验 tag 表
"""
from ..model import SwanModel
from peewee import ForeignKeyField, CharField, TextField, IntegerField, IntegrityError, DatabaseProxy
from peewee import ForeignKeyField, CharField, TextField, IntegerField, IntegrityError, DatabaseProxy, fn
from ..error import ExistedError, ForeignExpNotExistedError
from .experiments import Experiment

Expand Down Expand Up @@ -40,6 +40,8 @@ class Meta:
"""tag的描述,可为空"""
system = IntegerField(default=0, choices=[0, 1])
"""标识这个tag数据由系统生成还是用户生成,0: 用户生成,1: 系统生成,默认为0"""
sort = IntegerField(default=0)
"""tag在实验中的排序,值越小越靠前"""
more = TextField(null=True)
"""更多信息配置,json格式,将在表函数中检查并解析"""
create_time = CharField(max_length=30, null=False)
Expand Down Expand Up @@ -101,7 +103,9 @@ def create(
# 如果实验id不存在,则抛出异常
if not Experiment.filter(Experiment.id == experiment_id).exists():
raise ForeignExpNotExistedError("experiment不存在")

# 获取当前实验下tag的最大排序索引,如果没有则为0
sort = Tag.select(fn.Max(Tag.sort)).where(Tag.experiment_id == experiment_id).scalar()
sort = sort + 1 if sort is not None else 0
# 尝试创建实验tag,如果已经存在则抛出异常
try:
return super().create(
Expand All @@ -110,6 +114,7 @@ def create(
type=type,
description=description,
system=system,
sort=sort,
more=more,
)
except IntegrityError:
Expand Down
21 changes: 20 additions & 1 deletion swanlab/server/controller/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@
# 实验运行状态
RUNNING_STATUS = Experiment.RUNNING_STATUS

# tag 总结文件名
TAG_SUMMARY_FILE = "_summary.json"
# logs 目录下的配置文件
LOGS_CONFIGS = [TAG_SUMMARY_FILE]


# ---------------------------------- 工具函数 ----------------------------------

Expand Down Expand Up @@ -328,21 +333,35 @@ def get_experiment_summary(experiment_id: int) -> dict:
"""

experiment = Experiment.get_by_id(experiment_id)
# 通过外键反链获取实验下的所有tag
tag_list = [tag["name"] for tag in __to_list(experiment.tags)]
experiment_path = __get_logs_dir_by_id(experiment_id)
# 通过目录结构获取所有正常的tag
tags = [f for f in os.listdir(experiment_path) if os.path.isdir(os.path.join(experiment_path, f))]
# 实验总结数据
summaries = []
for tag in tag_list:
# 如果 tag 记录在数据库,但是没有对应目录,说明 tag 有问题
# 所以 tags 是 tag_list 的子集,出现异常的 tag 会记录在数据库但不会添加到目录结构中
if quote(tag, safe="") not in tags:
summaries.append({"key": tag, "value": "TypeError"})
continue
tag_path = os.path.join(experiment_path, quote(tag, safe=""))
logs = sorted([item for item in os.listdir(tag_path) if item != "_summary.json"])
logs = sorted([item for item in os.listdir(tag_path) if not item in LOGS_CONFIGS])
# 打开 tag 目录下最后一个存储文件,获取最后一条数据
with get_a_lock(os.path.join(tag_path, logs[-1]), mode="r") as f:
data = ujson.load(f)
# str 转化的目的是为了防止有些不合规范的数据导致返回体对象化失败
data = str(data["data"][-1]["data"])
summaries.append({"key": tag, "value": data})
# 获取数据库记录时在实验下的排序
sorts = {item["name"]: item["sort"] for item in __to_list(experiment.tags)}
# 如果 sorts 中的值不都为 0,说明是新版添加排序后的 tag,这里进行排序 (如果是旧版没有排序的tag,直接按照数据库顺序即可)
if not all(value == 0 for value in sorts.values()):
temp = [0] * len(summaries)
for item in summaries:
temp[sorts[item["key"]]] = item
summaries = temp

return SUCCESS_200({"summaries": summaries})

Expand Down
18 changes: 15 additions & 3 deletions swanlab/server/controller/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,22 @@ def get_project_summary(project_id: int = DEFAULT_PROJECT_ID) -> dict:
ids = [item["id"] for item in exprs]

# 根据 id 列表找到所有的 tag,提出不含重复 tag 名的元组
tags = Tag.filter(Tag.experiment_id.in_(ids))
# tags = Tag.filter(Tag.experiment_id.in_(ids)).order_by(Tag.experiment_id)
tags = (
Tag.select()
.join(Experiment, on=(Tag.experiment_id == Experiment.id))
.where(Tag.experiment_id.in_(ids))
.order_by(Tag.experiment_id, Tag.create_time)
)
# tag_names不用编码,因为前端需要展示
tag_names = list(set(tag["name"] for tag in __to_list(tags)))

# 要保持顺序,需要额外使用一个列表,直接用 set 会打乱顺序
unique_values = []
seen = set()
for value in [item["name"] for item in __to_list(tags)]:
if value not in seen:
unique_values.append(value)
seen.add(value)
tag_names = list(unique_values)
# 所有总结数据
data = {}
# 第一层循环对应实验层,每次探寻一个实验
Expand Down

0 comments on commit 7b964e9

Please sign in to comment.