Skip to content

Commit

Permalink
Revert "whisper : remove extra backend instance (huh?)" (#2182)
Browse files Browse the repository at this point in the history
This reverts commit 4caa64b.
  • Loading branch information
ggerganov authored May 27, 2024
1 parent a7dc2aa commit 05042a7
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,8 @@ struct whisper_state {

whisper_decoder decoders[WHISPER_MAX_DECODERS];

ggml_backend_t backend = nullptr;

// ggml-alloc:
// - stores meta info about the intermediate tensors into the `meta` buffers
// - stores the actual tensor data into the `data` buffers
Expand Down Expand Up @@ -2261,7 +2263,7 @@ static bool whisper_encode_internal(
}

if (!whisper_encode_external(wstate)) {
if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) {
if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
return false;
}
} else {
Expand All @@ -2284,7 +2286,7 @@ static bool whisper_encode_internal(
return false;
}

if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) {
if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
return false;
}
}
Expand All @@ -2300,7 +2302,7 @@ static bool whisper_encode_internal(
return false;
}

if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) {
if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
return false;
}
}
Expand Down Expand Up @@ -2801,7 +2803,7 @@ static bool whisper_decode_internal(

logits = gf->nodes[gf->n_nodes - 1];

if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) {
if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
return false;
}
}
Expand Down Expand Up @@ -3248,6 +3250,13 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {

whisper_state * state = new whisper_state;

state->backend = whisper_backend_init(ctx->params);
if (!state->backend) {
WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__);
whisper_free_state(state);
return nullptr;
}

// at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
// in theory, there can be a case where this is not enough, but in practice it should always be enough
const int factor = 3;
Expand Down Expand Up @@ -3684,6 +3693,8 @@ void whisper_free_state(struct whisper_state * state) {
ggml_gallocr_free(state->alloc_cross.alloc);
ggml_gallocr_free(state->alloc_decode.alloc);

ggml_backend_free(state->backend);

// [EXPERIMENTAL] Token-level timestamps with DTW
aheads_masks_free(state->aheads_masks);

Expand Down

0 comments on commit 05042a7

Please sign in to comment.