From 4e1d0d409dace820beda3b9c157b77274ec0b50a Mon Sep 17 00:00:00 2001 From: KAAANG <79990647+SAKURA-CAT@users.noreply.github.com> Date: Thu, 25 Jul 2024 17:38:06 +0800 Subject: [PATCH] alpha/launch (#652) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 更新环境变量配置 * alpha版launch * version: 0.3.15a0 --- .vscode/launch.json | 14 +- CONTRIBUTING.md | 45 ++-- requirements.txt | 1 + swanlab/__init__.py | 22 ++ swanlab/api/auth/login.py | 8 +- swanlab/api/http.py | 10 + swanlab/cli/commands/__init__.py | 1 + swanlab/cli/commands/task/__init__.py | 26 ++ swanlab/cli/commands/task/launch.py | 242 ++++++++++++++++++ swanlab/cli/commands/task/list.py | 199 ++++++++++++++ swanlab/cli/commands/task/utils.py | 23 ++ swanlab/cli/main.py | 3 + swanlab/data/callback_cloud.py | 17 +- swanlab/data/sdk.py | 36 ++- swanlab/env.py | 47 +++- swanlab/package.json | 2 +- swanlab/package.py | 50 +++- test/unit/api/auth/test_login.py | 22 +- test/unit/api/test_http.py | 6 +- test/unit/cli/test_cli_login.py | 9 +- test/unit/cli/test_cli_logout.py | 10 +- test/unit/data/{pytest_sdk.py => test_sdk.py} | 68 ++++- test/unit/test_env.py | 39 +++ test/unit/test_package.py | 27 +- tutils/__init__.py | 19 +- tutils/check.py | 50 ++-- tutils/config.py | 19 +- 27 files changed, 854 insertions(+), 161 deletions(-) create mode 100644 swanlab/cli/commands/task/__init__.py create mode 100644 swanlab/cli/commands/task/launch.py create mode 100644 swanlab/cli/commands/task/list.py create mode 100644 swanlab/cli/commands/task/utils.py rename test/unit/data/{pytest_sdk.py => test_sdk.py} (76%) create mode 100644 test/unit/test_env.py diff --git a/.vscode/launch.json b/.vscode/launch.json index 32753039a..02238d78c 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -37,7 +37,9 @@ "type": "debugpy", "request": "launch", "module": "pytest", - "args": ["${file}"], + "args": [ + "${file}" + ], "console": "integratedTerminal" }, { @@ -45,10 +47,12 @@ "type": "debugpy", "request": "launch", "module": "pytest", - "args": ["test/unit"], + "args": [ + "test/unit" + ], "console": "integratedTerminal", "env": { - "TEST_CLOUD_SKIP": "true" + "is_skip_cloud_test": "true" } }, { @@ -56,7 +60,9 @@ "type": "debugpy", "request": "launch", "module": "pytest", - "args": ["test/unit"], + "args": [ + "test/unit" + ], "console": "integratedTerminal" }, // 打包命令 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 960e368ff..7847956ac 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -6,24 +6,24 @@ - [标准开发流程](#标准开发流程) - [本地调试](#本地调试) - - [IDE 与插件](#IDE与插件) - - [配置 Python 环境](#配置python环境) - - [调试脚本](#调试脚本) + - [IDE 与插件](#IDE与插件) + - [配置 Python 环境](#配置python环境) + - [调试脚本](#调试脚本) - [本地测试](#本地测试) - - [python 脚本调试](#python-脚本调试) - - [单元测试](#单元测试) + - [python 脚本调试](#python-脚本调试) + - [单元测试](#单元测试) ## 标准开发流程 1. 浏览 GitHub 上的[Issues](https://github.com/SwanHubX/SwanLab/issues),查看你愿意添加的功能或修复的错误,以及它们是否已被 Pull Request。 - - 如果没有,请创建一个[新 Issue](https://github.com/SwanHubX/SwanLab/issues/new/choose)——这将帮助项目跟踪功能请求和错误报告,并确保不重复工作。 + - 如果没有,请创建一个[新 Issue](https://github.com/SwanHubX/SwanLab/issues/new/choose)——这将帮助项目跟踪功能请求和错误报告,并确保不重复工作。 2. 如果你是第一次为开源项目贡献代码,请转到 [本项目首页](https://github.com/SwanHubX/SwanLab) 并单击右上角的"Fork" 按钮。这将创建你用于开发的仓库的个人副本。 - - 将 Fork 的项目克隆到你的计算机,并添加指向`swanlab`项目的远程链接: + - 将 Fork 的项目克隆到你的计算机,并添加指向`swanlab`项目的远程链接: ```bash git clone https://github.com//swanlab.git @@ -33,20 +33,20 @@ 3. 开发你的贡献 - - 确保您的 Fork 与主存储库同步: + - 确保您的 Fork 与主存储库同步: ```bash git checkout main git pull upstream main ``` - - 创建一个`git`分支,您将在其中发展您的贡献。为分支使用合理的名称,例如: + - 创建一个`git`分支,您将在其中发展您的贡献。为分支使用合理的名称,例如: ```bash git checkout -b / ``` - - 当你取得进展时,在本地提交你的改动,例如: + - 当你取得进展时,在本地提交你的改动,例如: ```bash git add changed-file.py tests/test-changed-file.py @@ -55,19 +55,19 @@ 4. 发起贡献: - - [Github Pull Request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/about-pull-requests) - - 当您的贡献准备就绪后,将您的分支推送到 GitHub: + - [Github Pull Request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/about-pull-requests) + - 当您的贡献准备就绪后,将您的分支推送到 GitHub: ```bash git push origin / ``` - - 分支上传后, `GitHub`将打印一个 URL,用于将您的贡献作为拉取请求提交。在浏览器中打开该 URL,为您的拉取请求编写信息丰富的标题和详细描述,然后提交。 + - 分支上传后, `GitHub`将打印一个 URL,用于将您的贡献作为拉取请求提交。在浏览器中打开该 URL,为您的拉取请求编写信息丰富的标题和详细描述,然后提交。 - - 请将相关 Issue(现有 Issue 或您创建的 Issue)链接到您的 PR。请参阅 PR 页面的右栏。或者,在 PR - 描述中提及“修复问题链接” - GitHub 将自动进行链接。 + - 请将相关 Issue(现有 Issue 或您创建的 Issue)链接到您的 PR。请参阅 PR 页面的右栏。或者,在 PR + 描述中提及“修复问题链接” - GitHub 将自动进行链接。 - - 我们将审查您的贡献并提供反馈。要合并审阅者建议的更改,请将编辑提交到您的分支,然后再次推送到分支(无需重新创建拉取请求,它将自动跟踪对分支的修改),例如: + - 我们将审查您的贡献并提供反馈。要合并审阅者建议的更改,请将编辑提交到您的分支,然后再次推送到分支(无需重新创建拉取请求,它将自动跟踪对分支的修改),例如: ```bash git add tests/test-changed-file.py @@ -75,7 +75,7 @@ git push origin / ``` - - 一旦您的拉取请求被审阅者批准,它将被合并到存储库的主分支中。 + - 一旦您的拉取请求被审阅者批准,它将被合并到存储库的主分支中。 ## 本地调试 @@ -141,7 +141,8 @@ Ps: 如果你不想使用 VSCode 进行开发,可以前往`.vscode/launch.json ### python 脚本调试 -在完成你的改动后,可以将你用于测试的 python 脚本放到根目录或`test`文件夹下,然后通过[VSCode 脚本](#调试脚本)中的"运行当前文件"来运行你的 Python 测试脚本, 这样你的脚本运行将使用到已改动后的 swanlab。 +在完成你的改动后,可以将你用于测试的 python 脚本放到根目录或`test`文件夹下,然后通过[VSCode 脚本](#调试脚本)中的" +运行当前文件"来运行你的 Python 测试脚本, 这样你的脚本运行将使用到已改动后的 swanlab。 ### 单元测试 @@ -155,11 +156,7 @@ export PYTHONPATH=. && pytest test/unit 针对这种情况,请在本地根目录下创建`.env`文件,并填写如下环境变量配置: ```dotenv -TEST_CLOUD_SKIP=true +SWANLAB_RUNTIME=test-no-cloud ``` -这样就可以跳过云端测试,只进行本地的部分功能测试。 - -> 事实上`TEST_CLOUD_SKIP`环境变量可以是任意值,只要存在即可跳过云端测试。 - -如果想进行完整的测试,请联系项目维护者,我们会提供测试环境的配置。 +这样就可以跳过云端测试,只进行本地的部分功能测试。 如果想进行完整的测试,请联系项目维护者,我们会提供测试环境的配置。 diff --git a/requirements.txt b/requirements.txt index 03cec0f17..33934c268 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ click pyyaml psutil pynvml +rich diff --git a/swanlab/__init__.py b/swanlab/__init__.py index 099406cd9..3cdaea31a 100755 --- a/swanlab/__init__.py +++ b/swanlab/__init__.py @@ -15,5 +15,27 @@ from .data.run.main import config from .package import get_package_version +from .env import SwanLabEnv + +# 设置默认环境变量 +SwanLabEnv.set_default() +# 检查当前需要检查的环境变量 +SwanLabEnv.check() __version__ = get_package_version() + +__all__ = [ + "login", + "init", + "log", + "finish", + "Audio", + "Image", + "Text", + "Run", + "State", + "get_run", + "get_config", + "config", + "__version__", +] diff --git a/swanlab/api/auth/login.py b/swanlab/api/auth/login.py index 5c85bd7fb..42302daeb 100644 --- a/swanlab/api/auth/login.py +++ b/swanlab/api/auth/login.py @@ -76,11 +76,9 @@ def input_api_key( def code_login(api_key: str) -> LoginInfo: """ 代码内登录,此时会覆盖本地token文件 - - Parameters - ---------- - api_key : str - 用户api_key + :param api_key: 用户的api_key + :return: 登录信息 + :raises ValidationError: 登录失败 """ tip = "Waiting for the swanlab cloud response." login_info: LoginInfo = FONT.loading(tip, login_by_key, args=(api_key,), interval=0.5) diff --git a/swanlab/api/http.py b/swanlab/api/http.py index 428230f43..f8c3475f4 100644 --- a/swanlab/api/http.py +++ b/swanlab/api/http.py @@ -152,9 +152,19 @@ def get(self, url: str, params: dict = None) -> dict: get请求 """ url = self.base_url + url + self.__before_request() resp = self.__session.get(url, params=params) return decode_response(resp) + def patch(self, url: str, data: dict = None) -> Union[dict, str]: + """ + patch请求 + """ + url = self.base_url + url + self.__before_request() + resp = self.__session.patch(url, json=data) + return decode_response(resp) + def __get_cos(self): cos = self.get(f"/project/{self.groupname}/{self.projname}/runs/{self.exp_id}/sts") self.__cos = CosClient(cos) diff --git a/swanlab/cli/commands/__init__.py b/swanlab/cli/commands/__init__.py index ec3909d19..5e3433032 100644 --- a/swanlab/cli/commands/__init__.py +++ b/swanlab/cli/commands/__init__.py @@ -10,3 +10,4 @@ from .auth import login, logout from .dashboard import watch from .converter import convert +from .task import task diff --git a/swanlab/cli/commands/task/__init__.py b/swanlab/cli/commands/task/__init__.py new file mode 100644 index 000000000..c9bf1f77d --- /dev/null +++ b/swanlab/cli/commands/task/__init__.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +r""" +@DATE: 2024/7/17 17:16 +@File: __init__.py.py +@IDE: pycharm +@Description: + 启动! + beta版 +""" +from .launch import launch +from .list import list +import click + +__all__ = ["task"] + + +@click.group() +def task(): + pass + + +# noinspection PyTypeChecker +task.add_command(launch) +# noinspection PyTypeChecker +task.add_command(list) diff --git a/swanlab/cli/commands/task/launch.py b/swanlab/cli/commands/task/launch.py new file mode 100644 index 000000000..cd89c1025 --- /dev/null +++ b/swanlab/cli/commands/task/launch.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +r""" +@DATE: 2024/7/17 19:30 +@File: task.py +@IDE: pycharm +@Description: + 打包、上传、开启任务 +""" +import click +from .utils import login_init_sid +from swanlab.api import get_http +# noinspection PyPackageRequirements +from qcloud_cos import CosConfig, CosS3Client +from swanlab.log import swanlog +from swankit.log import FONT +import zipfile +import threading +from rich.progress import ( + BarColumn, + DownloadColumn, + Progress, + TextColumn, + TimeRemainingColumn, + TransferSpeedColumn, +) +from datetime import datetime +import time +import io +import os + + +@click.command() +@click.argument( + "path", + type=click.Path( + exists=True, + dir_okay=True, + file_okay=False, + resolve_path=True, + readable=True, + ), + default=".", + nargs=1, + required=True, +) +@click.option( + "--entry", + "-e", + default="main.py", + nargs=1, + type=click.Path( + exists=True, + dir_okay=False, + file_okay=True, + resolve_path=True, + readable=True, + ), + help="The entry file of the task, default by main.py", +) +@click.option( + "--python", + default="python3.10", + nargs=1, + type=click.Choice(["python3.8", "python3.9", "python3.10"]), + help="The python version of the task, default by python3.10", +) +@click.option( + "--name", + "-n", + default="Task_{}".format(datetime.now().strftime("%b%d_%H-%M-%S")), + nargs=1, + type=str, + help="The name of the task, default by Task_{current_time}", +) +def launch(path: str, entry: str, python: str, name: str): + if not entry.startswith(path): + raise ValueError(f"Error: Entry file '{entry}' must be in directory '{path}'") + entry = entry[len(path):] + # 获取访问凭证,生成http会话对象 + login_info = login_init_sid() + print(FONT.swanlab("Login successfully. Hi, " + FONT.bold(FONT.default(login_info.username))) + "!") + # 上传文件 + text = f"The target folder {FONT.yellow(path)} will be packaged and uploaded, " + text += f"and you have specified {FONT.yellow(entry)} as the task entry point. " + swanlog.info(text) + # click.confirm(FONT.swanlab("Do you wish to proceed?")) + # 压缩文件夹 + memory_file = zip_folder(path) + # 上传文件 + src = upload_memory_file(memory_file) + # 发布任务 + ctm = CreateTaskModel(login_info.username, src, login_info.api_key, python, name, entry) + ctm.create() + swanlog.info(f"Task launched successfully. You can use {FONT.yellow('swanlab task list')} to view the task.") + + +def zip_folder(dirpath: str) -> io.BytesIO: + """ + 压缩文件夹 + :param dirpath: 传入文件夹路径 + """ + memory_file = io.BytesIO() + z = zipfile.ZipFile(memory_file, "w", zipfile.ZIP_DEFLATED) + fs = [x for x in os.walk(dirpath)] + + progress = Progress( + TextColumn("{task.description}", justify="left"), + BarColumn(), + "[progress.percentage]{task.percentage:>3.1f}%", + "•", + TimeRemainingColumn(), + ) + with progress: + for i in progress.track(range(len(fs)), description=FONT.swanlab("Packing... ")): + root, dirs, files = fs[i] + for file in files: + # 构建文件的完整路径 + file_path = os.path.join(root, file) + # 构建在压缩文件中的路径 + arc_name = os.path.relpath(file_path.__str__(), start=dirpath) + # 将文件添加到压缩文件中 + z.write(file_path.__str__(), arc_name) + memory_file.seek(0) + return memory_file + + +class CosClientForTask: + def __init__(self, sts): + region = sts["region"] + self.bucket = sts["bucket"] + token = sts["credentials"]["sessionToken"] + secret_id = sts["credentials"]["tmpSecretId"] + secret_key = sts["credentials"]["tmpSecretKey"] + config = CosConfig(Region=region, SecretId=secret_id, SecretKey=secret_key, Token=token, Scheme="https") + self.client = CosS3Client(config) + self.key = sts["prefix"] + "/tasks/" + f"{int(time.time() * 1000)}.zip" + + def upload(self, buffer: io.BytesIO): + return self.client.upload_file_from_buffer( + Bucket=self.bucket, + Key=self.key, + Body=buffer, + MAXThread=5, + MaxBufferSize=5, + PartSize=1 + ) + + +class TaskBytesIO(io.BytesIO): + + def __init__(self, read_callback, *args, **kwargs): + super().__init__(*args, **kwargs) + self.read_callback = read_callback + + def read(self, *args): + self.read_callback(*args) + return super().read(*args) + + +class TaskProgressBar: + def __init__(self, total_size: int): + """ + :param total_size: 总大小(bytes) + """ + self.total_size = total_size + self.current = 0 + self.progress = Progress( + TextColumn("{task.description}", justify="left"), + BarColumn(), + "[progress.percentage]{task.percentage:>3.1f}%", + "•", + DownloadColumn(), + "•", + TransferSpeedColumn(), + "•", + TimeRemainingColumn(), + ) + + def update(self, *args): + self.current += args[0] + + def start(self): + with self.progress as progress: + for i in progress.track(range(self.total_size), description=FONT.swanlab("Uploading...")): + if self.current > i: + continue + time.sleep(0.5) + while True: + if self.current > i: + break + + +def upload_memory_file(memory_file: io.BytesIO) -> str: + """ + 上传内存文件 + :returns 上传成功后的文件路径 + """ + sts = get_http().get("/user/codes/sts") + cos = CosClientForTask(sts) + val = memory_file.getvalue() + progress = TaskProgressBar(len(val)) + buffer = TaskBytesIO(progress.update, val) + t = threading.Thread(target=progress.start) + t.start() + cos.upload(buffer) + t.join() + return cos.key + + +class CreateTaskModel: + def __init__(self, username, src, key, python, name, index): + """ + :param username: 用户username + :param key: 用户的api_key + :param src: 任务zip文件路径 + :param python: 任务的python版本 + :param name: 任务名称 + :param index: 任务入口文件 + """ + self.username = username + self.src = src + self.key = key + self.python = python + self.name = name + self.index = index + + def __dict__(self): + return { + "username": self.username, + "src": self.src, + "index": self.index, + "python": self.python, + "conf": {"key": self.key}, + "name": self.name + } + + def create(self): + """ + 创建任务 + """ + get_http().post("/task", self.__dict__()) diff --git a/swanlab/cli/commands/task/list.py b/swanlab/cli/commands/task/list.py new file mode 100644 index 000000000..f80980a8a --- /dev/null +++ b/swanlab/cli/commands/task/list.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +r""" +@DATE: 2024/7/19 14:09 +@File: status.py +@IDE: pycharm +@Description: + 列出任务状态 +""" +import time + +import click +from typing import List +from .utils import login_init_sid +from rich.layout import Layout +from datetime import datetime +from rich.panel import Panel +from rich.table import Table +from rich.live import Live +from swanlab.api import get_http +from swanlab.package import get_experiment_url + + +@click.command() +@click.option( + "--max-num", + "-m", + default=10, + nargs=1, + type=click.IntRange(1, 100), + help="The maximum number of tasks to display, default by 10, maximum by 100", +) +def list(max_num: int): # noqa + # 获取访问凭证,生成http会话对象 + login_info = login_init_sid() + # 获取任务列表 + ltm = ListTasksModel(num=max_num, username=login_info.username) + layout = ListTaskLayout(ltm) + layout.show() + + +class ListTasksModel: + class TaskListModel: + """ + 获取到的任务列表模型 + """ + + def __init__(self, username: str, task: dict, ): + self.username = username + self.name = task["name"] + """ + 任务名称 + """ + self.python = task["python"] + """ + 任务的python版本 + """ + self.project_name = task.get("pName", None) + """ + 项目名称 + """ + self.experiment_id = task.get("eId", None) + """ + 实验ID + """ + self.created_at = task["createdAt"] + self.started_at = task.get("startedAt", None) + self.finished_at = task.get("finishedAt", None) + self.status = task["status"] + self.msg = task.get("msg", None) + + @property + def url(self): + if self.project_name is None or self.experiment_id is None: + return None + return get_experiment_url(self.username, self.project_name, self.experiment_id) + + def __init__(self, num: int, username: str): + """ + :param num: 最大显示的任务数 + """ + self.num = num + self.username = username + self.http = get_http() + + def __dict__(self): + return {"num": self.num} + + def list(self) -> List[TaskListModel]: + tasks = self.http.get("/task", self.__dict__()) + return [self.TaskListModel(self.username, task) for task in tasks] + + def table(self): + st = Table( + expand=True, + show_header=True, + header_style="bold", + title="[magenta][b]Now Tasks![/b]", + highlight=True, + border_style="magenta", + ) + st.add_column("Task Name", justify="right") + st.add_column("Status", justify="center") + st.add_column("URL", justify="center") + st.add_column("Python Version", justify="center"), + st.add_column("Created Time", justify="center") + st.add_column("Started Time", justify="center") + st.add_column("Finished Time", justify="center") + for tlm in self.list(): + st.add_row( + tlm.name, + tlm.status, + tlm.url, + tlm.python, + tlm.created_at, + tlm.started_at, + tlm.finished_at, + ) + return st + + +class ListTaskHeader: + """ + Display header with clock. + """ + + @staticmethod + def __rich__() -> Panel: + grid = Table.grid(expand=True) + grid.add_column(justify="center", ratio=1) + grid.add_column(justify="right") + grid.add_row( + "[b]SwanLab[/b] task dashboard", + datetime.now().ctime().replace(":", "[blink]:[/]"), + ) + return Panel(grid, style="red on black") + + +class ListTaskLayout: + """ + 任务列表展示 + """ + + def __init__(self, ltm: ListTasksModel): + self.event = [] + self.add_event(f"👏Welcome, [b]{ltm.username}[/b].") + self.add_event("⌛️Task board is loading...") + self.layout = Layout() + self.layout.split( + Layout(name="header", size=3), + Layout(name="main") + ) + self.layout["main"].split_row( + Layout(name="task_table", ratio=5), + Layout(name="term_output", ratio=2, ) + ) + self.layout["header"].update(ListTaskHeader()) + self.layout["task_table"].update(Panel(ltm.table(), border_style="magenta")) + self.ltm = ltm + self.add_event("🍺Task board is loaded.") + self.redraw_term_output() + + @property + def term_output(self): + to = Table( + expand=True, + show_header=False, + header_style="bold", + title="[blue][b]Log Messages[/b]", + highlight=True, + border_style="blue", + ) + to.add_column("Log Output") + return to + + def redraw_term_output(self, ): + term_output = self.term_output + for row in self.event: + term_output.add_row(row) + self.layout["term_output"].update(Panel(term_output, border_style="blue")) + + def add_event(self, info: str, max_length=15): + # 事件格式:yyyy-mm-dd hh:mm:ss - info + self.event.append(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} - {info}") + while len(self.event) > max_length: + self.event.pop(0) + + def show(self): + with Live(self.layout, refresh_per_second=10, screen=True) as live: + now = time.time() + while True: + time.sleep(1) + self.layout["header"].update(ListTaskHeader()) + if time.time() - now > 5: + now = time.time() + self.add_event("🔍Searching for new tasks...") + self.layout["task_table"].update(Panel(self.ltm.table(), border_style="magenta")) + self.redraw_term_output() + live.refresh() diff --git a/swanlab/cli/commands/task/utils.py b/swanlab/cli/commands/task/utils.py new file mode 100644 index 000000000..472d5e7e6 --- /dev/null +++ b/swanlab/cli/commands/task/utils.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +r""" +@DATE: 2024/7/19 14:30 +@File: utils.py +@IDE: pycharm +@Description: + 任务相关工具函数 +""" +from swanlab.package import get_key +from swanlab.api import terminal_login, create_http, LoginInfo +from swanlab.error import KeyFileError + + +def login_init_sid() -> LoginInfo: + key = None + try: + key = get_key() + except KeyFileError: + pass + login_info = terminal_login(key) + create_http(login_info) + return login_info diff --git a/swanlab/cli/main.py b/swanlab/cli/main.py index 3806a4cca..73c9a3e8a 100644 --- a/swanlab/cli/main.py +++ b/swanlab/cli/main.py @@ -32,5 +32,8 @@ def cli(): # noinspection PyTypeChecker cli.add_command(C.convert) # 转换命令,用于转换其他实验跟踪工具 +# noinspection PyTypeChecker +cli.add_command(C.task) # 任务式作业 + if __name__ == "__main__": cli() diff --git a/swanlab/data/callback_cloud.py b/swanlab/data/callback_cloud.py index 5b8a4962a..97598353a 100644 --- a/swanlab/data/callback_cloud.py +++ b/swanlab/data/callback_cloud.py @@ -9,14 +9,13 @@ """ from swankit.callback import RuntimeInfo, MetricInfo, ColumnInfo from swankit.core import SwanLabSharedSettings - from swanlab.data.cloud import UploadType from swanlab.api.upload.model import ColumnModel, ScalarModel, MediaModel, FileModel from swanlab.api import LoginInfo, create_http, terminal_login from swanlab.api.upload import upload_logs from swanlab.log import swanlog from swanlab.api import get_http -from swanlab.env import in_jupyter +from swanlab.env import in_jupyter, SwanLabEnv from swanlab.package import get_host_web, get_key from swanlab.error import KeyFileError from .callback_local import LocalRunCallback, get_run, SwanLabRunState @@ -203,8 +202,9 @@ def before_run(self, settings: SwanLabSharedSettings): def on_run(self): swanlog.install(self.settings.console_dir) + http = get_http() # 注册实验信息 - get_http().mount_exp( + http.mount_exp( exp_name=self.settings.exp_name, colors=self.settings.exp_colors, description=self.settings.description, @@ -228,6 +228,17 @@ def _write_call_call(message): if in_jupyter(): show_button_html(experiment_url) + # task环境下,同步实验信息回调 + if SwanLabEnv.RUNTIME.value == "task": + cuid = os.environ["SWANLAB_TASK_ID"] + info = { + "cuid": cuid, + "pId": http.proj_id, + "eId": http.exp_id, + "pName": http.projname + } + http.patch("/task/experiment", info) + def on_runtime_info_update(self, r: RuntimeInfo): # 执行local逻辑,保存文件到本地 super(CloudRunCallback, self).on_runtime_info_update(r) diff --git a/swanlab/data/sdk.py b/swanlab/data/sdk.py index 4df26776e..9c07433c8 100644 --- a/swanlab/data/sdk.py +++ b/swanlab/data/sdk.py @@ -56,12 +56,13 @@ def login(api_key: str = None): """ Login to SwanLab Cloud. If you already have logged in, you can use this function to relogin. Every time you call this function, the previous login information will be overwritten. + [Note that] this function should be called before `init`. - Parameters - ---------- - api_key : str + :param api_key: str, optional authentication key, if not provided, the key will be read from the key file. + + :return: LoginInfo """ if SwanLabRun.is_started(): raise RuntimeError("You must call swanlab.login() before using init()") @@ -69,16 +70,16 @@ def login(api_key: str = None): def init( - project: str = None, - workspace: str = None, - experiment_name: str = None, - description: str = None, - config: Union[dict, str] = None, - logdir: str = None, - suffix: Union[str, None, bool] = "default", - mode: Literal["disabled", "cloud", "local"] = None, - load: str = None, - **kwargs, + project: str = None, + workspace: str = None, + experiment_name: str = None, + description: str = None, + config: Union[dict, str] = None, + logdir: str = None, + suffix: Union[str, None, bool] = "default", + mode: Literal["disabled", "cloud", "local"] = None, + load: str = None, + **kwargs, ) -> SwanLabRun: """ Start a new run to track and log. Once you have called this function, you can use 'swanlab.log' to log data to @@ -144,13 +145,6 @@ def init( swanlog.warning("You have already initialized a run, the init function will be ignored") return get_run() # ---------------------------------- 一些变量、格式检查 ---------------------------------- - # TODO 下个版本删除 - if "cloud" in kwargs: - swanlog.warning( - "The `cloud` parameter in swanlab.init is deprecated and will be removed in the future" - "please use `mode='cloud'` instead." - ) - mode = "cloud" if kwargs["cloud"] else mode if load: load_data = check_load_json_yaml(load, load) experiment_name = _load_data(load_data, "experiment_name", experiment_name) @@ -248,7 +242,7 @@ def _init_mode(mode: str = None): :raise ValueError: mode参数不合法 """ allowed = [m.value for m in SwanLabMode] - mode_key = SwanLabEnv.SWANLAB_MODE.value + mode_key = SwanLabEnv.MODE.value mode_value = os.environ.get(mode_key) if mode_value is not None and mode is not None: swanlog.warning(f"The environment variable {mode_key} will be overwritten by the parameter mode") diff --git a/swanlab/env.py b/swanlab/env.py index a134c8ad3..5f2b2b46f 100644 --- a/swanlab/env.py +++ b/swanlab/env.py @@ -11,6 +11,7 @@ import swankit.env as E from swankit.env import SwanLabSharedEnv import enum +import os # ---------------------------------- 环境变量枚举类 ---------------------------------- @@ -28,10 +29,10 @@ class SwanLabEnv(enum.Enum): """ swanlab解析日志文件保存的路径,默认为当前运行目录的swanlog文件夹 """ - SWANLAB_MODE = SwanLabSharedEnv.SWANLAB_MODE.value + MODE = SwanLabSharedEnv.SWANLAB_MODE.value """ swanlab的解析模式,涉及操作员注册的回调,目前有三种:local、cloud、disabled,默认为cloud - 大小写不敏感 + 大小写敏感 """ SWANBOARD_PROT = "SWANLAB_BOARD_PORT" """ @@ -41,18 +42,52 @@ class SwanLabEnv(enum.Enum): """ cli swanboard 服务地址 """ - SWANLAB_WEB_HOST = "SWANLAB_WEB_HOST" + WEB_HOST = "SWANLAB_WEB_HOST" """ swanlab云端环境的web地址 """ - SWANLAB_API_HOST = "SWANLAB_API_HOST" + API_HOST = "SWANLAB_API_HOST" """ swanlab云端环境的api地址 """ - SWANLAB_VERSION = "SWANLAB_VERSION" + RUNTIME = "SWANLAB_RUNTIME" """ - swanlab的版本号,主要用于开发者调试 + swanlab的运行时环境,"user" "develop" "test" "test-no-cloud" "task" """ + API_KEY = "SWANLAB_API_KEY" + """ + 云端api key,登录时会首先查找此环境变量,如果不存在,判断用户是否已登录,未登录则进入登录流程 + + * 如果login接口传入字符串,此环境变量无效,此时相当于绕过 get_key 接口 + * 如果用户已登录,此环境变量的优先级高于本地存储登录信息 + """ + + @classmethod + def set_default(cls): + """ + 设置默认的环境变量值 + """ + envs = { + cls.WEB_HOST.value: "https://swanlab.cn", + cls.API_HOST.value: "https://api.swanlab.cn/api", + cls.RUNTIME.value: "user", + } + for k, v in envs.items(): + os.environ.setdefault(k, v) + + @classmethod + def check(cls): + """ + 检查环境变量的值是否为预期值中的一个 + :raises ValueError: 如果环境变量的值不在预期值中 + """ + envs = { + cls.MODE.value: ["local", "cloud", "disabled"], + cls.RUNTIME.value: ["user", "develop", "test", "test-no-cloud", "task"], + } + for k, vs in envs.items(): + if k in os.environ and os.environ[k] not in vs: + raise ValueError(f"Unknown value for {k}: {os.environ[k]}") @classmethod def list(cls) -> List[str]: diff --git a/swanlab/package.json b/swanlab/package.json index f7fac0536..65cb45889 100644 --- a/swanlab/package.json +++ b/swanlab/package.json @@ -1,6 +1,6 @@ { "name": "swanlab", - "version": "0.3.14", + "version": "0.3.15-alpha", "description": "", "python": "true" } diff --git a/swanlab/package.py b/swanlab/package.py index f015e2517..4f1e2749a 100644 --- a/swanlab/package.py +++ b/swanlab/package.py @@ -24,8 +24,6 @@ def get_package_version() -> str: """获取swanlab的版本号 :return: swanlab的版本号 """ - if SwanLabEnv.SWANLAB_VERSION.value in os.environ: - return os.environ[SwanLabEnv.SWANLAB_VERSION.value] # 读取package.json文件 with open(package_path, "r") as f: return json.load(f)["version"] @@ -57,14 +55,14 @@ def get_host_web() -> str: """获取swanlab网站网址 :return: swanlab网站的网址 """ - return os.getenv(SwanLabEnv.SWANLAB_WEB_HOST.value, "https://swanlab.cn") + return os.getenv(SwanLabEnv.WEB_HOST.value, "https://swanlab.cn") def get_host_api() -> str: """获取swanlab网站api网址 :return: swanlab网站的api网址 """ - return os.getenv(SwanLabEnv.SWANLAB_API_HOST.value, "https://swanlab.cn/api") + return os.getenv(SwanLabEnv.API_HOST.value, "https://api.swanlab.cn/api") def get_user_setting_path() -> str: @@ -101,6 +99,9 @@ def get_key(): :raise KeyFileError: 文件不存在或者host不存在 :return: token """ + env_key = os.getenv(SwanLabEnv.API_KEY.value) + if env_key is not None: + return env_key path = os.path.join(get_save_dir(), ".netrc") host = get_host_api() if not os.path.exists(path): @@ -132,13 +133,44 @@ def save_key(username: str, password: str, host: str = None): f.write(nrc.__repr__()) +class LoginCheckContext: + """ + 进入上下文时,会删除环境变量中的api key,退出上下文时会恢复原来的值 + """ + + def __init__(self): + self.__tmp_key = None + """ + 临时保存的key + """ + self.is_login = False + """ + 标注是否已经登录 + """ + + def __enter__(self): + self.__tmp_key = os.environ.get(SwanLabEnv.API_KEY.value) + if self.__tmp_key is not None: + del os.environ[SwanLabEnv.API_KEY.value] + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + # 恢复原来的值 + if self.__tmp_key is not None: + os.environ[SwanLabEnv.API_KEY.value] = self.__tmp_key + if exc_type is KeyFileError: # 未登录 + return True + elif exc_type is not None: # 其他错误 + return False + self.is_login = True + return True + + def is_login() -> bool: - """判断是否已经登录,与当前的host相关 + """判断是否已经登录,与当前的host相关,与get_key不同,不考虑环境变量的因素 但不会检查key的有效性 :return: 是否已经登录 """ - try: + with LoginCheckContext() as checker: _ = get_key() - return True - except KeyFileError: - return False + return checker.is_login diff --git a/test/unit/api/auth/test_login.py b/test/unit/api/auth/test_login.py index d4836defd..fe879306d 100644 --- a/test/unit/api/auth/test_login.py +++ b/test/unit/api/auth/test_login.py @@ -20,21 +20,21 @@ def get_password(prompt: str): if "Paste" in prompt: return generate() else: - return T.TEST_CLOUD_KEY + return T.API_KEY -@pytest.mark.skipif(T.TEST_CLOUD_SKIP, reason="skip cloud test") +@pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test") def test_login_success(): """ 测试登录成功 """ - login_info = login_by_key(T.TEST_CLOUD_KEY, save=False) + login_info = login_by_key(T.API_KEY, save=False) assert not login_info.is_fail - assert login_info.api_key == T.TEST_CLOUD_KEY + assert login_info.api_key == T.API_KEY assert login_info.__str__() == "Login success" -@pytest.mark.skipif(T.TEST_CLOUD_SKIP, reason="skip cloud test") +@pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test") def test_login_error_key(): """ 测试登录失败, 错误的key @@ -45,27 +45,27 @@ def test_login_error_key(): assert login_info.__str__() == "Error api key" -@pytest.mark.skipif(T.TEST_CLOUD_SKIP, reason="skip cloud test") +@pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test") def test_terminal_login(monkeypatch): """ 测试终端登录 """ monkeypatch.setattr("getpass.getpass", get_password) - login_info = terminal_login(T.TEST_CLOUD_KEY) + login_info = terminal_login(T.API_KEY) assert not login_info.is_fail - assert login_info.api_key == T.TEST_CLOUD_KEY + assert login_info.api_key == T.API_KEY assert login_info.__str__() == "Login success" assert is_login() -@pytest.mark.skipif(T.TEST_CLOUD_SKIP, reason="skip cloud test") +@pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test") def test_code_login(): """ 测试code登录 """ - login_info = code_login(T.TEST_CLOUD_KEY) + login_info = code_login(T.API_KEY) assert not login_info.is_fail - assert login_info.api_key == T.TEST_CLOUD_KEY + assert login_info.api_key == T.API_KEY assert login_info.__str__() == "Login success" with pytest.raises(ValidationError): _ = code_login("wrong-key") diff --git a/test/unit/api/test_http.py b/test/unit/api/test_http.py index 281c88a89..d4f598d1c 100644 --- a/test/unit/api/test_http.py +++ b/test/unit/api/test_http.py @@ -14,13 +14,13 @@ from swanlab.api.http import create_http, HTTP, CosClient from swanlab.api.auth.login import login_by_key from swanlab.data.modules import MediaBuffer -from tutils import TEST_CLOUD_KEY, TEMP_PATH, TEST_CLOUD_SKIP +from tutils import API_KEY, TEMP_PATH, is_skip_cloud_test import pytest alphabet = "abcdefghijklmnopqrstuvwxyz" -@pytest.mark.skipif(TEST_CLOUD_SKIP, reason="skip cloud test") +@pytest.mark.skipif(is_skip_cloud_test, reason="skip cloud test") class TestCosSuite: http: HTTP = None project_name = nanoid.generate(alphabet) @@ -33,7 +33,7 @@ class TestCosSuite: def setup_class(cls): CosClient.REFRESH_TIME = cls.now_refresh_time # 这里不测试保存token的功能 - login_info = login_by_key(TEST_CLOUD_KEY, save=False) + login_info = login_by_key(API_KEY, save=False) cls.http = create_http(login_info) cls.http.mount_project(cls.project_name) cls.http.mount_exp(cls.experiment_name, ('#ffffff', '#ffffff')) diff --git a/test/unit/cli/test_cli_login.py b/test/unit/cli/test_cli_login.py index f1bbfd02b..4e1a853ca 100644 --- a/test/unit/cli/test_cli_login.py +++ b/test/unit/cli/test_cli_login.py @@ -10,23 +10,22 @@ from swanlab.package import get_key from click.testing import CliRunner from swanlab.cli.main import cli -from tutils import TEST_CLOUD_KEY from swanlab.error import ValidationError import tutils as T import pytest # noinspection PyTypeChecker -@pytest.mark.skipif(T.TEST_CLOUD_SKIP, reason="skip cloud test") +@pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test") def test_login_ok(): runner = CliRunner() - result = runner.invoke(cli, ["login", "--api-key", TEST_CLOUD_KEY]) + result = runner.invoke(cli, ["login", "--api-key", T.API_KEY]) assert result.exit_code == 0 - assert get_key() == TEST_CLOUD_KEY + assert get_key() == T.API_KEY # noinspection PyTypeChecker -@pytest.mark.skipif(T.TEST_CLOUD_SKIP, reason="skip cloud test") +@pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test") def test_login_fail(): runner = CliRunner() result = runner.invoke(cli, ["login", "--api-key", "123"]) diff --git a/test/unit/cli/test_cli_logout.py b/test/unit/cli/test_cli_logout.py index b15a009c9..d98096add 100644 --- a/test/unit/cli/test_cli_logout.py +++ b/test/unit/cli/test_cli_logout.py @@ -14,29 +14,29 @@ # noinspection PyTypeChecker -@pytest.mark.skipif(T.TEST_CLOUD_SKIP, reason="skip cloud test") +@pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test") def test_logout_ok(monkeypatch): runner = CliRunner() # 先登录 - runner.invoke(cli, ["login", "--api-key", T.TEST_CLOUD_KEY]) + runner.invoke(cli, ["login", "--api-key", T.API_KEY]) monkeypatch.setattr("builtins.input", lambda x: "y") result = runner.invoke(cli, ["logout"]) assert result.exit_code == 0 # noinspection PyTypeChecker -@pytest.mark.skipif(T.TEST_CLOUD_SKIP, reason="skip cloud test") +@pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test") def test_logout_cancel(monkeypatch): runner = CliRunner() # 先登录 - runner.invoke(cli, ["login", "--api-key", T.TEST_CLOUD_KEY]) + runner.invoke(cli, ["login", "--api-key", T.API_KEY]) monkeypatch.setattr("builtins.input", lambda x: "n") result = runner.invoke(cli, ["logout"]) assert result.exit_code == 0 # noinspection PyTypeChecker -@pytest.mark.skipif(T.TEST_CLOUD_SKIP, reason="skip cloud test") +@pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test") def test_logout_no_login(): runner = CliRunner() result = runner.invoke(cli, ["logout"]) diff --git a/test/unit/data/pytest_sdk.py b/test/unit/data/test_sdk.py similarity index 76% rename from test/unit/data/pytest_sdk.py rename to test/unit/data/test_sdk.py index a138a748d..7ca7984cf 100644 --- a/test/unit/data/pytest_sdk.py +++ b/test/unit/data/test_sdk.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- r""" @DATE: 2024/4/26 16:03 -@File: pytest_sdk.py +@File: test_sdk.py @IDE: pycharm @Description: 测试sdk的一些api @@ -31,7 +31,38 @@ def setup_function(): swanlog.enable_log() -MODE = SwanLabEnv.SWANLAB_MODE.value +MODE = SwanLabEnv.MODE.value + + +class TestInitModeFunc: + + def test_init_error_mode(self): + """ + 初始化时mode参数错误 + """ + with pytest.raises(ValueError): + S._init_mode("123456") # noqa + + @pytest.mark.parametrize("mode", ["disabled", "local", "cloud"]) + def test_init_mode(self, mode): + """ + 初始化时mode参数正确 + """ + S._init_mode(mode) + assert os.environ[MODE] == mode + del os.environ[MODE] + # # 大写 + # S._init_mode(mode.upper()) + # assert os.environ[MODE] == mode + + @pytest.mark.parametrize("mode", ["disabled", "local", "cloud"]) + def test_overwrite_mode(self, mode): + """ + 初始化时mode参数正确,覆盖环境变量 + """ + os.environ[MODE] = "123456" + S._init_mode(mode) + assert os.environ[MODE] == mode class TestInitMode: @@ -40,7 +71,9 @@ class TestInitMode: """ def test_init_disabled(self): - run = S.init(mode="disabled", logdir=generate()) + logdir = os.path.join(T.TEMP_PATH, generate()).__str__() + run = S.init(mode="disabled", logdir=logdir) + assert not os.path.exists(logdir) assert os.environ[MODE] == "disabled" run.log({"TestInitMode": 1}) # 不会报错 a = run.settings.run_dir @@ -53,9 +86,9 @@ def test_init_local(self): run.log({"TestInitMode": 1}) # 不会报错 assert get_run() is not None - @pytest.mark.skipif(T.TEST_CLOUD_SKIP, reason="skip cloud test") + @pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test") def test_init_cloud(self): - S.login(T.TEST_CLOUD_KEY) + S.login(T.is_skip_cloud_test) run = S.init(mode="cloud") assert os.environ[MODE] == "cloud" run.log({"TestInitMode": 1}) # 不会报错 @@ -83,10 +116,10 @@ def test_init_local_env(self): assert os.environ[MODE] == "local" run.log({"TestInitMode": 1}) - @pytest.mark.skipif(T.TEST_CLOUD_SKIP, reason="skip cloud test") + @pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test") def test_init_cloud_env(self): os.environ[MODE] = "cloud" - S.login(T.TEST_CLOUD_KEY) + S.login(T.is_skip_cloud_test) run = S.init() assert os.environ[MODE] == "cloud" run.log({"TestInitMode": 1}) @@ -175,7 +208,7 @@ def test_init_logdir_env(self): assert run.settings.swanlog_dir == logdir -@pytest.mark.skipif(T.TEST_CLOUD_SKIP, reason="skip cloud test") +@pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test") class TestLogin: """ 测试通过sdk封装的login函数登录 @@ -188,7 +221,7 @@ def get_password(prompt: str): if "Paste" in prompt: return generate() else: - return T.TEST_CLOUD_KEY + return T.is_skip_cloud_test def test_use_home_key(self, monkeypatch): """ @@ -208,5 +241,20 @@ def test_use_input_key(self, monkeypatch): key = generate() with pytest.raises(Err.ValidationError): S.login(api_key=key) - key = T.TEST_CLOUD_KEY + key = T.API_KEY S.login(api_key=key) + + def test_use_env_key(self, monkeypatch): + """ + 测试code登录,使用环境变量key + """ + + def _(): + raise RuntimeError("this function should not be called") + + monkeypatch.setattr("getpass.getpass", _) + os.environ[SwanLabEnv.API_KEY.value] = "1234" + with pytest.raises(Err.ValidationError): + S.login() + os.environ[SwanLabEnv.API_KEY.value] = T.API_KEY + S.login() diff --git a/test/unit/test_env.py b/test/unit/test_env.py new file mode 100644 index 000000000..5827631dd --- /dev/null +++ b/test/unit/test_env.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +r""" +@DATE: 2024/7/18 15:28 +@File: test_env.py +@IDE: pycharm +@Description: + 测试swanlab.env模块 +""" +import pytest + +from swanlab.env import SwanLabEnv +import swanlab +import os + + +def test_default(): + """ + 测试获取默认的环境变量 + """ + del os.environ[SwanLabEnv.WEB_HOST.value] + del os.environ[SwanLabEnv.API_HOST.value] + del os.environ[SwanLabEnv.RUNTIME.value] + swanlab.env.SwanLabEnv.set_default() + assert swanlab.package.get_host_web() == "https://swanlab.cn" + assert swanlab.package.get_host_api() == "https://api.swanlab.cn/api" + assert os.getenv(SwanLabEnv.RUNTIME.value) == "user" + + +def test_check(): + """ + 测试检查环境变量 + """ + os.environ[SwanLabEnv.MODE.value] = "124345" + with pytest.raises(ValueError): + SwanLabEnv.check() + os.environ[SwanLabEnv.RUNTIME.value] = "124" + with pytest.raises(ValueError): + SwanLabEnv.check() diff --git a/test/unit/test_package.py b/test/unit/test_package.py index 3a5e74df8..7d4e7e5c2 100644 --- a/test/unit/test_package.py +++ b/test/unit/test_package.py @@ -24,8 +24,6 @@ def test_get_package_version(): """ 测试获取版本号 """ - assert P.get_package_version() == os.getenv(SwanLabEnv.SWANLAB_VERSION.value) - del os.environ[SwanLabEnv.SWANLAB_VERSION.value] assert P.get_package_version() == package_data["version"] @@ -33,16 +31,16 @@ def test_get_host_web_env(): """ 通过环境变量指定web地址 """ - os.environ[SwanLabEnv.SWANLAB_WEB_HOST.value] = nanoid.generate() - assert P.get_host_web() == os.environ[SwanLabEnv.SWANLAB_WEB_HOST.value] + os.environ[SwanLabEnv.WEB_HOST.value] = nanoid.generate() + assert P.get_host_web() == os.environ[SwanLabEnv.WEB_HOST.value] def test_get_host_api_env(): """ 通过环境变量指定api地址 """ - os.environ[SwanLabEnv.SWANLAB_API_HOST.value] = nanoid.generate() - assert P.get_host_api() == os.environ[SwanLabEnv.SWANLAB_API_HOST.value] + os.environ[SwanLabEnv.API_HOST.value] = nanoid.generate() + assert P.get_host_api() == os.environ[SwanLabEnv.API_HOST.value] def test_get_user_setting_path(): @@ -82,6 +80,7 @@ def test_ok(self): """ 获取key成功 """ + del os.environ[SwanLabEnv.API_KEY.value] # 首先需要登录 file = os.path.join(get_save_dir(), ".netrc") with open(file, "w"): @@ -97,6 +96,7 @@ def test_no_file(self): """ 文件不存在 """ + del os.environ[SwanLabEnv.API_KEY.value] from swanlab.error import KeyFileError with pytest.raises(KeyFileError) as e: P.get_key() @@ -104,14 +104,23 @@ def test_no_file(self): def test_no_host(self): from swanlab.error import KeyFileError - self.test_ok() + self.test_ok() # 此时删除了环境变量 host = nanoid.generate() - os.environ[SwanLabEnv.SWANLAB_API_HOST.value] = host + os.environ[SwanLabEnv.API_HOST.value] = host assert P.get_host_api() == host with pytest.raises(KeyFileError) as e: P.get_key() assert str(e.value) == f"The host {host} does not exist" + def test_use_env(self): + """ + 使用环境变量,优先级高于本地文件 + """ + self.test_ok() + key = nanoid.generate() + os.environ[SwanLabEnv.API_KEY.value] = key + assert P.get_key() == key + class TestSaveKey: @@ -162,5 +171,5 @@ def test_wrong_host(self): host不匹配 """ self.login() - os.environ[SwanLabEnv.SWANLAB_API_HOST.value] = nanoid.generate() + os.environ[SwanLabEnv.API_HOST.value] = nanoid.generate() assert not P.is_login() diff --git a/tutils/__init__.py b/tutils/__init__.py index e9f585984..657fa560a 100644 --- a/tutils/__init__.py +++ b/tutils/__init__.py @@ -7,21 +7,24 @@ @Description: tutils模块的初始化文件 """ -from swanlab.env import SwanLabEnv from .check import * from .config import * +from swanlab.env import SwanLabEnv -api = os.getenv("SWANLAB_API_HOST") -web = os.getenv("SWANLAB_WEB_HOST") +API_HOST = os.getenv(SwanLabEnv.API_HOST.value) +WEB_HOST = os.getenv(SwanLabEnv.WEB_HOST.value) +API_KEY = os.getenv(SwanLabEnv.API_KEY.value) def reset_some_env(): - os.environ[SwanLabEnv.SWANLAB_VERSION.value] = "development" os.environ[SwanLabEnv.SWANLOG_FOLDER.value] = SWANLOG_FOLDER os.environ[SwanLabEnv.SWANLAB_FOLDER.value] = SWANLAB_FOLDER - if not TEST_CLOUD_SKIP: - os.environ[SwanLabEnv.SWANLAB_API_HOST.value] = api - os.environ[SwanLabEnv.SWANLAB_WEB_HOST.value] = web + SwanLabEnv.set_default() + SwanLabEnv.check() + if not is_skip_cloud_test: + os.environ[SwanLabEnv.API_HOST.value] = API_HOST + os.environ[SwanLabEnv.WEB_HOST.value] = WEB_HOST + os.environ[SwanLabEnv.API_KEY.value] = API_KEY if not os.path.exists(TEMP_PATH): @@ -35,4 +38,4 @@ def open_dev_mode() -> str: 在上层config部分已经执行了环境变量注入 :return: api-key """ - return TEST_CLOUD_KEY + return API_KEY diff --git a/tutils/check.py b/tutils/check.py index 1f0b8e024..2c9e340ef 100644 --- a/tutils/check.py +++ b/tutils/check.py @@ -26,28 +26,38 @@ packages = [i for i in packages if "swanboard" in i or "swankit" in i] for i in packages: if "swanboard" in i and swanboard_version not in i: - raise Exception(f"swanboard过时,运行 pip install -r requirements.txt 进行更新.", file=sys.stderr) + raise Exception(f"swanboard过时,运行 pip install -r requirements.txt 进行更新.") if "swankit" in i and swankit_version not in i: - raise Exception(f"swankit过时,运行 pip install -r requirements.txt 进行更新.", file=sys.stderr) + raise Exception(f"swankit过时,运行 pip install -r requirements.txt 进行更新.") -# ---------------------------------- 检查是否跳过云测试,如果没跳过,相关环境变量需要指定---------------------------------- +# ---------------------------------- 检查是否跳过云测试 ---------------------------------- load_dotenv(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), ".env")) - +runtime = os.getenv("SWANLAB_RUNTIME") +# pytest测试环境 is_pytest_env = "PYTEST_VERSION" in os.environ -is_skip_test = os.getenv("TEST_CLOUD_SKIP") is not None -is_cloud_dev_env = os.getenv("SWANLAB_API_HOST") is not None and os.getenv("SWANLAB_WEB_HOST") is not None -if not is_cloud_dev_env: - # 测试环境 - if is_pytest_env and not is_skip_test: - print("请设置开发云服务环境变量,或者设置环境变量TEST_CLOUD_SKIP以跳过云测试", file=sys.stderr) - """ - 可以根据不同版本选择需要的命令 - WINDOWS CMD COMMAND: set TEST_CLOUD_SKIP=1 - WINDOWS POWERSHELL COMMAND: $env:TEST_CLOUD_SKIP="1" - MAC & LINUX COMMAND: export TEST_CLOUD_SKIP=1 - """ - sys.exit(2) - # 开发环境 - elif not is_pytest_env: - print("请设置开发云服务环境变量以运行开发测试脚本", file=sys.stderr) +# 是否跳过部分云端测试 +is_skip_cloud_test = runtime == 'test-no-cloud' +# 是否为测试环境 +is_test_runtime = os.getenv("SWANLAB_RUNTIME") in ['test', 'test-no-cloud'] +# 如果为pytest测试环境,环境变量SWANLAB_RUNTIME必须为['test', 'test-no-cloud']之一 +# 如果没有跳过部分云端测试,必须设置SWANLAB_WEB_HOST、SWANLAB_API_HOST、SWANLAB_API_KEY +""" +* 推荐在项目根目录下设置.env文件完成环境变量的设置,具体代码为: + + SWANLAB_RUNTIME=test-no-cloud + +* 如果使用终端,也可以根据不同操作系统版本选择需要的命令 + + WINDOWS CMD COMMAND: set SWANLAB_RUNTIME="test-no-cloud" + WINDOWS POWERSHELL COMMAND: $env:SWANLAB_RUNTIME="test-no-cloud" + MAC & LINUX COMMAND: export SWANLAB_RUNTIME="test-no-cloud" +""" +if is_pytest_env: + if not is_test_runtime: + print("请设置SWANLAB_RUNTIME环境变量为 test 或 test-no-cloud 以运行云测试", file=sys.stderr) sys.exit(2) + if not is_skip_cloud_test: + envs = ["SWANLAB_WEB_HOST", "SWANLAB_API_HOST", "SWANLAB_API_KEY"] + if not all([os.getenv(i) for i in envs]): + print("请设置云测试相关环境变量以运行云测试", file=sys.stderr) + sys.exit(2) diff --git a/tutils/config.py b/tutils/config.py index 9be3c35b1..cdb43c8e4 100644 --- a/tutils/config.py +++ b/tutils/config.py @@ -28,32 +28,17 @@ SWANLOG_FOLDER = os.path.join(TEMP_PATH, "swanlog") """ -默认情况下,swanlog保存的文件夹 +默认情况下,swanlog保存的文件夹,对应SWANLAB_SAVE_DIR环境变量 """ SWANLAB_FOLDER = os.path.join(TEMP_PATH, ".swanlab") """ -默认情况下,系统信息保存的文件夹 +默认情况下,系统信息保存的文件夹,对应SWANLAB_LOG_DIR环境变量 """ - -# ---------------------------------- 测试用变量 ---------------------------------- - -TEST_CLOUD_SKIP = os.getenv("TEST_CLOUD_SKIP") is not None -""" -是否跳过云测试 -""" - -TEST_CLOUD_KEY = os.getenv("TEST_CLOUD_KEY") -""" -云测试的key -""" - # ---------------------------------- 导出 ---------------------------------- __all__ = [ "TEMP_PATH", "nanoid", - "TEST_CLOUD_SKIP", "SWANLOG_FOLDER", "SWANLAB_FOLDER", - "TEST_CLOUD_KEY", ]