Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change: VoiceModelVoiceModelFile #832

Merged
merged 19 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ anstyle-query = "1.0.0"
anyhow = "1.0.65"
assert_cmd = "2.0.8"
async-fs = "2.1.2"
async-lock = "3.4.0"
async_zip = "=0.0.16"
bindgen = "0.69.4"
binstall-tar = "0.4.39"
Expand Down
3 changes: 2 additions & 1 deletion crates/voicevox_core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ link-onnxruntime = []

[dependencies]
anyhow.workspace = true
async-fs.workspace = true
async-fs.workspace = true # 今これを使っている箇所はどこにも無いが、`UserDict`にはこれを使った方がよいはず
async-lock.workspace = true
async_zip = { workspace = true, features = ["deflate"] }
blocking.workspace = true
camino.workspace = true
Expand Down
2 changes: 1 addition & 1 deletion crates/voicevox_core/src/__internal/doctest_fixtures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub async fn synthesizer_with_sample_voice_model(
},
)?;

let model = &crate::nonblocking::VoiceModel::from_path(voice_model_path).await?;
let model = &crate::nonblocking::VoiceModelFile::open(voice_model_path).await?;
syntesizer.load_voice_model(model).await?;

Ok(syntesizer)
Expand Down
2 changes: 2 additions & 0 deletions crates/voicevox_core/src/__internal/interop.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
pub mod raii;

pub use crate::{
metas::merge as merge_metas, synthesizer::blocking::PerformInference,
voice_model::blocking::IdRef,
Expand Down
43 changes: 43 additions & 0 deletions crates/voicevox_core/src/__internal/interop/raii.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
use std::{marker::PhantomData, ops::Deref};

use ouroboros::self_referencing;

pub enum MaybeClosed<T> {
Open(T),
Closed,
}

// [`mapped_lock_guards`]のようなことをやるためのユーティリティ。
//
// [`mapped_lock_guards`]: https://github.com/rust-lang/rust/issues/117108
pub fn try_map_guard<'lock, G, F, T, E>(guard: G, f: F) -> Result<impl Deref<Target = T> + 'lock, E>
where
G: 'lock,
F: FnOnce(&G) -> Result<&T, E>,
T: 'lock,
{
return MappedLockTryBuilder {
guard,
target_builder: f,
marker: PhantomData,
}
.try_build();

#[self_referencing]
struct MappedLock<'lock, G: 'lock, T> {
guard: G,

#[borrows(guard)]
target: &'this T,

marker: PhantomData<&'lock T>,
}

impl<'lock, G: 'lock, T: 'lock> Deref for MappedLock<'lock, G, T> {
type Target = T;

fn deref(&self) -> &Self::Target {
self.borrow_target()
}
}
}
152 changes: 126 additions & 26 deletions crates/voicevox_core/src/asyncs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,29 @@

use std::{
io::{self, Read as _, Seek as _, SeekFrom},
ops::DerefMut,
path::Path,
pin::Pin,
task::{self, Poll},
};

use blocking::Unblock;
use futures_io::{AsyncRead, AsyncSeek};
use futures_util::ready;

pub(crate) trait Async: 'static {
async fn open_file(path: impl AsRef<Path>) -> io::Result<impl AsyncRead + AsyncSeek + Unpin>;
type Mutex<T: Send + Sync + Unpin>: Mutex<T>;
type RoFile: AsyncRead + AsyncSeek + Send + Sync + Unpin;

/// ファイルを読み取り専用(RO)で開く。
///
/// `io::Error`は素(`i32`相当)のままにしておき、この関数を呼び出す側でfs-err風のメッセージを付
/// ける。
async fn open_file_ro(path: impl AsRef<Path>) -> io::Result<Self::RoFile>;
Hiroshiba marked this conversation as resolved.
Show resolved Hide resolved
}

pub(crate) trait Mutex<T>: From<T> + Send + Sync + Unpin {
async fn lock(&self) -> impl DerefMut<Target = T>;
}

/// エグゼキュータが非同期タスクの並行実行をしないことを仮定する、[`Async`]の実装。
Expand All @@ -39,30 +53,47 @@ pub(crate) trait Async: 'static {
pub(crate) enum SingleTasked {}

impl Async for SingleTasked {
async fn open_file(path: impl AsRef<Path>) -> io::Result<impl AsyncRead + AsyncSeek + Unpin> {
return std::fs::File::open(path).map(BlockingFile);

struct BlockingFile(std::fs::File);

impl AsyncRead for BlockingFile {
fn poll_read(
mut self: Pin<&mut Self>,
_: &mut task::Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(self.0.read(buf))
}
}
type Mutex<T: Send + Sync + Unpin> = StdMutex<T>;
type RoFile = StdFile;

impl AsyncSeek for BlockingFile {
fn poll_seek(
mut self: Pin<&mut Self>,
_: &mut task::Context<'_>,
pos: SeekFrom,
) -> Poll<io::Result<u64>> {
Poll::Ready(self.0.seek(pos))
}
}
async fn open_file_ro(path: impl AsRef<Path>) -> io::Result<Self::RoFile> {
std::fs::File::open(path).map(StdFile)
}
}

pub(crate) struct StdMutex<T>(std::sync::Mutex<T>);

impl<T> From<T> for StdMutex<T> {
fn from(inner: T) -> Self {
Self(inner.into())
}
}

impl<T: Send + Sync + Unpin> Mutex<T> for StdMutex<T> {
async fn lock(&self) -> impl DerefMut<Target = T> {
self.0.lock().unwrap_or_else(|e| panic!("{e}"))
}
}

pub(crate) struct StdFile(std::fs::File);
Hiroshiba marked this conversation as resolved.
Show resolved Hide resolved

impl AsyncRead for StdFile {
fn poll_read(
mut self: Pin<&mut Self>,
_: &mut task::Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(self.0.read(buf))
}
}

impl AsyncSeek for StdFile {
fn poll_seek(
mut self: Pin<&mut Self>,
_: &mut task::Context<'_>,
pos: SeekFrom,
) -> Poll<io::Result<u64>> {
Poll::Ready(self.0.seek(pos))
}
}

Expand All @@ -74,7 +105,76 @@ impl Async for SingleTasked {
pub(crate) enum BlockingThreadPool {}

impl Async for BlockingThreadPool {
async fn open_file(path: impl AsRef<Path>) -> io::Result<impl AsyncRead + AsyncSeek + Unpin> {
async_fs::File::open(path).await
type Mutex<T: Send + Sync + Unpin> = async_lock::Mutex<T>;
type RoFile = AsyncRoFile;

async fn open_file_ro(path: impl AsRef<Path>) -> io::Result<Self::RoFile> {
AsyncRoFile::open(path).await
}
}

impl<T: Send + Sync + Unpin> Mutex<T> for async_lock::Mutex<T> {
async fn lock(&self) -> impl DerefMut<Target = T> {
self.lock().await
}
}

// TODO: `async_fs::File::into_std_file`みたいなのがあればこんなの↓は作らなくていいはず。PR出す?
pub(crate) struct AsyncRoFile {
Hiroshiba marked this conversation as resolved.
Show resolved Hide resolved
// `poll_read`と`poll_seek`しかしない
unblock: Unblock<std::fs::File>,

// async-fsの実装がやっているように「正しい」シーク位置を保持する。ただしファイルはパイプではな
// いことがわかっているため smol-rs/async-fs#4 は考えない
real_seek_pos: Option<u64>,
}

impl AsyncRoFile {
async fn open(path: impl AsRef<Path>) -> io::Result<Self> {
let path = path.as_ref().to_owned();
let unblock = Unblock::new(blocking::unblock(|| std::fs::File::open(path)).await?);
Ok(Self {
unblock,
real_seek_pos: None,
})
}

pub(crate) async fn close(self) {
let file = self.unblock.into_inner().await;
blocking::unblock(|| drop(file)).await;
}
}

impl AsyncRead for AsyncRoFile {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
if self.real_seek_pos.is_none() {
self.real_seek_pos = Some(ready!(
Pin::new(&mut self.unblock).poll_seek(cx, SeekFrom::Current(0))
)?);
}
let n = ready!(Pin::new(&mut self.unblock).poll_read(cx, buf))?;
*self.real_seek_pos.as_mut().expect("should be present") += n as u64;
Poll::Ready(Ok(n))
}
}

impl AsyncSeek for AsyncRoFile {
fn poll_seek(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
pos: SeekFrom,
) -> Poll<io::Result<u64>> {
// async-fsの実装がやっているような"reposition"を行う。
// https://github.com/smol-rs/async-fs/issues/2#issuecomment-675595170
if let Some(real_seek_pos) = self.real_seek_pos {
ready!(Pin::new(&mut self.unblock).poll_seek(cx, SeekFrom::Start(real_seek_pos)))?;
}
self.real_seek_pos = None;

Pin::new(&mut self.unblock).poll_seek(cx, pos)
}
}
2 changes: 1 addition & 1 deletion crates/voicevox_core/src/blocking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
pub use crate::{
engine::open_jtalk::blocking::OpenJtalk, infer::runtimes::onnxruntime::blocking::Onnxruntime,
synthesizer::blocking::Synthesizer, user_dict::dict::blocking::UserDict,
voice_model::blocking::VoiceModel,
voice_model::blocking::VoiceModelFile,
};

pub mod onnxruntime {
Expand Down
2 changes: 1 addition & 1 deletion crates/voicevox_core/src/engine/open_jtalk.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// TODO: `VoiceModel`のように、次のような設計にする。
// TODO: `VoiceModelFile`のように、次のような設計にする。
//
// ```
// pub(crate) mod blocking {
Expand Down
2 changes: 1 addition & 1 deletion crates/voicevox_core/src/infer/runtimes/onnxruntime.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// TODO: `VoiceModel`のように、次のような設計にする。
// TODO: `VoiceModelFile`のように、次のような設計にする。
//
// ```
// pub(crate) mod blocking {
Expand Down
2 changes: 1 addition & 1 deletion crates/voicevox_core/src/nonblocking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
pub use crate::{
engine::open_jtalk::nonblocking::OpenJtalk,
infer::runtimes::onnxruntime::nonblocking::Onnxruntime, synthesizer::nonblocking::Synthesizer,
user_dict::dict::nonblocking::UserDict, voice_model::nonblocking::VoiceModel,
user_dict::dict::nonblocking::UserDict, voice_model::nonblocking::VoiceModelFile,
};

pub mod onnxruntime {
Expand Down
4 changes: 2 additions & 2 deletions crates/voicevox_core/src/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ mod tests {
talk: enum_map!(_ => InferenceSessionOptions::new(0, DeviceSpec::Cpu)),
},
);
let model = &crate::nonblocking::VoiceModel::sample().await.unwrap();
let model = &crate::nonblocking::VoiceModelFile::sample().await.unwrap();
let model_contents = &model.read_inference_models().await.unwrap();
let result = status.insert_model(model.header(), model_contents);
assert_debug_fmt_eq!(Ok(()), result);
Expand All @@ -424,7 +424,7 @@ mod tests {
talk: enum_map!(_ => InferenceSessionOptions::new(0, DeviceSpec::Cpu)),
},
);
let vvm = &crate::nonblocking::VoiceModel::sample().await.unwrap();
let vvm = &crate::nonblocking::VoiceModelFile::sample().await.unwrap();
let model_header = vvm.header();
let model_contents = &vvm.read_inference_models().await.unwrap();
assert!(
Expand Down
Loading
Loading