-
Notifications
You must be signed in to change notification settings - Fork 684
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Conformer RNN-T model prototype #2322
Conversation
2245f8c
to
444a87d
Compare
eb2c046
to
c255d6b
Compare
…d _ConformerTranscriber
c255d6b
to
4e31ff2
Compare
5619f61
to
490a981
Compare
ef91d94
to
0514f97
Compare
0514f97
to
84e6107
Compare
otherwise the PR looks good |
torchaudio/prototype/models/rnnt.py
Outdated
from torchaudio.models.rnnt import _Joiner, _Predictor, _TimeReduction, _Transcriber | ||
|
||
|
||
class _ConformerTranscriber(torch.nn.Module, _Transcriber): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
encoder might be a more standard name than transcriber?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i can rename the class and associated variables in another pr, as there may be bc-breaking implications
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually, renamed the concrete classes here — will handle the abstract class separately
2bdbe42
to
1c44c67
Compare
@hwangjeff has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
1 similar comment
@hwangjeff has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Hey @hwangjeff. |
lstm_dropout: int, | ||
joiner_activation: str, | ||
) -> RNNT: | ||
r"""Builds Conformer-based recurrent neural network transducer (RNN-T) model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to overwrite the return signature otherwise the documentation shows wrong module path.
|
||
|
||
class ConformerRNNTTestImpl(TestBaseMixin): | ||
def _get_input_config(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I'd move these helper functions to module-level, so that the the code around the test class is more focused on test logic.
Summary: Adds Conformer RNN-T model as prototype feature, by way of factory functions `conformer_rnnt_model` and `conformer_rnnt_base`, which instantiates a baseline version of the model. Also includes the following: - Modifies `Conformer` to accept arguments `use_group_norm` and `convolution_first` to pass to each of its `ConformerLayer` instances. - Makes `_Predictor` an abstract class and introduces `_EmformerEncoder` and `_ConformerEncoder`. - Introduces tests for `conformer_rnnt_model`. - Adds docs. Pull Request resolved: pytorch#2322 Reviewed By: xiaohui-zhang Differential Revision: D35565987 Pulled By: hwangjeff fbshipit-source-id: cb37bb0477ae3d5fcf0b7124f334f4cbb89b5789
Adds Conformer RNN-T model as prototype feature, by way of factory functions
conformer_rnnt_model
andconformer_rnnt_base
, which instantiates a baseline version of the model. Also includes the following:Conformer
to accept argumentsuse_group_norm
andconvolution_first
to pass to each of itsConformerLayer
instances._Predictor
an abstract class and introduces_EmformerEncoder
and_ConformerEncoder
.conformer_rnnt_model
.