Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

対応デバイスが分かるAPIエンドポイントを追加 #299

Merged
merged 4 commits into from
Feb 3, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
ParseKanaError,
Speaker,
SpeakerInfo,
SupportedDevicesInfo,
)
from voicevox_engine.morphing import synthesis_morphing
from voicevox_engine.morphing import (
Expand Down Expand Up @@ -535,6 +536,18 @@ def speaker_info(speaker_uuid: str, core_version: Optional[str] = None):
ret_data = {"policy": policy, "portrait": portrait, "style_infos": style_infos}
return ret_data

@app.get("/supported_devices", response_model=SupportedDevicesInfo, tags=["その他"])
def supported_devices(
core_version: Optional[str] = None,
):
supported_devices = get_engine(core_version).supported_devices
takana-v marked this conversation as resolved.
Show resolved Hide resolved
if supported_devices is None:
raise HTTPException(status_code=422, detail="非対応の機能です。")
return Response(
content=supported_devices,
media_type="application/json",
)

return app


Expand Down
2 changes: 1 addition & 1 deletion test/test_mock_synthesis_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def setUp(self):
pause_mora=None,
),
]
self.engine = MockSynthesisEngine(speakers="")
self.engine = MockSynthesisEngine(speakers="", supported_devices="")
PickledChair marked this conversation as resolved.
Show resolved Hide resolved

def test_replace_phoneme_length(self):
self.assertEqual(
Expand Down
2 changes: 2 additions & 0 deletions voicevox_engine/dev/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
decode_forward,
initialize,
metas,
supported_devices,
yukarin_s_forward,
yukarin_sa_forward,
)
Expand All @@ -12,4 +13,5 @@
"yukarin_s_forward",
"yukarin_sa_forward",
"metas",
"supported_devices",
]
9 changes: 9 additions & 0 deletions voicevox_engine/dev/core/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,12 @@ def metas() -> str:
},
]
)


def supported_devices() -> str:
return json.dumps(
{
"cpu": True,
"cuda": False,
}
)
13 changes: 11 additions & 2 deletions voicevox_engine/dev/synthesis_engine/mock.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from logging import getLogger
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

import numpy as np
from pyopenjtalk import tts
Expand All @@ -21,9 +21,18 @@ def __init__(self, **kwargs):
"""
super().__init__()

self.speakers = kwargs["speakers"]
self._speakers = kwargs["speakers"]
self._supported_devices = kwargs["supported_devices"]
self.default_sampling_rate = 24000

@property
def speakers(self) -> str:
return self._speakers

@property
def supported_devices(self) -> Optional[str]:
return self._supported_devices

def replace_phoneme_length(
self, accent_phrases: List[AccentPhrase], speaker_id: int
) -> List[AccentPhrase]:
Expand Down
9 changes: 9 additions & 0 deletions voicevox_engine/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,12 @@ class SpeakerInfo(BaseModel):
policy: str = Field(title="policy.md")
portrait: str = Field(title="portrait.pngをbase64エンコードしたもの")
style_infos: List[StyleInfo] = Field(title="スタイルの追加情報")


class SupportedDevicesInfo(BaseModel):
"""
対応しているデバイスの情報
"""

cpu: bool = Field(title="CPUに対応しているか")
cuda: bool = Field(title="CUDA(GPU)に対応しているか")
10 changes: 9 additions & 1 deletion voicevox_engine/synthesis_engine/make_synthesis_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,16 @@ def make_synthesis_engines(
file=sys.stderr,
)
continue
try:
supported_devices = core.supported_devices()
except NameError:
supported_devices = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pytorch版にはこの関数がないのでこのエラーが出るので、何がサポートされているかわからないという意味のNoneを代入みたいな、意図のコメントがあったほうが良いかなと思いました。

synthesis_engines[core_version] = SynthesisEngine(
yukarin_s_forwarder=core.yukarin_s_forward,
yukarin_sa_forwarder=core.yukarin_sa_forward,
decode_forwarder=core.decode_forward,
speakers=core.metas(),
supported_devices=supported_devices,
)
except Exception:
if not enable_mock:
Expand All @@ -96,9 +101,12 @@ def make_synthesis_engines(
file=sys.stderr,
)
from ..dev.core import metas as mock_metas
from ..dev.core import supported_devices as mock_supported_devices
from ..dev.synthesis_engine import MockSynthesisEngine

if "0.0.0" not in synthesis_engines:
synthesis_engines["0.0.0"] = MockSynthesisEngine(speakers=mock_metas())
synthesis_engines["0.0.0"] = MockSynthesisEngine(
speakers=mock_metas(), supported_devices=mock_supported_devices()
)

return synthesis_engines
14 changes: 13 additions & 1 deletion voicevox_engine/synthesis_engine/synthesis_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def __init__(
yukarin_sa_forwarder,
decode_forwarder,
speakers: str,
supported_devices: Optional[str] = None,
):
"""
yukarin_s_forwarder: 音素列から、音素ごとの長さを求める関数
Expand Down Expand Up @@ -160,15 +161,26 @@ def __init__(
return: 音声波形

speakers: coreから取得したspeakersに関するjsonデータの文字列

supported_devices: coreから取得した対応デバイスに関するjsonデータの文字列
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

おそらくNoneのときに意味を持っている設計だと思うので、その説明があると良いのかなと思いました。
(Noneのときはすべてのデバイスでサポートしているか不明です、みたいな)

"""
super().__init__()
self.yukarin_s_forwarder = yukarin_s_forwarder
self.yukarin_sa_forwarder = yukarin_sa_forwarder
self.decode_forwarder = decode_forwarder

self.speakers = speakers
self._speakers = speakers
self._supported_devices = supported_devices
self.default_sampling_rate = 24000

@property
def speakers(self) -> str:
return self._speakers

@property
def supported_devices(self) -> Optional[str]:
return self._supported_devices

def replace_phoneme_length(
self, accent_phrases: List[AccentPhrase], speaker_id: int
) -> List[AccentPhrase]:
Expand Down
12 changes: 11 additions & 1 deletion voicevox_engine/synthesis_engine/synthesis_engine_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import copy
from abc import ABCMeta, abstractmethod
from typing import List
from typing import List, Optional

from .. import full_context_label
from ..full_context_label import extract_full_context_label
Expand Down Expand Up @@ -78,6 +78,16 @@ def full_context_label_moras_to_moras(


class SynthesisEngineBase(metaclass=ABCMeta):
@property
@abstractmethod
def speakers(self) -> str:
raise NotImplementedError

@property
@abstractmethod
def supported_devices(self) -> Optional[str]:
raise NotImplementedError

@abstractmethod
def replace_phoneme_length(
self, accent_phrases: List[AccentPhrase], speaker_id: int
Expand Down