Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,14 @@ extern "C" {
int n_samples,
int n_processors);

WHISPER_API int whisper_full_batch_parallel(
struct whisper_context * ctx,
struct whisper_full_params params,
const float * const * batches,
const int * size_per_batch,
int n_batches,
int n_processors);

// Number of generated text segments
// A segment can be a few words, a sentence, or even a paragraph.
WHISPER_API int whisper_full_n_segments (struct whisper_context * ctx);
Expand Down
123 changes: 123 additions & 0 deletions src/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7893,6 +7893,129 @@ int whisper_full_parallel(
return ret;
}


int whisper_full_batch_parallel(
struct whisper_context *ctx,
struct whisper_full_params params,
const float *const *batches,
const int *size_per_batch,
int n_batches,
int n_processors)
{
int ret = 0;
n_processors = std::min(n_processors, n_batches);
if (n_batches > n_processors)
{
throw std::runtime_error("batch size must be equal to number of processors");
}
// prepare separate states for each thread
std::vector<whisper_state *> states;
std::vector<std::vector<float>> batches_vector;
batches_vector.reserve(n_batches);
for (int i = 0; i < n_batches; ++i)
{
int batch_size = size_per_batch[i];
batches_vector.emplace_back(batches[i], batches[i] + batch_size);
}

// the calling thread will process the first chunk
// while the other threads will process the remaining chunks
const int n_parallel_processes = n_processors - 1;
std::vector<std::thread> workers(n_parallel_processes);
for (int i = 0; i < n_parallel_processes; ++i)
{
if (i + 1 > n_batches - 1)
{
// break when batch not exist for parallel process
break;
}
const float *samples = batches_vector[i + 1].data();
const int n_samples = batches_vector[i + 1].size();
// create a new state for each thread
states.push_back(whisper_init_state(ctx));

auto params_cur = params;

params_cur.offset_ms = 0;
params_cur.print_progress = false;
params_cur.print_realtime = false;

params_cur.new_segment_callback = nullptr;
params_cur.new_segment_callback_user_data = nullptr;

params_cur.progress_callback = nullptr;
params_cur.progress_callback_user_data = nullptr;

workers[i] = std::thread(whisper_full_with_state, ctx, states[i], std::move(params_cur), samples, n_samples);
}

{
auto params_cur = params;

// We need to disable the print real-time for this one as well, otherwise it will show only for the first chunk.
params_cur.print_realtime = false;

const float *samples = batches_vector[0].data();
const int n_samples = batches_vector[0].size();

// Run the first transformation using default state but only for the first chunk.
ret = whisper_full_with_state(ctx, ctx->state, std::move(params_cur), samples, n_samples);
}

for (int i = 0; i < n_parallel_processes; ++i)
{
workers[i].join();
}

// combine results into result_state->result_all from all other states
for (int i = 0; i < n_processors - 1; ++i)
{
auto &results_i = states[i]->result_all;

for (auto &result : results_i)
{

// make sure that segments are not overlapping
if (!ctx->state->result_all.empty())
{
result.t0 = std::max(result.t0, ctx->state->result_all.back().t1);
}

ctx->state->result_all.push_back(std::move(result));

// call the new_segment_callback for each segment
if (params.new_segment_callback)
{
params.new_segment_callback(ctx, ctx->state, 1, params.new_segment_callback_user_data);
}
}

ctx->state->t_mel_us += states[i]->t_mel_us;

ctx->state->t_sample_us += states[i]->t_sample_us;
ctx->state->t_encode_us += states[i]->t_encode_us;
ctx->state->t_decode_us += states[i]->t_decode_us;
ctx->state->t_batchd_us += states[i]->t_batchd_us;
ctx->state->t_prompt_us += states[i]->t_prompt_us;

ctx->state->n_sample += states[i]->n_sample;
ctx->state->n_encode += states[i]->n_encode;
ctx->state->n_decode += states[i]->n_decode;
ctx->state->n_batchd += states[i]->n_batchd;
ctx->state->n_prompt += states[i]->n_prompt;

whisper_free_state(states[i]);
}

// average the timings
ctx->state->t_mel_us /= n_processors;
ctx->state->t_sample_us /= n_processors;
ctx->state->t_encode_us /= n_processors;
ctx->state->t_decode_us /= n_processors;

return ret;
}

int whisper_full_n_segments_from_state(struct whisper_state * state) {
return state->result_all.size();
}
Expand Down