Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat name limit #158

Merged
merged 12 commits into from
Jan 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ classifiers = [ # 项目分类
swanlab = "swanlab.cli:cli" # 项目命令行工具(可添加多个)

[project.urls] # 项目链接
"Website" = "https://swanhub.co"
"Homepage" = "https://swanhub.co"
"Source" = "https://github.com/SwanHubX/SwanLab"
"Bug Reports" = "https://github.com/SwanHubX/SwanLab/issues"
"Documentation" = "https://geektechstudio.feishu.cn/wiki/space/7310593325374013444?ccm_open_type=lark_wiki_spaceLink&open_tab_from=wiki_home"
Expand Down
22 changes: 14 additions & 8 deletions swanlab/data/run/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@
from ..settings import SwanDataSettings, get_runtime_project
from ...log import register, swanlog
from ..system import get_system_info
from .utils import get_a_lock, check_name_format, get_package_version, create_time, generate_color
from .utils import (
get_a_lock,
check_exp_name_format,
check_desc_format,
get_package_version,
create_time,
generate_color,
)
from datetime import datetime
import sys, os
import random
Expand Down Expand Up @@ -150,7 +157,7 @@ def __get_exp_name(self, experiment_name: str = None, project_path: str = None,
"""
max_len = 20
cut = experiment_name is not None and len(experiment_name) > max_len
experiment_name = "exp" if experiment_name is None else check_name_format(experiment_name)
experiment_name = "exp" if experiment_name is None else check_exp_name_format(experiment_name)
# 为实验名称添加后缀,格式为yyyy-mm-dd_HH-MM-SS
if suffix is not None and suffix.lower() != "timestamp":
suffix = "timestamp"
Expand Down Expand Up @@ -226,15 +233,14 @@ def __check_log_level(self, log_level: str) -> str:
else:
return "info"

def __check_description(self, description: str, max_len: int = 120) -> str:
def __check_description(self, description: str) -> str:
"""检查实验描述是否合法"""
if description is None:
return ""
if not isinstance(description, str):
raise TypeError(f"description: {description} is not a string")
if len(description) > max_len:
swanlog.warning(f"The description you provided is too long, it has been truncated to {max_len} characters.")
return description[:max_len]
desc = check_desc_format(description)
if desc != description:
swanlog.warning("The description has been truncated automatically.")
return desc

def __check_config(self, config: dict) -> dict:
"""检查实验配置是否合法"""
Expand Down
2 changes: 1 addition & 1 deletion swanlab/data/run/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@
@Description:
运行时工具函数
"""
from ...utils.file import check_key_format, get_a_lock, check_name_format
from ...utils.file import check_key_format, get_a_lock, check_exp_name_format, check_desc_format
from ...utils import get_package_version, create_time, generate_color
13 changes: 11 additions & 2 deletions swanlab/server/api/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from datetime import datetime
import shutil
from fastapi import APIRouter, Request

from ...utils.file import check_exp_name_format, check_desc_format
from ..module.resp import SUCCESS_200, NOT_FOUND_404, PARAMS_ERROR_422, Conflict_409
import os
import ujson
Expand Down Expand Up @@ -353,7 +355,7 @@ async def stop_experiment(experiment_id: int):
return SUCCESS_200({"update_time": create_time()})


@router.patch("/{experiment_id}/update")
@router.patch("/{experiment_id}")
async def update_experiment_config(experiment_id: int, request: Request):
"""修改实验的元信息

Expand All @@ -372,11 +374,16 @@ async def update_experiment_config(experiment_id: int, request: Request):
object
"""
body: dict = await request.json()
# 校验参数
check_exp_name_format(body["name"], False)
SAKURA-CAT marked this conversation as resolved.
Show resolved Hide resolved
body["description"] = check_desc_format(body["description"], False)

with open(PROJECT_PATH, "r") as f:
project = ujson.load(f)
experiment = __find_experiment(experiment_id)
# 寻找实验在列表中对应的 index
experiment_index = project["experiments"].index(experiment)

# 修改实验名称
if not experiment["name"] == body["name"]:
# 修改实验名称
Expand All @@ -390,15 +397,17 @@ async def update_experiment_config(experiment_id: int, request: Request):
old_path = os.path.join(SWANLOG_DIR, experiment["name"])
new_path = os.path.join(SWANLOG_DIR, body["name"])
os.rename(old_path, new_path)

# 修改实验描述
if not experiment["description"] == body["description"]:
project["experiments"][experiment_index]["description"] = body["description"]
with get_a_lock(PROJECT_PATH, "w") as f:
ujson.dump(project, f, indent=4, ensure_ascii=False)

return SUCCESS_200({"experiment": project["experiments"][experiment_index]})


@router.delete("/{experiment_id}/delete")
@router.delete("/{experiment_id}")
async def delete_experiment(experiment_id: int):
"""删除实验

Expand Down
4 changes: 4 additions & 0 deletions swanlab/server/api/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from fastapi import APIRouter, Request

from ...utils import get_a_lock, create_time
from ...utils.file import check_desc_format
from ..module.resp import SUCCESS_200, DATA_ERROR_500, Conflict_409
from ..module import PT
from swanlab.env import get_swanlog_dir
Expand Down Expand Up @@ -89,6 +90,9 @@ async def update(request: Request):
返回 project.json 的所有内容,目的是方便前端在修改信息后重置 pinia 的状态
"""
body = await request.json()
# 检查格式
body["description"] = check_desc_format(body["description"], False)

with open(PROJECT_PATH, "r") as f:
project = ujson.load(f)
# 检查名字
Expand Down
103 changes: 97 additions & 6 deletions swanlab/utils/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,30 +78,121 @@ def check_key_format(key: str) -> str:
return key


def check_name_format(name: str, max_len: int = 20) -> str:
"""检查name字符串格式,必须是0-9a-zA-Z _-和/或者中文字符组成的字符串,并且开头必须是0-9a-zA-Z或者中文字符
最大长度为max_len个字符,一个中文字符算一个字符,如果超出长度,将被截断
def check_exp_name_format(name: str, auto_cut: bool = True) -> str:
"""检查实验名格式,必须是0-9a-zA-Z和连字符(_-),并且不能以连字符(_-)开头或结尾
最大长度为100个字符,一个中文字符算一个字符

Parameters
----------
name : str
待检查的字符串
auto_cut : bool, optional
如果超出长度,是否自动截断,默认为True
如果为False,则超出长度会抛出异常

Returns
-------
str
检查后的字符串

Raises
------
TypeError
name不是字符串,或者name为空字符串
ValueError
name不符合规定格式
IndexError
name超出长度
"""
max_len = 100
if not isinstance(name, str) or name == "":
raise TypeError(f"name: {name} is not a string")
# 定义正则表达式
pattern = re.compile(r"^[0-9a-zA-Z][0-9a-zA-Z_-]*[0-9a-zA-Z]$")
# 检查 name 是否符合规定格式
if not pattern.match(name):
raise ValueError(
f"name: {name} is not a valid string, which must be composed of 0-9a-zA-Z _- and / or Chinese characters, and the first character must be 0-9a-zA-Z or Chinese characters"
)
# 检查长度
if auto_cut and len(name) > max_len:
name = name[:max_len]
elif not auto_cut and len(name) > max_len:
raise IndexError(f"name: {name} is too long, which must be less than {max_len} characters")
return name


def check_desc_format(description: str, auto_cut: bool = True):
"""检查实验描述
不能超过255个字符,可以包含任何字符

Parameters
----------
description : str
需要检查和处理的描述信息
auto_cut : bool
如果超出长度,是否裁剪并抛弃多余部分

Returns
-------
str
检查后的字符串,同时会去除字符串头尾的空格

Raises
------
IndexError
name超出长度
"""
max_length = 255
description = description.strip()

if len(description) > max_length:
if auto_cut:
return description[:max_length]
else:
raise IndexError(f"description too long that exceeds {max_length} characters.")
return description


def check_proj_name_format(name: str, auto_cut: bool = True) -> str:
"""检查项目名格式,必须是0-9a-zA-Z和中文以及连字符(_-),并且不能以连字符(_-)开头或结尾
最大长度为100个字符,一个中文字符算一个字符

Parameters
----------
name : str
待检查的字符串
auto_cut : bool, optional
如果超出长度,是否自动截断,默认为True
如果为False,则超出长度会抛出异常

Returns
-------
str
检查后的字符串

Raises
------
TypeError
name不是字符串,或者name为空字符串
ValueError
name不符合规定格式
IndexError
name超出长度
"""
if not isinstance(name, str):
max_len = 100
if not isinstance(name, str) or name == "":
raise TypeError(f"name: {name} is not a string")
# 定义正则表达式
pattern = re.compile("^[0-9a-zA-Z\u4e00-\u9fa5][0-9a-zA-Z\u4e00-\u9fa5_/-]*$")
pattern = re.compile(r"^[0-9a-zA-Z\u4e00-\u9fa5]+[0-9a-zA-Z\u4e00-\u9fa5_-]*[0-9a-zA-Z\u4e00-\u9fa5]$")
# 检查 name 是否符合规定格式
if not pattern.match(name):
raise ValueError(
f"name: {name} is not a valid string, which must be composed of 0-9a-zA-Z _- and / or Chinese characters, and the first character must be 0-9a-zA-Z or Chinese characters"
)
# 检查长度
if len(name) > max_len:
if auto_cut and len(name) > max_len:
name = name[:max_len]
elif not auto_cut and len(name) > max_len:
raise IndexError(f"name: {name} is too long, which must be less than {max_len} characters")
return name
2 changes: 1 addition & 1 deletion test/create_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_chart_type(self) -> str:
offset = random.random() / 5
# 创建一个实验
sw.init(
description="this is a test experiment",
description=" this is a test experiment",
config={
"learning_rate": lr,
"epochs": epochs,
Expand Down
Loading