Skip to content

Commit 16bb406

Browse files
committed
improving code comments
1 parent a71f682 commit 16bb406

File tree

7 files changed

+391
-22
lines changed

7 files changed

+391
-22
lines changed

whisper_ros/include/silero_vad/silero_vad_node.hpp

+54-9
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2121
// SOFTWARE.
2222

23-
#ifndef SILERO_VAD_SILERO_VAD_NODE_HPP
24-
#define SILERO_VAD_SILERO_VAD_NODE_HPP
23+
#ifndef SILERO_VAD__SILERO_VAD_NODE_HPP
24+
#define SILERO_VAD__SILERO_VAD_NODE_HPP
2525

2626
#include <atomic>
2727
#include <memory>
@@ -36,50 +36,95 @@
3636

3737
namespace silero_vad {
3838

39+
/// @class SileroVadNode
40+
/// @brief A ROS 2 lifecycle node for performing voice activity detection.
3941
class SileroVadNode : public rclcpp_lifecycle::LifecycleNode {
4042

4143
public:
44+
/// @brief Constructs a new SileroVadNode object.
4245
SileroVadNode();
4346

47+
/// @brief Callback for configuring the lifecycle node.
48+
/// @param state The current state of the node.
49+
/// @return Success or failure of configuration.
4450
rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn
45-
on_configure(const rclcpp_lifecycle::State &);
51+
on_configure(const rclcpp_lifecycle::State &state);
52+
53+
/// @brief Callback for activating the lifecycle node.
54+
/// @param state The current state of the node.
55+
/// @return Success or failure of activation.
4656
rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn
47-
on_activate(const rclcpp_lifecycle::State &);
57+
on_activate(const rclcpp_lifecycle::State &state);
58+
59+
/// @brief Callback for deactivating the lifecycle node.
60+
/// @param state The current state of the node.
61+
/// @return Success or failure of deactivation.
4862
rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn
49-
on_deactivate(const rclcpp_lifecycle::State &);
63+
on_deactivate(const rclcpp_lifecycle::State &state);
64+
65+
/// @brief Callback for cleaning up the lifecycle node.
66+
/// @param state The current state of the node.
67+
/// @return Success or failure of cleanup.
5068
rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn
51-
on_cleanup(const rclcpp_lifecycle::State &);
69+
on_cleanup(const rclcpp_lifecycle::State &state);
70+
71+
/// @brief Callback for shutting down the lifecycle node.
72+
/// @param state The current state of the node.
73+
/// @return Success or failure of shutdown.
5274
rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn
53-
on_shutdown(const rclcpp_lifecycle::State &);
75+
on_shutdown(const rclcpp_lifecycle::State &state);
5476

5577
protected:
78+
/// Indicates if VAD is enabled.
5679
std::atomic<bool> enabled;
80+
/// Indicates if VAD is in listening mode.
5781
std::atomic<bool> listening;
82+
/// Indicates if audio data should be published.
5883
std::atomic<bool> publish;
84+
/// Buffer for storing audio data.
5985
std::vector<float> data;
86+
/// Pointer to the VAD iterator.
6087
std::unique_ptr<VadIterator> vad_iterator;
6188

6289
private:
90+
/// Buffer for storing previous audio data.
6391
std::vector<float> prev_data;
92+
/// Path to the VAD model.
6493
std::string model_path_;
94+
/// Sampling rate of the audio data.
6595
int sample_rate_;
96+
/// Frame size in milliseconds.
6697
int frame_size_ms_;
98+
/// Threshold for VAD decision-making.
6799
float threshold_;
100+
/// Minimum silence duration in milliseconds.
68101
int min_silence_ms_;
102+
/// Padding duration for detected speech in milliseconds.
69103
int speech_pad_ms_;
70104

105+
/// Publisher for VAD output.
71106
rclcpp::Publisher<std_msgs::msg::Float32MultiArray>::SharedPtr publisher_;
107+
/// Subscription for audio input.
72108
rclcpp::Subscription<audio_common_msgs::msg::AudioStamped>::SharedPtr
73109
subscription_;
74-
110+
/// Service for enabling/disabling VAD.
75111
rclcpp::Service<std_srvs::srv::SetBool>::SharedPtr enable_srv_;
76112

113+
/// @brief Callback for handling incoming audio data.
114+
/// @param msg The audio message containing the audio data.
77115
void
78116
audio_callback(const audio_common_msgs::msg::AudioStamped::SharedPtr msg);
79117

118+
/// @brief Callback for enabling/disabling the VAD.
119+
/// @param request The service request containing the desired enable state.
120+
/// @param response The service response indicating success or failure.
80121
void enable_cb(const std::shared_ptr<std_srvs::srv::SetBool::Request> request,
81122
std::shared_ptr<std_srvs::srv::SetBool::Response> response);
82123

124+
/// @brief Converts audio data to a float vector normalized to [-1.0, 1.0].
125+
/// @tparam T The input audio data type.
126+
/// @param input The input audio data.
127+
/// @return A vector of normalized float audio data.
83128
template <typename T>
84129
std::vector<float> convert_to_float(const std::vector<T> &input) {
85130
static_assert(std::is_integral<T>::value,
@@ -123,4 +168,4 @@ class SileroVadNode : public rclcpp_lifecycle::LifecycleNode {
123168

124169
} // namespace silero_vad
125170

126-
#endif
171+
#endif

whisper_ros/include/silero_vad/timestamp.hpp

+21
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,38 @@
2727

2828
namespace silero_vad {
2929

30+
/// @class Timestamp
31+
/// @brief Represents a time interval with speech probability.
3032
class Timestamp {
3133
public:
34+
/// The start time of the interval, in milliseconds.
3235
int start;
36+
37+
/// The end time of the interval, in milliseconds.
3338
int end;
39+
40+
/// The probability of speech detected in the interval, ranging from 0 to 1.
3441
float speech_prob;
3542

43+
/// @brief Constructs a `Timestamp` object.
44+
/// @param start The start time of the interval (default: -1).
45+
/// @param end The end time of the interval (default: -1).
46+
/// @param speech_prob The speech probability (default: 0).
3647
Timestamp(int start = -1, int end = -1, float speech_prob = 0);
3748

49+
/// @brief Assigns the values of another `Timestamp` to this instance.
50+
/// @param other The `Timestamp` to copy from.
51+
/// @return A reference to this `Timestamp`.
3852
Timestamp &operator=(const Timestamp &other);
53+
54+
/// @brief Compares two `Timestamp` objects for equality.
55+
/// @param other The `Timestamp` to compare with.
56+
/// @return `true` if the start and end times are equal; `false` otherwise.
3957
bool operator==(const Timestamp &other) const;
4058

59+
/// @brief Converts the `Timestamp` object to a string representation.
60+
/// @return A string representing the `Timestamp` in the format
61+
/// `{start:...,end:...,prob:...}`.
4162
std::string to_string() const;
4263
};
4364

whisper_ros/include/silero_vad/vad_iterator.hpp

+68-4
Original file line numberDiff line numberDiff line change
@@ -33,58 +33,122 @@
3333

3434
namespace silero_vad {
3535

36+
/// @class VadIterator
37+
/// @brief Implements a Voice Activity Detection (VAD) iterator using an ONNX
38+
/// model.
39+
///
40+
/// This class provides methods to load a pre-trained ONNX VAD model, process
41+
/// audio data, and predict the presence of speech. It manages the model state
42+
/// and handles input/output tensors for inference.
3643
class VadIterator {
37-
3844
public:
45+
/// @brief Constructs a VadIterator object.
46+
///
47+
/// @param model_path Path to the ONNX model file.
48+
/// @param sample_rate The audio sample rate in Hz (default: 16000).
49+
/// @param frame_size_ms Size of the audio frame in milliseconds (default:
50+
/// 32).
51+
/// @param threshold The threshold for speech detection (default: 0.5).
52+
/// @param min_silence_ms Minimum silence duration in milliseconds to mark the
53+
/// end of speech (default: 100).
54+
/// @param speech_pad_ms Additional padding in milliseconds added to speech
55+
/// segments (default: 30).
3956
VadIterator(const std::string &model_path, int sample_rate = 16000,
4057
int frame_size_ms = 32, float threshold = 0.5f,
4158
int min_silence_ms = 100, int speech_pad_ms = 30);
4259

60+
/// @brief Resets the internal state of the model.
61+
///
62+
/// Clears the state, context, and resets internal flags related to speech
63+
/// detection.
4364
void reset_states();
65+
66+
/// @brief Processes audio data and predicts speech segments.
67+
///
68+
/// @param data A vector of audio samples (single-channel, float values).
69+
/// @return A Timestamp object containing start and end times of detected
70+
/// speech, or -1 for inactive values.
4471
Timestamp predict(const std::vector<float> &data);
4572

4673
private:
74+
/// ONNX Runtime environment.
4775
Ort::Env env;
76+
/// ONNX session options.
4877
Ort::SessionOptions session_options;
78+
/// ONNX session for running inference.
4979
std::shared_ptr<Ort::Session> session;
80+
/// Memory allocator for ONNX runtime.
5081
Ort::AllocatorWithDefaultOptions allocator;
82+
/// Memory info for tensor allocation.
5183
Ort::MemoryInfo memory_info =
5284
Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU);
5385

54-
// Model configuration
86+
/// Detection threshold for speech probability.
5587
float threshold;
88+
/// Audio sample rate in Hz.
5689
int sample_rate;
90+
/// Samples per millisecond.
5791
int sr_per_ms;
92+
/// Number of samples in a single frame.
5893
int64_t window_size_samples;
94+
/// Padding in samples added to speech segments.
5995
int speech_pad_samples;
96+
/// Minimum silence duration in samples to mark the end of speech.
6097
unsigned int min_silence_samples;
98+
/// Size of the context buffer.
6199
int context_size;
62100

63-
// Model state
101+
/// Indicates whether speech has been detected.
64102
bool triggered = false;
103+
/// Temporary end position during silence detection.
65104
unsigned int temp_end = 0;
105+
/// Current sample position in the input stream.
66106
unsigned int current_sample = 0;
107+
/// End sample of the last speech segment.
67108
int prev_end = 0;
109+
/// Start sample of the next speech segment.
68110
int next_start = 0;
69111

112+
/// ONNX model input tensors.
70113
std::vector<Ort::Value> ort_inputs;
114+
115+
/// Names of the input nodes in the ONNX model.
71116
std::vector<const char *> input_node_names = {"input", "state", "sr"};
72117

118+
/// Input buffer for audio data and context.
73119
std::vector<float> input;
120+
121+
/// Context buffer storing past audio samples.
74122
std::vector<float> context;
123+
124+
/// Internal state of the model.
75125
std::vector<float> state;
126+
127+
/// Sample rate tensor.
76128
std::vector<int64_t> sr;
77129

130+
/// Dimensions for the input tensor.
78131
int64_t input_node_dims[2] = {};
132+
133+
/// Dimensions for the state tensor.
79134
const int64_t state_node_dims[3] = {2, 1, 128};
135+
136+
/// Dimensions for the sample rate tensor.
80137
const int64_t sr_node_dims[1] = {1};
81138

139+
/// ONNX model output tensors.
82140
std::vector<Ort::Value> ort_outputs;
141+
142+
/// Names of the output nodes in the ONNX model.
83143
std::vector<const char *> output_node_names = {"output", "stateN"};
84144

145+
/// @brief Initializes the ONNX model session.
146+
///
147+
/// @param model_path Path to the ONNX model file.
148+
/// @throws std::runtime_error If the ONNX session initialization fails.
85149
void init_onnx_model(const std::string &model_path);
86150
};
87151

88152
} // namespace silero_vad
89153

90-
#endif
154+
#endif

whisper_ros/include/whisper_ros/whisper.hpp

+48-1
Original file line numberDiff line numberDiff line change
@@ -29,40 +29,87 @@
2929
#include "grammar-parser.h"
3030
#include "whisper.h"
3131

32+
/// Represents the result of a transcription operation.
3233
struct TranscriptionOutput {
34+
/// The transcribed text.
3335
std::string text;
36+
37+
/// The confidence probability of the transcription.
3438
float prob;
3539
};
3640

3741
namespace whisper_ros {
3842

43+
/// Class for performing speech-to-text transcription using the Whisper model.
3944
class Whisper {
4045

4146
public:
47+
/// Constructs a Whisper object with the specified model and parameters.
48+
/// @param model The path to the Whisper model file.
49+
/// @param openvino_encode_device The OpenVINO device used for encoding.
50+
/// @param n_processors Number of processors to use for parallel processing.
51+
/// @param cparams Whisper context parameters.
52+
/// @param wparams Whisper full parameters.
4253
Whisper(const std::string &model, const std::string &openvino_encode_device,
4354
int n_processors, const struct whisper_context_params &cparams,
4455
const struct whisper_full_params &wparams);
56+
57+
/// Destructor to clean up resources used by the Whisper object.
4558
~Whisper();
4659

60+
/// Transcribes the given audio data.
61+
/// @param pcmf32 A vector of audio samples in 32-bit float format.
62+
/// @return A TranscriptionOutput structure containing the transcription text
63+
/// and confidence probability.
4764
struct TranscriptionOutput transcribe(const std::vector<float> &pcmf32);
65+
66+
/// Trims leading and trailing whitespace from the input string.
67+
/// @param s The string to trim.
68+
/// @return The trimmed string.
4869
std::string trim(const std::string &s);
70+
71+
/// Converts a timestamp to a string format.
72+
/// @param t The timestamp in 10 ms units.
73+
/// @param comma If true, use a comma as the decimal separator; otherwise, use
74+
/// a period.
75+
/// @return The formatted timestamp as a string.
4976
std::string timestamp_to_str(int64_t t, bool comma = false);
5077

78+
/// Sets a grammar for transcription with a starting rule and penalty.
79+
/// @param grammar The grammar rules as a string.
80+
/// @param start_rule The starting rule for the grammar.
81+
/// @param grammar_penalty A penalty factor for grammar violations.
82+
/// @return True if the grammar is set successfully; false otherwise.
5183
bool set_grammar(const std::string grammar, const std::string start_rule,
5284
float grammar_penalty);
85+
86+
/// Resets the grammar to its default state.
5387
void reset_grammar();
88+
89+
/// Sets an initial prompt for transcription.
90+
/// @param prompt The initial prompt text.
5491
void set_init_prompt(const std::string prompt);
92+
93+
/// Resets the initial prompt to its default state.
5594
void reset_init_prompt();
5695

5796
protected:
97+
/// Number of processors used for parallel processing.
5898
int n_processors;
99+
100+
/// Parameters used for full transcription tasks.
59101
struct whisper_full_params wparams;
60102

103+
/// The Whisper context.
61104
struct whisper_context *ctx;
105+
106+
/// Parsed grammar state.
62107
grammar_parser::parse_state grammar_parsed;
108+
109+
/// Grammar rules derived from the parsed grammar.
63110
std::vector<const whisper_grammar_element *> grammar_rules;
64111
};
65112

66113
} // namespace whisper_ros
67114

68-
#endif
115+
#endif

0 commit comments

Comments
 (0)