Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Conformer RNN-T model prototype #2322

Closed
wants to merge 9 commits into from

Conversation

hwangjeff
Copy link
Contributor

@hwangjeff hwangjeff commented Apr 7, 2022

Adds Conformer RNN-T model as prototype feature, by way of factory functions conformer_rnnt_model and conformer_rnnt_base, which instantiates a baseline version of the model. Also includes the following:

  • Modifies Conformer to accept arguments use_group_norm and convolution_first to pass to each of its ConformerLayer instances.
  • Makes _Predictor an abstract class and introduces _EmformerEncoder and _ConformerEncoder.
  • Introduces tests for conformer_rnnt_model.
  • Adds docs.

@hwangjeff hwangjeff force-pushed the conformer_rnnt_prototype branch 2 times, most recently from eb2c046 to c255d6b Compare April 7, 2022 05:03
@hwangjeff hwangjeff force-pushed the conformer_rnnt_prototype branch from c255d6b to 4e31ff2 Compare April 7, 2022 06:07
@hwangjeff hwangjeff force-pushed the conformer_rnnt_prototype branch from 5619f61 to 490a981 Compare April 8, 2022 19:58
@hwangjeff hwangjeff force-pushed the conformer_rnnt_prototype branch 2 times, most recently from ef91d94 to 0514f97 Compare April 9, 2022 15:21
@hwangjeff hwangjeff force-pushed the conformer_rnnt_prototype branch from 0514f97 to 84e6107 Compare April 9, 2022 19:34
@hwangjeff hwangjeff marked this pull request as ready for review April 11, 2022 14:29
@xiaohui-zhang
Copy link
Contributor

otherwise the PR looks good

from torchaudio.models.rnnt import _Joiner, _Predictor, _TimeReduction, _Transcriber


class _ConformerTranscriber(torch.nn.Module, _Transcriber):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

encoder might be a more standard name than transcriber?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i can rename the class and associated variables in another pr, as there may be bc-breaking implications

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, renamed the concrete classes here — will handle the abstract class separately

@hwangjeff hwangjeff force-pushed the conformer_rnnt_prototype branch from 2bdbe42 to 1c44c67 Compare April 11, 2022 19:29
@facebook-github-bot
Copy link
Contributor

@hwangjeff has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

1 similar comment
@facebook-github-bot
Copy link
Contributor

@hwangjeff has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@github-actions
Copy link

Hey @hwangjeff.
You merged this PR, but labels were not properly added. Please add a primary and secondary label (See https://github.com/pytorch/audio/blob/main/.github/process_commit.py)

lstm_dropout: int,
joiner_activation: str,
) -> RNNT:
r"""Builds Conformer-based recurrent neural network transducer (RNN-T) model.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.



class ConformerRNNTTestImpl(TestBaseMixin):
def _get_input_config(self):
Copy link
Collaborator

@mthrok mthrok Apr 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I'd move these helper functions to module-level, so that the the code around the test class is more focused on test logic.

xiaohui-zhang pushed a commit to xiaohui-zhang/audio that referenced this pull request May 4, 2022
Summary:
Adds Conformer RNN-T model as prototype feature, by way of factory functions `conformer_rnnt_model` and `conformer_rnnt_base`, which instantiates a baseline version of the model. Also includes the following:
- Modifies `Conformer` to accept arguments `use_group_norm` and `convolution_first` to pass to each of its `ConformerLayer` instances.
- Makes `_Predictor` an abstract class and introduces `_EmformerEncoder` and `_ConformerEncoder`.
- Introduces tests for `conformer_rnnt_model`.
- Adds docs.

Pull Request resolved: pytorch#2322

Reviewed By: xiaohui-zhang

Differential Revision: D35565987

Pulled By: hwangjeff

fbshipit-source-id: cb37bb0477ae3d5fcf0b7124f334f4cbb89b5789
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants