diff --git a/.vscode/settings.json b/.vscode/settings.json index 3dbe0f5c..23e3a10a 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -17,7 +17,8 @@ "run": "core", "system": "core", "data": "guard", - "unit": "test" + "unit": "test", + "migrate": "Husky" }, "material-icon-theme.files.associations": { ".env.mock": "Tune" diff --git a/swanlab/data/run/exp.py b/swanlab/data/run/exp.py index 9001296c..6c4742b8 100644 --- a/swanlab/data/run/exp.py +++ b/swanlab/data/run/exp.py @@ -108,6 +108,18 @@ class SwanLabTag: __slice_size = 1000 def __init__(self, experiment_id, tag: str, log_dir: str) -> None: + """ + 初始化tag对象 + + Parameters + ---------- + experiment_id : int + 实验id + tag : str + tag名称 + log_dir : str + log文件夹路径 + """ self.experiment_id = experiment_id self.tag = tag self.__steps = set() diff --git a/swanlab/db/db_connect.py b/swanlab/db/db_connect.py index 4d3a4d37..151f5531 100644 --- a/swanlab/db/db_connect.py +++ b/swanlab/db/db_connect.py @@ -10,7 +10,8 @@ from ..env import get_db_path import os from peewee import SqliteDatabase -from .table_config import tables +from .table_config import tables, Tag +from .migrate import * # 判断是否已经binded了 binded = False @@ -48,7 +49,7 @@ def connect(autocreate=False) -> SqliteDatabase: path_exists = os.path.exists(os.path.dirname(path)) if not path_exists or (not db_exists and not autocreate): raise FileNotFoundError(f"DB file {path} not found") - + # 启用外键约束 swandb = SqliteDatabase(path, pragmas={"foreign_keys": 1}) if not binded: # 动态绑定数据库 @@ -56,5 +57,9 @@ def connect(autocreate=False) -> SqliteDatabase: swandb.bind(tables) swandb.create_tables(tables) swandb.close() + # 完成数据迁移,如果tag表中没有sort字段,则添加 + if not Tag.field_exists("sort"): + # 不启用外键约束 + add_sort(SqliteDatabase(path)) binded = True return swandb diff --git a/swanlab/db/migrate/__init__.py b/swanlab/db/migrate/__init__.py new file mode 100644 index 00000000..661d08b9 --- /dev/null +++ b/swanlab/db/migrate/__init__.py @@ -0,0 +1,10 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +r""" +@DATE: 2024-02-19 15:41:25 +@File: swanlab/db/migrate/__init__.py +@IDE: vscode +@Description: + 数据库迁移模块 +""" +from .tag_sort import add_sort diff --git a/swanlab/db/migrate/tag_sort.py b/swanlab/db/migrate/tag_sort.py new file mode 100644 index 00000000..4c6c52a9 --- /dev/null +++ b/swanlab/db/migrate/tag_sort.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +r""" +@DATE: 2024-02-19 15:40:58 +@File: swanlab/db/migrate/tag_sort.py +@IDE: vscode +@Description: + 为数据库的Tag表添加sort字段 +""" +from peewee import SqliteDatabase, IntegerField +from playhouse.migrate import migrate, SqliteMigrator + + +def add_sort(db: SqliteDatabase): + """ + 为数据库的Tag表添加sort字段 + """ + migrator = SqliteMigrator(db) + sort = IntegerField(default=0) + migrate(migrator.add_column("tag", "sort", sort)) diff --git a/swanlab/db/model.py b/swanlab/db/model.py index 70e59bbb..fd6496eb 100644 --- a/swanlab/db/model.py +++ b/swanlab/db/model.py @@ -7,7 +7,7 @@ @Description: 在此处定义基础模型类 """ -from peewee import Model +from peewee import Model, OperationalError from playhouse.shortcuts import model_to_dict from ..utils import create_time import json @@ -116,3 +116,21 @@ def create(cls, **query): query["create_time"] = current_time query["update_time"] = current_time return super().create(**query) + + @classmethod + def field_exists(cls, field: str) -> bool: + """ + 判断某个字段是否存在于表中 + + Parameters + ---------- + field : str + 字段名 + """ + # 进行一次查询 + try: + a = cls.select().where(cls.__getattribute__(cls, field) is not None) + cls.search2dict(a) + except OperationalError: + return False + return True diff --git a/swanlab/db/models/tags.py b/swanlab/db/models/tags.py index 072766bc..e3da9617 100644 --- a/swanlab/db/models/tags.py +++ b/swanlab/db/models/tags.py @@ -8,7 +8,7 @@ 实验 tag 表 """ from ..model import SwanModel -from peewee import ForeignKeyField, CharField, TextField, IntegerField, IntegrityError, DatabaseProxy +from peewee import ForeignKeyField, CharField, TextField, IntegerField, IntegrityError, DatabaseProxy, fn from ..error import ExistedError, ForeignExpNotExistedError from .experiments import Experiment @@ -40,6 +40,8 @@ class Meta: """tag的描述,可为空""" system = IntegerField(default=0, choices=[0, 1]) """标识这个tag数据由系统生成还是用户生成,0: 用户生成,1: 系统生成,默认为0""" + sort = IntegerField(default=0) + """tag在实验中的排序,值越小越靠前""" more = TextField(null=True) """更多信息配置,json格式,将在表函数中检查并解析""" create_time = CharField(max_length=30, null=False) @@ -101,7 +103,9 @@ def create( # 如果实验id不存在,则抛出异常 if not Experiment.filter(Experiment.id == experiment_id).exists(): raise ForeignExpNotExistedError("experiment不存在") - + # 获取当前实验下tag的最大排序索引,如果没有则为0 + sort = Tag.select(fn.Max(Tag.sort)).where(Tag.experiment_id == experiment_id).scalar() + sort = sort + 1 if sort is not None else 0 # 尝试创建实验tag,如果已经存在则抛出异常 try: return super().create( @@ -110,6 +114,7 @@ def create( type=type, description=description, system=system, + sort=sort, more=more, ) except IntegrityError: diff --git a/swanlab/server/controller/experiment.py b/swanlab/server/controller/experiment.py index a39717c5..3172d8ba 100644 --- a/swanlab/server/controller/experiment.py +++ b/swanlab/server/controller/experiment.py @@ -49,6 +49,11 @@ # 实验运行状态 RUNNING_STATUS = Experiment.RUNNING_STATUS +# tag 总结文件名 +TAG_SUMMARY_FILE = "_summary.json" +# logs 目录下的配置文件 +LOGS_CONFIGS = [TAG_SUMMARY_FILE] + # ---------------------------------- 工具函数 ---------------------------------- @@ -328,21 +333,35 @@ def get_experiment_summary(experiment_id: int) -> dict: """ experiment = Experiment.get_by_id(experiment_id) + # 通过外键反链获取实验下的所有tag tag_list = [tag["name"] for tag in __to_list(experiment.tags)] experiment_path = __get_logs_dir_by_id(experiment_id) + # 通过目录结构获取所有正常的tag tags = [f for f in os.listdir(experiment_path) if os.path.isdir(os.path.join(experiment_path, f))] + # 实验总结数据 summaries = [] for tag in tag_list: + # 如果 tag 记录在数据库,但是没有对应目录,说明 tag 有问题 + # 所以 tags 是 tag_list 的子集,出现异常的 tag 会记录在数据库但不会添加到目录结构中 if quote(tag, safe="") not in tags: summaries.append({"key": tag, "value": "TypeError"}) continue tag_path = os.path.join(experiment_path, quote(tag, safe="")) - logs = sorted([item for item in os.listdir(tag_path) if item != "_summary.json"]) + logs = sorted([item for item in os.listdir(tag_path) if not item in LOGS_CONFIGS]) + # 打开 tag 目录下最后一个存储文件,获取最后一条数据 with get_a_lock(os.path.join(tag_path, logs[-1]), mode="r") as f: data = ujson.load(f) # str 转化的目的是为了防止有些不合规范的数据导致返回体对象化失败 data = str(data["data"][-1]["data"]) summaries.append({"key": tag, "value": data}) + # 获取数据库记录时在实验下的排序 + sorts = {item["name"]: item["sort"] for item in __to_list(experiment.tags)} + # 如果 sorts 中的值不都为 0,说明是新版添加排序后的 tag,这里进行排序 (如果是旧版没有排序的tag,直接按照数据库顺序即可) + if not all(value == 0 for value in sorts.values()): + temp = [0] * len(summaries) + for item in summaries: + temp[sorts[item["key"]]] = item + summaries = temp return SUCCESS_200({"summaries": summaries}) diff --git a/swanlab/server/controller/project.py b/swanlab/server/controller/project.py index 8d28cca6..ca8fb0a4 100644 --- a/swanlab/server/controller/project.py +++ b/swanlab/server/controller/project.py @@ -156,10 +156,22 @@ def get_project_summary(project_id: int = DEFAULT_PROJECT_ID) -> dict: ids = [item["id"] for item in exprs] # 根据 id 列表找到所有的 tag,提出不含重复 tag 名的元组 - tags = Tag.filter(Tag.experiment_id.in_(ids)) + # tags = Tag.filter(Tag.experiment_id.in_(ids)).order_by(Tag.experiment_id) + tags = ( + Tag.select() + .join(Experiment, on=(Tag.experiment_id == Experiment.id)) + .where(Tag.experiment_id.in_(ids)) + .order_by(Tag.experiment_id, Tag.create_time) + ) # tag_names不用编码,因为前端需要展示 - tag_names = list(set(tag["name"] for tag in __to_list(tags))) - + # 要保持顺序,需要额外使用一个列表,直接用 set 会打乱顺序 + unique_values = [] + seen = set() + for value in [item["name"] for item in __to_list(tags)]: + if value not in seen: + unique_values.append(value) + seen.add(value) + tag_names = list(unique_values) # 所有总结数据 data = {} # 第一层循环对应实验层,每次探寻一个实验