Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the adaptive delay #17

Merged
merged 2 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion buildspec.json
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
}
},
"name": "obs-cleanstream",
"version": "0.0.6",
"version": "0.0.7",
"author": "Roy Shilkrot",
"website": "https://github.com/occ-ai/obs-cleanstream/",
"email": "roy.shil@gmail.com",
Expand Down
9 changes: 3 additions & 6 deletions src/cleanstream-filter-data.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,10 @@ struct cleanstream_data {
uint32_t sample_rate; // input sample rate
// How many input frames (in input sample rate) are needed for the next whisper frame
size_t frames;
// How many ms/frames are needed to overlap with the next whisper frame
size_t overlap_frames;
size_t overlap_ms;
// How many frames were processed in the last whisper frame (this is dynamic)
size_t last_num_frames;
int current_result;
uint64_t current_result_end_timestamp;
uint64_t current_result_start_timestamp;
uint32_t delay_ms;

/* Silero VAD */
std::unique_ptr<VadIterator> vad;
Expand Down Expand Up @@ -76,7 +74,6 @@ struct cleanstream_data {
size_t audioFilePointer = 0;

float filler_p_threshold;
bool do_silence;
bool vad_enabled;
int log_level;
const char *detect_regex;
Expand Down
113 changes: 69 additions & 44 deletions src/cleanstream-filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
#define BUFFER_SIZE_MSEC 1010
// at 16Khz, 1010 msec is 16160 frames
#define WHISPER_FRAME_SIZE 16160
// overlap in msec
#define OVERLAP_SIZE_MSEC 340
// initial delay in msec
#define INITIAL_DELAY_MSEC 500

#define VAD_THOLD 0.0001f
#define FREQ_THOLD 100.0f
Expand All @@ -58,44 +58,56 @@ struct obs_audio_data *cleanstream_filter_audio(void *data, struct obs_audio_dat
return audio;
}

std::lock_guard<std::mutex> lock(gf->whisper_buf_mutex); // scoped lock
size_t input_buffer_size = 0;
{
std::lock_guard<std::mutex> lock(gf->whisper_buf_mutex); // scoped lock

if (audio != nullptr && audio->frames > 0) {
// push back current audio data to input circlebuf
for (size_t c = 0; c < gf->channels; c++) {
circlebuf_push_back(&gf->input_buffers[c], audio->data[c],
audio->frames * sizeof(float));
if (audio != nullptr && audio->frames > 0) {
// push back current audio data to input circlebuf
for (size_t c = 0; c < gf->channels; c++) {
circlebuf_push_back(&gf->input_buffers[c], audio->data[c],
audio->frames * sizeof(float));
}
// push audio packet info (timestamp/frame count) to info circlebuf
struct cleanstream_audio_info info = {0};
info.frames = audio->frames; // number of frames in this packet
info.timestamp = audio->timestamp; // timestamp of this packet
circlebuf_push_back(&gf->info_buffer, &info, sizeof(info));
}
// push audio packet info (timestamp/frame count) to info circlebuf
struct cleanstream_audio_info info = {0};
info.frames = audio->frames; // number of frames in this packet
info.timestamp = audio->timestamp; // timestamp of this packet
circlebuf_push_back(&gf->info_buffer, &info, sizeof(info));
input_buffer_size = gf->input_buffers[0].size;
}

// check the size of the input buffer - if it's more than 1500ms worth of audio, start playback
if (gf->input_buffers[0].size > 1500 * gf->sample_rate * sizeof(float) / 1000) {
// check the size of the input buffer - if it's more than <delay>ms worth of audio, start playback
if (input_buffer_size > gf->delay_ms * gf->sample_rate * sizeof(float) / 1000) {
// find needed number of frames from the incoming audio
size_t num_frames_needed = audio->frames;

std::vector<float> temporary_buffers[MAX_AUDIO_CHANNELS];
uint64_t timestamp = 0;

while (temporary_buffers[0].size() < num_frames_needed) {
struct cleanstream_audio_info info_out = {0};
{
std::lock_guard<std::mutex> lock(gf->whisper_buf_mutex);
// pop from input buffers to get audio packet info
circlebuf_pop_front(&gf->info_buffer, &info_out, sizeof(info_out));

// pop from input circlebuf to audio data
for (size_t i = 0; i < gf->channels; i++) {
// increase the size of the temporary buffer to hold the incoming audio in addition
// to the existing audio on the temporary buffer
temporary_buffers[i].resize(temporary_buffers[i].size() +
info_out.frames);
circlebuf_pop_front(&gf->input_buffers[i],
temporary_buffers[i].data() +
temporary_buffers[i].size() -
info_out.frames,
info_out.frames * sizeof(float));
while (temporary_buffers[0].size() < num_frames_needed) {
struct cleanstream_audio_info info_out = {0};
// pop from input buffers to get audio packet info
circlebuf_pop_front(&gf->info_buffer, &info_out, sizeof(info_out));
if (timestamp == 0) {
timestamp = info_out.timestamp;
}

// pop from input circlebuf to audio data
for (size_t i = 0; i < gf->channels; i++) {
// increase the size of the temporary buffer to hold the incoming audio in addition
// to the existing audio on the temporary buffer
temporary_buffers[i].resize(temporary_buffers[i].size() +
info_out.frames);
circlebuf_pop_front(&gf->input_buffers[i],
temporary_buffers[i].data() +
temporary_buffers[i].size() -
info_out.frames,
info_out.frames * sizeof(float));
}
}
}
const size_t num_frames = temporary_buffers[0].size();
Expand All @@ -105,7 +117,18 @@ struct obs_audio_data *cleanstream_filter_audio(void *data, struct obs_audio_dat
da_resize(gf->output_data, frames_size_bytes * gf->channels);
memset(gf->output_data.array, 0, frames_size_bytes * gf->channels);

if (gf->current_result == DetectionResult::DETECTION_RESULT_BEEP) {
int inference_result = DetectionResult::DETECTION_RESULT_UNKNOWN;
uint64_t inference_result_start_timestamp = 0;
uint64_t inference_result_end_timestamp = 0;
{
std::lock_guard<std::mutex> outbuf_lock(gf->whisper_outbuf_mutex);
inference_result = gf->current_result;
inference_result_start_timestamp = gf->current_result_start_timestamp;
inference_result_end_timestamp = gf->current_result_end_timestamp;
}

if (timestamp > inference_result_start_timestamp &&
timestamp < inference_result_end_timestamp) {
if (gf->replace_sound == REPLACE_SOUNDS_SILENCE) {
// set the audio to 0
for (size_t i = 0; i < gf->channels; i++) {
Expand Down Expand Up @@ -207,9 +230,12 @@ void cleanstream_update(void *data, obs_data_t *s)
gf->replace_sound = obs_data_get_int(s, "replace_sound");
gf->filler_p_threshold = (float)obs_data_get_double(s, "filler_p_threshold");
gf->log_level = (int)obs_data_get_int(s, "log_level");
gf->do_silence = obs_data_get_bool(s, "do_silence");
gf->vad_enabled = obs_data_get_bool(s, "vad_enabled");
gf->log_words = obs_data_get_bool(s, "log_words");
gf->delay_ms = BUFFER_SIZE_MSEC + INITIAL_DELAY_MSEC;
gf->current_result = DetectionResult::DETECTION_RESULT_UNKNOWN;
gf->current_result_start_timestamp = 0;
gf->current_result_end_timestamp = 0;

obs_log(gf->log_level, "update whisper model");
update_whisper_model(gf, s);
Expand Down Expand Up @@ -260,7 +286,10 @@ void *cleanstream_create(obs_data_t *settings, obs_source_t *filter)
gf->channels = audio_output_get_channels(obs_get_audio());
gf->sample_rate = audio_output_get_sample_rate(obs_get_audio());
gf->frames = (size_t)((float)gf->sample_rate / (1000.0f / (float)BUFFER_SIZE_MSEC));
gf->last_num_frames = 0;
gf->delay_ms = BUFFER_SIZE_MSEC + INITIAL_DELAY_MSEC;
gf->current_result = DetectionResult::DETECTION_RESULT_UNKNOWN;
gf->current_result_start_timestamp = 0;
gf->current_result_end_timestamp = 0;

for (size_t i = 0; i < MAX_AUDIO_CHANNELS; i++) {
circlebuf_init(&gf->input_buffers[i]);
Expand All @@ -283,10 +312,8 @@ void *cleanstream_create(obs_data_t *settings, obs_source_t *filter)
gf->whisper_model_path = std::string(""); // The update function will set the model path
gf->whisper_context = nullptr;

gf->overlap_ms = OVERLAP_SIZE_MSEC;
gf->overlap_frames = (size_t)((float)gf->sample_rate / (1000.0f / (float)gf->overlap_ms));
obs_log(LOG_INFO, "CleanStream filter: channels %d, frames %d, sample_rate %d",
(int)gf->channels, (int)gf->frames, gf->sample_rate);
obs_log(LOG_INFO, "CleanStream filter: channels %d, sample_rate %d", (int)gf->channels,
gf->sample_rate);

struct resample_info src, dst;
src.samples_per_sec = gf->sample_rate;
Expand Down Expand Up @@ -356,7 +383,6 @@ void cleanstream_defaults(obs_data_t *s)
obs_data_set_default_int(s, "replace_sound", REPLACE_SOUNDS_SILENCE);
obs_data_set_default_bool(s, "advanced_settings", false);
obs_data_set_default_double(s, "filler_p_threshold", 0.75);
obs_data_set_default_bool(s, "do_silence", true);
obs_data_set_default_bool(s, "vad_enabled", true);
obs_data_set_default_int(s, "log_level", LOG_DEBUG);
obs_data_set_default_bool(s, "log_words", false);
Expand All @@ -365,10 +391,10 @@ void cleanstream_defaults(obs_data_t *s)

// Whisper parameters
obs_data_set_default_int(s, "whisper_sampling_method", WHISPER_SAMPLING_BEAM_SEARCH);
obs_data_set_default_string(s, "initial_prompt", "uhm, Uh, um, Uhh, um. um... uh. uh... ");
obs_data_set_default_string(s, "initial_prompt", "");
obs_data_set_default_int(s, "n_threads", 4);
obs_data_set_default_int(s, "n_max_text_ctx", 16384);
obs_data_set_default_bool(s, "no_context", true);
obs_data_set_default_bool(s, "no_context", false);
obs_data_set_default_bool(s, "single_segment", true);
obs_data_set_default_bool(s, "print_special", false);
obs_data_set_default_bool(s, "print_progress", false);
Expand All @@ -379,7 +405,7 @@ void cleanstream_defaults(obs_data_t *s)
obs_data_set_default_double(s, "thold_ptsum", 0.01);
obs_data_set_default_int(s, "max_len", 0);
obs_data_set_default_bool(s, "split_on_word", false);
obs_data_set_default_int(s, "max_tokens", 3);
obs_data_set_default_int(s, "max_tokens", 7);
obs_data_set_default_bool(s, "speed_up", false);
obs_data_set_default_bool(s, "suppress_blank", true);
obs_data_set_default_bool(s, "suppress_non_speech_tokens", true);
Expand Down Expand Up @@ -479,8 +505,8 @@ obs_properties_t *cleanstream_properties(void *data)
// If advanced settings is enabled, show the advanced settings group
const bool show_hide = obs_data_get_bool(settings, "advanced_settings");
for (const std::string &prop_name :
{"whisper_params_group", "log_words", "filler_p_threshold", "do_silence",
"vad_enabled", "log_level"}) {
{"whisper_params_group", "log_words", "filler_p_threshold", "vad_enabled",
"log_level"}) {
obs_property_set_visible(obs_properties_get(props, prop_name.c_str()),
show_hide);
}
Expand All @@ -489,7 +515,6 @@ obs_properties_t *cleanstream_properties(void *data)

obs_properties_add_float_slider(ppts, "filler_p_threshold", MT_("filler_p_threshold"), 0.0f,
1.0f, 0.05f);
obs_properties_add_bool(ppts, "do_silence", MT_("do_silence"));
obs_properties_add_bool(ppts, "vad_enabled", MT_("vad_enabled"));
obs_property_t *list = obs_properties_add_list(ppts, "log_level", MT_("log_level"),
OBS_COMBO_TYPE_LIST, OBS_COMBO_FORMAT_INT);
Expand Down
38 changes: 28 additions & 10 deletions src/whisper-utils/whisper-processing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,26 +305,31 @@ int run_whisper_inference(struct cleanstream_data *gf, const float *pcm32f_data,
long long process_audio_from_buffer(struct cleanstream_data *gf)
{
uint64_t start_timestamp = 0;
uint64_t end_timestamp = 0;

{
// scoped lock the buffer mutex
std::lock_guard<std::mutex> lock(gf->whisper_buf_mutex);

// copy gf->frames from the end of the input buffer to the copy_buffers
for (size_t c = 0; c < gf->channels; c++) {
circlebuf_peek_front(&gf->input_buffers[c], gf->copy_buffers[c],
gf->frames * sizeof(float));
circlebuf_peek_back(&gf->input_buffers[c], gf->copy_buffers[c],
gf->frames * sizeof(float));
}

// peek at the info_buffer to get the timestamp of the first info
// peek at the info_buffer to get the timestamp of the last info
struct cleanstream_audio_info info_from_buf = {0};
circlebuf_peek_front(&gf->info_buffer, &info_from_buf,
sizeof(struct cleanstream_audio_info));
start_timestamp = info_from_buf.timestamp;
circlebuf_peek_back(&gf->info_buffer, &info_from_buf,
sizeof(struct cleanstream_audio_info));
end_timestamp = info_from_buf.timestamp;
start_timestamp =
end_timestamp - (int)(gf->frames * 1000 / gf->sample_rate) * 1000000;
}

obs_log(gf->log_level, "processing %lu frames (%d ms), start timestamp %llu ", gf->frames,
(int)(gf->frames * 1000 / gf->sample_rate), start_timestamp);
obs_log(gf->log_level,
"processing %lu frames (%d ms), start timestamp %llu, end timestamp %llu ",
gf->frames, (int)(gf->frames * 1000 / gf->sample_rate), start_timestamp,
end_timestamp);

// time the audio processing
auto start = std::chrono::high_resolution_clock::now();
Expand All @@ -349,8 +354,7 @@ long long process_audio_from_buffer(struct cleanstream_data *gf)

std::vector<timestamp_t> stamps = gf->vad->get_speech_timestamps();
if (stamps.size() == 0) {
obs_log(gf->log_level, "VAD detected no speech in %d frames",
whisper_buffer_16khz);
obs_log(gf->log_level, "VAD detected no speech");
skipped_inference = true;
}
}
Expand All @@ -362,8 +366,13 @@ long long process_audio_from_buffer(struct cleanstream_data *gf)
{
std::lock_guard<std::mutex> lock(gf->whisper_outbuf_mutex);
gf->current_result = inference_result;
if (gf->current_result == DETECTION_RESULT_BEEP) {
gf->current_result_start_timestamp = start_timestamp;
gf->current_result_end_timestamp = end_timestamp;
}
}
} else {
gf->current_result = DETECTION_RESULT_SILENCE;
if (gf->log_words) {
obs_log(LOG_INFO, "skipping inference");
}
Expand All @@ -377,6 +386,15 @@ long long process_audio_from_buffer(struct cleanstream_data *gf)
obs_log(gf->log_level, "audio processing of %u ms new data took %d ms", audio_processed_ms,
(int)duration);

if (duration > (gf->delay_ms - audio_processed_ms)) {
obs_log(gf->log_level,
"audio processing (%d ms) longer than delay (%lu ms), increase delay",
(int)duration, gf->delay_ms);
gf->delay_ms += 100;
} else {
gf->delay_ms -= 100;
}

return duration;
}

Expand Down
Loading