diff --git a/swanlab/data/callback_cloud.py b/swanlab/data/callback_cloud.py index 81f3fec7..2235d947 100644 --- a/swanlab/data/callback_cloud.py +++ b/swanlab/data/callback_cloud.py @@ -7,7 +7,6 @@ @Description: 云端回调 """ -import io import json import os import sys @@ -23,7 +22,7 @@ from swanlab.api.upload.model import ColumnModel, ScalarModel, MediaModel, FileModel from swanlab.data.cloud import ThreadPool from swanlab.data.cloud import UploadType -from swanlab.env import in_jupyter, SwanLabEnv +from swanlab.env import in_jupyter, SwanLabEnv, is_interactive from swanlab.error import KeyFileError from swanlab.log import swanlog from swanlab.package import ( @@ -152,15 +151,11 @@ def create_login_info(cls): try: key = get_key() except KeyFileError: - try: - fd = sys.stdin.fileno() - # 不是标准终端,且非jupyter环境,无法控制其回显 - if not os.isatty(fd) and not in_jupyter(): - raise KeyFileError("The key file is not found, call `swanlab.login()` or use `swanlab login` ") - # 当使用capsys、capfd或monkeypatch等fixture来捕获或修改标准输入输出时,会抛出io.UnsupportedOperation - # 这种情况下为用户自定义情况 - except io.UnsupportedOperation: - pass + pass + if key is None and not is_interactive(): + raise KeyFileError( + "api key not configured (no-tty), call `swanlab.login(api_key=[your_api_key])` or set `swanlab.init(mode=\"local\")`." + ) return terminal_login(key) @staticmethod @@ -208,12 +203,11 @@ def __str__(self): def on_init(self, project: str, workspace: str, logdir: str = None, **kwargs) -> int: super(CloudRunCallback, self).on_init(project, workspace, logdir) - # 检测是否有最新的版本 - self._get_package_latest_version() if self.login_info is None: swanlog.debug("Login info is None, get login info.") self.login_info = self.create_login_info() - + # 检测是否有最新的版本 + self._get_package_latest_version() http = create_http(self.login_info) return http.mount_project(project, workspace, self.public).history_exp_count diff --git a/swanlab/data/sdk.py b/swanlab/data/sdk.py index cda40253..895c991b 100644 --- a/swanlab/data/sdk.py +++ b/swanlab/data/sdk.py @@ -15,7 +15,7 @@ from swankit.log import FONT from swanlab.api import code_login, terminal_login -from swanlab.env import SwanLabEnv +from swanlab.env import SwanLabEnv, is_interactive from swanlab.log import swanlog from .callback_cloud import CloudRunCallback from .callback_local import LocalRunCallback @@ -29,7 +29,7 @@ ) from .run.helper import SwanLabRunOperator from ..error import KeyFileError -from ..package import get_key, get_host_web, get_user_setting_path +from ..package import get_key, get_host_web def _check_proj_name(name: str) -> str: @@ -265,11 +265,16 @@ def _init_mode(mode: str = None): login_info = None if mode == "cloud" and no_api_key: # 判断当前进程是否在交互模式下 - if os.isatty(0) and (os.isatty(1) or os.isatty(2)): + if is_interactive(): + swanlog.info( + "Using SwanLab to track your experiments. Please refer to https://docs.swanlab.cn for more information." + ) swanlog.info("(1) Create a SwanLab account.") swanlog.info("(2) Use an existing SwanLab account.") swanlog.info("(3) Don't visualize my results.") - tip = FONT.swanlab("Enter your choice:") + + # 交互选择 + tip = FONT.swanlab("Enter your choice: ") code = input(tip) while code not in ["1", "2", "3"]: swanlog.warning("Invalid choice, please enter again.") @@ -278,13 +283,11 @@ def _init_mode(mode: str = None): mode = "local" elif code == "2": swanlog.info("You chose 'Create a swanlab account'") - swanlog.info("Create a SwanLab account here: " + FONT.yellow(get_host_web() + "/login")) - swanlog.info("You can find your API key in your browser here: " + FONT.yellow(get_user_setting_path())) + swanlog.info("Create a SwanLab account here: " + get_host_web() + "/login") login_info = terminal_login() elif code == "1": swanlog.info("You chose 'Use an existing swanlab account'") swanlog.info("Logging into " + get_host_web()) - swanlog.info("You can find your API key in your browser here: " + FONT.yellow(get_user_setting_path())) login_info = terminal_login() else: raise ValueError("Invalid choice") diff --git a/swanlab/env.py b/swanlab/env.py index 0472c293..a51c8f80 100644 --- a/swanlab/env.py +++ b/swanlab/env.py @@ -8,7 +8,9 @@ 除了utils和error模块,其他模块都可以使用这个模块 """ import enum +import io import os +import sys from typing import List import swankit.env as E @@ -127,3 +129,17 @@ def in_jupyter() -> bool: return True except NameError: return False + + +def is_interactive(): + """ + 是否为可交互式环境(输入连接tty设备) + 特殊的环境:jupyter notebook + """ + try: + fd = sys.stdin.fileno() + return os.isatty(fd) or in_jupyter() + # 当使用capsys、capfd或monkeypatch等fixture来捕获或修改标准输入输出时,会抛出io.UnsupportedOperation + # 多为测试情况,可交互 + except io.UnsupportedOperation: + return True diff --git a/test/unit/data/test_sdk.py b/test/unit/data/test_sdk.py index 6f78a110..0fdfef05 100644 --- a/test/unit/data/test_sdk.py +++ b/test/unit/data/test_sdk.py @@ -72,8 +72,7 @@ def test_init_cloud_with_no_api_key(self, monkeypatch): """ api_key = os.environ[SwanLabEnv.API_KEY.value] del os.environ[SwanLabEnv.API_KEY.value] - # 在测试时默认不在交互模式下,因此不会做任何输入选择交互 - S._init_mode("cloud") + # 在测试时默认会在交互模式下 # 接下来需要模拟终端连接,使用monkeypatch # 模拟 os.isatty(0) 返回 True monkeypatch.setattr(os, "isatty", lambda fd: True) diff --git a/test/unit/test_env.py b/test/unit/test_env.py index 5827631d..3db74234 100644 --- a/test/unit/test_env.py +++ b/test/unit/test_env.py @@ -7,11 +7,12 @@ @Description: 测试swanlab.env模块 """ +import os + import pytest -from swanlab.env import SwanLabEnv import swanlab -import os +from swanlab.env import SwanLabEnv, is_interactive def test_default(): @@ -37,3 +38,8 @@ def test_check(): os.environ[SwanLabEnv.RUNTIME.value] = "124" with pytest.raises(ValueError): SwanLabEnv.check() + + +def test_is_interactive(): + # 测试时默认返回true + assert is_interactive() == True