diff --git a/pyproject.toml b/pyproject.toml index 99a8efba..c3f6b4ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ classifiers = [ # 项目分类 swanlab = "swanlab.cli:cli" # 项目命令行工具(可添加多个) [project.urls] # 项目链接 -"Website" = "https://swanhub.co" +"Homepage" = "https://swanhub.co" "Source" = "https://github.com/SwanHubX/SwanLab" "Bug Reports" = "https://github.com/SwanHubX/SwanLab/issues" "Documentation" = "https://geektechstudio.feishu.cn/wiki/space/7310593325374013444?ccm_open_type=lark_wiki_spaceLink&open_tab_from=wiki_home" diff --git a/swanlab/data/run/main.py b/swanlab/data/run/main.py index 863fee0d..81aab9f7 100644 --- a/swanlab/data/run/main.py +++ b/swanlab/data/run/main.py @@ -10,7 +10,14 @@ from ..settings import SwanDataSettings, get_runtime_project from ...log import register, swanlog from ..system import get_system_info -from .utils import get_a_lock, check_name_format, get_package_version, create_time, generate_color +from .utils import ( + get_a_lock, + check_exp_name_format, + check_desc_format, + get_package_version, + create_time, + generate_color, +) from datetime import datetime import sys, os import random @@ -150,7 +157,7 @@ def __get_exp_name(self, experiment_name: str = None, project_path: str = None, """ max_len = 20 cut = experiment_name is not None and len(experiment_name) > max_len - experiment_name = "exp" if experiment_name is None else check_name_format(experiment_name) + experiment_name = "exp" if experiment_name is None else check_exp_name_format(experiment_name) # 为实验名称添加后缀,格式为yyyy-mm-dd_HH-MM-SS if suffix is not None and suffix.lower() != "timestamp": suffix = "timestamp" @@ -226,15 +233,14 @@ def __check_log_level(self, log_level: str) -> str: else: return "info" - def __check_description(self, description: str, max_len: int = 120) -> str: + def __check_description(self, description: str) -> str: """检查实验描述是否合法""" if description is None: return "" - if not isinstance(description, str): - raise TypeError(f"description: {description} is not a string") - if len(description) > max_len: - swanlog.warning(f"The description you provided is too long, it has been truncated to {max_len} characters.") - return description[:max_len] + desc = check_desc_format(description) + if desc != description: + swanlog.warning("The description has been truncated automatically.") + return desc def __check_config(self, config: dict) -> dict: """检查实验配置是否合法""" diff --git a/swanlab/data/run/utils.py b/swanlab/data/run/utils.py index c1df2848..0755c5c7 100644 --- a/swanlab/data/run/utils.py +++ b/swanlab/data/run/utils.py @@ -7,5 +7,5 @@ @Description: 运行时工具函数 """ -from ...utils.file import check_key_format, get_a_lock, check_name_format +from ...utils.file import check_key_format, get_a_lock, check_exp_name_format, check_desc_format from ...utils import get_package_version, create_time, generate_color diff --git a/swanlab/server/api/experiment.py b/swanlab/server/api/experiment.py index af5084de..35229fc5 100644 --- a/swanlab/server/api/experiment.py +++ b/swanlab/server/api/experiment.py @@ -10,6 +10,8 @@ from datetime import datetime import shutil from fastapi import APIRouter, Request + +from ...utils.file import check_exp_name_format, check_desc_format from ..module.resp import SUCCESS_200, NOT_FOUND_404, PARAMS_ERROR_422, Conflict_409 import os import ujson @@ -353,7 +355,7 @@ async def stop_experiment(experiment_id: int): return SUCCESS_200({"update_time": create_time()}) -@router.patch("/{experiment_id}/update") +@router.patch("/{experiment_id}") async def update_experiment_config(experiment_id: int, request: Request): """修改实验的元信息 @@ -372,11 +374,16 @@ async def update_experiment_config(experiment_id: int, request: Request): object """ body: dict = await request.json() + # 校验参数 + check_exp_name_format(body["name"], False) + body["description"] = check_desc_format(body["description"], False) + with open(PROJECT_PATH, "r") as f: project = ujson.load(f) experiment = __find_experiment(experiment_id) # 寻找实验在列表中对应的 index experiment_index = project["experiments"].index(experiment) + # 修改实验名称 if not experiment["name"] == body["name"]: # 修改实验名称 @@ -390,15 +397,17 @@ async def update_experiment_config(experiment_id: int, request: Request): old_path = os.path.join(SWANLOG_DIR, experiment["name"]) new_path = os.path.join(SWANLOG_DIR, body["name"]) os.rename(old_path, new_path) + # 修改实验描述 if not experiment["description"] == body["description"]: project["experiments"][experiment_index]["description"] = body["description"] with get_a_lock(PROJECT_PATH, "w") as f: ujson.dump(project, f, indent=4, ensure_ascii=False) + return SUCCESS_200({"experiment": project["experiments"][experiment_index]}) -@router.delete("/{experiment_id}/delete") +@router.delete("/{experiment_id}") async def delete_experiment(experiment_id: int): """删除实验 diff --git a/swanlab/server/api/project.py b/swanlab/server/api/project.py index 2dfc6282..60f65143 100644 --- a/swanlab/server/api/project.py +++ b/swanlab/server/api/project.py @@ -12,6 +12,7 @@ from fastapi import APIRouter, Request from ...utils import get_a_lock, create_time +from ...utils.file import check_desc_format from ..module.resp import SUCCESS_200, DATA_ERROR_500, Conflict_409 from ..module import PT from swanlab.env import get_swanlog_dir @@ -89,6 +90,9 @@ async def update(request: Request): 返回 project.json 的所有内容,目的是方便前端在修改信息后重置 pinia 的状态 """ body = await request.json() + # 检查格式 + body["description"] = check_desc_format(body["description"], False) + with open(PROJECT_PATH, "r") as f: project = ujson.load(f) # 检查名字 diff --git a/swanlab/utils/file.py b/swanlab/utils/file.py index 66d15964..5ebd2490 100644 --- a/swanlab/utils/file.py +++ b/swanlab/utils/file.py @@ -78,30 +78,121 @@ def check_key_format(key: str) -> str: return key -def check_name_format(name: str, max_len: int = 20) -> str: - """检查name字符串格式,必须是0-9a-zA-Z _-和/或者中文字符组成的字符串,并且开头必须是0-9a-zA-Z或者中文字符 - 最大长度为max_len个字符,一个中文字符算一个字符,如果超出长度,将被截断 +def check_exp_name_format(name: str, auto_cut: bool = True) -> str: + """检查实验名格式,必须是0-9a-zA-Z和连字符(_-),并且不能以连字符(_-)开头或结尾 + 最大长度为100个字符,一个中文字符算一个字符 Parameters ---------- name : str 待检查的字符串 + auto_cut : bool, optional + 如果超出长度,是否自动截断,默认为True + 如果为False,则超出长度会抛出异常 Returns ------- str 检查后的字符串 + + Raises + ------ + TypeError + name不是字符串,或者name为空字符串 + ValueError + name不符合规定格式 + IndexError + name超出长度 + """ + max_len = 100 + if not isinstance(name, str) or name == "": + raise TypeError(f"name: {name} is not a string") + # 定义正则表达式 + pattern = re.compile(r"^[0-9a-zA-Z][0-9a-zA-Z_-]*[0-9a-zA-Z]$") + # 检查 name 是否符合规定格式 + if not pattern.match(name): + raise ValueError( + f"name: {name} is not a valid string, which must be composed of 0-9a-zA-Z _- and / or Chinese characters, and the first character must be 0-9a-zA-Z or Chinese characters" + ) + # 检查长度 + if auto_cut and len(name) > max_len: + name = name[:max_len] + elif not auto_cut and len(name) > max_len: + raise IndexError(f"name: {name} is too long, which must be less than {max_len} characters") + return name + + +def check_desc_format(description: str, auto_cut: bool = True): + """检查实验描述 + 不能超过255个字符,可以包含任何字符 + + Parameters + ---------- + description : str + 需要检查和处理的描述信息 + auto_cut : bool + 如果超出长度,是否裁剪并抛弃多余部分 + + Returns + ------- + str + 检查后的字符串,同时会去除字符串头尾的空格 + + Raises + ------ + IndexError + name超出长度 + """ + max_length = 255 + description = description.strip() + + if len(description) > max_length: + if auto_cut: + return description[:max_length] + else: + raise IndexError(f"description too long that exceeds {max_length} characters.") + return description + + +def check_proj_name_format(name: str, auto_cut: bool = True) -> str: + """检查项目名格式,必须是0-9a-zA-Z和中文以及连字符(_-),并且不能以连字符(_-)开头或结尾 + 最大长度为100个字符,一个中文字符算一个字符 + + Parameters + ---------- + name : str + 待检查的字符串 + auto_cut : bool, optional + 如果超出长度,是否自动截断,默认为True + 如果为False,则超出长度会抛出异常 + + Returns + ------- + str + 检查后的字符串 + + Raises + ------ + TypeError + name不是字符串,或者name为空字符串 + ValueError + name不符合规定格式 + IndexError + name超出长度 """ - if not isinstance(name, str): + max_len = 100 + if not isinstance(name, str) or name == "": raise TypeError(f"name: {name} is not a string") # 定义正则表达式 - pattern = re.compile("^[0-9a-zA-Z\u4e00-\u9fa5][0-9a-zA-Z\u4e00-\u9fa5_/-]*$") + pattern = re.compile(r"^[0-9a-zA-Z\u4e00-\u9fa5]+[0-9a-zA-Z\u4e00-\u9fa5_-]*[0-9a-zA-Z\u4e00-\u9fa5]$") # 检查 name 是否符合规定格式 if not pattern.match(name): raise ValueError( f"name: {name} is not a valid string, which must be composed of 0-9a-zA-Z _- and / or Chinese characters, and the first character must be 0-9a-zA-Z or Chinese characters" ) # 检查长度 - if len(name) > max_len: + if auto_cut and len(name) > max_len: name = name[:max_len] + elif not auto_cut and len(name) > max_len: + raise IndexError(f"name: {name} is too long, which must be less than {max_len} characters") return name diff --git a/test/create_experiment.py b/test/create_experiment.py index a68f8827..6feecf3d 100644 --- a/test/create_experiment.py +++ b/test/create_experiment.py @@ -36,7 +36,7 @@ def get_chart_type(self) -> str: offset = random.random() / 5 # 创建一个实验 sw.init( - description="this is a test experiment", + description=" this is a test experiment", config={ "learning_rate": lr, "epochs": epochs, diff --git a/vue/src/components/config-editor/EditorWrap.vue b/vue/src/components/config-editor/EditorWrap.vue index 593fd015..fe12f2d2 100644 --- a/vue/src/components/config-editor/EditorWrap.vue +++ b/vue/src/components/config-editor/EditorWrap.vue @@ -10,11 +10,19 @@ class="input" v-model="info.name" :placeholder="`edit your ${type} name here`" + :maxlength="max_name_len" pattern="[a-zA-Z0-9_\-\u4e00-\u9fa5]*" required /> + +
+ {{ `${info.name.length} / ${max_name_len}` }} +
- {{ errors.name }} + {{ errors.name }}+ {{ `${info.description ? info.description.length : 0} / ${max_description_len}` }} +
- {{ errors.description }} + {{ errors.description }}{{ experimentStore.description }}
{{ $t('experiment.index.header.stop.modal.text') }}
-