Skip to content

Commit

Permalink
feat!: SupportedDevicesからデシアライズ機能を剥奪
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed Jan 30, 2025
1 parent 78a92a9 commit 6a0544d
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 38 deletions.
5 changes: 3 additions & 2 deletions crates/voicevox_core/src/devices.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{
};

use derive_more::BitAnd;
use serde::{Deserialize, Serialize};
use serde::Serialize;

pub(crate) fn test_gpus(
gpus: impl IntoIterator<Item = GpuSpec>,
Expand Down Expand Up @@ -65,7 +65,8 @@ fn test_gpu(
/// # Ok(())
/// # }
/// ```
#[derive(Clone, Copy, PartialEq, Eq, Debug, BitAnd, Serialize, Deserialize)]
// 互換性保証のため、`Deserialize`は実装するべきではない
#[derive(Clone, Copy, PartialEq, Eq, Debug, BitAnd, Serialize)]
#[non_exhaustive]
pub struct SupportedDevices {
/// CPUが利用可能。
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
// エンジンを起動してyukarin_s・yukarin_sa・decodeの推論を行う

use std::collections::HashMap;
use std::sync::LazyLock;
use std::{cmp::min, ffi::CStr};

use assert_cmd::assert::AssertResult;
use libloading::Library;
use serde::{Deserialize, Serialize};
use voicevox_core::SupportedDevices;

use test_util::{c_api::CApi, EXAMPLE_DATA};

Expand All @@ -33,7 +33,9 @@ impl assert_cdylib::TestCase for TestCase {

{
let supported_devices = lib.supported_devices();
serde_json::from_str::<SupportedDevices>(CStr::from_ptr(supported_devices).to_str()?)?;
serde_json::from_str::<HashMap<String, bool>>(
CStr::from_ptr(supported_devices).to_str()?,
)?;
}

assert!(lib.initialize(false, 0, false));
Expand Down
5 changes: 3 additions & 2 deletions crates/voicevox_core_c_api/tests/e2e/testcases/global_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use libloading::Library;
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, DisplayFromStr};
use test_util::c_api::{self, CApi, VoicevoxLoadOnnxruntimeOptions, VoicevoxResultCode};
use voicevox_core::SupportedDevices;

use crate::{
assert_cdylib::{self, case, Utf8Output},
Expand Down Expand Up @@ -65,7 +64,9 @@ impl assert_cdylib::TestCase for TestCase {
supported_devices.as_mut_ptr(),
));
let supported_devices = supported_devices.assume_init();
serde_json::from_str::<SupportedDevices>(CStr::from_ptr(supported_devices).to_str()?)?;
serde_json::from_str::<HashMap<String, bool>>(
CStr::from_ptr(supported_devices).to_str()?,
)?;
lib.voicevox_json_free(supported_devices);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ public static String getVersion() {
*
* <p>あくまでONNX Runtimeが対応しているデバイスの情報であることに注意。GPUが使える環境ではなかったとしても {@link #cuda} や {@link #dml} は
* {@code true} を示しうる。
*
* <p>{@code Gson#fromJson} によりJSONから変換することはできない。その試みは {@link UnsupportedOperationException} となる。
*/
public static class SupportedDevices {
/**
Expand Down Expand Up @@ -71,9 +73,14 @@ public static class SupportedDevices {
public final boolean dml;

private SupportedDevices() {
this.cpu = false;
this.cuda = false;
this.dml = false;
throw new UnsupportedOperationException("You cannot deserialize `SupportedDevices`");
}

/** accessed only via JNI */
private SupportedDevices(boolean cpu, boolean cuda, boolean dml) {
this.cpu = cpu;
this.cuda = cuda;
this.dml = dml;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import static jp.hiroshiba.voicevoxcore.GlobalInfo.SupportedDevices;

import com.google.gson.Gson;
import jakarta.annotation.Nonnull;
import jakarta.annotation.Nullable;
import java.util.Optional;
Expand Down Expand Up @@ -122,16 +121,10 @@ private Onnxruntime(@Nullable String filename) {
* @return {@link SupportedDevices}。
*/
public SupportedDevices supportedDevices() {
Gson gson = new Gson();
String supportedDevicesJson = rsSupportedDevices();
SupportedDevices supportedDevices = gson.fromJson(supportedDevicesJson, SupportedDevices.class);
if (supportedDevices == null) {
throw new NullPointerException("supported_devices");
}
return supportedDevices;
return rsSupportedDevices();
}

private native void rsNew(@Nullable String filename);

private native String rsSupportedDevices();
private native SupportedDevices rsSupportedDevices();
}
18 changes: 14 additions & 4 deletions crates/voicevox_core_java_api/src/onnxruntime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use jni::{
JNIEnv,
};

use crate::common::throw_if_err;
use crate::{common::throw_if_err, object};

// SAFETY: voicevox_core_java_apiを構成するライブラリの中に、これと同名のシンボルは存在しない
#[duplicate_item(
Expand Down Expand Up @@ -54,8 +54,18 @@ unsafe extern "system" fn Java_jp_hiroshiba_voicevoxcore_blocking_Onnxruntime_rs
let this = *env.get_rust_field::<_, _, &'static voicevox_core::blocking::Onnxruntime>(
&this, "handle",
)?;
let json = this.supported_devices()?.to_json().to_string();
let json = env.new_string(json)?;
Ok(json.into_raw())
let devices = this.supported_devices()?;

assert!(match devices.to_json() {
serde_json::Value::Object(o) => o.len() == 3, // `cpu`, `cuda`, `dml`
_ => false,
});

let obj = env.new_object(
object!("GlobalInfo$SupportedDevices"),
"(ZZZ)V",
&[devices.cpu.into(), devices.cuda.into(), devices.dml.into()],
)?;
Ok(obj.into_raw())
})
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import dataclasses
from typing import Literal, NewType, TypeAlias
from typing import Literal, NewType, NoReturn, TypeAlias
from uuid import UUID

import pydantic
from pydantic_core import ArgsKwargs

from .._rust import _to_zenkaku, _validate_pronunciation
from ._please_do_not_use import _Reserved
Expand Down Expand Up @@ -137,6 +138,9 @@ class SupportedDevices:
あくまでONNX Runtimeが対応しているデバイスの情報であることに注意。GPUが使える環境ではなかったとしても
``cuda`` や ``dml`` は ``True`` を示しうる。
JSONからの変換も含め、VOICEVOX CORE以外が作ることはできない。作ろうとした場合
``TypeError`` となる。
"""

cpu: bool
Expand All @@ -162,6 +166,13 @@ class SupportedDevices:
(``DmlExecutionProvider``)に対応する。必要な環境についてはそちらを参照。
"""

@pydantic.model_validator(mode="before")
@staticmethod
def _deny_unless_from_pyo3(data: ArgsKwargs) -> ArgsKwargs:
if "I AM FROM PYO3" not in data.args:
raise TypeError("You cannot deserialize `SupportedDevices`")
return ArgsKwargs((), kwargs=data.kwargs)


AccelerationMode: TypeAlias = Literal["AUTO", "CPU", "GPU"] | _Reserved
"""
Expand Down
21 changes: 20 additions & 1 deletion crates/voicevox_core_python_api/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use pyo3::{
use serde::{de::DeserializeOwned, Serialize};
use serde_json::json;
use uuid::Uuid;
use voicevox_core::{AccelerationMode, AccentPhrase, StyleId, VoiceModelMeta};
use voicevox_core::{AccelerationMode, AccentPhrase, StyleId, SupportedDevices, VoiceModelMeta};

use crate::{
AnalyzeTextError, GetSupportedDevicesError, GpuSupportError, InitInferenceRuntimeError,
Expand Down Expand Up @@ -255,6 +255,25 @@ pub(crate) impl<T> voicevox_core::Result<T> {
}
}

#[ext(SupportedDevicesExt)]
impl SupportedDevices {
pub(crate) fn to_py(self, py: Python<'_>) -> PyResult<&PyAny> {
let class = py
.import("voicevox_core")?
.getattr("SupportedDevices")?
.downcast()?;
assert!(match self.to_json() {
serde_json::Value::Object(o) => o.len() == 3, // `cpu`, `cuda`, `dml`
_ => false,
});
PyAny::call(
class,
("I AM FROM PYO3",),
Some([("cpu", self.cpu), ("cuda", self.cuda), ("dml", self.dml)].into_py_dict(py)),
)
}
}

#[ext]
impl<T> std::result::Result<T, uuid::Error> {
fn into_py_value_result(self) -> PyResult<T> {
Expand Down
22 changes: 8 additions & 14 deletions crates/voicevox_core_python_api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,8 @@ mod blocking {
use voicevox_core::{AccelerationMode, AudioQuery, StyleId, UserDictWord};

use crate::{
convert::VoicevoxCoreResultExt as _, Closable, SingleTasked, VoiceModelFilePyFields,
convert::{SupportedDevicesExt as _, VoicevoxCoreResultExt as _},
Closable, SingleTasked, VoiceModelFilePyFields,
};

#[pyclass]
Expand Down Expand Up @@ -415,12 +416,7 @@ mod blocking {
}

fn supported_devices<'py>(&self, py: Python<'py>) -> PyResult<&'py PyAny> {
let class = py
.import("voicevox_core")?
.getattr("SupportedDevices")?
.downcast()?;
let s = self.0.supported_devices().into_py_result(py)?;
crate::convert::to_pydantic_dataclass(s, class)
self.0.supported_devices().into_py_result(py)?.to_py(py)
}
}

Expand Down Expand Up @@ -888,7 +884,10 @@ mod asyncio {
use uuid::Uuid;
use voicevox_core::{AccelerationMode, AudioQuery, StyleId, UserDictWord};

use crate::{convert::VoicevoxCoreResultExt as _, Closable, Tokio, VoiceModelFilePyFields};
use crate::{
convert::{SupportedDevicesExt as _, VoicevoxCoreResultExt as _},
Closable, Tokio, VoiceModelFilePyFields,
};

#[pyclass]
#[derive(Clone)]
Expand Down Expand Up @@ -1017,12 +1016,7 @@ mod asyncio {
}

fn supported_devices<'py>(&self, py: Python<'py>) -> PyResult<&'py PyAny> {
let class = py
.import("voicevox_core")?
.getattr("SupportedDevices")?
.downcast()?;
let s = self.0.supported_devices().into_py_result(py)?;
crate::convert::to_pydantic_dataclass(s, class)
self.0.supported_devices().into_py_result(py)?.to_py(py)
}
}

Expand Down

0 comments on commit 6a0544d

Please sign in to comment.