Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ユーザー辞書機能:単語の更新/削除機能を追加 #338

Merged
merged 11 commits into from
Feb 25, 2022
57 changes: 56 additions & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@
from voicevox_engine.synthesis_engine import SynthesisEngineBase, make_synthesis_engines
from voicevox_engine.user_dict import (
apply_word,
delete_word,
read_dict,
rewrite_word,
user_dict_startup_processing,
)
from voicevox_engine.utility import (
Expand Down Expand Up @@ -559,7 +561,7 @@ def get_user_dict_words():
traceback.print_exc()
raise HTTPException(status_code=422, detail="辞書の読み込みに失敗しました。")

@app.post("/user_dict", status_code=204, tags=["ユーザー辞書"])
@app.post("/user_dict_word", status_code=204, tags=["ユーザー辞書"])
def add_user_dict_word(surface: str, pronunciation: str, accent_type: int):
"""
ユーザ辞書に言葉を追加します。
Expand All @@ -584,6 +586,59 @@ def add_user_dict_word(surface: str, pronunciation: str, accent_type: int):
traceback.print_exc()
raise HTTPException(status_code=422, detail="ユーザ辞書への追加に失敗しました。")

@app.put("/user_dict_word/{id}", status_code=204, tags=["ユーザー辞書"])
def rewrite_user_dict_word(
surface: str, pronunciation: str, accent_type: int, id: int
):
"""
ユーザ辞書に登録されている言葉を更新します。

Parameters
----------
surface : str
言葉の表層形
pronunciation: str
言葉の発音(カタカナ)
accent_type: int
アクセント型(音が下がる場所を指す)
id: int
更新する言葉のID
"""
try:
rewrite_word(
surface=surface,
pronunciation=pronunciation,
accent_type=accent_type,
id=id,
)
return Response(status_code=204)
except HTTPException:
raise
except ValidationError as e:
raise HTTPException(status_code=422, detail="パラメータに誤りがあります。\n" + str(e))
except Exception:
traceback.print_exc()
raise HTTPException(status_code=422, detail="ユーザ辞書の更新に失敗しました。")

@app.delete("/user_dict_word/{id}", status_code=204, tags=["ユーザー辞書"])
def delete_user_dict_word(id: int):
"""
ユーザ辞書に登録されている言葉を削除します。

Parameters
----------
id: int
削除する言葉のID
"""
try:
delete_word(id=id)
return Response(status_code=204)
except HTTPException:
raise
except Exception:
traceback.print_exc()
raise HTTPException(status_code=422, detail="ユーザ辞書の更新に失敗しました。")

@app.get("/supported_devices", response_model=SupportedDevicesInfo, tags=["その他"])
def supported_devices(
core_version: Optional[str] = None,
Expand Down
178 changes: 178 additions & 0 deletions test/test_user_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import json
from pathlib import Path
from tempfile import TemporaryDirectory
from unittest import TestCase

from fastapi import HTTPException
from pyopenjtalk import unset_user_dict

from voicevox_engine.model import UserDictJson, UserDictWord
from voicevox_engine.user_dict import (
apply_word,
create_word,
delete_word,
read_dict,
rewrite_word,
)

valid_dict_dict = {
"next_id": 1,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

将来辞書エディターみたないなのが出たりしたとき、next_idの数値と、word idがずれてしまう未来が見えました。
uuidにしても良いかも?

"words": {
"0": {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IDはintなんでしたっけ、stringなんでしたっけ

"surface": "test",
"cost": 8600,
"part_of_speech": "名詞",
"part_of_speech_detail_1": "固有名詞",
"part_of_speech_detail_2": "一般",
"part_of_speech_detail_3": "*",
"inflectional_type": "*",
"inflectional_form": "*",
"stem": "*",
"yomi": "テスト",
"pronunciation": "テスト",
"accent_type": 1,
"accent_associative_rule": "*",
}
},
}


class TestUserDict(TestCase):
def setUp(self):
self.tmp_dir = TemporaryDirectory()
self.tmp_dir_path = Path(self.tmp_dir.name)

def tearDown(self):
unset_user_dict()
self.tmp_dir.cleanup()

def test_read_not_exist_json(self):
self.assertEqual(
read_dict(user_dict_path=(self.tmp_dir_path / "not_exist.json")),
UserDictJson(**{"next_id": 0, "words": {}}),
)

def test_create_word(self):
# 将来的に品詞などが追加された時にテストを増やす
self.assertEqual(
create_word(surface="test", pronunciation="テスト", accent_type=1),
UserDictWord(
surface="test",
cost=8600,
part_of_speech="名詞",
part_of_speech_detail_1="固有名詞",
part_of_speech_detail_2="一般",
part_of_speech_detail_3="*",
inflectional_type="*",
inflectional_form="*",
stem="*",
yomi="テスト",
pronunciation="テスト",
accent_type=1,
accent_associative_rule="*",
),
)

def test_apply_word_without_json(self):
user_dict_path = self.tmp_dir_path / "test_apply_word_without_json.json"
apply_word(
surface="test",
pronunciation="テスト",
accent_type=1,
user_dict_path=user_dict_path,
compiled_dict_path=(self.tmp_dir_path / "test_apply_word_without_json.dic"),
)
res = read_dict(user_dict_path=user_dict_path)
self.assertEqual(len(res.words), 1)
self.assertEqual(res.next_id, 1)
self.assertEqual(
(
res.words[0].surface,
res.words[0].pronunciation,
res.words[0].accent_type,
),
("test", "テスト", 1),
)

def test_apply_word_with_json(self):
user_dict_path = self.tmp_dir_path / "test_apply_word_with_json.json"
user_dict_path.write_text(
json.dumps(valid_dict_dict, ensure_ascii=False), encoding="utf-8"
)
apply_word(
surface="test2",
pronunciation="テストツー",
accent_type=3,
user_dict_path=user_dict_path,
compiled_dict_path=(self.tmp_dir_path / "test_apply_word_with_json.dic"),
)
res = read_dict(user_dict_path=user_dict_path)
self.assertEqual(len(res.words), 2)
self.assertEqual(res.next_id, 2)
self.assertEqual(
(
res.words[1].surface,
res.words[1].pronunciation,
res.words[1].accent_type,
),
("test2", "テストツー", 3),
)

def test_rewrite_word_invalid_id(self):
user_dict_path = self.tmp_dir_path / "test_rewrite_word_invalid_id.json"
user_dict_path.write_text(
json.dumps(valid_dict_dict, ensure_ascii=False), encoding="utf-8"
)
self.assertRaises(
HTTPException,
rewrite_word,
id=1,
surface="test2",
pronunciation="テストツー",
accent_type=2,
user_dict_path=user_dict_path,
compiled_dict_path=(self.tmp_dir_path / "test_rewrite_word_invalid_id.dic"),
)

def test_rewrite_word_valid_id(self):
user_dict_path = self.tmp_dir_path / "test_rewrite_word_valid_id.json"
user_dict_path.write_text(
json.dumps(valid_dict_dict, ensure_ascii=False), encoding="utf-8"
)
rewrite_word(
id=0,
surface="test2",
pronunciation="テストツー",
accent_type=2,
user_dict_path=user_dict_path,
compiled_dict_path=(self.tmp_dir_path / "test_rewrite_word_valid_id.dic"),
)
res = read_dict(user_dict_path=user_dict_path).words[0]
self.assertEqual(
(res.surface, res.pronunciation, res.accent_type), ("test2", "テストツー", 2)
)

def test_delete_word_invalid_id(self):
user_dict_path = self.tmp_dir_path / "test_delete_word_invalid_id.json"
user_dict_path.write_text(
json.dumps(valid_dict_dict, ensure_ascii=False), encoding="utf-8"
)
self.assertRaises(
HTTPException,
delete_word,
id=1,
user_dict_path=user_dict_path,
compiled_dict_path=(self.tmp_dir_path / "test_delete_word_invalid_id.dic"),
)

def test_delete_word_valid_id(self):
user_dict_path = self.tmp_dir_path / "test_delete_word_valid_id.json"
user_dict_path.write_text(
json.dumps(valid_dict_dict, ensure_ascii=False), encoding="utf-8"
)
delete_word(
id=0,
user_dict_path=user_dict_path,
compiled_dict_path=(self.tmp_dir_path / "test_delete_word_valid_id.dic"),
)
self.assertEqual(len(read_dict(user_dict_path=user_dict_path).words), 0)
70 changes: 54 additions & 16 deletions voicevox_engine/user_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pyopenjtalk
from appdirs import user_data_dir
from fastapi import HTTPException

from .model import UserDictJson, UserDictWord
from .utility import engine_root
Expand Down Expand Up @@ -35,7 +36,6 @@ def user_dict_startup_processing(

def update_dict(
default_dict_path: Path = default_dict_path,
user_dict_path: Path = user_dict_path,
compiled_dict_path: Path = compiled_dict_path,
):
with NamedTemporaryFile(encoding="utf-8", mode="w", delete=False) as f:
Expand Down Expand Up @@ -89,13 +89,9 @@ def read_dict(user_dict_path: Path = user_dict_path) -> UserDictJson:
return UserDictJson(**json.load(f))


def apply_word(**kwargs):
if "user_dict_path" in kwargs:
_user_dict_path = kwargs["user_dict_path"]
else:
_user_dict_path = user_dict_path
word = UserDictWord(
surface=kwargs["surface"],
def create_word(surface: str, pronunciation: str, accent_type: int) -> UserDictWord:
return UserDictWord(
surface=surface,
cost=8600,
part_of_speech="名詞",
part_of_speech_detail_1="固有名詞",
Expand All @@ -104,16 +100,58 @@ def apply_word(**kwargs):
inflectional_type="*",
inflectional_form="*",
stem="*",
yomi=kwargs["pronunciation"],
pronunciation=kwargs["pronunciation"],
accent_type=kwargs["accent_type"],
yomi=pronunciation,
pronunciation=pronunciation,
accent_type=accent_type,
accent_associative_rule="*",
)
user_dict = read_dict(user_dict_path=_user_dict_path)


def apply_word(
surface: str,
pronunciation: str,
accent_type: int,
user_dict_path: Path = user_dict_path,
compiled_dict_path: Path = compiled_dict_path,
):
word = create_word(
surface=surface, pronunciation=pronunciation, accent_type=accent_type
)
user_dict = read_dict(user_dict_path=user_dict_path)
id = user_dict.next_id
user_dict.next_id += 1
user_dict.words[id] = word
with _user_dict_path.open(encoding="utf-8", mode="w") as f:
json.dump(user_dict.dict(), f, ensure_ascii=False)
_user_dict_path.write_text(user_dict.json(ensure_ascii=False), encoding="utf-8")
update_dict(user_dict_path=_user_dict_path)
user_dict_path.write_text(user_dict.json(ensure_ascii=False), encoding="utf-8")
update_dict(compiled_dict_path=compiled_dict_path)


def rewrite_word(
id: int,
surface: str,
pronunciation: str,
accent_type: int,
user_dict_path: Path = user_dict_path,
compiled_dict_path: Path = compiled_dict_path,
):
word = create_word(
surface=surface, pronunciation=pronunciation, accent_type=accent_type
)
user_dict = read_dict(user_dict_path=user_dict_path)
if id not in user_dict.words:
raise HTTPException(status_code=422, detail="IDに該当するワードが見つかりませんでした")
user_dict.words[id] = word
user_dict_path.write_text(user_dict.json(ensure_ascii=False), encoding="utf-8")
update_dict(compiled_dict_path=compiled_dict_path)


def delete_word(
id: int,
user_dict_path: Path = user_dict_path,
compiled_dict_path: Path = compiled_dict_path,
):
user_dict = read_dict(user_dict_path=user_dict_path)
if id not in user_dict.words:
raise HTTPException(status_code=422, detail="IDに該当するワードが見つかりませんでした")
del user_dict.words[id]
user_dict_path.write_text(user_dict.json(ensure_ascii=False), encoding="utf-8")
update_dict(compiled_dict_path=compiled_dict_path)