From 300178b673a8ac5008bd565db75bd1db6715cba0 Mon Sep 17 00:00:00 2001 From: Zirui Cai <74649535+Feudalman@users.noreply.github.com> Date: Sun, 31 Dec 2023 00:08:21 +0800 Subject: [PATCH] feat: stop experiment on dashboard (#108) * fix #107 * fix #36 --- swanlab/server/api/experiment.py | 27 +++++++- vue/src/components/SLStatusLabel.vue | 2 +- vue/src/i18n/en-US/experiment.json | 7 +++ vue/src/store/modules/experiment.js | 7 ++- .../index/components/ExperimentHeader.vue | 9 ++- .../pages/index/components/StopButton.vue | 61 +++++++++++++++++++ 6 files changed, 107 insertions(+), 6 deletions(-) create mode 100644 vue/src/views/experiment/pages/index/components/StopButton.vue diff --git a/swanlab/server/api/experiment.py b/swanlab/server/api/experiment.py index 196b2a9e..1d257cbb 100644 --- a/swanlab/server/api/experiment.py +++ b/swanlab/server/api/experiment.py @@ -18,7 +18,7 @@ # from ...utils import create_time from urllib.parse import quote, unquote # 转码路径参数 from typing import List, Dict -from ...utils import get_a_lock +from ...utils import get_a_lock, create_time from ...log import swanlog as swl router = APIRouter() @@ -329,3 +329,28 @@ async def get_experimet_charts(experiment_id: int): chart_path: str = os.path.join(swc.root, __find_experiment(experiment_id)["name"], "chart.json") chart = __get_charts(chart_path) return SUCCESS_200(chart) + + +@router.get("/{experiment_id}/stop") +async def get_stop_charts(experiment_id: int): + """停止实验 + + Parameters + ---------- + experiment_id : int + 实验唯一ID + """ + config_path: str = os.path.join(swc.root, "project.json") + with open(config_path, mode="r", encoding="utf-8") as f: + config = ujson.load(f) + # 获取需要停止的实验在配置中的索引 + index = next((index for index, d in enumerate(config["experiments"]) if d["experiment_id"] == experiment_id), None) + # 修改对应实验的状态 + if not config["experiments"][index]["status"] == 0: + # 不在运行中的状态不予修改 + return Exception("Experiment status is not running") + config["experiments"][index]["status"] = -1 + config["experiments"][index]["update_time"] = create_time() + with get_a_lock(config_path, "w") as f: + ujson.dump(config, f, ensure_ascii=False, indent=4) + return SUCCESS_200({"update_time": create_time()}) diff --git a/vue/src/components/SLStatusLabel.vue b/vue/src/components/SLStatusLabel.vue index 2f1528b4..7edf2cd5 100644 --- a/vue/src/components/SLStatusLabel.vue +++ b/vue/src/components/SLStatusLabel.vue @@ -67,7 +67,7 @@ const handleClick = () => {