Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

整理: CLI 引数を CLIArgs クラスで型付け #1401

Merged
merged 7 commits into from
Jun 20, 2024
103 changes: 58 additions & 45 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import os
import sys
import warnings
from dataclasses import dataclass
from io import TextIOWrapper
from pathlib import Path
from typing import TextIO, TypeVar

import uvicorn
from pydantic import TypeAdapter

from voicevox_engine.app.application import generate_app
from voicevox_engine.cancellable_engine import CancellableEngine
Expand Down Expand Up @@ -109,15 +111,31 @@ def select_first_not_none_or_none(candidates: list[S | None]) -> S | None:
return None


def main() -> None:
"""VOICEVOX ENGINE を実行する"""
@dataclass(frozen=True)
class CLIArgs:
host: str
port: int
use_gpu: bool
voicevox_dir: Path | None
voicelib_dir: list[Path] | None
runtime_dir: list[Path] | None
tarepan marked this conversation as resolved.
Show resolved Hide resolved
enable_mock: bool
enable_cancellable_synthesis: bool
init_processes: int
load_all_models: bool
cpu_num_threads: int | None
output_log_utf8: bool
cors_policy_mode: CorsPolicyMode | None
allow_origin: list[str] | None
setting_file: Path
preset_file: Path | None
disable_mutable_api: bool

multiprocessing.freeze_support()

output_log_utf8 = decide_boolean_from_env("VV_OUTPUT_LOG_UTF8")
if output_log_utf8:
set_output_log_utf8()
_cli_args_adapter = TypeAdapter(CLIArgs)


def read_cli_arguments() -> CLIArgs:
parser = argparse.ArgumentParser(description="VOICEVOX のエンジンです。")
# Uvicorn でバインドするアドレスを "localhost" にすることで IPv4 (127.0.0.1) と IPv6 ([::1]) の両方でリッスンできます.
# これは Uvicorn のドキュメントに記載されていない挙動です; 将来のアップデートにより動作しなくなる可能性があります.
Expand Down Expand Up @@ -247,52 +265,47 @@ def main() -> None:
),
)

args = parser.parse_args()
args = _cli_args_adapter.validate_python(vars(parser.parse_args()))

# NOTE: 型検査のため Any 値に対して明示的に型を付ける
arg_cors_policy_mode: CorsPolicyMode | None = args.cors_policy_mode
arg_allow_origin: list[str] | None = args.allow_origin
arg_preset_path: Path | None = args.preset_file
arg_disable_mutable_api: bool = args.disable_mutable_api
return args

if args.output_log_utf8:

def main() -> None:
"""VOICEVOX ENGINE を実行する"""

multiprocessing.freeze_support()

output_log_utf8 = decide_boolean_from_env("VV_OUTPUT_LOG_UTF8")
if output_log_utf8:
set_output_log_utf8()

# Synthesis Engine
use_gpu: bool = args.use_gpu
voicevox_dir: Path | None = args.voicevox_dir
voicelib_dirs: list[Path] | None = args.voicelib_dir
runtime_dirs: list[Path] | None = args.runtime_dir
enable_mock: bool = args.enable_mock
cpu_num_threads: int | None = args.cpu_num_threads
load_all_models: bool = args.load_all_models
args = read_cli_arguments()

if args.output_log_utf8:
set_output_log_utf8()

core_manager = initialize_cores(
use_gpu=use_gpu,
voicelib_dirs=voicelib_dirs,
voicevox_dir=voicevox_dir,
runtime_dirs=runtime_dirs,
cpu_num_threads=cpu_num_threads,
enable_mock=enable_mock,
load_all_models=load_all_models,
use_gpu=args.use_gpu,
voicelib_dirs=args.voicelib_dir,
voicevox_dir=args.voicevox_dir,
runtime_dirs=args.runtime_dir,
cpu_num_threads=args.cpu_num_threads,
enable_mock=args.enable_mock,
load_all_models=args.load_all_models,
)
tts_engines = make_tts_engines_from_cores(core_manager)
assert len(tts_engines.versions()) != 0, "音声合成エンジンがありません。"

# Cancellable Engine
enable_cancellable_synthesis: bool = args.enable_cancellable_synthesis
init_processes: int = args.init_processes

cancellable_engine: CancellableEngine | None = None
if enable_cancellable_synthesis:
if args.enable_cancellable_synthesis:
cancellable_engine = CancellableEngine(
init_processes=init_processes,
use_gpu=use_gpu,
voicelib_dirs=voicelib_dirs,
voicevox_dir=voicevox_dir,
runtime_dirs=runtime_dirs,
cpu_num_threads=cpu_num_threads,
enable_mock=enable_mock,
init_processes=args.init_processes,
use_gpu=args.use_gpu,
voicelib_dirs=args.voicelib_dir,
voicevox_dir=args.voicevox_dir,
runtime_dirs=args.runtime_dir,
cpu_num_threads=args.cpu_num_threads,
enable_mock=args.enable_mock,
)

setting_loader = SettingHandler(args.setting_file)
Expand All @@ -301,14 +314,14 @@ def main() -> None:
# 複数方式で指定可能な場合、優先度は上から「引数」「環境変数」「設定ファイル」「デフォルト値」

cors_policy_mode = select_first_not_none(
[arg_cors_policy_mode, settings.cors_policy_mode]
[args.cors_policy_mode, settings.cors_policy_mode]
)

setting_allow_origin = None
if settings.allow_origin is not None:
setting_allow_origin = settings.allow_origin.split(" ")
allow_origin = select_first_not_none_or_none(
[arg_allow_origin, setting_allow_origin]
[args.allow_origin, setting_allow_origin]
)

env_preset_path_str = os.getenv("VV_PRESET_FILE")
Expand All @@ -318,7 +331,7 @@ def main() -> None:
env_preset_path = None
root_preset_path = engine_root() / "presets.yaml"
preset_path = select_first_not_none(
[arg_preset_path, env_preset_path, root_preset_path]
[args.preset_file, env_preset_path, root_preset_path]
)
# ファイルの存在に関わらず指定されたパスをプリセットファイルとして使用する
preset_manager = PresetManager(preset_path)
Expand All @@ -335,12 +348,12 @@ def main() -> None:
engine_manifest.uuid,
)

if arg_disable_mutable_api:
if args.disable_mutable_api:
disable_mutable_api = True
else:
disable_mutable_api = decide_boolean_from_env("VV_DISABLE_MUTABLE_API")

root_dir = select_first_not_none([voicevox_dir, engine_root()])
root_dir = select_first_not_none([args.voicevox_dir, engine_root()])
speaker_info_dir = root_dir / "resources" / "character_info"
# NOTE: ENGINE v0.19 以前向けに後方互換性を確保する
if not speaker_info_dir.exists():
Expand Down