Skip to content

Commit

Permalink
Cleanup#2
Browse files Browse the repository at this point in the history
Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
  • Loading branch information
borisfom committed Jul 6, 2023
1 parent 5c3fb5f commit fd913fd
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 6 deletions.
1 change: 0 additions & 1 deletion nemo/collections/asr/modules/conformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,6 @@ def forward_internal(
if cache_last_channel is not None:
cache_len = self.streaming_cfg.last_channel_cache_size
cache_keep_size = max_audio_length - self.streaming_cfg.cache_drop_size
# cache_last_channel_next = torch.zeros_like(cache_last_channel)
max_audio_length = max_audio_length + cache_len
padding_length = length + cache_len
offset = torch.neg(cache_last_channel_len) + cache_len
Expand Down
8 changes: 3 additions & 5 deletions nemo/collections/asr/parts/submodules/conformer_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,7 @@ def __init__(
self.dropout = nn.Dropout(dropout)
self.norm_out = LayerNorm(d_model)

def forward(
self, x, att_mask=None, pos_emb=None, pad_mask=None, cache_last_channel=None, cache_last_time=None,
):
def forward(self, x, att_mask=None, pos_emb=None, pad_mask=None, cache_last_channel=None, cache_last_time=None):
"""
Args:
x (torch.Tensor): input signals (B, T, d_model)
Expand All @@ -161,9 +159,9 @@ def forward(

x = self.norm_self_att(residual)
if self.self_attention_model == 'rel_pos':
x = self.self_attn(query=x, key=x, value=x, mask=att_mask, pos_emb=pos_emb, cache=cache_last_channel,)
x = self.self_attn(query=x, key=x, value=x, mask=att_mask, pos_emb=pos_emb, cache=cache_last_channel)
elif self.self_attention_model == 'rel_pos_local_attn':
x = self.self_attn(query=x, key=x, value=x, pad_mask=pad_mask, pos_emb=pos_emb, cache=cache_last_channel,)
x = self.self_attn(query=x, key=x, value=x, pad_mask=pad_mask, pos_emb=pos_emb, cache=cache_last_channel)
elif self.self_attention_model == 'abs_pos':
x = self.self_attn(query=x, key=x, value=x, mask=att_mask, cache=cache_last_channel)
else:
Expand Down

0 comments on commit fd913fd

Please sign in to comment.