Skip to content

Commit

Permalink
chore: opt code
Browse files Browse the repository at this point in the history
chore: opt code

chore: add test
  • Loading branch information
SAKURA-CAT committed Jan 16, 2025
1 parent 8421d95 commit 7a16dfd
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 25 deletions.
22 changes: 8 additions & 14 deletions swanlab/data/callback_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
@Description:
云端回调
"""
import io
import json
import os
import sys
Expand All @@ -23,7 +22,7 @@
from swanlab.api.upload.model import ColumnModel, ScalarModel, MediaModel, FileModel
from swanlab.data.cloud import ThreadPool
from swanlab.data.cloud import UploadType
from swanlab.env import in_jupyter, SwanLabEnv
from swanlab.env import in_jupyter, SwanLabEnv, is_interactive
from swanlab.error import KeyFileError
from swanlab.log import swanlog
from swanlab.package import (
Expand Down Expand Up @@ -152,15 +151,11 @@ def create_login_info(cls):
try:
key = get_key()
except KeyFileError:
try:
fd = sys.stdin.fileno()
# 不是标准终端,且非jupyter环境,无法控制其回显
if not os.isatty(fd) and not in_jupyter():
raise KeyFileError("The key file is not found, call `swanlab.login()` or use `swanlab login` ")
# 当使用capsys、capfd或monkeypatch等fixture来捕获或修改标准输入输出时,会抛出io.UnsupportedOperation
# 这种情况下为用户自定义情况
except io.UnsupportedOperation:
pass
pass
if key is None and not is_interactive():
raise KeyFileError(
"api key not configured (no-tty), call `swanlab.login(api_key=[your_api_key])` or set `swanlab.init(mode=\"local\")`."
)
return terminal_login(key)

@staticmethod
Expand Down Expand Up @@ -208,12 +203,11 @@ def __str__(self):

def on_init(self, project: str, workspace: str, logdir: str = None, **kwargs) -> int:
super(CloudRunCallback, self).on_init(project, workspace, logdir)
# 检测是否有最新的版本
self._get_package_latest_version()
if self.login_info is None:
swanlog.debug("Login info is None, get login info.")
self.login_info = self.create_login_info()

# 检测是否有最新的版本
self._get_package_latest_version()
http = create_http(self.login_info)
return http.mount_project(project, workspace, self.public).history_exp_count

Expand Down
17 changes: 10 additions & 7 deletions swanlab/data/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from swankit.log import FONT

from swanlab.api import code_login, terminal_login
from swanlab.env import SwanLabEnv
from swanlab.env import SwanLabEnv, is_interactive
from swanlab.log import swanlog
from .callback_cloud import CloudRunCallback
from .callback_local import LocalRunCallback
Expand All @@ -29,7 +29,7 @@
)
from .run.helper import SwanLabRunOperator
from ..error import KeyFileError
from ..package import get_key, get_host_web, get_user_setting_path
from ..package import get_key, get_host_web


def _check_proj_name(name: str) -> str:
Expand Down Expand Up @@ -265,11 +265,16 @@ def _init_mode(mode: str = None):
login_info = None
if mode == "cloud" and no_api_key:
# 判断当前进程是否在交互模式下
if os.isatty(0) and (os.isatty(1) or os.isatty(2)):
if is_interactive():
swanlog.info(
"Using SwanLab to track your experiments. Please refer to https://docs.swanlab.cn for more information."
)
swanlog.info("(1) Create a SwanLab account.")
swanlog.info("(2) Use an existing SwanLab account.")
swanlog.info("(3) Don't visualize my results.")
tip = FONT.swanlab("Enter your choice:")

# 交互选择
tip = FONT.swanlab("Enter your choice: ")
code = input(tip)
while code not in ["1", "2", "3"]:
swanlog.warning("Invalid choice, please enter again.")
Expand All @@ -278,13 +283,11 @@ def _init_mode(mode: str = None):
mode = "local"
elif code == "2":
swanlog.info("You chose 'Create a swanlab account'")
swanlog.info("Create a SwanLab account here: " + FONT.yellow(get_host_web() + "/login"))
swanlog.info("You can find your API key in your browser here: " + FONT.yellow(get_user_setting_path()))
swanlog.info("Create a SwanLab account here: " + get_host_web() + "/login")
login_info = terminal_login()
elif code == "1":
swanlog.info("You chose 'Use an existing swanlab account'")
swanlog.info("Logging into " + get_host_web())
swanlog.info("You can find your API key in your browser here: " + FONT.yellow(get_user_setting_path()))
login_info = terminal_login()
else:
raise ValueError("Invalid choice")
Expand Down
16 changes: 16 additions & 0 deletions swanlab/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
除了utils和error模块,其他模块都可以使用这个模块
"""
import enum
import io
import os
import sys
from typing import List

import swankit.env as E
Expand Down Expand Up @@ -127,3 +129,17 @@ def in_jupyter() -> bool:
return True
except NameError:
return False


def is_interactive():
"""
是否为可交互式环境(输入连接tty设备)
特殊的环境:jupyter notebook
"""
try:
fd = sys.stdin.fileno()
return os.isatty(fd) or in_jupyter()
# 当使用capsys、capfd或monkeypatch等fixture来捕获或修改标准输入输出时,会抛出io.UnsupportedOperation
# 多为测试情况,可交互
except io.UnsupportedOperation:
return True
3 changes: 1 addition & 2 deletions test/unit/data/test_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ def test_init_cloud_with_no_api_key(self, monkeypatch):
"""
api_key = os.environ[SwanLabEnv.API_KEY.value]
del os.environ[SwanLabEnv.API_KEY.value]
# 在测试时默认不在交互模式下,因此不会做任何输入选择交互
S._init_mode("cloud")
# 在测试时默认会在交互模式下
# 接下来需要模拟终端连接,使用monkeypatch
# 模拟 os.isatty(0) 返回 True
monkeypatch.setattr(os, "isatty", lambda fd: True)
Expand Down
10 changes: 8 additions & 2 deletions test/unit/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
@Description:
测试swanlab.env模块
"""
import os

import pytest

from swanlab.env import SwanLabEnv
import swanlab
import os
from swanlab.env import SwanLabEnv, is_interactive


def test_default():
Expand All @@ -37,3 +38,8 @@ def test_check():
os.environ[SwanLabEnv.RUNTIME.value] = "124"
with pytest.raises(ValueError):
SwanLabEnv.check()


def test_is_interactive():
# 测试时默认返回true
assert is_interactive() == True

0 comments on commit 7a16dfd

Please sign in to comment.