Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hotfix chunk_length_s instead of _ms. #15029

Merged
merged 4 commits into from
Jan 4, 2022

Conversation

Narsil
Copy link
Contributor

@Narsil Narsil commented Jan 4, 2022

And fixes issue that the filled in token was the padded_token which could lead to incorrect decoding.
Using the same token for CTC is better (prevent extra repetition)

What does this PR do?

Fixes # (issue)

@patrickvonplaten

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

The input length for in each chunk. If `0` then chunking is disabled (default). Only available for CTC
models.
stride_length_ms (`int`, *optional*, defaults to `chunk_length_ms / 6`):
The length of stride on the left and right of each chunk. Used only with `chunk_length_ms > 0`. This
stride_length_s (`int`, *optional*, defaults to `chunk_length_s / 6`):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type should be float IMO

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True !

inputs_len = len(inputs)
chunk_len = chunk_length_ms * self.feature_extractor.sampling_rate // 1000
stride_len = stride_length_ms * self.feature_extractor.sampling_rate // 1000
chunk_len = chunk_length_s * self.feature_extractor.sampling_rate
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
chunk_len = chunk_length_s * self.feature_extractor.sampling_rate
chunk_len = int(chunk_length_s * self.feature_extractor.sampling_rate)

is safer no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, need to round too actually !

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

@patrickvonplaten
Copy link
Contributor

Feel free to merge as soon as the tests are passing :-)

@Narsil Narsil merged commit 19d37c2 into huggingface:master Jan 4, 2022
@Narsil Narsil deleted the hotfix_chunking branch January 4, 2022 13:07
stevhliu pushed a commit to stevhliu/transformers that referenced this pull request Jan 6, 2022
* Hotfix `chunk_length_s` instead of `_ms`.

* Adding fix of `pad_token` which should be last/previous token for CTC

proper decoding

* Fixing ChunkPipeline unwrapping.

* Adding a PackIterator specific test.

def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right):
inputs_len = inputs.shape[0]
step = chunk_len - stride_left - stride_right
Copy link

@systemdevart systemdevart Mar 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does stride mean in this context? Because step doesn't equal chunk_size + stride_left, it is kind of hard to interpret what exactly stride_left or stride_right represent.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you calculate the step in this way, the actual overlap with the left chunk would be 2 * (stride_left + stride_right).
Consider, for example, chunk_len = 15, stride_left = 3, stride_right = 3, and step = 15-3-3 = 9.
0 chunk: start=0, end=15,
1 chunk: start=9, end=24,
2 chunk: start=18, end=33

For 1 chunk, the overlap with the 0 chunk is 6, and the overlap with the 2 chunk is 6, totaling 12 of overlap, but the intended overlap was 3 with the left chunk and 3 with the right chunk, totaling 5 of overlap.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants