diff --git a/.eslintignore b/.eslintignore
index 3a2b7636..3005c266 100644
--- a/.eslintignore
+++ b/.eslintignore
@@ -17,3 +17,6 @@ dist/
package.json
*.md
+
+
+swanlab/
diff --git "a/.github/ISSUE_TEMPLATE/bug\346\217\220\344\272\244.md" "b/.github/ISSUE_TEMPLATE/bug\346\217\220\344\272\244.md"
new file mode 100644
index 00000000..3732203d
--- /dev/null
+++ "b/.github/ISSUE_TEMPLATE/bug\346\217\220\344\272\244.md"
@@ -0,0 +1,29 @@
+---
+name: Bug提交
+about: 向开发者们反映出现的bug
+title: '[BUG] '
+labels: BUG
+assignees: ''
+---
+
+## Bug 描述
+
+> 描述 bug 的主要内容
+
+## 如何复现
+
+> 在此处向开发者描述 bug 的复现过程,在必要时请附上截图
+
+1. 前往 '....'
+
+2. 点击 '....'
+
+3. 出现如下问题 '....'
+
+## 预期行为
+
+> 向开发者描述如果没有此 bug,应该是什么样的
+
+## 录屏
+
+> 如果需要,可以附上截图
diff --git a/.gitignore b/.gitignore
index 2e93538f..6201c874 100644
--- a/.gitignore
+++ b/.gitignore
@@ -5,6 +5,7 @@ package-lock.json
pnpm-lock.yaml
vue/components.d.ts
vue/auto-imports.d.ts
+swanlog/
# Byte-compiled / optimized / DLL files
__pycache__/
diff --git a/.prettierignore b/.prettierignore
index a0eb825f..741c0385 100644
--- a/.prettierignore
+++ b/.prettierignore
@@ -64,3 +64,5 @@ README.md
doc/**
**.md
+
+swanlab/
diff --git a/.vscode/launch.json b/.vscode/launch.json
index 864a5c9e..ae542daa 100644
--- a/.vscode/launch.json
+++ b/.vscode/launch.json
@@ -21,15 +21,6 @@
"runtimeExecutable": "npm",
"runtimeArgs": ["run", "dev:mock"]
},
- // 启动前端编译服务
- {
- "name": "编译前端项目",
- "request": "launch",
- "cwd": "${workspaceRoot}",
- "type": "node",
- "runtimeExecutable": "npm",
- "runtimeArgs": ["run", "build"]
- },
// 启动后端开发服务
{
"name": "后端开发",
@@ -42,25 +33,27 @@
//sys.path 会加入顶层目录,影响模块导入查询路径
"env": { "PYTHONPATH": "${workspaceFolder}" }
},
+ // 打包命令
{
- "name": "开启一个实验",
+ "name": "构建项目",
"type": "python",
"request": "launch",
- "program": "${workspaceFolder}/test/create_experiment.py",
+ "program": "${workspaceFolder}/build_pypi.py",
"console": "integratedTerminal",
"justMyCode": true,
- "cwd": "${workspaceFolder}",
- //sys.path 会加入顶层目录,影响模块导入查询路径
- "env": { "PYTHONPATH": "${workspaceFolder}" }
+ "cwd": "${workspaceFolder}"
},
+ // 模拟实验开启
{
- "name": "构建项目",
+ "name": "开启一个实验",
"type": "python",
"request": "launch",
- "program": "${workspaceFolder}/build_pypi.py",
+ "program": "${workspaceFolder}/test/create_experiment.py",
"console": "integratedTerminal",
"justMyCode": true,
- "cwd": "${workspaceFolder}"
+ "cwd": "${workspaceFolder}",
+ //sys.path 会加入顶层目录,影响模块导入查询路径
+ "env": { "PYTHONPATH": "${workspaceFolder}" }
},
// python运行当前文件
{
diff --git a/.vscode/settings.json b/.vscode/settings.json
index a1a7f399..03743bd4 100644
--- a/.vscode/settings.json
+++ b/.vscode/settings.json
@@ -5,7 +5,7 @@
"material-icon-theme.folders.associations": {
"swanlab": "api",
"wandb": "database",
- ".swanlab": "admin",
+ "swanlog": "log",
".config": "config",
"store": "database",
"help": "java",
diff --git a/build_pypi.py b/build_pypi.py
index 97bf530a..f58093f4 100644
--- a/build_pypi.py
+++ b/build_pypi.py
@@ -12,7 +12,7 @@
import os
# 构建node项目
-subprocess.run("npm run build", shell=True)
+subprocess.run("npm run build.release", shell=True)
# 如果dist文件夹存在则删除
if os.path.exists("dist"):
shutil.rmtree("dist")
diff --git a/package.json b/package.json
index a7578112..03b6760f 100644
--- a/package.json
+++ b/package.json
@@ -1,12 +1,13 @@
{
"name": "swanlab-ui",
"private": true,
- "version": "0.0.1beta6",
+ "version": "0.0.2",
"type": "module",
"scripts": {
"dev": "vite",
"dev:mock": "vite --mode mock",
"build": "vite build",
+ "build.release": "vite build --mode release",
"preview": "vite preview"
},
"dependencies": {
@@ -16,6 +17,7 @@
"moment": "^2.29.4",
"pinia": "^2.1.7",
"sass": "^1.69.5",
+ "terser": "^5.26.0",
"vue": "^3.3.8",
"vue-i18n": "^9.8.0",
"vue-router": "^4.2.5",
diff --git a/requirements.txt b/requirements.txt
index 9e14e528..7eeafa12 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -8,3 +8,6 @@ click
ujson
portalocker
+# Information collection
+psutil
+
diff --git a/swanlab/__init__.py b/swanlab/__init__.py
index 38e59238..320e793e 100755
--- a/swanlab/__init__.py
+++ b/swanlab/__init__.py
@@ -1,4 +1,12 @@
-from .database import sd
+from .log import swanlog as swl
+from .env import swc
+
+_sd = None
+"""
+swandatabase对象
+使用动态导入的方式有助于环境隔离
+比如cli不需要此对象,就不需要导入
+"""
def init(experiment_name: str = None, description: str = "", config: dict = {}):
@@ -13,11 +21,24 @@ def init(experiment_name: str = None, description: str = "", config: dict = {}):
config : dict, optional
实验可选配置,在此处可以记录一些实验的超参数等信息
"""
- sd.init(
+ global _sd
+ if _sd is not None:
+ raise RuntimeError("swanlab has been initialized")
+ from .database import swandatabase as sd
+
+ # 挂载对象
+ _sd = sd
+ # 初始化数据库
+ _sd.init(
experiment_name=experiment_name,
description=description,
config=config,
)
+ # 初始化日志对象
+ swl.init(swc.output)
+ swl.info("Run data will be saved locally in " + swc.exp_folder)
+ swl.info("Experiment_name: " + _sd.experiment.name)
+ swl.info("Run `swanlab watch` to view SwanLab Experiment Dashboard")
def log(data: dict):
@@ -31,11 +52,14 @@ def log(data: dict):
data : dict
此处填写需要记录的数据
"""
+ if _sd is None:
+ raise RuntimeError("swanlab has not been initialized")
+
if not isinstance(data, dict):
raise TypeError("log data must be a dict")
for key in data:
# 遍历字典的key,记录到本地文件中
- sd.add(key, data[key])
+ _sd.add(key, data[key])
__all__ = ["init", "log"]
diff --git a/swanlab/cli/main.py b/swanlab/cli/main.py
index e4081366..02a158ac 100644
--- a/swanlab/cli/main.py
+++ b/swanlab/cli/main.py
@@ -9,7 +9,7 @@
"""
import click
-import uvicorn
+from .utils import is_vaild_ip, is_available_port
@click.group()
@@ -18,38 +18,57 @@ def cli():
@cli.command()
+# 控制服务发布的ip地址
@click.option(
- "--share",
- is_flag=True,
- help="When shared, swanlab web will run on localhost",
-)
-@click.option(
- "--debug",
- is_flag=True,
- help="Show more logs when use debug mode",
+ "--host",
+ "-h",
+ default="127.0.0.1",
+ type=str,
+ help="The host of swanlab web, default by 127.0.0.1",
+ callback=is_vaild_ip,
)
+# 控制服务发布的端口,默认5092
@click.option(
"--port",
"-p",
default=5092,
+ type=int,
help="The port of swanlab web, default by 5092",
)
-def watch(share, debug, port):
+# 日志等级
+@click.option(
+ "--log-level",
+ default="info",
+ type=click.Choice(["debug", "info", "warning", "error", "critical"]),
+ help="The level of log, default by info; You can choose one of [debug, info, warning, error, critical]",
+)
+def watch(log_level: str, host: tuple, port: int):
"""Run this command to turn on the swanlab service."""
- # print("share", share)
- # print("debug", debug)
- # print("port", port)
+ # 导入必要的模块
+ from ..log import swanlog as swl
from ..server import app
+ import uvicorn
- # 服务地址
- host = "localhost" if share else "127.0.0.1"
- # 日志等级
- log_level = "info" if debug else "warning"
- click.echo(f"swanlab running on \033[1mhttp://{host}:{port}\033[0m")
+ # ---------------------------------- 日志等级处理 ----------------------------------
+ swl.setLevel(log_level)
+ # ---------------------------------- 服务地址处理 ----------------------------------
+ # 拿到当前本机可用的所有ip地址
+ ip, ipv4 = host
+ ips = [f"http://{ip}:{port}" for ip in ipv4]
+ # 判断ip:port是否被占用
+ is_available_port(ip, port)
+ # ---------------------------------- 日志打印 ----------------------------------
+ if ip == "0.0.0.0":
+ # 检查每个ip地址的端口占用情况
+ swl.info(f"SwanLab Experiment Dashboard running...")
+ swl.info(f"Available on: \n" + "\n".join(ips))
+ else:
+ swl.info(f"SwanLab Experiment Dashboard running on \033[1mhttp://{ip}:{port}\033[0m")
+ # ---------------------------------- 启动服务 ----------------------------------
- # 使用 uvicorn 启动 FastAPI 应用
- uvicorn.run(app, host=host, port=port, log_level=log_level)
+ # 使用 uvicorn 启动 FastAPI 应用,关闭原生日志
+ uvicorn.run(app, host=ip, port=port, log_level="critical")
if __name__ == "__main__":
- cli()
+ watch()
diff --git a/swanlab/cli/utils.py b/swanlab/cli/utils.py
new file mode 100644
index 00000000..fc31ec35
--- /dev/null
+++ b/swanlab/cli/utils.py
@@ -0,0 +1,59 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+r"""
+@DATE: 2023-12-15 15:53:21
+@File: swanlab/cli/utils.py
+@IDE: vscode
+@Description:
+ 命令行工具
+"""
+import re
+import psutil
+import socket
+import click
+
+
+def is_vaild_ip(ctx, param, ip: str) -> tuple:
+ """检测是否是合法的ip地址
+
+ Parameters
+ ----------
+ ctx : click.Context
+ 上下文
+ param : click.Parameter
+ 参数
+ ip : str
+ 带检测的字符串
+ """
+ ip = str(ip)
+ pattern = re.compile(r"^((2[0-4]\d|25[0-5]|[01]?\d\d?)\.){3}(2[0-4]\d|25[0-5]|[01]?\d\d?)$")
+ if not pattern.match(ip):
+ raise click.BadParameter("Invalid IP address format: " + ip)
+ # 没有问题,获取当前机器的所有ip地址
+ interfaces = psutil.net_if_addrs()
+ ipv4 = []
+ for _, addresses in interfaces.items():
+ for address in addresses:
+ # 如果是ipv4地址
+ if address.family == socket.AddressFamily.AF_INET:
+ ipv4.append(address.address)
+ if ip not in ipv4 and ip != "0.0.0.0":
+ raise click.BadParameter("IP address '" + ip + "' should be one of " + str(ipv4) + ".")
+ return ip, ipv4
+
+
+def is_available_port(host, port):
+ """检测端口是否可用
+
+ Parameters
+ ----------
+ host : str
+ ip地址
+ port : int
+ 端口号
+ """
+ try:
+ with socket.create_server((host, port), reuse_port=True):
+ pass
+ except:
+ raise OSError("Port '" + str(port) + "' is not available on " + host + ".")
diff --git a/swanlab/database/__init__.py b/swanlab/database/__init__.py
index bd62cb1f..5ef2d146 100644
--- a/swanlab/database/__init__.py
+++ b/swanlab/database/__init__.py
@@ -1,14 +1,55 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-r"""
-@DATE: 2023-11-26 16:55:09
-@File: swanlab\database\__init__.py
-@IDE: vscode
-@Description:
- 数据库模块,用于创建数据库连接并执行一些数据库操作,也封装数据搜集函数
-"""
-from .database import sd
-from .project import PT
-
-
-# print("swanlab.database init")
+import atexit, sys, traceback, os
+from datetime import datetime
+from ..env import swc
+from ..log import swanlog as swl
+
+# 注册环境变量,需要在初始化数据库之前注册
+swc.init(swc.getcwd(), "train")
+# 初始化数据库
+from .main import SwanDatabase
+
+swandatabase = SwanDatabase()
+
+
+# 定义清理函数
+def clean_handler():
+ if not swl.isError:
+ swl.info("train successfully")
+ swandatabase.success()
+ swl.setSuccess()
+ swl.reset_console()
+
+
+# 定义异常处理函数
+def except_handler(tp, val, tb):
+ swl.error("Error happended while training, SwanLab will throw it")
+ # 标记实验失败
+ swandatabase.fail()
+ swl.setError()
+ # 记录异常信息
+ # 追踪信息
+ traceList = traceback.format_tb(tb)
+ html = repr(tp) + "\n"
+ html += repr(val) + "\n"
+ for line in traceList:
+ html += line + "\n"
+
+ if os.path.exists(swc.error):
+ swl.warning("Error log file already exists, append error log to it")
+ # 写入日志文件
+ with open(swc.error, "a") as fError:
+ print(datetime.now(), file=fError)
+ print(html, file=fError)
+ # 重置控制台记录器
+ swl.reset_console()
+ raise tp(val)
+
+
+# 注册异常处理函数
+sys.excepthook = except_handler
+
+
+# 注册清理函数
+atexit.register(clean_handler)
+
+__all__ = ["swandatabase"]
diff --git a/swanlab/database/chart.py b/swanlab/database/chart.py
index e6d5b1f0..b762266a 100644
--- a/swanlab/database/chart.py
+++ b/swanlab/database/chart.py
@@ -1,6 +1,7 @@
from .table import ProjectTablePoxy
import ujson
import os
+from ..env import swc
from ..utils import create_time, get_a_lock
from typing import List, Union
@@ -10,12 +11,12 @@ class ChartTable(ProjectTablePoxy):
default_data = {"_sum": 0, "charts": []}
- def __init__(self, base_path: str, experiment_id: int):
+ def __init__(self, experiment_id: int):
"""初始化图表管理类"""
# 判断path是否存在,如果存在,则加载数据,否则创建
self.experiment_id = experiment_id
# 文件保存路径
- self.path = os.path.join(base_path, "charts.json")
+ self.path = swc.chart
if os.path.exists(self.path):
with open(self.path, "r", encoding="utf-8") as f:
data = ujson.load(f)
diff --git a/swanlab/database/expriment.py b/swanlab/database/expriment.py
index 43583fe4..c854d94d 100644
--- a/swanlab/database/expriment.py
+++ b/swanlab/database/expriment.py
@@ -10,7 +10,7 @@
from .table import ExperimentPoxy
from .chart import ChartTable
import os
-from ..env import SWANLAB_LOGS_FOLDER
+from ..env import swc
from typing import Union
from ..utils import create_time, generate_color
from .system import get_system_info
@@ -23,10 +23,10 @@ class ExperimentTable(ExperimentPoxy):
def __init__(self, experiment_id: int, name: str, description: str, config: dict, index: int):
# 初始化一个实验配置,name必须保证唯一,但是不在此处检查,而是在创建实验的时候检查
# 创建name对应的文件夹
- path = os.path.join(SWANLAB_LOGS_FOLDER, name)
- if not os.path.exists(path):
- os.mkdir(path)
- super().__init__(path)
+ swc.add_exp(name)
+ if not os.path.exists(swc.logs_folder):
+ os.makedirs(swc.logs_folder)
+ super().__init__(swc.logs_folder)
self.experiment_id = experiment_id
self.name = name
# tags数据不会被序列化
@@ -37,8 +37,8 @@ def __init__(self, experiment_id: int, name: str, description: str, config: dict
self.argv = sys.argv
self.index = index
self.status = 0 # 0: 正在运行,1: 运行成功,-1: 运行失败
- self.__chart = ChartTable(base_path=path, experiment_id=experiment_id)
- self.color = generate_color()
+ self.__chart = ChartTable(experiment_id=experiment_id)
+ self.color = generate_color(experiment_id)
def __dict__(self) -> dict:
"""序列化此对象
@@ -110,3 +110,8 @@ def success(self):
"""实验成功完成,更新实验状态"""
self.update_time = create_time()
self.status = 1
+
+ def fail(self):
+ """实验失败,更新状态"""
+ self.update_time = create_time()
+ self.status = -1
diff --git a/swanlab/database/database.py b/swanlab/database/main.py
similarity index 78%
rename from swanlab/database/database.py
rename to swanlab/database/main.py
index b2e01a0e..c92a7112 100644
--- a/swanlab/database/database.py
+++ b/swanlab/database/main.py
@@ -2,23 +2,19 @@
# -*- coding: utf-8 -*-
r"""
@DATE: 2023-12-02 00:38:37
-@File: swanlab\database\database.py
+@File: swanlab\database\main.py
@IDE: vscode
@Description:
数据库模块,连接project表单对象
"""
import os
-from ..env import SWANLAB_LOGS_FOLDER
+from ..env import swc, SwanlabConfig
from .project import ProjectTable
from .expriment import ExperimentTable
from ..utils import lock_file
from typing import Union
from io import TextIOWrapper
import ujson
-import atexit
-
-# flag,代表已经执行了inited函数
-inited = False
class SwanDatabase(object):
@@ -36,15 +32,10 @@ def __init__(self):
"""
# 此时必须保证.swanlab文件夹存在,但是这并不是本类的职责,所以不检查
- # TODO 但是目前还是先在这里创建
- from ..env import SWANLAB_FOLDER
-
- if not os.path.exists(SWANLAB_FOLDER):
- os.mkdir(SWANLAB_FOLDER)
+ swc.init(SwanlabConfig.getcwd(), "train")
+ if not os.path.exists(swc.root):
+ os.mkdir(swc.root)
- # 需要检查logs文件夹是否存在,不存在则创建
- if not os.path.exists(SWANLAB_LOGS_FOLDER):
- os.mkdir(SWANLAB_LOGS_FOLDER)
# 项目基础表单
self.__project: ProjectTable = None
# 如果项目配置文件不存在,创建
@@ -72,18 +63,6 @@ def init(
file : TextIOWrapper, optional
文件对象,用于文件锁定, by default None
"""
- # 同一运行时不允许运行两次此函数,通过flag来实现
- global inited
- if inited:
- raise RuntimeError("Swanlab has already been inited!")
- inited = True
-
- # 创建回调函数函数
- def callback():
- sd.success()
-
- # 注册此函数
- atexit.register(callback)
# 检查实验名称是否存在
project_exist = os.path.exists(ProjectTable.path) and os.path.getsize(ProjectTable.path) != 0
# 初始化项目对象
@@ -115,14 +94,12 @@ def add(self, tag: str, data: Union[str, float], namespace: str = "charts"):
namespace : str, optional
命名空间,用于区分不同的数据资源(对应{experiment_name}$chart中的tag), by default "charts"
"""
- global inited
- if not inited:
- raise RuntimeError("Swanlab must be inited first!")
self.__project.experiment.add(tag, data, namespace)
def success(self):
"""标记实验成功"""
self.__project.success()
-
-sd = SwanDatabase()
+ def fail(self):
+ """标记实验失败"""
+ self.__project.fail()
diff --git a/swanlab/database/project.py b/swanlab/database/project.py
index c77b7b8a..af7a6cf5 100644
--- a/swanlab/database/project.py
+++ b/swanlab/database/project.py
@@ -8,7 +8,7 @@
项目模块,创建项目级别数据库库,接下来针对实验级别的数据在此基础上进行操作
"""
import os
-from ..env import SWANLAB_LOGS_FOLDER
+from ..env import swc
from .experiments_name import generate_random_tree_name, check_experiment_name, make_experiment_name_unique
from .table import ProjectTablePoxy
from .expriment import ExperimentTable
@@ -24,7 +24,7 @@ class ProjectTable(ProjectTablePoxy):
data: dict,实验管理类的数据,json格式
"""
- path = os.path.join(SWANLAB_LOGS_FOLDER, "project.json")
+ path = swc.project
default_data = {"_sum": 0, "experiments": []}
def __init__(self, data: dict):
@@ -64,10 +64,13 @@ def add_experiment(self, name: str = None, description: str = None, config: dict
"""
# 获取当前已经存在的实验名称集合
experiments = [item["name"] for item in self["experiments"]]
+
+ # 获取实验名称
if name is None:
name = generate_random_tree_name(experiments)
else:
check_experiment_name(name)
+ # 获取实验描述和配置
if description is None:
description = ""
if config is None:
@@ -79,7 +82,6 @@ def add_experiment(self, name: str = None, description: str = None, config: dict
self.__experiment = ExperimentTable(self.sum, name, description, config, len(experiments) + 1)
# 添加一个实验到self["experiments"]中
self["experiments"].append(self.__experiment.__dict__())
- print("add experiment")
@lock_file(file_path=path, mode="r+")
def success(self, file: TextIOWrapper):
@@ -94,13 +96,15 @@ def success(self, file: TextIOWrapper):
break
self.save(file, project)
-
-class PT(object):
- """后端层面上的项目管理类,适配后端的项目管理接口,提供项目管理的相关功能"""
-
- path = ProjectTable.path
-
- @lock_file(file_path=path, mode="r")
- def get(self, file: TextIOWrapper):
- """获取实验信息"""
- return ujson.load(file)
+ @lock_file(file_path=path, mode="r+")
+ def fail(self, file: TextIOWrapper):
+ """实验失败,更新实验状态,再次保存实验信息"""
+ # 锁上文件,更新实验状态
+ project = ujson.load(file)
+ self.__experiment.fail()
+ for index, experiment in enumerate(project["experiments"]):
+ if experiment["experiment_id"] == self.__experiment.experiment_id:
+ project["experiments"][index] = self.__experiment.__dict__()
+ # print("success experiment ", project["experiments"][index])
+ break
+ self.save(file, project)
diff --git a/swanlab/env.py b/swanlab/env.py
index 6262c939..b17bfea2 100644
--- a/swanlab/env.py
+++ b/swanlab/env.py
@@ -1,5 +1,4 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
+#!/usr/bin/env python# -*- coding: utf-8 -*-
r"""
@DATE: 2023-11-30 21:20:13
@File: swanlab\env.py
@@ -9,10 +8,8 @@
"""
import os
import mimetypes
+from functools import wraps
-# 默认存放数据的目录为用户执行python命令时的目录
-SWANLAB_FOLDER = os.path.join(os.getcwd(), ".swanlab")
-SWANLAB_LOGS_FOLDER = os.path.join(SWANLAB_FOLDER, "logs")
"""
在此处注册静态文件路径,因为静态文件由vue框架编译后生成,在配置中,编译后的文件存储在/swanlab/template中
入口文件为index.html,网页图标为logo.ico,其他文件为assets文件夹中的文件
@@ -26,3 +23,163 @@
TEMPLATE_PATH = os.path.join(FILEPATH, "template")
ASSETS = os.path.join(TEMPLATE_PATH, "assets")
INDEX = os.path.join(TEMPLATE_PATH, "index.html")
+
+
+class SwanlabConfig(object):
+ """Swanlab全局配置对象"""
+
+ def __init__(self) -> None:
+ # 标志位,用于判断是否已经初始化
+ self.__init = False
+ # 根目录,这将决定日志输出的位置以及服务读取的位置
+ self.__folder = None
+ # 当前实验名称
+ self.__exp_name = None
+ # 当前模式,可选值: train, server; 前者代表日志记录模式,后者代表服务模式
+ self.__mode = None
+
+ def _should_initialized(func):
+ """装饰器:必须在初始化完毕以后才能执行"""
+
+ def wrapper(cls, *args, **kwargs):
+ if cls.__init is False:
+ raise ValueError("config has not been initialized")
+ result = func(cls, *args, **kwargs)
+ return result
+
+ return wrapper
+
+ def _should_added_exp(func):
+ """装饰器:比如已经添加了实验"""
+
+ def wrapper(cls, *args, **kwargs):
+ if cls.__exp_name is None:
+ raise ValueError("config has not add experiment")
+ result = func(cls, *args, **kwargs)
+ return result
+
+ return wrapper
+
+ def _should_server_mode(func):
+ """装饰器:必须是server mode"""
+
+ def wrapper(cls, *args, **kwargs):
+ if cls.__mode != "server":
+ raise ValueError(f"{func.__name__} is only available in server mode")
+ result = func(cls, *args, **kwargs)
+ return result
+
+ return wrapper
+
+ def _should_train_mode(func):
+ """装饰器:必须是train mode"""
+
+ def wrapper(cls, *args, **kwargs):
+ if cls.__mode != "train":
+ raise ValueError(f"{func.__name__} is only available in train mode")
+ result = func(cls, *args, **kwargs)
+ return result
+
+ return wrapper
+
+ def init(self, root: str, mode: str):
+ """初始化配置对象"""
+ if self.__init:
+ # TODO debug输出一下,已经初始化了
+ return
+ self.__folder = root
+ if mode not in ["train", "server"]:
+ raise ValueError("mode must be train or server")
+ self.__mode = mode
+ self.__init = True
+
+ def add_exp(self, exp_name: str):
+ if self.__exp_name is not None and self.__mode == "train":
+ raise ValueError("config has been added experiment in train mode")
+ self.__exp_name = exp_name
+
+ @property
+ @_should_initialized
+ def isTrain(self) -> str:
+ """当前模式是否为训练模式"""
+ return self.__mode == "train"
+
+ @staticmethod
+ def getcwd() -> str:
+ """当前程序运行路径,不包括文件名"""
+ return os.getcwd()
+
+ @property
+ @_should_initialized
+ def root(self) -> str:
+ """项目输出根目录,必须先被初始化"""
+ r = os.path.join(self.__folder, "swanlog")
+ if not os.path.exists(r):
+ os.mkdir(r)
+ return r
+
+ @property
+ @_should_initialized
+ def project(self) -> str:
+ """项目配置文件路径,必须是训练模式"""
+ return os.path.join(self.root, "project.json")
+
+ @property
+ @_should_initialized
+ def output(self) -> str:
+ """服务日志输出文件路径或者训练时swanlab的日志输出文件路径"""
+ if self.__mode == "train":
+ return os.path.join(self.exp_folder, "output.log")
+ else:
+ #
+ return os.path.join(self.root, "output.log")
+
+ @property
+ @_should_initialized
+ @_should_train_mode
+ @_should_added_exp
+ def exp_folder(self) -> str:
+ """实验存储路径"""
+ return os.path.join(self.root, self.__exp_name)
+
+ @property
+ @_should_initialized
+ @_should_train_mode
+ @_should_added_exp
+ def logs_folder(self) -> str:
+ """日志输出根目录,必须是训练模式"""
+ return os.path.join(self.root, self.__exp_name, "logs")
+
+ @property
+ @_should_initialized
+ @_should_train_mode
+ @_should_added_exp
+ def chart(self) -> str:
+ """表格路径"""
+ return os.path.join(self.root, self.__exp_name, "chart.json")
+
+ @property
+ @_should_initialized
+ @_should_train_mode
+ @_should_added_exp
+ def console_folder(self) -> str:
+ """终端监听文件根目录,必须是训练模式"""
+ return os.path.join(self.root, self.__exp_name, "console")
+
+ @property
+ @_should_initialized
+ @_should_train_mode
+ @_should_added_exp
+ def error(self) -> str:
+ """终端错误日志打印路径"""
+ return os.path.join(self.root, self.__exp_name, "console", "error.log")
+
+
+swc = SwanlabConfig()
+
+
+if __name__ == "__main__":
+ swc.init("test", "server")
+ print(swc.root)
+ print(swc.output)
+ print(swc.console_folder)
diff --git a/swanlab/log/__init__.py b/swanlab/log/__init__.py
new file mode 100644
index 00000000..adc29100
--- /dev/null
+++ b/swanlab/log/__init__.py
@@ -0,0 +1,12 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+r"""
+@DATE: 2023-12-15 17:33:56
+@File: swanlab/log/__init__.py
+@IDE: vscode
+@Description:
+ 日志记录模块
+"""
+from .log import Swanlog
+
+swanlog = Swanlog("SwanLab")
diff --git a/swanlab/log/console.py b/swanlab/log/console.py
new file mode 100644
index 00000000..99ecc2b2
--- /dev/null
+++ b/swanlab/log/console.py
@@ -0,0 +1,70 @@
+import sys
+import os
+from datetime import datetime
+
+
+class Consoler(sys.stdout.__class__):
+ def __init__(self):
+ super().__init__(sys.stdout.buffer)
+ self.original_stdout = sys.stdout # 保存原始的 sys.stdout
+
+ def init(self, path):
+ # 通过当前日期生成日志文件名
+ self.now = datetime.now().strftime("%Y-%m-%d")
+ self.console_folder = path
+ # path 是否存在
+ if not os.path.exists(path):
+ os.makedirs(path)
+ # 日志文件路径
+ console_path = os.path.join(path, f"{self.now}.log")
+ # 日志文件
+ self.console = open(console_path, "a")
+
+ # 检查当前日期是否和控制台日志文件名一致
+ def _check_file_name(func):
+ """装饰器,判断是否需要根据日期对控制台输出进行分片存储"""
+
+ def wrapper(self, *args, **kwargs):
+ now = datetime.now().strftime("%Y-%m-%d")
+ # 检测now是否和self.now一致
+ if now != self.now:
+ self.now = now
+ if hasattr(self, "console") and not self.console.closed:
+ self.console.close()
+ self.console = open(os.path.join(self.console_folder, self.now + ".log"), "a")
+ return func(self, *args, **kwargs)
+
+ return wrapper
+
+ @_check_file_name
+ def write(self, message):
+ self.console.write(message)
+ self.console.flush()
+ self.original_stdout.write(message) # 同时写入原始 sys.stdout
+ self.original_stdout.flush()
+
+ @_check_file_name
+ def add(self, message: str):
+ """此接口用于写入额外的信息到日志文件中,但是不会写入到控制台
+
+ Parameters
+ ----------
+ message : str
+ 写入的信息
+ """
+ self.console.write(message)
+ self.console.flush()
+
+
+class SwanConsoler:
+ def __init__(self):
+ self.consoler: Consoler = Consoler()
+ self.add: function = self.consoler.add
+
+ def init(self, path):
+ self.consoler.init(path)
+ sys.stdout = self.consoler
+
+ def reset(self):
+ """重置输出为原本的样子"""
+ sys.stdout = self.consoler.original_stdout
diff --git a/swanlab/log/log.py b/swanlab/log/log.py
new file mode 100644
index 00000000..c388312e
--- /dev/null
+++ b/swanlab/log/log.py
@@ -0,0 +1,206 @@
+import logging
+import logging.config
+import logging.handlers
+from .console import SwanConsoler
+from ..env import swc
+
+
+class Logsys:
+ # 日志系统状态:running / success / error
+ __status = "running"
+
+ def __init__(self):
+ self.__status = "running"
+
+ def setSuccess(self):
+ if self.isRunning:
+ self.__status = "success"
+ else:
+ raise Exception("current status is %s. You can only set success while runnging" % self.__status)
+
+ def setError(self):
+ if self.isRunning:
+ self.__status = "error"
+ else:
+ raise Exception("current status is %s. You can only set success while runnging" % self.__status)
+
+ @property
+ def isSuccess(self) -> bool:
+ return self.__status == "success"
+
+ @property
+ def isError(self) -> bool:
+ return self.__status == "error"
+
+ @property
+ def isRunning(self) -> bool:
+ return self.__status == "running"
+
+
+# 新增的带颜色的格式化类
+class ColoredFormatter(logging.Formatter):
+ def __init__(self, fmt=None, datefmt=None, style="%", handle=None):
+ super().__init__(fmt, datefmt, style)
+ self.__handle = handle
+
+ _color_mapping = {
+ logging.DEBUG: "\033[37m", # White
+ logging.INFO: "\033[32m", # Green
+ logging.WARNING: "\033[33m", # Yellow
+ logging.ERROR: "\033[91m", # Red
+ logging.CRITICAL: "\033[1;31m", # Bold Red
+ }
+
+ def format(self, record):
+ log_message = super().format(record)
+ self.__handle(log_message + "\n") if self.__handle else None
+ color = self._color_mapping.get(record.levelno, "\033[0m") # Default: Reset color
+ reset_color = "\033[0m"
+ # 分割消息,分别处理头尾
+ messages: list = log_message.split(":", 1)
+ target_length = 20
+ message_header = messages[0] + ":" + " " * max(0, target_length - len(messages[0]))
+ return f"{color}{message_header}{reset_color} {messages[1]}"
+
+
+class Swanlog(Logsys):
+ # 日志系统支持的输出等级
+ __levels = {
+ "debug": logging.DEBUG,
+ "info": logging.INFO,
+ "warning": logging.WARNING,
+ "error": logging.ERROR,
+ "critical": logging.CRITICAL,
+ }
+
+ def __init__(self, name=__name__, level="debug"):
+ super()
+ self.logger = logging.getLogger(name)
+ self.logger.setLevel(self._getLevel(level))
+ self.__consoler: SwanConsoler = None
+
+ def init(self, path, level=None, console_level=None, file_level=None):
+ # 初始化的顺序最好别变,下面的一些设置方法没有使用查找式获取处理器,而是直接用索引获取的
+ # 所以 handlers 列表中,第一个是控制台处理器,第二个是日志文件处理器
+
+ # 初始化控制台记录器
+ if self.__consoler is None and swc.isTrain:
+ self.debug("init consoler")
+ self.__consoler = SwanConsoler()
+ self.__consoler.init(swc.console_folder)
+
+ self._create_console_handler()
+ self._create_file_handler(path)
+ if level:
+ self.logger.setLevel(self._getLevel(level))
+ if console_level:
+ self.setConsoleLevel(console_level)
+ if file_level:
+ self.setFileLevel(file_level)
+
+ # 检测日志处理器是否重复注册
+ def _check_init(func):
+ """装饰器,防止多次注册处理器"""
+
+ def wrapper(self, *args, **kwargs):
+ if len(self.logger.handlers) == 2:
+ return self.debug("init more than once")
+ result = func(self, *args, **kwargs)
+ return result
+
+ return wrapper
+
+ # 创建控制台记录器
+ @_check_init
+ def _create_console_handler(self, level="debug"):
+ console_handler = logging.StreamHandler()
+ handle = None if self.__consoler is None else self.__consoler.add
+ # 添加颜色格式化,并在此处设置格式化后的输出流是否可以被其他处理器处理
+ colored_formatter = ColoredFormatter("[%(name)s-%(levelname)s]: %(message)s", handle=handle)
+ console_handler.setFormatter(colored_formatter)
+ console_handler.setLevel(self._getLevel(level))
+ self.logger.addHandler(console_handler)
+
+ # 创建日志文件记录器
+ @_check_init
+ def _create_file_handler(self, log_path, level="debug"):
+ file_handler = logging.FileHandler(log_path)
+ formatter = logging.Formatter("%(name)s %(levelname)s [%(asctime)s] %(message)s")
+ file_handler.setFormatter(formatter)
+ file_handler.setLevel(self._getLevel(level))
+ self.logger.addHandler(file_handler)
+
+ def setOutput(self, log_path=None, level="debug"):
+ """
+ 设置日志文件的存储位置。
+
+ Parameters:
+ log_path (str): 日志文件路径。
+ level (str): 日志级别,可以是 "debug", "info", "warning", "error", 或 "critical".
+ """
+ file_handler = self.logger.handlers[1] # Assuming file handler is the second handler
+ self.logger.removeHandler(file_handler)
+ self._create_file_handler(log_path, level)
+
+ def setConsoleLevel(self, level):
+ """
+ 设置控制台输出的日志级别。
+
+ Parameters:
+ level (str): 日志级别,可以是 "debug", "info", "warning", "error", 或 "critical".
+ """
+ console_handler = self.logger.handlers[0]
+ console_handler.setLevel(self._getLevel(level))
+
+ def setFileLevel(self, level):
+ """
+ 设置写入日志文件的日志级别。
+
+ Parameters:
+ level (str): 日志级别,可以是 "debug", "info", "warning", "error", 或 "critical".
+ """
+ file_handler = self.logger.handlers[1]
+ file_handler.setLevel(self._getLevel(level))
+
+ def setLevel(self, level):
+ """
+ 设置日志级别。
+
+ Parameters:
+ level (str): 日志级别,可以是 "debug", "info", "warning", "error", 或 "critical".
+ """
+ self.logger.setLevel(self._getLevel(level))
+
+ # 获取对应等级的logging对象
+ def _getLevel(self, level):
+ if level.lower() in self.__levels:
+ return self.__levels.get(level.lower())
+ else:
+ raise KeyError(
+ "Invalid log level: %s, level must be one of ['debug', 'info', 'warning', 'error', 'critical']" % level
+ )
+
+ # 发送调试消息
+ def debug(self, message):
+ self.logger.debug(message)
+
+ # 发送通知
+ def info(self, message):
+ self.logger.info(message)
+
+ # 发生警告
+ def warning(self, message):
+ self.logger.warning(message)
+
+ # 发生错误
+ def error(self, message):
+ self.logger.error(message)
+
+ # 致命错误
+ def critical(self, message):
+ self.logger.critical(message)
+
+ def reset_console(self):
+ """重置控制台记录器"""
+ self.__consoler.reset()
+ self.__consoler = None
diff --git a/swanlab/server/__init__.py b/swanlab/server/__init__.py
index 8a1f2ad1..de367740 100644
--- a/swanlab/server/__init__.py
+++ b/swanlab/server/__init__.py
@@ -7,4 +7,12 @@
@Description:
在此处引出swanlab的web服务器框架,名为SwanWeb以及一些神奇的路由配置,以完成在库最外层的函数式调用
"""
+from swanlab.env import swc
+from swanlab.log import swanlog as swl
+
+# 先初始化配置文件和日志对象
+swc.init(swc.getcwd(), "server")
+swl.init(swc.output, level="debug")
+
+# 导出app对象
from .router import app
diff --git a/swanlab/server/api/experiment.py b/swanlab/server/api/experiment.py
index 3859162d..a84329db 100644
--- a/swanlab/server/api/experiment.py
+++ b/swanlab/server/api/experiment.py
@@ -8,11 +8,11 @@
实验相关api,前缀:/experiment
"""
from fastapi import APIRouter
-from ..utils import ResponseBody
-from ...env import SWANLAB_LOGS_FOLDER
-from ...database.project import ProjectTable
+from ..module.resp import SUCCESS_200, NOT_FOUND_404
+from ...env import swc
import os
import ujson
+from ...utils import DEFAULT_COLOR
# from ...utils import create_time
from urllib.parse import unquote # 转码路径参数
@@ -21,8 +21,6 @@
router = APIRouter()
-CONFIG_PATH = ProjectTable.path
-
# ---------------------------------- 工具函数 ----------------------------------
def __find_experiment(experiment_id: int) -> dict:
@@ -38,7 +36,7 @@ def __find_experiment(experiment_id: int) -> dict:
dict
实验信息
"""
- with get_a_lock(CONFIG_PATH, "r") as f:
+ with get_a_lock(swc.project, "r") as f:
experiments: list = ujson.load(f)["experiments"]
for experiment in experiments:
if experiment["experiment_id"] == experiment_id:
@@ -89,7 +87,6 @@ def __list_subdirectories(folder_path: str) -> List[str]:
# ---------------------------------- 业务路由 ----------------------------------
-# 获取当前实验信息
@router.get("/{experiment_id}")
async def get_experiment(experiment_id: int):
"""获取当前实验的信息
@@ -100,22 +97,24 @@ async def get_experiment(experiment_id: int):
实验唯一id,路径传参
"""
# 读取 project.json 文件内容
- with get_a_lock(CONFIG_PATH, "r") as f:
+ with get_a_lock(swc.project, "r") as f:
experiments: list = ujson.load(f)["experiments"]
- f.close()
- # 在experiments列表中查找对应实验的信息
- experiment = None
- for ex in experiments:
- if ex["experiment_id"] == experiment_id:
- experiment = ex
- break
- # 生成实验存储路径
- path = os.path.join(SWANLAB_LOGS_FOLDER, experiment["name"])
- experiment["tags"] = __list_subdirectories(path)
- return ResponseBody(0, data=experiment)
-
-
-# 获取某个实验的表单数据
+ # 在experiments列表中查找对应实验的信息
+ experiment = None
+ for ex in experiments:
+ if ex["experiment_id"] == experiment_id:
+ experiment = ex
+ break
+ # 如果没有找到,即实验不存在
+ if experiment is None:
+ return NOT_FOUND_404()
+ # 生成实验存储路径
+ path = os.path.join(swc.root, experiment["name"], "logs")
+ experiment["tags"] = __list_subdirectories(path)
+ experiment["default_color"] = DEFAULT_COLOR
+ return SUCCESS_200(experiment)
+
+
@router.get("/{experiment_id}/tag/{tag}")
async def get_tag_data(experiment_id: int, tag: str):
"""获取表单数据
@@ -132,12 +131,15 @@ async def get_tag_data(experiment_id: int, tag: str):
# num=None: 返回所有数据, num=10: 返回最新的10条数据, num=-1: 返回最后一条数据
num = None
# 在experiments列表中查找对应实验的信息
- experiment_name = __find_experiment(experiment_id)["name"]
+ try:
+ experiment_name = __find_experiment(experiment_id)["name"]
+ except KeyError as e:
+ return NOT_FOUND_404("experiment not found")
# ---------------------------------- 前置处理 ----------------------------------
# 获取tag对应的存储目录
- tag_path: str = os.path.join(SWANLAB_LOGS_FOLDER, experiment_name, tag)
+ tag_path: str = os.path.join(swc.root, experiment_name, "logs", tag)
if not os.path.exists(tag_path):
- raise KeyError(f'tag "{tag}" not found')
+ return NOT_FOUND_404("tag not found")
# 获取目录下存储的所有数据
# 降序排列,最新的数据在最前面
files: list = os.listdir(tag_path)
@@ -158,7 +160,6 @@ async def get_tag_data(experiment_id: int, tag: str):
# 倒数第二个文件可能不存在
count = files[-2].split(".")[0] if len(files) > 1 else 0
count = int(count) + len(tag_json["data"])
- # print(f"count={count}")
# 读取完毕,文件解锁
# ---------------------------------- tag=-1的情况:返回最后一条数据 ----------------------------------
@@ -184,7 +185,7 @@ async def get_tag_data(experiment_id: int, tag: str):
tag_data.extend(data)
tag_data.extend(tag_json["data"])
# 返回数据
- return ResponseBody(0, data={"sum": len(tag_data), "list": tag_data})
+ return SUCCESS_200(data={"sum": len(tag_data), "list": tag_data})
else:
# TODO 采样读取数据
raise NotImplementedError("采样读取数据")
@@ -200,7 +201,7 @@ async def get_experiment_status(experiment_id: int):
实验唯一id,路径传参
"""
status = __find_experiment(experiment_id)["status"]
- return ResponseBody(0, data={"status": status})
+ return SUCCESS_200(data={"status": status})
@router.get("/{experiment_id}/summary")
@@ -217,7 +218,7 @@ async def get_experiment_summary(experiment_id: int):
array
每个tag的最后一个数据
"""
- experiment_path: str = os.path.join(SWANLAB_LOGS_FOLDER, __find_experiment(experiment_id)["name"])
+ experiment_path: str = os.path.join(swc.root, __find_experiment(experiment_id)["name"], "logs")
tags = [f for f in os.listdir(experiment_path) if os.path.isdir(os.path.join(experiment_path, f))]
summaries = []
for tag in tags:
@@ -226,4 +227,4 @@ async def get_experiment_summary(experiment_id: int):
with get_a_lock(os.path.join(tag_path, logs[-1]), mode="r") as f:
data = ujson.load(f)
summaries.append([tag, data["data"][-1]["data"]])
- return ResponseBody(0, data={"summaries": summaries})
+ return SUCCESS_200(data={"summaries": summaries})
diff --git a/swanlab/server/api/project.py b/swanlab/server/api/project.py
index d22507a8..a09123c2 100644
--- a/swanlab/server/api/project.py
+++ b/swanlab/server/api/project.py
@@ -8,12 +8,8 @@
项目相关的api,前缀:/project
"""
from fastapi import APIRouter
-from ..utils import ResponseBody
-from ...database import PT
-
-# from ...database import
-import os
-import ujson
+from ..module.resp import SUCCESS_200, DATA_ERROR_500
+from ..module import PT
router = APIRouter()
@@ -24,5 +20,8 @@ async def _():
"""
获取项目信息,列出当前项目下的所有实验
"""
- pt = PT()
- return ResponseBody(0, data=pt.get())
+ try:
+ pt = PT()
+ return SUCCESS_200(data=pt.get())
+ except Exception:
+ return DATA_ERROR_500("project data error")
diff --git a/swanlab/server/api/test.py b/swanlab/server/api/test.py
deleted file mode 100644
index 480dbef1..00000000
--- a/swanlab/server/api/test.py
+++ /dev/null
@@ -1,23 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-r"""
-@DATE: 2023-11-30 20:47:18
-@File: swanlab\server\api\data.py
-@IDE: vscode
-@Description:
- 本文件用于处理数据相关的请求,包括获取数据,新建图表等
-"""
-import random
-from fastapi import APIRouter
-from ..utils import ResponseBody
-
-router = APIRouter()
-
-
-# 测试路由,每次请求返回一个0到30的随机数
-@router.get("/test")
-async def _():
- # 生成一个 0 到 30 之间的随机整数
- random_number = random.randint(0, 30)
- data = {"data": random_number}
- return ResponseBody(0)
diff --git a/swanlab/server/module/__init__.py b/swanlab/server/module/__init__.py
new file mode 100644
index 00000000..ed941881
--- /dev/null
+++ b/swanlab/server/module/__init__.py
@@ -0,0 +1 @@
+from .models.project import PT
diff --git a/swanlab/server/module/models/experiment.py b/swanlab/server/module/models/experiment.py
new file mode 100644
index 00000000..e69de29b
diff --git a/swanlab/server/module/models/project.py b/swanlab/server/module/models/project.py
new file mode 100644
index 00000000..d1835853
--- /dev/null
+++ b/swanlab/server/module/models/project.py
@@ -0,0 +1,14 @@
+from ....utils import lock_file
+from io import TextIOWrapper
+from ....env import swc
+import os
+import ujson
+
+
+class PT(object):
+ """后端层面上的项目管理类,适配后端的项目管理接口,提供项目管理的相关功能"""
+
+ @lock_file(file_path=swc.project, mode="r")
+ def get(self, file: TextIOWrapper):
+ """获取实验信息"""
+ return ujson.load(file)
diff --git a/swanlab/server/module/models/tag.py b/swanlab/server/module/models/tag.py
new file mode 100644
index 00000000..e69de29b
diff --git a/swanlab/server/module/resp.py b/swanlab/server/module/resp.py
new file mode 100644
index 00000000..c0c8e00f
--- /dev/null
+++ b/swanlab/server/module/resp.py
@@ -0,0 +1,104 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+r"""
+@DATE: 2023-12-13 19:46:18
+@File: swanlab/server/module/resp.py
+@IDE: vscode
+@Description:
+ 定义响应结构体
+ 说是结构体,实际上是各种函数,用于返回各种响应
+ 结构体名称结构为:[错误描述]_[HTTP状态码]
+"""
+from fastapi.responses import JSONResponse as _JSONResponse
+
+# ---------------------------------- 错误码 ----------------------------------
+
+
+_SUCCESS_200 = 0
+"""一切正常,成功,期望的HTTP状态码为200"""
+
+_PARAMS_ERROR_422 = 3422
+"""参数错误,通常对应着前端传输的参数无法通过校验,这在中间件中处理,期望的HTTP状态码为422"""
+
+_NOT_FOUND_404 = 3404
+"""资源不存在,通常对应着路径不存在,期望的HTTP状态码为404"""
+
+_DATA_ERROR_500 = 3500
+"""服务端存储的数据格式错误,这通常意味着指定资源无法解析为期望格式,期望的HTTP状态码为500"""
+
+_UNEXCEPTED_ERROR_500 = 3555
+"""未知错误,通常对应着未知的异常,期望的HTTP状态码为500"""
+
+# ---------------------------------- 定义响应结构体 ----------------------------------
+
+
+def _ResponseBody(code: int, message: str = None, data: dict = None):
+ """构造响应,返回一个字典,包含响应码,响应信息和响应数据
+
+ Parameters
+ ----------
+ code : int
+ 响应码,0表示成功,非0表示失败
+ message : str, optional
+ 错误信息,如果code为0,错误信息强制为success,如果code不为0,错误信息必须提供
+ data : dict, optional
+ 响应数据,如果传入,必须为字典类型
+ """
+ # 如果code为0,错误信息强制为success
+ message = "success" if code == 0 else message
+ # 如果code不为0,错误信息必须提供
+ assert code == 0 or message is not None and len(message) > 0
+ # 如果传入了响应数据,必须为字典类型
+ assert data is None or isinstance(data, dict)
+ # 构造响应
+ if data is None:
+ return {
+ "code": code,
+ "message": message,
+ }
+ else:
+ return {
+ "code": code,
+ "message": message,
+ "data": data,
+ }
+
+
+def SUCCESS_200(data: dict):
+ """成功响应"""
+ return _JSONResponse(
+ status_code=200,
+ content=_ResponseBody(_SUCCESS_200, data=data),
+ )
+
+
+def PARAMS_ERROR_422(message: str = "params error"):
+ """请求参数错误"""
+ return _JSONResponse(
+ status_code=422,
+ content=_ResponseBody(_PARAMS_ERROR_422, message=message),
+ )
+
+
+def NOT_FOUND_404(message: str = "NotFound"):
+ """资源不存在"""
+ return _JSONResponse(
+ status_code=404,
+ content=_ResponseBody(_NOT_FOUND_404, message=message),
+ )
+
+
+def DATA_ERROR_500(message: str = "data error"):
+ """服务端存储的数据格式错误"""
+ return _JSONResponse(
+ status_code=500,
+ content=_ResponseBody(_DATA_ERROR_500, message=message),
+ )
+
+
+def UNEXPECTED_ERROR_500(message: str = "unexpected error"):
+ """未知错误"""
+ return _JSONResponse(
+ status_code=500,
+ content=_ResponseBody(_UNEXCEPTED_ERROR_500, message=message),
+ )
diff --git a/swanlab/server/router.py b/swanlab/server/router.py
index bb65e159..03a31c27 100644
--- a/swanlab/server/router.py
+++ b/swanlab/server/router.py
@@ -8,10 +8,12 @@
综合服务 api
"""
-from fastapi import FastAPI
+from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
import time
+from .module.resp import UNEXPECTED_ERROR_500, PARAMS_ERROR_422
+from ..log import swanlog as swl
# 响应路径
from ..env import INDEX, ASSETS
@@ -25,12 +27,21 @@
static = StaticFiles(directory=ASSETS)
app.mount(static_path, static)
+# 将uvicorn的日志输出handler删除
+import logging
+
+# 删除 uvicorn logger
+uvicorn_error = logging.getLogger("uvicorn.error")
+uvicorn_error.disabled = True
+uvicorn_access = logging.getLogger("uvicorn.access")
+uvicorn_access.disabled = True
+
# ---------------------------------- 在此处注册中间件 ----------------------------------
@app.middleware("http")
-async def resp_api(request, call_next):
+async def resp_base(request, call_next):
"""基础中间件,调整响应结果,添加处理时间等信息"""
# 如果请求路径不以'/api'开头,说明并不是后端服务的请求,直接返回
if not request.url.path.startswith("/api"):
@@ -51,7 +62,7 @@ async def resp_api(request, call_next):
@app.middleware("http")
async def resp_static(request, call_next):
- """资源中间件,此时所有与api相关的内容不会传入此中间件"""
+ """资源中间件,此时所有与api相关的内容不会在此中间件中处理"""
if request.url.path.startswith(static_path):
# 如果是请求静态资源,直接返回
return await call_next(request)
@@ -64,16 +75,76 @@ async def resp_static(request, call_next):
return HTMLResponse(content=html_content, status_code=200)
+@app.middleware("http")
+async def catch_error(request: Request, call_next):
+ """异常中间件,捕获异常,重构异常信息"""
+ if not request.url.path.startswith("/api"):
+ # 如果不是请求api,直接返回
+ return await call_next(request)
+ try:
+ return await call_next(request)
+ except Exception as e:
+ return UNEXPECTED_ERROR_500(str(e))
+
+
+@app.middleware("http")
+async def log_print(request: Request, call_next):
+ """日志打印中间件"""
+ swl.debug("[" + request.method + "] from " + request.base_url._url)
+ resp = await call_next(request)
+ # 拿到状态码
+ status = str(resp.status_code)
+ if not request.url.path.startswith("/api"):
+ # 如果不是请求api,直接返回
+ swl.debug("[" + str(resp.status_code) + "] " + request.method + " assets: " + request.url.path)
+ else:
+ content = "[" + str(resp.status_code) + "] " + request.method + " api: " + request.url.path
+ if status.startswith("2"):
+ swl.info(content)
+ else:
+ swl.error(content)
+ return resp
+
+
+@app.middleware("http")
+async def resp_params(request: Request, call_next):
+ """参数中间件,处理api请求中的参数校验问题,重新结构化校验错误结果"""
+ if not request.url.path.startswith("/api"):
+ # 如果不是请求api,直接返回
+ return await call_next(request)
+ # print("请求api")
+ resp = await call_next(request)
+ # 拿到状态码
+ status = resp.status_code
+ if status == 422:
+ # 参数校验错误,重构错误信息
+ # 拿到响应体
+ import json
+
+ body = [chunk async for chunk in resp.body_iterator][0].decode()
+ body = json.loads(body)
+ detail = body["detail"][0]
+ msg = detail["msg"].split(" ")[1:]
+ loc = detail["loc"][1]
+ msg.insert(0, loc)
+ msg = " ".join(msg)
+ return PARAMS_ERROR_422(msg)
+ """
+ 参数校验错误并不会影响其他情况的响应结果
+ 此外由于参数校验错误在绝大多数情况应该是开发时的错误
+ 所以不会影响正式版本的性能
+ """
+ return resp
+
+
# ---------------------------------- 在此处注册相关路由 ----------------------------------
# 导入数据相关的路由
-from .api.test import router as test
from .api.project import router as project
from .api.experiment import router as experiment
# 使用配置列表,统一导入
prefix = "/api/v1"
-app.include_router(test, prefix=prefix)
app.include_router(project, prefix=prefix + "/project")
app.include_router(experiment, prefix=prefix + "/experiment")
diff --git a/swanlab/server/utils.py b/swanlab/server/utils.py
deleted file mode 100644
index 1753c80d..00000000
--- a/swanlab/server/utils.py
+++ /dev/null
@@ -1,73 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-r"""
-@DATE: 2023-12-03 01:25:31
-@File: swanlab\server\api\utils.py
-@IDE: vscode
-@Description:
- 工具文件,重复使用的逻辑的封装
-"""
-from typing import Callable
-from fastapi import Response
-import ujson
-
-
-def ResponseBody(code: int, message: str = None, data: dict = None):
- """构造响应,返回一个字典,包含响应码,响应信息和响应数据
-
- Parameters
- ----------
- code : int
- 响应码,0表示成功,非0表示失败
- message : str, optional
- 错误信息,如果code为0,错误信息强制为success,如果code不为0,错误信息必须提供
- data : dict, optional
- 响应数据,如果传入,必须为字典类型
- """
- # 如果code为0,错误信息强制为success
- message = "success" if code == 0 else message
- # 如果code不为0,错误信息必须提供
- assert code == 0 or message is not None and len(message) > 0
- # 如果传入了响应数据,必须为字典类型
- assert data is None or isinstance(data, dict)
- # 构造响应
- if data is None:
- return {
- "code": code,
- "message": message,
- }
- else:
- return {
- "code": code,
- "message": message,
- "data": data,
- }
-
-
-async def get_response_body(response: Response, callback: Callable[[dict], dict] = None) -> Response:
- """从响应对象中获取响应体,并转换为字典,进行回调处理
-
- Parameters
- ----------
- response : Response
- 响应对象
- callback : Callable[[dict], dict], optional
- 回调函数,由于本函数本身异步,建议传入的回调函数是同步的, by default None
-
- Returns
- -------
- Response
- 新的响应对象
- """
- response_body = b""
- async for chunk in response.body_iterator:
- response_body += chunk
- body = ujson.loads(response_body.decode("utf-8"))
- if callback is not None:
- body = callback(body)
- return Response(
- content=ujson.dumps(body, ensure_ascii=False),
- status_code=response.status_code,
- headers=dict(response.headers),
- media_type=response.media_type,
- )
diff --git a/swanlab/utils/__init__.py b/swanlab/utils/__init__.py
index dac5fe5b..b6bd7984 100644
--- a/swanlab/utils/__init__.py
+++ b/swanlab/utils/__init__.py
@@ -1,3 +1,3 @@
-from .color import generate_color
+from .color import generate_color, DEFAULT_COLOR
from .time import create_time
from .file import lock_file, get_a_lock
diff --git a/swanlab/utils/color.py b/swanlab/utils/color.py
index 4b3ef171..f1d64588 100644
--- a/swanlab/utils/color.py
+++ b/swanlab/utils/color.py
@@ -7,48 +7,10 @@
@Description:
颜色处理工具
"""
-import random
-def rgb_to_hex(rgb_color: tuple):
- """将RGB转为十六进制颜色字符串
-
- Returns
- -------
- str
- 颜色字符串,以#开头的十六进制字符串,如#FFFFFF
- 字符串字母大写
- """
- r, g, b = rgb_color
-
- # 将rgb转为颜色字符串
- hex_color = "#{:02X}{:02X}{:02X}".format(r, g, b)
-
- return hex_color
-
-
-def hex_to_rgb(hex_color: str):
- """将十六进制颜色字符串转为RGB
-
- Returns
- -------
- tuple
- 包含RGB的元组,每个元素都是0-255之间的整数,如(255, 255, 255)
- """
-
- # 去除可能包含的 '#' 符号
- hex_color = hex_color.lstrip("#")
-
- # 将十六进制颜色代码分成红、绿和蓝部分
- r = int(hex_color[0:2], 16)
- g = int(hex_color[2:4], 16)
- b = int(hex_color[4:6], 16)
-
- return (r, g, b)
-
-
-def generate_color() -> str:
- """生成十六进制颜色字符串
+def generate_color(number: int = 0) -> str:
+ """输入数字,在设定好顺序的颜色列表中返回十六进制颜色字符串
Returns
-------
@@ -58,43 +20,52 @@ def generate_color() -> str:
"""
# 生成 RGB 随机变化值
- r_random = random.randint(0, 10)
- g_random = random.randint(0, 10)
- b_random = random.randint(0, 10)
+ # r_random = random.randint(0, 10)
+ # g_random = random.randint(0, 10)
+ # b_random = random.randint(0, 10)
# 生成随机数, 用于在颜色列表中选择一个随机颜色
- random_number = random.randint(0, 15)
+ # random_number = random.randint(0, 15)
color_list = [
- "#528d59",
- "#9cbe5d",
- "#dfb142",
- "#d0703c",
- "#e3b292",
- "#c24d46",
- "#892d58",
- "#d47694",
- "#8cc5b7",
- "#40877c",
- "#6ebad3",
- "#587ad2",
- "#6d4ba4",
- "#b15fbb",
- "#905f4a",
- "#989fa3",
+ "#528d59", # 绿色
+ "#587ad2", # 蓝色
+ "#c24d46", # 红色
+ "#9cbe5d", # 青绿色
+ "#6ebad3", # 天蓝色
+ "#dfb142", # 橙色
+ "#6d4ba4", # 紫色
+ "#8cc5b7", # 淡青绿色
+ "#892d58", # 紫红色
+ "#40877c", # 深青绿色
+ "#d0703c", # 深橙色
+ "#d47694", # 粉红色
+ "#e3b292", # 淡橙色
+ "#b15fbb", # 浅紫红色
+ "#905f4a", # 棕色
+ "#989fa3", # 灰色
]
# 将随机选择的十六进制字符串转为RGB
- r, g, b = hex_to_rgb(color_list[random_number])
+ # r, g, b = hex_to_rgb(color_list[random_number])
+
+ # # 在RGB通道增加随机波动
+ # r = min(r + r_random, 255)
+ # g = min(g + g_random, 255)
+ # b = min(b + b_random, 255)
+
+ if number % 16 == 0:
+ number = 16
+ else:
+ number = number % 16
+
+ return color_list[number - 1]
- # 在RGB通道增加随机波动
- r = min(r + r_random, 255)
- g = min(g + g_random, 255)
- b = min(b + b_random, 255)
- # 将RGB转为十六进制字符串,然后返回
- return rgb_to_hex(rgb_color=(r, g, b))
+# 默认颜色,也就是前端单实验内容显示的颜色
+DEFAULT_COLOR = generate_color(1)
if __name__ == "__main__":
- print(generate_color())
+ print(generate_color(1))
+ print(DEFAULT_COLOR)
diff --git a/test/cil_test.py b/test/cil_test.py
new file mode 100644
index 00000000..695a4dfa
--- /dev/null
+++ b/test/cil_test.py
@@ -0,0 +1,20 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+r"""
+@DATE: 2023-12-17 15:26:00
+@File: test/cil_test.py
+@IDE: vscode
+@Description:
+ 测试cil模块
+"""
+import sys
+import os
+
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+
+from swanlab.cli import cli
+
+
+if __name__ == "__main__":
+ cli()
diff --git a/test/create_experiment.py b/test/create_experiment.py
index 2dfafe43..45b39df3 100644
--- a/test/create_experiment.py
+++ b/test/create_experiment.py
@@ -12,12 +12,11 @@
import time
# 迭代次数
-epochs = 2000
+epochs = 200
# 学习率
lr = 0.01
# 随机偏移量
offset = random.random() / 5
-
# 创建一个实验
sw.init(
description="this is a test experiment",
@@ -27,6 +26,8 @@
},
)
+print("start training")
+
# 模拟训练过程
for epoch in range(2, epochs):
acc = 1 - 2**-epoch - random.random() / epoch - offset
@@ -34,3 +35,5 @@
print(f"epoch={epoch}, accuracy={acc}, loss={loss}")
sw.log({"loss": loss, "accuracy": acc})
time.sleep(0.1)
+ if epoch % 10 == 0:
+ raise Exception("error")
diff --git a/test/start_server.py b/test/start_server.py
index 6e136d69..2f490a51 100644
--- a/test/start_server.py
+++ b/test/start_server.py
@@ -11,6 +11,10 @@
"""
from swanlab.server.router import app
import uvicorn
+from swanlab.server import swl
+
if __name__ == "__main__":
- uvicorn.run("start_server:app", host="127.0.0.1", port=6092, reload=True)
+ swl.info("start server")
+ uvicorn.run("start_server:app", host="0.0.0.0", port=6092, reload=True, log_level="critical")
+ # swl.info("hello")
diff --git a/test/test_catch_error.py b/test/test_catch_error.py
new file mode 100644
index 00000000..ea4229e4
--- /dev/null
+++ b/test/test_catch_error.py
@@ -0,0 +1,54 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+r"""
+@DATE: 2023-12-16 14:35:04
+@File: test/test_catch_error.py
+@IDE: vscode
+@Description:
+ 测试异常抛出
+"""
+# excepthook.py
+import sys, traceback
+from datetime import datetime
+import atexit
+
+fError = open("except_error.log", "a")
+
+
+def UserExceptHook(tp, val, tb):
+ traceList = traceback.format_tb(tb)
+ html = repr(tp) + "\n"
+ html += repr(val) + "\n"
+ for line in traceList:
+ html += line + "\n"
+ print(html, file=sys.stderr)
+ print(datetime.now(), file=fError)
+ print(html, file=fError)
+ fError.close()
+ import time
+
+ print("执行异常回调函数")
+ time.sleep(10)
+
+
+def close():
+ print("执行程序结束的回调.")
+
+
+def main():
+ sFirst = input("First number:")
+ sSecond = input("Second number:")
+ try:
+ fResult = int(sFirst) / int(sSecond)
+ except Exception:
+ print("发现异常,但我不处理,抛出去.")
+ raise
+ else:
+ print(sFirst, "/", sSecond, "=", fResult)
+
+
+atexit.register(close)
+
+sys.excepthook = UserExceptHook
+main()
+fError.close()
diff --git a/test/test_consoler.py b/test/test_consoler.py
new file mode 100644
index 00000000..c448c12f
--- /dev/null
+++ b/test/test_consoler.py
@@ -0,0 +1,18 @@
+import swanlab as sw
+
+# 迭代次数
+epochs = 200
+# 学习率
+lr = 0.01
+
+# 创建一个实验
+sw.init(
+ description="this is a test experiment",
+ config={
+ "learning_rate": lr,
+ "epochs": epochs,
+ },
+)
+
+print("test myconsoler")
+print("nihao swanlab")
diff --git a/test/test_database_create.py b/test/test_database_create.py
index 827863c7..d952cbec 100644
--- a/test/test_database_create.py
+++ b/test/test_database_create.py
@@ -12,7 +12,7 @@
import time
# 迭代次数
-epochs = 2000
+epochs = 200
# 学习率
lr = 0.01
# 随机偏移量
@@ -27,6 +27,8 @@
},
)
+print("test logs")
+
# 模拟训练过程
for epoch in range(2, epochs):
acc = 1 - 2**-epoch - random.random() / epoch - offset
diff --git a/test/test_get_console.py b/test/test_get_console.py
new file mode 100644
index 00000000..bcdff605
--- /dev/null
+++ b/test/test_get_console.py
@@ -0,0 +1,25 @@
+from io import StringIO
+import sys
+
+
+class ConsoleCapture:
+ def __enter__(self):
+ self.original_stdout = sys.stdout
+ sys.stdout = self._stdout = StringIO()
+ return self
+
+ def __exit__(self, *args):
+ sys.stdout = self.original_stdout
+
+ def get_captured_text(self):
+ return self._stdout.getvalue()
+
+
+# 用法示例
+with ConsoleCapture() as cc:
+ print("Hello, this will be captured.")
+ print("So will this.")
+
+captured_text = cc.get_captured_text()
+print("Captured text:")
+print(captured_text)
diff --git a/test/test_logging/__init__.py b/test/test_logging/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/test/test_logging/app.py b/test/test_logging/app.py
new file mode 100644
index 00000000..096f623d
--- /dev/null
+++ b/test/test_logging/app.py
@@ -0,0 +1,33 @@
+from test1 import test1
+from swanlab.log import swanlog as sl
+from swanlab.env import swc
+import swanlab as sw
+
+swc.init(swc.getcwd(), "train")
+
+# 迭代次数
+epochs = 200
+# 学习率
+lr = 0.01
+
+# 创建一个实验
+sw.init(
+ description="this is a test experiment",
+ config={
+ "learning_rate": lr,
+ "epochs": epochs,
+ },
+)
+sl.init("output.log", "debug")
+
+# sl.setLevel("error")
+# sl.setOutput()
+print(sl.isRunning)
+sl.debug("Watch out!")
+sl.info("I told you so")
+sl.warning("I told you so")
+sl.error("I told you so")
+sl.critical("I told you so")
+test1()
+
+# sl.setSuccess()
diff --git a/test/test_logging/test1.py b/test/test_logging/test1.py
new file mode 100644
index 00000000..2875ea80
--- /dev/null
+++ b/test/test_logging/test1.py
@@ -0,0 +1,13 @@
+from swanlab.log import swanlog as sl
+import logging
+
+test_string = "test1"
+
+
+def test1():
+ # sl.setLevel("error")
+ sl.debug(test_string)
+ sl.info(test_string)
+ sl.warning(test_string)
+ sl.error(test_string)
+ sl.critical(test_string)
diff --git a/test/test_logging/test2.py b/test/test_logging/test2.py
new file mode 100644
index 00000000..0eb28b5d
--- /dev/null
+++ b/test/test_logging/test2.py
@@ -0,0 +1,39 @@
+import logging
+import logging.config
+
+LOGGING_CONFIG = {
+ "version": 1,
+ "formatters": {
+ "default": {
+ "format": "%(asctime)s %(filename)s %(lineno)s %(levelname)s %(message)s",
+ },
+ },
+ "handlers": {
+ "console": {
+ "class": "logging.StreamHandler",
+ "level": "DEBUG", # 输出所有级别的日志到控制台
+ "formatter": "default",
+ },
+ "file": {
+ "class": "logging.FileHandler",
+ "level": "INFO", # 输出 INFO 及以上级别的日志到文件
+ "filename": "./log.txt",
+ "formatter": "default",
+ },
+ },
+ "loggers": {
+ "root": {
+ "handlers": ["console", "file"],
+ "level": "DEBUG", # 设置根记录器的级别为 DEBUG,以确保所有级别的日志都会传播到根记录器
+ },
+ },
+ "disable_existing_loggers": False, # 不禁用现有的记录器
+}
+
+logging.config.dictConfig(LOGGING_CONFIG)
+logger = logging.getLogger("root")
+logger.debug("debug message")
+logger.info("info message")
+logger.warning("warning message")
+logger.error("error message")
+logger.critical("critical message")
diff --git a/test/test_proxy.py b/test/test_proxy.py
deleted file mode 100644
index 1e5a9aa8..00000000
--- a/test/test_proxy.py
+++ /dev/null
@@ -1,16 +0,0 @@
-import random
-from fastapi.responses import JSONResponse
-from fastapi import FastAPI
-import uvicorn
-
-app = FastAPI()
-
-
-@app.get("/api/test")
-async def root():
- random_number = random.randint(1, 30)
- return JSONResponse({"data": random_number}, status_code=200)
-
-
-if __name__ == "__main__":
- uvicorn.run(app, host="0.0.0.0", port=10101)
diff --git a/vite.config.js b/vite.config.js
index 9c922c0a..49db574b 100644
--- a/vite.config.js
+++ b/vite.config.js
@@ -49,7 +49,16 @@ export default defineConfig(({ mode }) => {
// 标明编译后存放的位置
build: {
outDir: path.resolve(__dirname, 'swanlab/template'),
- emptyOutDir: true
+ emptyOutDir: true,
+ minify: 'terser',
+ // 根据模式应用不同的 terser 配置
+ terserOptions: {
+ // 生产环境移除console
+ compress: {
+ drop_console: mode === 'release',
+ drop_debugger: mode === 'release'
+ }
+ }
},
// 服务配置
server: {
diff --git a/vue/src/App.vue b/vue/src/App.vue
index 8796fc9b..5b15175f 100644
--- a/vue/src/App.vue
+++ b/vue/src/App.vue
@@ -3,24 +3,52 @@
{{ $t('error.wrap.reason-title') }}
++ {{ $t('error.wrap.message[0]') }} + {{ $t('error.wrap.message[1]') }} + 。 +
+{{ $t('error.wrap.error-code') + error_code }}
+{{ $t('error.wrap.time', { time }) }}
+{{ $t('error.404.message') }}
+ +