Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v2.1 Rust #367

Merged
merged 1 commit into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions binding/rust/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "pv_cheetah"
version = "2.0.3"
version = "2.1.0"
edition = "2018"
description = "The Rust bindings for Picovoice's Cheetah library"
license = "Apache-2.0"
Expand All @@ -27,10 +27,11 @@ crate_type = ["lib"]

[dependencies]
libc = "0.2"
libloading = "0.7"
libloading = "0.8"

[dev-dependencies]
distance = "0.4.0"
itertools = "0.10"
rodio = "0.15"
itertools = "0.11"
rodio = "0.17"
serde_json = "1.0"
serde = { version = "1.0", features = ["derive"] }
17 changes: 17 additions & 0 deletions binding/rust/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,23 @@ Replace `${ACCESS_KEY}` with yours obtained from [Picovoice Console](https://con

The model file contains the parameters for the Cheetah engine. You may create bespoke language models using [Picovoice Console](https://console.picovoice.ai/) and then pass in the relevant file.

### Language Model

The Cheetah Rust SDK comes preloaded with a default English language model (`.pv` file).
Default models for other supported languages can be found in [lib/common](../../lib/common).

Create custom language models using the [Picovoice Console](https://console.picovoice.ai/). Here you can train
language models with custom vocabulary and boost words in the existing vocabulary.

Pass in the `.pv` file via the `.model_path()` Builder argument:
```rust
let leopard: Cheetah = CheetahBuilder::new()
.access_key("${ACCESS_KEY}")
.model_path("${MODEL_FILE_PATH}")
.init()
.expect("Unable to create Cheetah");
```

## Demos

The [Cheetah Rust demo project](https://github.com/Picovoice/cheetah/tree/master/demo/rust) is a Rust console app that allows for processing real-time audio (i.e. microphone) and files using Cheetah.
105 changes: 70 additions & 35 deletions binding/rust/tests/cheetah_tests.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright 2022-2023 Picovoice Inc.
Copyright 2022-2024 Picovoice Inc.

You may not use this file except in compliance with the license. A copy of the license is located in the "LICENSE"
file accompanying this source.
Expand All @@ -14,22 +14,59 @@ mod tests {
use distance::*;
use itertools::Itertools;
use rodio::{source::Source, Decoder};
use serde_json::{json, Value};
use serde::Deserialize;
use std::env;
use std::fs::File;
use std::fs::{read_to_string, File};
use std::io::BufReader;

use cheetah::CheetahBuilder;

fn load_test_data() -> Value {
let test_json: Value = json!([{
"language": "en",
"transcript": "Mr quilter is the apostle of the middle classes and we are glad to welcome his gospel",
"transcript_with_punctuation": "Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
"error_rate": 0.025,
"audio_file": "test.wav"
}]);
test_json
#[derive(Debug, Deserialize)]
struct LanguageTestJson {
language: String,
audio_file: String,
transcript: String,
punctuations: Vec<String>,
error_rate: f32,
}

#[derive(Debug, Deserialize)]
struct TestsJson {
language_tests: Vec<LanguageTestJson>,
}

#[derive(Debug, Deserialize)]
struct RootJson {
tests: TestsJson,
}

fn load_test_data() -> TestsJson {
let test_json_path = format!(
"{}{}",
env!("CARGO_MANIFEST_DIR"),
"/../../resources/.test/test_data.json"
);
let contents: String =
read_to_string(test_json_path).expect("Unable to read test_data.json");
let root: RootJson = serde_json::from_str(&contents).expect("Failed to parse JSON");
root.tests
}

fn append_lang(path: &str, language: &str) -> String {
if language == "en" {
String::from(path)
} else {
format!("{}_{}", path, language)
}
}

fn model_path_by_language(language: &str) -> String {
format!(
"{}{}{}",
env!("CARGO_MANIFEST_DIR"),
append_lang("/../../lib/common/cheetah_params", language),
".pv"
)
}

fn character_error_rate(transcript: &str, expected_transcript: &str) -> f32 {
Expand All @@ -38,7 +75,7 @@ mod tests {
}

fn run_test_process(
_: &str,
language: &str,
transcript: &str,
test_punctuation: bool,
error_rate: f32,
Expand All @@ -47,6 +84,8 @@ mod tests {
let access_key = env::var("PV_ACCESS_KEY")
.expect("Pass the AccessKey in using the PV_ACCESS_KEY env variable");

let model_path = model_path_by_language(language);

let audio_path = format!(
"{}{}{}",
env!("CARGO_MANIFEST_DIR"),
Expand All @@ -59,6 +98,7 @@ mod tests {

let cheetah = CheetahBuilder::new()
.access_key(access_key)
.model_path(model_path)
.enable_automatic_punctuation(test_punctuation)
.init()
.expect("Unable to create Cheetah");
Expand All @@ -82,42 +122,37 @@ mod tests {

#[test]
fn test_process() {
let test_json: Value = load_test_data();

for t in test_json.as_array().unwrap() {
let language = t["language"].as_str().unwrap();
let transcript = t["transcript"].as_str().unwrap();
let error_rate = t["error_rate"].as_f64().unwrap() as f32;
let test_json: TestsJson = load_test_data();

let test_audio = t["audio_file"].as_str().unwrap();
for t in test_json.language_tests {
let mut transcript = t.transcript;
for p in t.punctuations {
transcript = transcript.replace(&p, "")
}

run_test_process(
language,
transcript,
&t.language,
&transcript,
false,
error_rate,
&test_audio,
t.error_rate,
&t.audio_file,
);
}
}

#[test]
fn test_process_punctuation() {
let test_json: Value = load_test_data();

for t in test_json.as_array().unwrap() {
let language = t["language"].as_str().unwrap();
let transcript_with_punctuation = t["transcript_with_punctuation"].as_str().unwrap();
let error_rate = t["error_rate"].as_f64().unwrap() as f32;
let test_json: TestsJson = load_test_data();

let test_audio = t["audio_file"].as_str().unwrap();
for t in test_json.language_tests {
let transcript = t.transcript;

run_test_process(
language,
transcript_with_punctuation,
&t.language,
&transcript,
true,
error_rate,
&test_audio,
t.error_rate,
&t.audio_file,
);
}
}
Expand Down
Loading