Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
  • Loading branch information
borisfom committed Jul 6, 2023
1 parent 25ff2bd commit 5c3fb5f
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 32 deletions.
31 changes: 8 additions & 23 deletions nemo/collections/asr/modules/conformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,14 +557,14 @@ def forward_internal(
if att_mask is not None:
att_mask = att_mask[:, cache_len:]
# Convert caches from the tensor to list
cache_last_time_next = [None] * self.streaming_cfg.last_time_num
cache_last_channel_next = [None] * self.streaming_cfg.last_channel_num
cache_last_time_next = []
cache_last_channel_next = []

for lth, (drop_prob, layer) in enumerate(zip(self.layer_drop_probs, self.layers)):
original_signal = audio_signal
if cache_last_channel is not None:
cache_last_channel_cur = cache_last_channel[layer.self_attn._cache_id]
cache_last_time_cur = cache_last_time[layer.conv.depthwise_conv._cache_id]
cache_last_channel_cur = cache_last_channel[lth]
cache_last_time_cur = cache_last_time[lth]
else:
cache_last_channel_cur = None
cache_last_time_cur = None
Expand All @@ -579,8 +579,8 @@ def forward_internal(

if cache_last_channel_cur is not None:
(audio_signal, cache_last_channel_cur, cache_last_time_cur) = audio_signal
cache_last_channel_next[layer.self_attn._cache_id] = cache_last_channel_cur
cache_last_time_next[layer.conv.depthwise_conv._cache_id] = cache_last_time_cur
cache_last_channel_next.append(cache_last_channel_cur)
cache_last_time_next.append(cache_last_time_cur)

# applying stochastic depth logic from https://arxiv.org/abs/2102.03216
if self.training and drop_prob > 0.0:
Expand Down Expand Up @@ -870,20 +870,12 @@ def setup_streaming_params(
else:
streaming_cfg.drop_extra_pre_encoded = streaming_cfg.pre_encode_cache_size // self.subsampling_factor

# counting the number of the layers need caching
streaming_cfg.last_channel_num = 0
streaming_cfg.last_time_num = 0
for m in self.layers.modules():
if hasattr(m, "_max_cache_len"):
if isinstance(m, MultiHeadAttention):
m._cache_id = streaming_cfg.last_channel_num
m.cache_drop_size = streaming_cfg.cache_drop_size
streaming_cfg.last_channel_num += 1

if isinstance(m, CausalConv1D):
m._cache_id = streaming_cfg.last_time_num
m.cache_drop_size = streaming_cfg.cache_drop_size
streaming_cfg.last_time_num += 1

self.streaming_cfg = streaming_cfg

Expand All @@ -896,19 +888,12 @@ def get_initial_cache_state(self, batch_size=1, dtype=torch.float32, device=None
create_tensor = torch.zeros
last_time_cache_size = self.conv_context_size[0]
cache_last_channel = create_tensor(
(
self.streaming_cfg.last_channel_num,
batch_size,
self.streaming_cfg.last_channel_cache_size,
self.d_model,
),
(len(self.layers), batch_size, self.streaming_cfg.last_channel_cache_size, self.d_model,),
device=device,
dtype=dtype,
)
cache_last_time = create_tensor(
(self.streaming_cfg.last_time_num, batch_size, self.d_model, last_time_cache_size),
device=device,
dtype=dtype,
(len(self.layers), batch_size, self.d_model, last_time_cache_size), device=device, dtype=dtype,
)
if max_dim > 0:
cache_last_channel_len = torch.randint(
Expand Down
12 changes: 4 additions & 8 deletions nemo/collections/asr/parts/submodules/causal_convs.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def __init__(
raise ValueError("Argument padding should be set to None for CausalConv2D.")
self._left_padding = kernel_size - 1
self._right_padding = stride - 1
self._cache_id = None

padding = 0
super(CausalConv2D, self).__init__(
Expand Down Expand Up @@ -113,7 +112,6 @@ def __init__(
raise ValueError(f"Invalid padding param: {padding}!")

self._max_cache_len = self._left_padding
self._cache_id = None

super(CausalConv1D, self).__init__(
in_channels=in_channels,
Expand All @@ -134,18 +132,16 @@ def update_cache(self, x, cache=None):
new_x = F.pad(x, pad=(self._left_padding, self._right_padding))
else:
new_x = F.pad(x, pad=(0, self._right_padding))

# todo: we should know input_x.size(-1) at config time
# cache_keep_size = x.size(-1) - self.cache_drop_size, dtype=torch.int64, device=x.device)
# print("cache:", cache.size(), "x:", x.size(), "new_x:", new_x.size())
new_x = torch.cat([cache, new_x], dim=-1)
if self.cache_drop_size > 0:
x = x[:, :, : -self.cache_drop_size]
cache = torch.cat([cache[:, :, x.size(-1) :], x], dim=-1)
# print("cache size: ", cache.size())
return new_x, cache

def forward(self, x, cache=None):
x, cache = self.update_cache(x, cache=cache)
x = super().forward(x)
return x, cache
if cache is None:
return x
else:
return x, cache
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def __init__(self, n_head, n_feat, dropout_rate, max_cache_len=0):
self.dropout = nn.Dropout(p=dropout_rate)

self._max_cache_len = max_cache_len
self._cache_id = None

def forward_qkv(self, query, key, value):
"""Transforms query, key and value.
Expand Down

0 comments on commit 5c3fb5f

Please sign in to comment.