Skip to content

Commit 870c655

Browse files
committed
removing unused params form vad_iterator
1 parent dd32b5e commit 870c655

File tree

6 files changed

+50
-63
lines changed

6 files changed

+50
-63
lines changed

whisper_bringup/launch/silero-vad.launch.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,9 @@ def run_silero_vad(context: LaunchContext, repo, file, model_path):
5454
"frame_size_ms": LaunchConfiguration("frame_size_ms", default=32),
5555
"threshold": LaunchConfiguration("threshold", default=0.5),
5656
"min_silence_ms": LaunchConfiguration(
57-
"min_silence_ms", default=0
58-
),
59-
"speech_pad_ms": LaunchConfiguration("speech_pad_ms", default=32),
60-
"min_speech_ms": LaunchConfiguration("min_speech_ms", default=32),
61-
"max_speech_s": LaunchConfiguration(
62-
"max_speech_s", default=float("inf")
57+
"min_silence_ms", default=100
6358
),
59+
"speech_pad_ms": LaunchConfiguration("speech_pad_ms", default=30),
6460
}
6561
],
6662
remappings=[("audio", "/audio/in")],

whisper_ros/include/silero_vad/silero_vad_node.hpp

-2
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,6 @@ class SileroVadNode : public rclcpp_lifecycle::LifecycleNode {
6565
float threshold_;
6666
int min_silence_ms_;
6767
int speech_pad_ms_;
68-
int min_speech_ms_;
69-
float max_speech_s_;
7068

7169
rclcpp::Publisher<std_msgs::msg::Float32MultiArray>::SharedPtr publisher_;
7270
rclcpp::Subscription<audio_common_msgs::msg::AudioStamped>::SharedPtr

whisper_ros/include/silero_vad/vad_iterator.hpp

+3-7
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,7 @@ class VadIterator {
3838
public:
3939
VadIterator(const std::string &model_path, int sample_rate = 16000,
4040
int frame_size_ms = 32, float threshold = 0.5f,
41-
int min_silence_ms = 0, int speech_pad_ms = 32,
42-
int min_speech_ms = 32,
43-
float max_speech_s = std::numeric_limits<float>::infinity());
41+
int min_silence_ms = 100, int speech_pad_ms = 30);
4442

4543
void reset_states();
4644
Timestamp predict(const std::vector<float> &data);
@@ -58,11 +56,9 @@ class VadIterator {
5856
int sample_rate;
5957
int sr_per_ms;
6058
int64_t window_size_samples;
61-
int min_speech_samples;
6259
int speech_pad_samples;
63-
float max_speech_samples;
6460
unsigned int min_silence_samples;
65-
unsigned int min_silence_samples_at_max_speech;
61+
int context_size;
6662

6763
// Model state
6864
bool triggered = false;
@@ -75,9 +71,9 @@ class VadIterator {
7571
std::vector<const char *> input_node_names = {"input", "state", "sr"};
7672

7773
std::vector<float> input;
74+
std::vector<float> context;
7875
std::vector<float> state;
7976
std::vector<int64_t> sr;
80-
std::vector<float> context;
8177

8278
int64_t input_node_dims[2] = {};
8379
const int64_t state_node_dims[3] = {2, 1, 128};

whisper_ros/src/silero_vad/silero_vad_node.cpp

+2-12
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,6 @@ SileroVadNode::SileroVadNode()
4444
this->declare_parameter<float>("threshold", 0.5f);
4545
this->declare_parameter<int>("min_silence_ms", 100);
4646
this->declare_parameter<int>("speech_pad_ms", 30);
47-
this->declare_parameter<int>("min_speech_ms", 32);
48-
this->declare_parameter<float>("max_speech_s",
49-
std::numeric_limits<float>::infinity());
5047
}
5148

5249
rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn
@@ -62,8 +59,6 @@ SileroVadNode::on_configure(const rclcpp_lifecycle::State &) {
6259
this->get_parameter("threshold", this->threshold_);
6360
this->get_parameter("min_silence_ms", this->min_silence_ms_);
6461
this->get_parameter("speech_pad_ms", this->speech_pad_ms_);
65-
this->get_parameter("min_speech_ms", this->min_speech_ms_);
66-
this->get_parameter("max_speech_s", this->max_speech_s_);
6762

6863
RCLCPP_INFO(get_logger(), "[%s] Configured", this->get_name());
6964

@@ -79,8 +74,7 @@ SileroVadNode::on_activate(const rclcpp_lifecycle::State &) {
7974
// create silero-vad
8075
this->vad_iterator = std::make_unique<VadIterator>(
8176
this->model_path_, this->sample_rate_, this->frame_size_ms_,
82-
this->threshold_, this->min_silence_ms_, this->speech_pad_ms_,
83-
this->min_speech_ms_, this->max_speech_s_);
77+
this->threshold_, this->min_silence_ms_, this->speech_pad_ms_);
8478

8579
this->publisher_ =
8680
this->create_publisher<std_msgs::msg::Float32MultiArray>("vad", 10);
@@ -185,8 +179,6 @@ void SileroVadNode::audio_callback(
185179

186180
// Predict if speech starts or ends
187181
auto timestamp = this->vad_iterator->predict(data);
188-
// RCLCPP_INFO(this->get_logger(), "Timestampt: %s",
189-
// timestamp.to_string().c_str());
190182

191183
// Check if speech starts
192184
if (timestamp.start != -1 && timestamp.end == -1 && !this->listening) {
@@ -209,9 +201,7 @@ void SileroVadNode::audio_callback(
209201
if (this->data.size() / msg->audio.info.rate < 1.0) {
210202
int pad_size =
211203
msg->audio.info.chunk + msg->audio.info.rate - this->data.size();
212-
for (int i = 0; i < pad_size; i++) {
213-
this->data.push_back(0.0);
214-
}
204+
this->data.insert(this->data.end(), pad_size, 0.0f);
215205
}
216206

217207
this->listening.store(false);

whisper_ros/src/silero_vad/timestamp.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// MIT License
22

3-
// Copyright (c) 2023 Miguel Ángel González Santamarta
3+
// Copyright (c) 2024 Miguel Ángel González Santamarta
44

55
// Permission is hereby granted, free of charge, to any person obtaining a copy
66
// of this software and associated documentation files (the "Software"), to deal

whisper_ros/src/silero_vad/vad_iterator.cpp

+42-35
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// MIT License
22

3-
// Copyright (c) 2023 Miguel Ángel González Santamarta
3+
// Copyright (c) 2024 Miguel Ángel González Santamarta
44

55
// Permission is hereby granted, free of charge, to any person obtaining a copy
66
// of this software and associated documentation files (the "Software"), to deal
@@ -20,6 +20,7 @@
2020
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2121
// SOFTWARE.
2222

23+
#include <algorithm>
2324
#include <limits>
2425
#include <memory>
2526
#include <string>
@@ -31,32 +32,39 @@ using namespace silero_vad;
3132

3233
VadIterator::VadIterator(const std::string &model_path, int sample_rate,
3334
int frame_size_ms, float threshold, int min_silence_ms,
34-
int speech_pad_ms, int min_speech_ms,
35-
float max_speech_s)
35+
int speech_pad_ms)
3636
: env(ORT_LOGGING_LEVEL_WARNING, "VadIterator"), threshold(threshold),
3737
sample_rate(sample_rate), sr_per_ms(sample_rate / 1000),
3838
window_size_samples(frame_size_ms * sr_per_ms),
39-
min_speech_samples(sr_per_ms * min_speech_ms),
4039
speech_pad_samples(sr_per_ms * speech_pad_ms),
41-
max_speech_samples(sample_rate * max_speech_s - window_size_samples -
42-
2 * speech_pad_samples),
4340
min_silence_samples(sr_per_ms * min_silence_ms),
44-
min_silence_samples_at_max_speech(sr_per_ms * 98),
45-
state(2 * 1 * 128, 0.0f), sr(1, sample_rate), context(64, 0.0f) {
41+
context_size(sample_rate == 16000 ? 64 : 32), context(context_size, 0.0f),
42+
state(2 * 1 * 128, 0.0f), sr(1, sample_rate) {
4643

47-
// this->input.resize(window_size_samples);
4844
this->input_node_dims[0] = 1;
4945
this->input_node_dims[1] = window_size_samples;
50-
this->init_onnx_model(model_path);
46+
47+
try {
48+
this->init_onnx_model(model_path);
49+
} catch (const std::exception &e) {
50+
throw std::runtime_error("Failed to initialize ONNX model: " +
51+
std::string(e.what()));
52+
}
5153
}
5254

5355
void VadIterator::init_onnx_model(const std::string &model_path) {
5456
this->session_options.SetIntraOpNumThreads(1);
5557
this->session_options.SetInterOpNumThreads(1);
5658
this->session_options.SetGraphOptimizationLevel(
5759
GraphOptimizationLevel::ORT_ENABLE_ALL);
58-
this->session = std::make_shared<Ort::Session>(this->env, model_path.c_str(),
59-
this->session_options);
60+
61+
try {
62+
this->session = std::make_shared<Ort::Session>(
63+
this->env, model_path.c_str(), this->session_options);
64+
} catch (const std::exception &e) {
65+
throw std::runtime_error("Failed to create ONNX session: " +
66+
std::string(e.what()));
67+
}
6068
}
6169

6270
void VadIterator::reset_states() {
@@ -68,16 +76,13 @@ void VadIterator::reset_states() {
6876
}
6977

7078
Timestamp VadIterator::predict(const std::vector<float> &data) {
71-
// Create input tensors
79+
// Pre-fill input with context
7280
this->input.clear();
73-
for (auto ele : this->context) {
74-
this->input.push_back(ele);
75-
}
76-
77-
for (auto ele : data) {
78-
this->input.push_back(ele);
79-
}
81+
this->input.reserve(context.size() + data.size());
82+
this->input.insert(input.end(), context.begin(), context.end());
83+
this->input.insert(input.end(), data.begin(), data.end());
8084

85+
// Create input tensors
8186
Ort::Value input_tensor = Ort::Value::CreateTensor<float>(
8287
this->memory_info, this->input.data(), this->input.size(),
8388
this->input_node_dims, 2);
@@ -95,20 +100,23 @@ Timestamp VadIterator::predict(const std::vector<float> &data) {
95100
this->ort_inputs.emplace_back(std::move(sr_tensor));
96101

97102
// Run inference
98-
this->ort_outputs = this->session->Run(
99-
Ort::RunOptions{nullptr}, this->input_node_names.data(),
100-
this->ort_inputs.data(), this->ort_inputs.size(),
101-
this->output_node_names.data(), this->output_node_names.size());
103+
try {
104+
this->ort_outputs = session->Run(
105+
Ort::RunOptions{nullptr}, this->input_node_names.data(),
106+
this->ort_inputs.data(), this->ort_inputs.size(),
107+
this->output_node_names.data(), this->output_node_names.size());
108+
} catch (const std::exception &e) {
109+
throw std::runtime_error("ONNX inference failed: " + std::string(e.what()));
110+
}
102111

103112
// Process output
104113
float speech_prob = this->ort_outputs[0].GetTensorMutableData<float>()[0];
105114
float *updated_state = this->ort_outputs[1].GetTensorMutableData<float>();
106115
std::copy(updated_state, updated_state + this->state.size(),
107116
this->state.begin());
108117

109-
for (int i = 64; i > 0; i--) {
110-
this->context.push_back(data.at(data.size() - i));
111-
}
118+
// Update context with the last 64 samples of data
119+
this->context.assign(data.end() - context_size, data.end());
112120

113121
// Handle result
114122
this->current_sample += this->window_size_samples;
@@ -119,10 +127,10 @@ Timestamp VadIterator::predict(const std::vector<float> &data) {
119127
}
120128

121129
if (!this->triggered) {
130+
int start_timestwamp = this->current_sample - this->speech_pad_samples -
131+
this->window_size_samples;
122132
this->triggered = true;
123-
return Timestamp(this->current_sample - this->speech_pad_samples -
124-
this->window_size_samples,
125-
-1, speech_prob);
133+
return Timestamp(start_timestwamp, -1, speech_prob);
126134
}
127135

128136
} else if (speech_prob < this->threshold - 0.15 && this->triggered) {
@@ -131,12 +139,11 @@ Timestamp VadIterator::predict(const std::vector<float> &data) {
131139
}
132140

133141
if (this->current_sample - this->temp_end >= this->min_silence_samples) {
134-
this->temp_end = 0;
142+
int end_timestamp =
143+
this->temp_end + this->speech_pad_samples - this->window_size_samples;
135144
this->triggered = false;
136-
return Timestamp(-1,
137-
this->temp_end + this->speech_pad_samples -
138-
this->window_size_samples,
139-
speech_prob);
145+
this->temp_end = 0;
146+
return Timestamp(-1, end_timestamp, speech_prob);
140147
}
141148
}
142149

0 commit comments

Comments
 (0)