-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hybrid Autoregressive Transducer (HAT) #6260
Conversation
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs a bit of refactoring
nemo/collections/asr/modules/hat.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Filename should be full "hybrid_autoregressive_transducer.py"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
nemo/collections/asr/modules/hat.py
Outdated
from nemo.utils import logging | ||
|
||
|
||
class HATJoint(rnnt_abstract.AbstractRNNTJoint, Exportable, AdapterModuleMixin): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class is duplicating a lot of code from RNNTJoint. Would it make sense to subclass it ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great comment. I took the RNNTJoint as a parent class and left only several modifications for new HATJoint class. Check it pls.
@@ -460,7 +466,12 @@ def greedy_search( | |||
|
|||
# TODO: Figure out how to remove this hard coding afterwords | |||
while not_blank and (symbols_added < 5): | |||
ytu = torch.log_softmax(self.joint.joint(hi, y) / self.softmax_temperature, dim=-1) # [1, 1, 1, V + 1] | |||
if isinstance(self.joint, HATJoint): | |||
ytu, _ = self.joint.joint(hi, y) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This kinda logic is problematic in the long run. Why not take a bool in the HAT module that determine what self.joint returns - by default it's set and returns both items, otherwise return things in the form of RNNT so that this code doesn't need to change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought that jit compiler does not like variable outputs number. Now I made default mode -- return only logprobs (like the standard rnnt joint) and return both logprobs and internal_lm_logprobs (in case of special boolean flag). This is allowed to save more rnnt decoding code unchanged.
@@ -34,6 +34,7 @@ | |||
from omegaconf import DictConfig | |||
|
|||
from nemo.collections.asr.modules import rnnt_abstract | |||
from nemo.collections.asr.modules.hat import HATJoint |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No modules should be imported inside of Greedy of Beam decoding libraries because it will eventually cause circular dependency
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line is no longer needed due to the new default hat.joint.joint
logic (the same as rnnt).
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
for more information, see https://pre-commit.ci
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CodeQL found more than 10 potential problems in the proposed changes. Check the Files changed tab for more details.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently the code is too circular for HAT import. Another thing is it requires too many modifications to an already very complicated function (mAES).
The first thing we can make more generic with dataclass and property trucks. Those changes are relatively simple but require some refactor.
The second one I dunno how to make more generic. Perhaps an abstract method inside of AbstractRNNTJoint that discussed how to do special forward of joint ? That's a heavy refactor so ignore it for now.
@@ -34,6 +34,7 @@ | |||
import torch | |||
from tqdm import tqdm | |||
|
|||
from nemo.collections.asr.modules import hybrid_autoregressive_transducer as hat |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm so this import doesn't actually fix circular import - think of it like this
RNNTModel needs EncDecJoint, Loss, Decoding, Metric
Decoding depends on Decoder + Joint
Metric depends on Decoding.
Joint depends on loss and metric.
But now decoding itself imports the joint module. That's fine for now but can be more circular and crash in the future. I'll discuss an alternative below
|
||
res = torch.cat((label_logprob_scaled, blank_logprob), dim=-1).contiguous() # [B, T, U, V+1] | ||
|
||
if return_ilm: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case, it seems incorrect to return a tuple here. Let's do this instead -
In rnnt_utils.py create a dataclass call HATJointOutput. It has just two value - a tensor for logprobs and a tensor for ilm. Both are none by default.
If return_ilm property of this class is set, you will build an object of this dataclass, put the two values and return that
More details below
|
||
def joint( | ||
self, f: torch.Tensor, g: torch.Tensor, return_ilm: bool = False | ||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove return_ilm from here, use the properties
beam_logp, beam_idx = torch.log_softmax( | ||
self.joint.joint(beam_enc_out, beam_dec_out) / self.softmax_temperature, dim=-1, | ||
).topk(self.max_candidates, dim=-1) | ||
if isinstance(self.joint, hat.HATJoint) and self.hat_subtract_ilm: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, everywhere, simply call the self.joint.joint(with the ordinary arguments for RNNT). The output can now be either torch.Tensor - (RNNT joint, HAT without the ILM subtract) or it can be HATOutput dataclass.
import RNNT utils and then check if torch.is_tensor(output) here - this is for og RNNT. Elif self.hat_subtract_ilm and isinstance(output, HATOutput):
Then do the required code path. On else path, raise error saying could not resolve the output
@@ -1196,7 +1206,12 @@ def modified_adaptive_expansion_search( | |||
lm_score, new_hyp.ngram_lm_state = self.compute_ngram_score( | |||
hyp.ngram_lm_state, int(k) | |||
) | |||
new_hyp.score += self.ngram_lm_alpha * lm_score | |||
if isinstance(self.joint, hat.HATJoint) and self.hat_subtract_ilm: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same for everywhere else below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @titu1994! Thank you for detailed review. I tried to modify HAT related code according to your suggestions. For convenience I also added resolve_joint_output
function. Check it pls.
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
for more information, see https://pre-commit.ci
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
from nemo.collections.asr.modules import rnnt | ||
from nemo.collections.asr.parts.utils.rnnt_utils import HATJointOutput | ||
|
||
from nemo.utils import logging |
Check notice
Code scanning / CodeQL
Unused import
self.pred, self.enc, self.joint_net, self.blank_pred = self._joint_hat_net_modules( | ||
num_classes=self._vocab_size, # non blank symbol | ||
pred_n_hidden=self.pred_hidden, | ||
enc_n_hidden=self.encoder_hidden, | ||
joint_n_hidden=self.joint_hidden, | ||
activation=self.activation, | ||
dropout=jointnet.get('dropout', 0.0), | ||
) |
Check warning
Code scanning / CodeQL
Overwriting attribute in super-class or sub-class
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Any, Dict, List, Optional, Tuple, Union |
Check notice
Code scanning / CodeQL
Unused import
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks very good now, could you add some tests that assert that normally forward returns tensor and hat forward with and without flag set returns either tensor or HATJointOutput.
Another thing is we support only mAES and normal beam, can you look into the complexity of the other beam algos to support hat ? If it's difficult we can leave it to another pr in the future
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see we support only basic beam and maes. Can you look into supporting HAT with other algos ? If it's simple, it can be done in this pr, if not in another pr.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean n-gram LM fusion (with RNNT and HAT) for other decoding algorithms? Now only maes
algorithm supports LM fusion. I did not do it for default beam search
because it works too slow. I do not think anyone wants to use it because of speed.
BTW, all the decoding algorithms can work now with HAT model without LM fusion because HATJoint has the same default output type like RNNTJoint.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh ok sounds good then
Hi @andrusenkoau , the HATJoint returns log softmaxed |
@kobenaxie that's a template implementation of the loss using pure PyTorch, it is not used during actual training since it is super slow. Instead we use numba bases cuda compiled loss. Also, hat during training does not return the dataclass (which the loss anyway would not accept) so it is fine |
Looks great ! |
Final things to do are to add HAT decoder based conformer config to a conf dir called conf/hat_transducer/conformer/conformer_hat_bpe.yaml / char.yaml |
That can be done when release bench is cut. |
Hi @kobenaxie, HAT logic demands to work in the probability domain in order to calculate blank probability and then scale labels probability. For the implementation simplicity we can use the rule -- |
@titu1994 thank you so much for great review and help with code modification! |
* add hat joint network Signed-off-by: andrusenkoau <andrusenkoau@gmail.com> * add HATJoint module Signed-off-by: andrusenkoau <andrusenkoau@gmail.com> * add hat script Signed-off-by: andrusenkoau <andrusenkoau@gmail.com> * add hat decoding option Signed-off-by: andrusenkoau <andrusenkoau@gmail.com> * add hat related parameters to maes decoding Signed-off-by: andrusenkoau <andrusenkoau@gmail.com> * minor fixes Signed-off-by: andrusenkoau <andrusenkoau@gmail.com> * add hat decoding option Signed-off-by: andrusenkoau <andrusenkoau@gmail.com> * minor fixes Signed-off-by: andrusenkoau <andrusenkoau@gmail.com> * add hat related parameters Signed-off-by: andrusenkoau <andrusenkoau@gmail.com> * minor fixes Signed-off-by: andrusenkoau <andrusenkoau@gmail.com> * add hat to all rnnt decoding types Signed-off-by: andrusenkoau <andrusenkoau@gmail.com> * add test for hatjoint Signed-off-by: andrusenkoau <andrusenkoau@gmail.com> * combine hatjoint with all rnntjoint tests Signed-off-by: andrusenkoau <andrusenkoau@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fixes Signed-off-by: andrusenkoau <andrusenkoau@gmail.com> * rename hat file Signed-off-by: andrusenkoau <andrusenkoau@gmail.com> * fix hat double output Signed-off-by: andrusenkoau <andrusenkoau@gmail.com> * fix hat double output Signed-off-by: andrusenkoau <andrusenkoau@gmail.com> * fix hat double output Signed-off-by: andrusenkoau <andrusenkoau@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py Signed-off-by: andrusenkoau <andrusenkoau@gmail.com> * minor fixes Signed-off-by: andrusenkoau <andrusenkoau@gmail.com> * add return_hat_ilm property Signed-off-by: andrusenkoau <andrusenkoau@gmail.com> * add HATJointOutput dataclass Signed-off-by: andrusenkoau <andrusenkoau@gmail.com> * add resolve_joint_output function Signed-off-by: andrusenkoau <andrusenkoau@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add local return_hat_ilm_default variable Signed-off-by: andrusenkoau <andrusenkoau@gmail.com> * minor fixes Signed-off-by: andrusenkoau <andrusenkoau@gmail.com> --------- Signed-off-by: andrusenkoau <andrusenkoau@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: hsiehjackson <c2hsieh@ucsd.edu>
What does this PR do ?
Add HAT model as a new joint network type (HATJoint) for RNNT model. The difference is only in decoding time -- HAT.joint.joint returns two outputs: hat_logprobs and internal_lm_logprobs (for internal lm subtraction in case of Shallow Fusion with external n-gram LM).
Collection: [ASR]
Usage
For HAT model training you need replace
_target_: nemo.collections.asr.modules.RNNTJoint
with_target_: nemo.collections.asr.modules.HATJoint
injoint
part of standard transducer config.For Shallow Fusion with external n-gram LM use RNNT
maes
decoding algorithm which is able to work with HATJoint model.# Add a code snippet demonstrating how to use this
Before your PR is "Ready for review"
Pre checks:
PR Type:
If you haven't finished some of the above items you can still open "Draft" PR.
Who can review?
Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.
Additional Information