diff --git a/pyproject.toml b/pyproject.toml index 99a8efba..c3f6b4ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ classifiers = [ # 项目分类 swanlab = "swanlab.cli:cli" # 项目命令行工具(可添加多个) [project.urls] # 项目链接 -"Website" = "https://swanhub.co" +"Homepage" = "https://swanhub.co" "Source" = "https://github.com/SwanHubX/SwanLab" "Bug Reports" = "https://github.com/SwanHubX/SwanLab/issues" "Documentation" = "https://geektechstudio.feishu.cn/wiki/space/7310593325374013444?ccm_open_type=lark_wiki_spaceLink&open_tab_from=wiki_home" diff --git a/swanlab/data/run/main.py b/swanlab/data/run/main.py index 863fee0d..81aab9f7 100644 --- a/swanlab/data/run/main.py +++ b/swanlab/data/run/main.py @@ -10,7 +10,14 @@ from ..settings import SwanDataSettings, get_runtime_project from ...log import register, swanlog from ..system import get_system_info -from .utils import get_a_lock, check_name_format, get_package_version, create_time, generate_color +from .utils import ( + get_a_lock, + check_exp_name_format, + check_desc_format, + get_package_version, + create_time, + generate_color, +) from datetime import datetime import sys, os import random @@ -150,7 +157,7 @@ def __get_exp_name(self, experiment_name: str = None, project_path: str = None, """ max_len = 20 cut = experiment_name is not None and len(experiment_name) > max_len - experiment_name = "exp" if experiment_name is None else check_name_format(experiment_name) + experiment_name = "exp" if experiment_name is None else check_exp_name_format(experiment_name) # 为实验名称添加后缀,格式为yyyy-mm-dd_HH-MM-SS if suffix is not None and suffix.lower() != "timestamp": suffix = "timestamp" @@ -226,15 +233,14 @@ def __check_log_level(self, log_level: str) -> str: else: return "info" - def __check_description(self, description: str, max_len: int = 120) -> str: + def __check_description(self, description: str) -> str: """检查实验描述是否合法""" if description is None: return "" - if not isinstance(description, str): - raise TypeError(f"description: {description} is not a string") - if len(description) > max_len: - swanlog.warning(f"The description you provided is too long, it has been truncated to {max_len} characters.") - return description[:max_len] + desc = check_desc_format(description) + if desc != description: + swanlog.warning("The description has been truncated automatically.") + return desc def __check_config(self, config: dict) -> dict: """检查实验配置是否合法""" diff --git a/swanlab/data/run/utils.py b/swanlab/data/run/utils.py index c1df2848..0755c5c7 100644 --- a/swanlab/data/run/utils.py +++ b/swanlab/data/run/utils.py @@ -7,5 +7,5 @@ @Description: 运行时工具函数 """ -from ...utils.file import check_key_format, get_a_lock, check_name_format +from ...utils.file import check_key_format, get_a_lock, check_exp_name_format, check_desc_format from ...utils import get_package_version, create_time, generate_color diff --git a/swanlab/server/api/experiment.py b/swanlab/server/api/experiment.py index af5084de..35229fc5 100644 --- a/swanlab/server/api/experiment.py +++ b/swanlab/server/api/experiment.py @@ -10,6 +10,8 @@ from datetime import datetime import shutil from fastapi import APIRouter, Request + +from ...utils.file import check_exp_name_format, check_desc_format from ..module.resp import SUCCESS_200, NOT_FOUND_404, PARAMS_ERROR_422, Conflict_409 import os import ujson @@ -353,7 +355,7 @@ async def stop_experiment(experiment_id: int): return SUCCESS_200({"update_time": create_time()}) -@router.patch("/{experiment_id}/update") +@router.patch("/{experiment_id}") async def update_experiment_config(experiment_id: int, request: Request): """修改实验的元信息 @@ -372,11 +374,16 @@ async def update_experiment_config(experiment_id: int, request: Request): object """ body: dict = await request.json() + # 校验参数 + check_exp_name_format(body["name"], False) + body["description"] = check_desc_format(body["description"], False) + with open(PROJECT_PATH, "r") as f: project = ujson.load(f) experiment = __find_experiment(experiment_id) # 寻找实验在列表中对应的 index experiment_index = project["experiments"].index(experiment) + # 修改实验名称 if not experiment["name"] == body["name"]: # 修改实验名称 @@ -390,15 +397,17 @@ async def update_experiment_config(experiment_id: int, request: Request): old_path = os.path.join(SWANLOG_DIR, experiment["name"]) new_path = os.path.join(SWANLOG_DIR, body["name"]) os.rename(old_path, new_path) + # 修改实验描述 if not experiment["description"] == body["description"]: project["experiments"][experiment_index]["description"] = body["description"] with get_a_lock(PROJECT_PATH, "w") as f: ujson.dump(project, f, indent=4, ensure_ascii=False) + return SUCCESS_200({"experiment": project["experiments"][experiment_index]}) -@router.delete("/{experiment_id}/delete") +@router.delete("/{experiment_id}") async def delete_experiment(experiment_id: int): """删除实验 diff --git a/swanlab/server/api/project.py b/swanlab/server/api/project.py index 2dfc6282..60f65143 100644 --- a/swanlab/server/api/project.py +++ b/swanlab/server/api/project.py @@ -12,6 +12,7 @@ from fastapi import APIRouter, Request from ...utils import get_a_lock, create_time +from ...utils.file import check_desc_format from ..module.resp import SUCCESS_200, DATA_ERROR_500, Conflict_409 from ..module import PT from swanlab.env import get_swanlog_dir @@ -89,6 +90,9 @@ async def update(request: Request): 返回 project.json 的所有内容,目的是方便前端在修改信息后重置 pinia 的状态 """ body = await request.json() + # 检查格式 + body["description"] = check_desc_format(body["description"], False) + with open(PROJECT_PATH, "r") as f: project = ujson.load(f) # 检查名字 diff --git a/swanlab/utils/file.py b/swanlab/utils/file.py index 66d15964..5ebd2490 100644 --- a/swanlab/utils/file.py +++ b/swanlab/utils/file.py @@ -78,30 +78,121 @@ def check_key_format(key: str) -> str: return key -def check_name_format(name: str, max_len: int = 20) -> str: - """检查name字符串格式,必须是0-9a-zA-Z _-和/或者中文字符组成的字符串,并且开头必须是0-9a-zA-Z或者中文字符 - 最大长度为max_len个字符,一个中文字符算一个字符,如果超出长度,将被截断 +def check_exp_name_format(name: str, auto_cut: bool = True) -> str: + """检查实验名格式,必须是0-9a-zA-Z和连字符(_-),并且不能以连字符(_-)开头或结尾 + 最大长度为100个字符,一个中文字符算一个字符 Parameters ---------- name : str 待检查的字符串 + auto_cut : bool, optional + 如果超出长度,是否自动截断,默认为True + 如果为False,则超出长度会抛出异常 Returns ------- str 检查后的字符串 + + Raises + ------ + TypeError + name不是字符串,或者name为空字符串 + ValueError + name不符合规定格式 + IndexError + name超出长度 + """ + max_len = 100 + if not isinstance(name, str) or name == "": + raise TypeError(f"name: {name} is not a string") + # 定义正则表达式 + pattern = re.compile(r"^[0-9a-zA-Z][0-9a-zA-Z_-]*[0-9a-zA-Z]$") + # 检查 name 是否符合规定格式 + if not pattern.match(name): + raise ValueError( + f"name: {name} is not a valid string, which must be composed of 0-9a-zA-Z _- and / or Chinese characters, and the first character must be 0-9a-zA-Z or Chinese characters" + ) + # 检查长度 + if auto_cut and len(name) > max_len: + name = name[:max_len] + elif not auto_cut and len(name) > max_len: + raise IndexError(f"name: {name} is too long, which must be less than {max_len} characters") + return name + + +def check_desc_format(description: str, auto_cut: bool = True): + """检查实验描述 + 不能超过255个字符,可以包含任何字符 + + Parameters + ---------- + description : str + 需要检查和处理的描述信息 + auto_cut : bool + 如果超出长度,是否裁剪并抛弃多余部分 + + Returns + ------- + str + 检查后的字符串,同时会去除字符串头尾的空格 + + Raises + ------ + IndexError + name超出长度 + """ + max_length = 255 + description = description.strip() + + if len(description) > max_length: + if auto_cut: + return description[:max_length] + else: + raise IndexError(f"description too long that exceeds {max_length} characters.") + return description + + +def check_proj_name_format(name: str, auto_cut: bool = True) -> str: + """检查项目名格式,必须是0-9a-zA-Z和中文以及连字符(_-),并且不能以连字符(_-)开头或结尾 + 最大长度为100个字符,一个中文字符算一个字符 + + Parameters + ---------- + name : str + 待检查的字符串 + auto_cut : bool, optional + 如果超出长度,是否自动截断,默认为True + 如果为False,则超出长度会抛出异常 + + Returns + ------- + str + 检查后的字符串 + + Raises + ------ + TypeError + name不是字符串,或者name为空字符串 + ValueError + name不符合规定格式 + IndexError + name超出长度 """ - if not isinstance(name, str): + max_len = 100 + if not isinstance(name, str) or name == "": raise TypeError(f"name: {name} is not a string") # 定义正则表达式 - pattern = re.compile("^[0-9a-zA-Z\u4e00-\u9fa5][0-9a-zA-Z\u4e00-\u9fa5_/-]*$") + pattern = re.compile(r"^[0-9a-zA-Z\u4e00-\u9fa5]+[0-9a-zA-Z\u4e00-\u9fa5_-]*[0-9a-zA-Z\u4e00-\u9fa5]$") # 检查 name 是否符合规定格式 if not pattern.match(name): raise ValueError( f"name: {name} is not a valid string, which must be composed of 0-9a-zA-Z _- and / or Chinese characters, and the first character must be 0-9a-zA-Z or Chinese characters" ) # 检查长度 - if len(name) > max_len: + if auto_cut and len(name) > max_len: name = name[:max_len] + elif not auto_cut and len(name) > max_len: + raise IndexError(f"name: {name} is too long, which must be less than {max_len} characters") return name diff --git a/test/create_experiment.py b/test/create_experiment.py index a68f8827..6feecf3d 100644 --- a/test/create_experiment.py +++ b/test/create_experiment.py @@ -36,7 +36,7 @@ def get_chart_type(self) -> str: offset = random.random() / 5 # 创建一个实验 sw.init( - description="this is a test experiment", + description=" this is a test experiment", config={ "learning_rate": lr, "epochs": epochs, diff --git a/vue/src/components/config-editor/EditorWrap.vue b/vue/src/components/config-editor/EditorWrap.vue index 593fd015..fe12f2d2 100644 --- a/vue/src/components/config-editor/EditorWrap.vue +++ b/vue/src/components/config-editor/EditorWrap.vue @@ -10,11 +10,19 @@ class="input" v-model="info.name" :placeholder="`edit your ${type} name here`" + :maxlength="max_name_len" pattern="[a-zA-Z0-9_\-\u4e00-\u9fa5]*" required /> + +

+ {{ `${info.name.length} / ${max_name_len}` }} +

- {{ errors.name }} + {{ errors.name }}
@@ -24,9 +32,17 @@ rows="10" v-model="info.description" :placeholder="`edit your ${type} description here`" + :maxlength="max_description_len" > + +

+ {{ `${info.description ? info.description.length : 0} / ${max_description_len}` }} +

- {{ errors.description }} + {{ errors.description }}
@@ -55,6 +71,7 @@ import { ref } from 'vue' import SLLoading from '../SLLoading.vue' import { message } from '@swanlab-vue/components/message' import { t } from '@swanlab-vue/i18n' +import { computed } from 'vue' const projectStore = useProjectStore() const experimentStore = useExperimentStroe() @@ -79,6 +96,66 @@ const errors = ref({ description: '' }) +// ---------------------------------- 参数限制 ---------------------------------- + +const max_name_len = 100 +const max_description_len = 255 + +/** + * 校验项目名称: + * 项目名不能超过100个字符, 只能包含字母、中文、数字、连字符(_和-),不能以连字符开头或结尾 + */ +const checkProjectName = computed(() => { + const name = info.value.name + // 判断字符串长度是否超过100个字符 + if (name.length > max_name_len) return false + + // 判断是否以连字符开头或结尾 + if (name.startsWith('-') || name.endsWith('-') || name.startsWith('_') || name.endsWith('_')) { + return false + } + + // 判断是否包含除字母、中文、数字、连字符(_和-)之外的字符 + const pattern = /^[a-zA-Z0-9_\-\u4e00-\u9fa5]+$/ + return pattern.test(name) +}) + +/** + * 校验实验名称 + * 实验名不能超过100字符,且只能包含字母、数字、连字符(_和-),不能以连字符开头或结尾 + */ +const checkExperimentName = computed(() => { + const name = info.value.name + // 判断字符串长度是否超过100个字符 + if (name.length > max_name_len) { + return false + } + + // 判断是否以连字符或下划线开头或结尾 + if (/^[_-]|[_-]$/.test(name)) { + return false + } + + // 判断是否包含除字母、数字、连字符(_和-)之外的字符 + if (/[^a-zA-Z0-9_-]/.test(name)) { + return false + } + + return true +}) + +/** + * 描述校验 + * 描述不能超过255个字符,可以包含任何字符,前后空格自动去除 + */ +const checkDescription = computed(() => { + if (!info.value.description) return true + if (info.value.description?.length > max_description_len) { + return false + } + return true +}) + // ---------------------------------- 重新设置 ---------------------------------- // 是否在处理中 @@ -93,11 +170,26 @@ const handling = ref(false) * 4. 触发确认函数并显示处理中 */ const save = async () => { + // 清空错误信息 errors.value = { name: '', description: '' } + // 检查格式要求 + if (props.type === 'project' && !checkProjectName.value) { + message.warning('invalid project name') + return (errors.value.name = 'too long or invalid characters') + } else if (!checkExperimentName.value) { + message.warning('invalid experiment name') + return (errors.value.name = 'too long or invalid characters') + } + if (!checkDescription.value) { + message.warning('invalid description') + return (errors.value.description = 'invalid description') + } + + // 判断是否一点没变 if (info.value.name === projectStore.name && info.value.description === projectStore.description) { return (errors.value.name = 'nothing changed in project config') } else if (info.value.name === experimentStore.name && info.value.description === experimentStore.description) { @@ -130,4 +222,8 @@ const save = async () => { .input { @apply w-full p-2 text-sm outline-none border rounded-lg bg-higher; } + +.tip { + @apply absolute bottom-[-20px] left-0 text-xs; +} diff --git a/vue/src/i18n/en-US/common.json b/vue/src/i18n/en-US/common.json index b77af72d..cd1b828c 100644 --- a/vue/src/i18n/en-US/common.json +++ b/vue/src/i18n/en-US/common.json @@ -7,8 +7,8 @@ "button": "Edit", "not-allowed": "Not Allowed Since Is Running", "title": { - "project": "Modify Project Information", - "experiment": "Modify Experiment Information" + "project": "Project Information", + "experiment": "Experiment Information" }, "sub-title": { "project": { diff --git a/vue/src/i18n/en-US/experiment.json b/vue/src/i18n/en-US/experiment.json index f883f5e8..6ee8f64f 100644 --- a/vue/src/i18n/en-US/experiment.json +++ b/vue/src/i18n/en-US/experiment.json @@ -6,7 +6,7 @@ }, "status": { "-1": "Crashed", - "0": "In Progress", + "0": "Running", "1": "Completed" }, "index": { diff --git a/vue/src/i18n/zh-CN/experiment.json b/vue/src/i18n/zh-CN/experiment.json index db302f1d..c3135588 100644 --- a/vue/src/i18n/zh-CN/experiment.json +++ b/vue/src/i18n/zh-CN/experiment.json @@ -27,6 +27,13 @@ "hostname": "Hostname", "os": "操作系统", "python": "Python版本" + }, + "stop": { + "button": "Confirm", + "modal": { + "title": "Attention!", + "text": "Should the experiment be stopped as it is irreversible and may result in unpredictable consequences?" + } } }, "config": { diff --git a/vue/src/views/experiment/pages/index/components/ExperimentHeader.vue b/vue/src/views/experiment/pages/index/components/ExperimentHeader.vue index d75f7f06..06c2abcc 100644 --- a/vue/src/views/experiment/pages/index/components/ExperimentHeader.vue +++ b/vue/src/views/experiment/pages/index/components/ExperimentHeader.vue @@ -14,8 +14,8 @@
-
- {{ experimentStore.description }} +
+

{{ experimentStore.description }}

@@ -26,7 +26,7 @@
{{ $t(`experiment.index.header.experiment_infos.status`) }}
- +
@@ -64,7 +64,7 @@ import ConfigEditor from '@swanlab-vue/components/config-editor/ConfigEditor.vue import { useRouter } from 'vue-router' import { inject } from 'vue' import { message } from '@swanlab-vue/components/message' -// import StopButton from './StopButton.vue' +import StopButton from './StopButton.vue' const experimentStore = useExperimentStroe() const experiment = ref(experimentStore.experiment) @@ -183,7 +183,7 @@ const duration = computed(() => { const modifyExperiment = async (newV, hideModal) => { const id = experimentStore.id - const { data } = await http.patch(`/experiment/${id}/update`, newV) + const { data } = await http.patch(`/experiment/${id}`, newV) experimentStore.setExperiment(data.experiment) projectStore.setExperimentInfo(id, newV) hideModal() @@ -196,7 +196,7 @@ const modifyExperiment = async (newV, hideModal) => { */ const deleteExperiment = () => { http - .delete(`/experiment/${experimentStore.id}/delete`) + .delete(`/experiment/${experimentStore.id}`) .then(({ data }) => { projectStore.setProject(data.project) router.replace('/').then(() => { diff --git a/vue/src/views/experiment/pages/index/components/StopButton.vue b/vue/src/views/experiment/pages/index/components/StopButton.vue index 31c30663..2ed8f225 100644 --- a/vue/src/views/experiment/pages/index/components/StopButton.vue +++ b/vue/src/views/experiment/pages/index/components/StopButton.vue @@ -1,24 +1,16 @@