1
1
// MIT License
2
2
3
- // Copyright (c) 2023 Miguel Ángel González Santamarta
3
+ // Copyright (c) 2024 Miguel Ángel González Santamarta
4
4
5
5
// Permission is hereby granted, free of charge, to any person obtaining a copy
6
6
// of this software and associated documentation files (the "Software"), to deal
20
20
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
21
// SOFTWARE.
22
22
23
+ #include < algorithm>
23
24
#include < limits>
24
25
#include < memory>
25
26
#include < string>
@@ -31,32 +32,39 @@ using namespace silero_vad;
31
32
32
33
VadIterator::VadIterator (const std::string &model_path, int sample_rate,
33
34
int frame_size_ms, float threshold, int min_silence_ms,
34
- int speech_pad_ms, int min_speech_ms,
35
- float max_speech_s)
35
+ int speech_pad_ms)
36
36
: env(ORT_LOGGING_LEVEL_WARNING, " VadIterator" ), threshold(threshold),
37
37
sample_rate(sample_rate), sr_per_ms(sample_rate / 1000 ),
38
38
window_size_samples(frame_size_ms * sr_per_ms),
39
- min_speech_samples(sr_per_ms * min_speech_ms),
40
39
speech_pad_samples(sr_per_ms * speech_pad_ms),
41
- max_speech_samples(sample_rate * max_speech_s - window_size_samples -
42
- 2 * speech_pad_samples),
43
40
min_silence_samples(sr_per_ms * min_silence_ms),
44
- min_silence_samples_at_max_speech(sr_per_ms * 98 ),
45
- state(2 * 1 * 128 , 0 .0f ), sr(1 , sample_rate), context( 64 , 0 . 0f ) {
41
+ context_size(sample_rate == 16000 ? 64 : 32 ), context(context_size, 0 . 0f ),
42
+ state(2 * 1 * 128 , 0 .0f ), sr(1 , sample_rate) {
46
43
47
- // this->input.resize(window_size_samples);
48
44
this ->input_node_dims [0 ] = 1 ;
49
45
this ->input_node_dims [1 ] = window_size_samples;
50
- this ->init_onnx_model (model_path);
46
+
47
+ try {
48
+ this ->init_onnx_model (model_path);
49
+ } catch (const std::exception &e) {
50
+ throw std::runtime_error (" Failed to initialize ONNX model: " +
51
+ std::string (e.what ()));
52
+ }
51
53
}
52
54
53
55
void VadIterator::init_onnx_model (const std::string &model_path) {
54
56
this ->session_options .SetIntraOpNumThreads (1 );
55
57
this ->session_options .SetInterOpNumThreads (1 );
56
58
this ->session_options .SetGraphOptimizationLevel (
57
59
GraphOptimizationLevel::ORT_ENABLE_ALL);
58
- this ->session = std::make_shared<Ort::Session>(this ->env , model_path.c_str (),
59
- this ->session_options );
60
+
61
+ try {
62
+ this ->session = std::make_shared<Ort::Session>(
63
+ this ->env , model_path.c_str (), this ->session_options );
64
+ } catch (const std::exception &e) {
65
+ throw std::runtime_error (" Failed to create ONNX session: " +
66
+ std::string (e.what ()));
67
+ }
60
68
}
61
69
62
70
void VadIterator::reset_states () {
@@ -68,16 +76,13 @@ void VadIterator::reset_states() {
68
76
}
69
77
70
78
Timestamp VadIterator::predict (const std::vector<float > &data) {
71
- // Create input tensors
79
+ // Pre-fill input with context
72
80
this ->input .clear ();
73
- for (auto ele : this ->context ) {
74
- this ->input .push_back (ele);
75
- }
76
-
77
- for (auto ele : data) {
78
- this ->input .push_back (ele);
79
- }
81
+ this ->input .reserve (context.size () + data.size ());
82
+ this ->input .insert (input.end (), context.begin (), context.end ());
83
+ this ->input .insert (input.end (), data.begin (), data.end ());
80
84
85
+ // Create input tensors
81
86
Ort::Value input_tensor = Ort::Value::CreateTensor<float >(
82
87
this ->memory_info , this ->input .data (), this ->input .size (),
83
88
this ->input_node_dims , 2 );
@@ -95,20 +100,23 @@ Timestamp VadIterator::predict(const std::vector<float> &data) {
95
100
this ->ort_inputs .emplace_back (std::move (sr_tensor));
96
101
97
102
// Run inference
98
- this ->ort_outputs = this ->session ->Run (
99
- Ort::RunOptions{nullptr }, this ->input_node_names .data (),
100
- this ->ort_inputs .data (), this ->ort_inputs .size (),
101
- this ->output_node_names .data (), this ->output_node_names .size ());
103
+ try {
104
+ this ->ort_outputs = session->Run (
105
+ Ort::RunOptions{nullptr }, this ->input_node_names .data (),
106
+ this ->ort_inputs .data (), this ->ort_inputs .size (),
107
+ this ->output_node_names .data (), this ->output_node_names .size ());
108
+ } catch (const std::exception &e) {
109
+ throw std::runtime_error (" ONNX inference failed: " + std::string (e.what ()));
110
+ }
102
111
103
112
// Process output
104
113
float speech_prob = this ->ort_outputs [0 ].GetTensorMutableData <float >()[0 ];
105
114
float *updated_state = this ->ort_outputs [1 ].GetTensorMutableData <float >();
106
115
std::copy (updated_state, updated_state + this ->state .size (),
107
116
this ->state .begin ());
108
117
109
- for (int i = 64 ; i > 0 ; i--) {
110
- this ->context .push_back (data.at (data.size () - i));
111
- }
118
+ // Update context with the last 64 samples of data
119
+ this ->context .assign (data.end () - context_size, data.end ());
112
120
113
121
// Handle result
114
122
this ->current_sample += this ->window_size_samples ;
@@ -119,10 +127,10 @@ Timestamp VadIterator::predict(const std::vector<float> &data) {
119
127
}
120
128
121
129
if (!this ->triggered ) {
130
+ int start_timestwamp = this ->current_sample - this ->speech_pad_samples -
131
+ this ->window_size_samples ;
122
132
this ->triggered = true ;
123
- return Timestamp (this ->current_sample - this ->speech_pad_samples -
124
- this ->window_size_samples ,
125
- -1 , speech_prob);
133
+ return Timestamp (start_timestwamp, -1 , speech_prob);
126
134
}
127
135
128
136
} else if (speech_prob < this ->threshold - 0.15 && this ->triggered ) {
@@ -131,12 +139,11 @@ Timestamp VadIterator::predict(const std::vector<float> &data) {
131
139
}
132
140
133
141
if (this ->current_sample - this ->temp_end >= this ->min_silence_samples ) {
134
- this ->temp_end = 0 ;
142
+ int end_timestamp =
143
+ this ->temp_end + this ->speech_pad_samples - this ->window_size_samples ;
135
144
this ->triggered = false ;
136
- return Timestamp (-1 ,
137
- this ->temp_end + this ->speech_pad_samples -
138
- this ->window_size_samples ,
139
- speech_prob);
145
+ this ->temp_end = 0 ;
146
+ return Timestamp (-1 , end_timestamp, speech_prob);
140
147
}
141
148
}
142
149
0 commit comments