Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Token level timestamps with DTW (#375) #1485

Merged
merged 20 commits into from
Mar 20, 2024

Conversation

denersc
Copy link
Contributor

@denersc denersc commented Nov 13, 2023

Tries to solve #375

Attempt to implement DTW-based token level timestamps, as seen in OpenAI Whisper.

This first commit implements the DTW algorithm on whisper.cpp and provides tests that compare the output with the output of OpenAI's implementation. Tests are done calling whisper.cpp from Python and comparing DTW output with OpenAI's dtw_cpu function.

An outline of remaining work is commented on whisper_exp_compute_token_level_timestamps_dtw in whisper.cpp. Help/insights are very appreciated, specially concerning how to cache/retrieve the output of MHA layers that are used as input for DTW.

In OpenAI's implementation, token-level timestamps are used with further heuristics to determine a supposed start/end time for words. In this PR, my intention is to implement token-level only as a first step that can be used to implement word timestamps in the future.

TODO

  • Implement/test DTW algorithm
  • Import index of alignment heads from original whisper.
  • Find a way to cache the output of the MHA layers QKs from cross-attention layers from alignment heads (perhaps in whisper_state?) and retrieve them in whisper_exp_compute_token_level_timestamps_dtw
  • Use GGML to pre-process QKs that are passed to DTW. Includes stacking cached QKs, scaling, clipping, normalizing. See whisper_exp_compute_token_level_timestamps_dtw
  • Plug whisper_exp_compute_token_level_timestamps_dtw into whisper_full and use results to place timestamps on each inferred token
  • Plumbing in general, enable/disable as whisper_full_param, etc
  • Implement N_TOP_MOST alignment heads
  • Implement CUSTOM alignment heads, decide comfortable API for setting custom alignment heads
  • Find a way to make whisper_build_graph_decoder to only save QK copies if requested, so there is no additional overhead when running decoder for other reasons than timestamps.
  • Avoid memory allocations on whisper_exp_compute_token_level_timestamps_dtw (probably allocate a buffer for the used tensors on init)
  • (Maybe) check if some operations on whisper_exp_compute_token_level_timestamps_dtw that are currently done with manual for loops can benefit from ggml functions.

@bmurray
Copy link
Contributor

bmurray commented Nov 16, 2023

This is awesome. I think the big question is where the alignment heads actually reside in GGML.

@denersc
Copy link
Contributor Author

denersc commented Nov 17, 2023

This is awesome. I think the big question is where the alignment heads actually reside in GGML.

Yes, that is the most critical point at the moment. I suspect that we would need to save this tensor:

// Inside whisper_build_graph_decoder
2448             // K * Q
2449             struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q);   // THIS ONE
2450 
2451             //struct ggml_tensor * KQ_scaled =
2452             //    ggml_scale(ctx0,
2453             //            KQ,
2454             //            ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
2455             //            );
2456 
2457             // no masking for cross-attention
2458             //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
2459 
2460             struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ); // OR THIS ONE

In OpenAI implementation, they do a second pass through the model passing the tokens generated on the first pass to be able to retrieve the correct weights. I think if we are able to that second pass just like in openAI impl, it should just be a matter of saving those KQs and selecting the ones useful for timestamping.

Apparently lintoai whisper-timestamped seems to be able to do in one pass, only when there is no temperature fallback, greedy bestof = 1 and no beam search. If any of those conditions are not met, then it resorts to doing a second pass like on OpenAI impl.

In what concerns of selecting heads that are useful for timestamping, from what i have seen, a accepted default would be the heads of the top most half of decoder layers (the ones closer to model output). A more optimal selection, which is provided on OpenAIs whisper individually by model, was apparently obtained by manual inspection . Note that alignment heads differ even for models with same dimensions (e.g. medium and medium.en have different alignment heads specified)

We could still use these indexes for the usual pre-trained models but offer the option to use the n top-most decoder layers OR some custom index i guess.

@linmi
Copy link

linmi commented Nov 28, 2023

Looking forward to the progress 🎉

@denersc
Copy link
Contributor Author

denersc commented Nov 28, 2023

Haven't been able to work with this past week, but making some progress now!

Trying to get a very poorly implemented end-to-end POC for the base.en model. Having some trouble figuring out if i'm correctly generating/retrieving the needed cross-attention QKs

Validated most of the pre-processing operations done on QKs before passing then to DTW to get timestamps. So, taking the OpenAI implementation on timing.py: (with added comments by me)

    # Implemented on whisper.cpp, seems ok
    tokens = torch.tensor(
        [
            *tokenizer.sot_sequence,
            tokenizer.no_timestamps,
            *text_tokens,
            tokenizer.eot,
        ]
    ).to(model.device)

    # I believe that i am making a mistake somewhere in this block
    # install hooks on the cross attention layers to retrieve the attention weights
    QKs = [None] * model.dims.n_text_layer
    hooks = [
        block.cross_attn.register_forward_hook(
            lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
        )
        for i, block in enumerate(model.decoder.blocks)
    ]
    with torch.no_grad():
        logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
        sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
        token_probs = sampled_logits.softmax(dim=-1)
        text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
        text_token_probs = text_token_probs.tolist()

    for hook in hooks:
        hook.remove()

    # Implemented poorly in whisper.cpp (only using base.en alignment heads for now)
    # heads * tokens * frames
    weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
    weights = weights[:, :, : num_frames // 2]
    weights = (weights * qk_scale).softmax(dim=-1)

    # Implemented and validated on whisper.cpp 
    std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
    weights = (weights - mean) / std
    weights = median_filter(weights, medfilt_width)
    matrix = weights.mean(axis=0)
    matrix = matrix[len(tokenizer.sot_sequence) : -1]
    text_indices, time_indices = dtw(-matrix)

Now, to get the whisper.cpp QKs, i temporarily added a decoder_save_cross_QKs bool to whisper_state. If enabled, it will save QKs during decoding. So:

// Inside whisper_build_graph_decoder
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ); // THIS TENSOR HERE!
if (wstate.decoder_save_cross_QKs) {
    wstate.cross_QKs.push_back(KQ_soft_max);
}


// After inference, when computing timestamps 
state->decoder_save_cross_QKs = true;
if (whisper_decode_with_state(ctx, state, tokens.data(), tokens.size(), 0, n_threads) != 0) { 
    WHISPER_LOG_INFO("DECODER FAILED\n");
}    
state->decoder_save_cross_QKs = false;

Doing this, i get a set of QKs to work with inside my timestamping function, and the amount of tensors retrieved and their dimensions are in line with what's retrieved in the openAI impl.

Now, since the output i got so far for timestamps is complete garbage, and most of the operations outside retrieval of QKs seem correct, i imagine something is wrong when i call whisper_decode_with_state. I am unsure if i need to do additional setup before calling whisper_decode_with_state or if the ggml_tensor pointers i saved are invalid or do not contain what i expect.

Any insight into how to correctly retrieve these QKs from decoder cross attention layers is very welcome! I'm currently on working on validating all operations other than QK retrieval to be absolutely sure this is where my mistake is.

@denersc
Copy link
Contributor Author

denersc commented Nov 30, 2023

Currently stuck on retrieving the attention weights from the decoder. In all my attempts, i either get tensors with floats > 1 (which indicates that they are not the output of the softmax layer i'm trying to retrieve), null tensors, or tensors with null "data" pointer.

What i have tried so far (inside whisper_build_graph_decoder):

  • Directly pushing KQ_soft_max into a vector on the state yields a tensor that does not seem like the output of softmax (floats > 1)
// Inside whisper_build_graph_decoder
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
wstate.cross_QKs.push_back(KQ_soft_max);


// Later on, e.g. on timestamping function
// Many values are > 1, indicating that they are not the output of softmax
float v = ggml_get_f32_nd(state.cross_QKs[0], 0, 0, 0, 0);
  • Using ggml_cpy on KQ_soft_max hoping that a copy of it would be later retrievable. This yields a tensor whose data field points to NULL
// Inside whisper_build_graph_decoder
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
struct ggml_tensor * KQ_copy = ggml_cpy(ctx0, KQ_soft_max, ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, KQ_soft_max->ne[0], KQ_soft_max->ne[1], KQ_soft_max->ne[2]));
wstate.cross_QKs.push_back(KQ_copy);

// Later on, e.g. on timestamping function
if (state.cross_QKs[0]->data == NULL)
  // this is true
  • Permutations of the the above, but instead of pushing these tensors to a vector, setting a name on them and trying to retrieve them by name right after ggml_graph_compute_helper inside whisper_decode_internal. This yields either a tensor that is clearly not softmax output (values > 1) if done directly on KQ_soft_max or a a null pointer if done on the copy:
// Inside whisper_build_graph_decoder
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);

struct ggml_tensor * KQ_copy = ggml_cpy(ctx0, KQ_soft_max, ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, KQ_soft_max->ne[0], KQ_soft_max->ne[1], KQ_soft_max->ne[2]));
char name[256];
snprintf(name, 256, "cross_QK_%d", il);
ggml_set_name(KQ_soft_max, name);
snprintf(name, 256, "cross_QK_copy_%d", il);
ggml_set_name(KQ_copy, name);


// Later on, right after ggml_graph_compute_helper inside whisper_decode_internal
ggml_graph_compute_helper(wstate.backend, gf, n_threads);

struct ggml_tensor * cross_QK = ggml_graph_get_tensor(gf, "cross_QK_0");
struct ggml_tensor * cross_copy = ggml_graph_get_tensor(gf, "cross_QK_copy_0");
float v = ggml_get_f32_nd(cross_QK, 0, 0, 0, 0) // v > 1, not softmax output
if (cross_copy == NULL)
  // this is true

@ggerganov would you be able to nudge me in right direction in how would i go about saving the values of the KQ_soft_max cross-attention tensor inside the whisper_build_graph_decoder function so i can re-use them for timestamping? I've been reading ggml source code but so far haven't been able to figure it out. Thanks!

@denersc
Copy link
Contributor Author

denersc commented Dec 1, 2023

So apparently i was missing the fact that i had to call ggml_build_forward_expand over my copy of QKs, since they would not be computed normally as decoder output does not depend on those copies.

With that done, finally got a correct output for timestamps. Running over the jfk.wav sample with base.en model, i got these timestamps:

|[_NOT_]|(0.00) | And|(0.52) | so|(0.78) | my|(1.16) | fellow|(1.60) | Americans|(2.08) |,|(3.48) | ask|(3.84) | not|(4.30) | what|(5.64) | your|(5.92) | country|(6.40) | can|(6.80) | do|(7.04) | for|(7.32) | you|(7.62) |,|(7.64) | ask|(8.66) | what|(9.08) | you|(9.36) | can|(9.68) | do|(9.94) | for|(10.18) | your|(10.36) | country|(10.82) |.|(11.22) 

Which is identical to the timestamps retrieved on openAI impl before their heuristics to determine start/end of each token:

|<|notimestamps|>|(0.0) | And|(0.52) | so|(0.78) | my|(1.16) | fellow|(1.6) | Americans|(2.08) |,|(3.46) | ask|(3.84) | not|(4.3) | what|(5.64) | your|(5.92) | country|(6.4) | can|(6.8) | do|(7.04) | for|(7.32) | you|(7.62) |,|(7.640000000000001) | ask|(8.66) | what|(9.08) | you|(9.36) | can|(9.68) | do|(9.94) | for|(10.18) | your|(10.36) | country|(10.86) |.|(11.200000000000001) 

I'll be working on correctly plumbing everything together, since i poorly stitched everything together just to see end-to-end execution.

@bmurray
Copy link
Contributor

bmurray commented Dec 2, 2023

Is this committed onto your fork? I would love to try it out. I have some audio that I consider to be a challenge, that the OpenAI Python one handles. I would love to compare it and test it.

@RRUK01
Copy link

RRUK01 commented Dec 3, 2023

Happy to help with this if your able to commit your latest updates

@denersc
Copy link
Contributor Author

denersc commented Dec 4, 2023

I'll try to commit the changes between today and tomorrow so you guys can give it spin!

@denersc denersc force-pushed the feature-token-timestamps-dtw branch from fbd390d to ab5cd86 Compare December 4, 2023 19:50
@denersc
Copy link
Contributor Author

denersc commented Dec 4, 2023

I've committed these recent results into the fork. Currently, the timestamps are only being placed if you run with params.single_segment=true, as there is some more work to get everything aligned in the case a single inference run produces more than one segment.

To run with DTW timestamps, you need to enable then in params and select a collection of alignment heads. I've imported those on OpenAI impl for each model, but have yet to implement setting custom alignment heads or using the N top most heads for alignment.

So, e.g. for main.cpp, you need to set

diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index 9699802..c101434 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -942,6 +942,10 @@ int main(int argc, char ** argv) {
 
             wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
 
+            wparams.dtw_token_timestamps = true;
+            wparams.dtw_ah_preset = WHISPER_AHEADS_BASE_EN; // Match correctly with the model you are using.
+            wparams.single_segment = true;
+
             wparams.print_realtime   = false;
             wparams.print_progress   = params.print_progress;
             wparams.print_timestamps = !params.no_timestamps;

and timestamps should be placed on the token_data struct for each token, in the field t_dtw.

auto token_data =  whisper_full_get_token_data(ctx, i_segment, i_token);
float timestamp_seconds = (float)token_data.t_dtw/100;

I've left several FIXMEs on whisper.cpp and whisper.h outlining some of the work needed to conclude this implementation. I'll be updating the TODO list on this PR soon to it reflects the current issues.

I've only tested so far on jfk.wav and with base.en model, so it is very possible you run into some problems during execution.

@denersc
Copy link
Contributor Author

denersc commented Dec 7, 2023

So, i've updated the TODO list and pushed some fixes:

  • Setting single_segment is no longer needed. Timestamps should work correctly if an inference run hast multiple segments.
  • Fixed compile error related to whisper_alignment_heads_preset on whisper_full_params that occurred in some compilers.
  • Fixed a incorrect assertion that would cause DTW timestamp function to break if you didn't set params.audio_ctx to any value
  • Fixed a detail in timestamping final loop that would cause gradual misalignment of timestamps on longer segments.

Any help in the current TODO list is very welcome, specially the last 3 items.

I've made some tests with the a13.wav audio, will continue validating with other cases. Any testing is highly appreciated!

@denersc denersc force-pushed the feature-token-timestamps-dtw branch 2 times, most recently from 936be88 to f8d1b1f Compare December 14, 2023 17:27
@denersc
Copy link
Contributor Author

denersc commented Dec 19, 2023

Just an update, haven't been able to work the last few issues lately and probably won't be completing in December. Nevertheless, I've been doing some tests with whats implemented so far and timestamps seem to be working as expected.

About the last couple of issues, what I'm having most trouble figuring it out is "Avoid memory allocations on whisper_exp_compute_token_level_timestamps_dtw".

I'm not very confident on how to retrieve the total memory needed in each case, since some operations are done "manually" over ggml tensors instead of using a ggml graph.

Also, the memory needed can vary greatly - like from <16MB to >2GB - depending on factors such as how many tokens the model yielded this run, audio_ctx size and the number of alignment heads (which is limited by numbers of attention heads in total). This can all be taken into account when allocating to try to allocate the minimum necessary for the worst case.

Any help on these final issues would be very appreciated!

@bmurray
Copy link
Contributor

bmurray commented Dec 21, 2023

I've been running some tests on this today. It's not perfect but it's definitely way way better. I have a test video that I've been using that has a lot of easy to find issues that let me compare the released version vs this branch (and against the python implementation). It starts with a 12 second silent section, and has a few silent sections peppered throughout. (Link below). It should be identifying the first word as starting at about 12.4 seconds in, but its identifying it as 7.6 seconds in. Thats a lot better than the original that was placing it at 0 seconds in. The second and third token seem to be correctly identified tho.

Overall, this is incredible work! I'm going to submit a PR here shortly that should expose this a little easier in the main.cpp and go bindings. Is there a model auto-detect already? I'm not seeing it here but I've only just barely started looking at the actual code.

Queen 'We will meet again' speech.
https://www.youtube.com/watch?v=2klmuggOElE

@denersc
Copy link
Contributor Author

denersc commented Dec 21, 2023

Hey @bmurray, thanks for testing! I'll try to run this audio here as well. May i ask which model size and alignment head preset were you using?

About that first token with wrong timestamp, i will try to test it on openAI impl as well to see their DTW output. It might be the case that they get the same output and fix that on the subsequent heuristics after running DTW. If you check their implementation, what we have implemented stops about here. After returning from that function, they still do some stuff that seems to be especially aimed to improving timestamps on bounds

About model auto-detect, I'm not sure how we can implement that. Since alignment heads are different for models with same size (e.g. large-v1 and large-v2 have different alignment heads), we might need user input to guarantee we are using the correct alignment heads. Not sure if I'm mistaken though.

Just a PS, I'll be out after this week and probably return to this only start/mid January, but I'll happily address any comments when i return.

@bobqianic bobqianic added help wanted Extra attention is needed high priority Very important issue research🔬 labels Jan 9, 2024
Comment on lines +6711 to +6877
// dtw
// supposedly can be optmized by computing diagonals in parallel ?
// Not sure it is worth it since x will be GENERATED_TOKENS*1500 size at most.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How long does this step take?

Copy link
Contributor Author

@denersc denersc Jan 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, i've just made some measurements and it is quite rather insignificant compared to the rest of the function. On a Apple M2 and metal enabled, less than 1% of the whole time-stamping process. (Moved longer analysis to MR comments)

@denersc denersc force-pushed the feature-token-timestamps-dtw branch from a3cbe38 to add2db7 Compare March 12, 2024 12:55
@denersc denersc force-pushed the feature-token-timestamps-dtw branch from add2db7 to 10b0304 Compare March 12, 2024 12:58
@denersc denersc marked this pull request as ready for review March 12, 2024 17:49
@ggerganov
Copy link
Owner

The DTW timestamps can now be generated with main using the -ojf CLI argument:

./main -m models/ggml-small.bin -f samples/gb0.wav -dtw small -ojf
{
			"timestamps": {
				"from": "00:02:02,960",
				"to": "00:02:05,520"
			},
			"offsets": {
				"from": 122960,
				"to": 125520
			},
			"text": " by the will of the people.",
			"tokens": [
				{
					"text": " by",
					"timestamps": {
						"from": "00:02:02,960",
						"to": "00:02:03,180"
					},
					"offsets": {
						"from": 122960,
						"to": 123180
					},
					"id": 538,
					"p": 0.999897,
					"t_dtw": 12312
				},
				{
					"text": " the",
					"timestamps": {
						"from": "00:02:03,180",
						"to": "00:02:03,510"
					},
					"offsets": {
						"from": 123180,
						"to": 123510
					},
					"id": 264,
					"p": 0.999729,
					"t_dtw": 12328
				},
				{
					"text": " will",
					"timestamps": {
						"from": "00:02:03,510",
						"to": "00:02:03,590"
					},
					"offsets": {
						"from": 123510,
						"to": 123590
					},
					"id": 486,
					"p": 0.997792,
					"t_dtw": 12354
				},
				{
					"text": " of",
					"timestamps": {
						"from": "00:02:04,140",
						"to": "00:02:04,170"
					},
					"offsets": {
						"from": 124140,
						"to": 124170
					},
					"id": 295,
					"p": 0.999649,
					"t_dtw": 12430
				},
				{
					"text": " the",
					"timestamps": {
						"from": "00:02:04,170",
						"to": "00:02:04,500"
					},
					"offsets": {
						"from": 124170,
						"to": 124500
					},
					"id": 264,
					"p": 0.999611,
					"t_dtw": 12440
				},
				{
					"text": " people",
					"timestamps": {
						"from": "00:02:04,500",
						"to": "00:02:05,090"
					},
					"offsets": {
						"from": 124500,
						"to": 125090
					},
					"id": 561,
					"p": 0.999641,
					"t_dtw": 12482
				},
				{
					"text": ".",
					"timestamps": {
						"from": "00:02:05,200",
						"to": "00:02:05,440"
					},
					"offsets": {
						"from": 125200,
						"to": 125440
					},
					"id": 13,
					"p": 0.998121,
					"t_dtw": 12512
				},
				{
					"text": "[_TT_416]",
					"timestamps": {
						"from": "00:02:05,520",
						"to": "00:02:05,520"
					},
					"offsets": {
						"from": 125520,
						"to": 125520
					},
					"id": 50780,
					"p": 0.280601,
					"t_dtw": -1
				}
			]
		},

@ggerganov ggerganov merged commit 741abb1 into ggerganov:master Mar 20, 2024
45 of 50 checks passed
jiahansu pushed a commit to WiseSync/whisper.cpp that referenced this pull request Apr 17, 2024
* whisper.cpp: impl dtw algo

* WIP: producing and placing DTW timestamps on tokens

* Fix compile and assertion errors. Attempt to DTW timestamp with single_segment=false.

* Fix mistake causing incorrect alignment of dtw timestamps

* implement N_TOP_MOST and CUSTOM alignment heads setting

* whisper: fix typo on alignment heads enum

* Fix issues related to changes in whisper.cpp

* Fixed excessive memory use when using DTW timestamps. Other minor fixes to DTW timestamping function

* decoder: save cross QKs only if requested

* Calling median filter with ggml_map_custom1

* Reimpl aheads n_top_most and custom. Sanity checks on chosen aheads

* Copying cross QKs from decoder backend correctly

* dtw: cleanup

* Fix incorrect n_frames passed to dtw when near end of audio

* Fix aheads_masks_init for backend != CPU

* whisper : minor style

* main : add dtw (wip)

* whisper: fix invalid memory access in aheads_masks_init

* main : add dtw (cont)

* whisper : minor

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
@hlevring
Copy link

hlevring commented May 4, 2024

I am looking to derive word level timestamps from DTW timestamps. Some tokens need to be joined to form a complete word, so as a start I figured to just concatenate tokens until a space or punctuation mark is encountered, and from just accumulate the timestamps from the tokens to determine the start and end times for each word.

I am pretty sure there must be way more to consider. @denersc it would be great if you can comment on this. - and hey super thanks for your work on the DTW timestamps.

@eschmidbauer
Copy link
Contributor

not sure if anyone has noticed this too but the DTW timestamps seem to be completely inaccurate in some segments of the transcript. while other segments, it's very precise.

@hlevring
Copy link

hlevring commented May 5, 2024

I was experimenting yesterday and saw forward shifts in token timestamps prior to silence periods. (Say pauses in speech for 1-2 sec). In my tests I did have a good bit of background noise that might have been a factor. I will try to redo some experiments in the next few days.

@denersc
Copy link
Contributor Author

denersc commented May 6, 2024

Hey @hlevring and @eschmidbauer, thanks for trying it out!

So, I'll try to address what you guys said, but unfortunately i don't think i can provide perfect answers.

First, when thinking about DTW timestamps, i crudely rationalize about them as "A estimate of the moment the model decided to output a certain token".

So, in common speech flow, like between equal paced words in a sentence, it is very likely that the DTW timestamp of the last token in a word will be very close to the actual time of the end of the word.

Nevertheless, it is not unusual that the model will output some token long after it actually occurred in audio. In that case, DTW timestamp will likely be incorrect. The most common example of this is the period (.) token. It can outputted by the model after some time of silence preceded by a sequence of words. So the DTW timestamp for the period will be long after sentence end, and should probably be ignored.

Although period and punctuation in general are the most common occurrence of this, I don't doubt this kind of thing can happen with words, e.g. model outputting 3 words almost simultaneously because only when it understood the third word could it actually infer all three. In that case, the first 2 words will have very imprecise timestamps. Although i think this may be possible, it does not seem to be very likely, at least in the sense that i haven't observed it directly.

All of this to say, DTW timestamps are a imperfect source of information, and should be used with some caution and combined with other data to provide good word timestamp estimates. OpenAI tries to address some of theses issues

Of course, it may also be the case that my implementation is incorrect on some point. Maybe on the step of selecting and saving alignment heads or when performing some matrix operations. I think a good starting point to check that would be to compare the DTW timestamps given by openAI impl with the ones in my implementation for a variety of audios and see if there are any large discrepancies. Some very small variance is bound to happen probably because of different matrix operation implementations.

Finally, I'm more on the developer side than on the ML research side, my math understanding is reasonably shallow. So take all i said with a grain of salt 😬

@eschmidbauer
Copy link
Contributor

I think a good starting point to check that would be to compare the DTW timestamps given by openAI impl with the ones in my implementation for a variety of audios and see if there are any large discrepancies.

This is a great idea! I'll set some time aside to compare the two & post my findings

@denersc
Copy link
Contributor Author

denersc commented May 6, 2024

Also, forgot to say, make sure whisper.cpp version you guys are using is after MR #2012. That bug likely caused incorrect alignment head selection, beyond the observed memory error.

Cool @eschmidbauer. On whisper.cpp, you can uncomment these lines if you want to print DTW timestamps.

You might need to change the code on the OpenAI package to get the actual raw DTW timestamps, since they don't provide them to the final user. They do a lot of additional processing before giving it back, so those will be different for sure. You'll probably have to add some code to retrieve these and do some sort of loop equivalent to what i did here to get timestamps for each token. These will be the raw timestamps which are comparable to the ones i made available on whisper.cpp.

viktor-silakov pushed a commit to viktor-silakov/whisper_node_mic.cpp that referenced this pull request May 11, 2024
* whisper.cpp: impl dtw algo

* WIP: producing and placing DTW timestamps on tokens

* Fix compile and assertion errors. Attempt to DTW timestamp with single_segment=false.

* Fix mistake causing incorrect alignment of dtw timestamps

* implement N_TOP_MOST and CUSTOM alignment heads setting

* whisper: fix typo on alignment heads enum

* Fix issues related to changes in whisper.cpp

* Fixed excessive memory use when using DTW timestamps. Other minor fixes to DTW timestamping function

* decoder: save cross QKs only if requested

* Calling median filter with ggml_map_custom1

* Reimpl aheads n_top_most and custom. Sanity checks on chosen aheads

* Copying cross QKs from decoder backend correctly

* dtw: cleanup

* Fix incorrect n_frames passed to dtw when near end of audio

* Fix aheads_masks_init for backend != CPU

* whisper : minor style

* main : add dtw (wip)

* whisper: fix invalid memory access in aheads_masks_init

* main : add dtw (cont)

* whisper : minor

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
iThalay pushed a commit to iThalay/whisper.cpp that referenced this pull request Sep 23, 2024
* whisper.cpp: impl dtw algo

* WIP: producing and placing DTW timestamps on tokens

* Fix compile and assertion errors. Attempt to DTW timestamp with single_segment=false.

* Fix mistake causing incorrect alignment of dtw timestamps

* implement N_TOP_MOST and CUSTOM alignment heads setting

* whisper: fix typo on alignment heads enum

* Fix issues related to changes in whisper.cpp

* Fixed excessive memory use when using DTW timestamps. Other minor fixes to DTW timestamping function

* decoder: save cross QKs only if requested

* Calling median filter with ggml_map_custom1

* Reimpl aheads n_top_most and custom. Sanity checks on chosen aheads

* Copying cross QKs from decoder backend correctly

* dtw: cleanup

* Fix incorrect n_frames passed to dtw when near end of audio

* Fix aheads_masks_init for backend != CPU

* whisper : minor style

* main : add dtw (wip)

* whisper: fix invalid memory access in aheads_masks_init

* main : add dtw (cont)

* whisper : minor

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
iThalay pushed a commit to iThalay/whisper.cpp that referenced this pull request Sep 23, 2024
* whisper.cpp: impl dtw algo

* WIP: producing and placing DTW timestamps on tokens

* Fix compile and assertion errors. Attempt to DTW timestamp with single_segment=false.

* Fix mistake causing incorrect alignment of dtw timestamps

* implement N_TOP_MOST and CUSTOM alignment heads setting

* whisper: fix typo on alignment heads enum

* Fix issues related to changes in whisper.cpp

* Fixed excessive memory use when using DTW timestamps. Other minor fixes to DTW timestamping function

* decoder: save cross QKs only if requested

* Calling median filter with ggml_map_custom1

* Reimpl aheads n_top_most and custom. Sanity checks on chosen aheads

* Copying cross QKs from decoder backend correctly

* dtw: cleanup

* Fix incorrect n_frames passed to dtw when near end of audio

* Fix aheads_masks_init for backend != CPU

* whisper : minor style

* main : add dtw (wip)

* whisper: fix invalid memory access in aheads_masks_init

* main : add dtw (cont)

* whisper : minor

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed high priority Very important issue research🔬
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Inaccurate token time Timestamp accuracy needs to be improved