-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make chuking smartly (long files) work on asr ctc_with_lm. #15219
Changes from 3 commits
7de47f9
7875652
a81f758
33dfac1
9cf3615
d96abe9
c37f396
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -66,14 +66,27 @@ def ffmpeg_read(bpayload: bytes, sampling_rate: int) -> np.array: | |
return audio | ||
|
||
|
||
def apply_stride(tokens, stride): | ||
max_token_n = tokens.shape[-1] | ||
def audio_to_logits(tokens_or_logits, stride): | ||
# Shape is [B, SEQ] for tokens | ||
# [B, SEQ, V] for logits | ||
max_token_n = tokens_or_logits.shape[1] | ||
max_input_n = max(input_n for input_n, _, _ in stride) | ||
ratio = max_token_n / max_input_n | ||
for i, (input_n, left, right) in enumerate(stride): | ||
new_strides = [] | ||
for input_n, left, right in stride: | ||
token_n = int(round(input_n * ratio)) | ||
left_token = int(round(left / input_n * token_n)) | ||
right_token = int(round((input_n - right) / input_n * token_n)) | ||
left = int(round(left / input_n * token_n)) | ||
right = int(round(right / input_n * token_n)) | ||
new_stride = (token_n, left, right) | ||
new_strides.append(new_stride) | ||
return new_strides | ||
|
||
|
||
def apply_stride(tokens, stride): | ||
new_stride = audio_to_logits(tokens, stride) | ||
for i, (input_n, left, right) in enumerate(new_stride): | ||
left_token = left | ||
right_token = input_n - right | ||
# This is CTC to preseve decoding, we need to duplicate | ||
# next letter, and last letter | ||
|
||
|
@@ -215,7 +228,7 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None): | |
stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate)) | ||
stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate)) | ||
|
||
if self.type != "ctc": | ||
if self.type not in {"ctc", "ctc_with_lm"}: | ||
raise ValueError( | ||
"`chunk_length_s` is only valid for CTC models, use other chunking options for other models" | ||
) | ||
|
@@ -244,9 +257,15 @@ def _forward(self, model_inputs): | |
) | ||
out = {"tokens": tokens} | ||
elif self.type == "ctc_with_lm": | ||
stride = model_inputs.pop("stride", None) | ||
outputs = self.model(**model_inputs) | ||
out = {"logits": outputs.logits} | ||
|
||
logits = outputs.logits | ||
out = {"logits": logits} | ||
if stride is not None: | ||
# Send stride to `postprocess`. | ||
# it needs to be handled there where | ||
# the pieces are to be concatenated. | ||
out["stride"] = audio_to_logits(logits, stride) | ||
elif self.type == "ctc": | ||
stride = model_inputs.pop("stride", None) | ||
outputs = self.model(**model_inputs) | ||
|
@@ -266,7 +285,20 @@ def _forward(self, model_inputs): | |
|
||
def postprocess(self, model_outputs): | ||
if self.type == "ctc_with_lm": | ||
logits = np.concatenate([outputs["logits"].numpy() for outputs in model_outputs], axis=1) | ||
final_logits = [] | ||
for outputs in model_outputs: | ||
logits = outputs["logits"].numpy() | ||
stride = outputs.get("stride", None) | ||
if stride is not None: | ||
total_n, left, right = stride | ||
# Total_n might be < logits.shape[1] | ||
# because of padding, that's why | ||
# we need to reconstruct this information | ||
# This won't work with left padding (which doesn't exist right now) | ||
right_n = total_n - right | ||
logits = logits[:, left:right_n] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice, yes I think that's the only approach that'll work right now There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Padding is not possible here sadly There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Padding works, it's handled by the batching mecanism. All I mentionned here is why we need so much information ( We can make it work, relatively trivially for left padding. if self.feature_extractor.padding_side == "left":
left = logits.shape[1] - total_n + left_n
right = logits.shaoe[1] - total_n + right_n
else:
left = left_n
right_n = total_n - right_n Just thought this was overly complex since left padding doesn't seem likely here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I see - ok yeah |
||
final_logits.append(logits) | ||
logits = np.concatenate(final_logits, axis=1) | ||
logits = logits.squeeze(0) | ||
text = self.decoder.decode_beams(logits)[0][0] | ||
else: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What exactly does the function do? Could we add some docstring? Also I don't understand the name
audio_to_logits
reallyThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's just change the stride number to go from
audio space
(10s at 16_000 means (160_000, 8_000, 8_000)) for instance. tologits_space
(2333, 160, 160) for instance.Do you think of a better name ? Doctstring could help a little.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see. Maybe
get_output_stride_from_input_stride(input_shape, stride)
and directly pass the shape? Yeah think a little docstring can help here