Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplified memory bank for Emformer #440

Merged
merged 30 commits into from
Jul 12, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
9c39d8b
Merge remote-tracking branch 'k2-fsa/master'
yaozengwei Apr 29, 2022
70634d5
Merge remote-tracking branch 'k2-fsa/master'
yaozengwei May 6, 2022
ecfb3e9
Merge remote-tracking branch 'k2-fsa/master'
yaozengwei May 7, 2022
bcef517
Merge remote-tracking branch 'k2-fsa/master'
yaozengwei May 12, 2022
c9d84ae
Merge remote-tracking branch 'k2-fsa/master'
yaozengwei May 15, 2022
fbbc24f
Merge remote-tracking branch 'k2-fsa/master'
yaozengwei May 26, 2022
5453166
Merge remote-tracking branch 'origin/master'
yaozengwei May 26, 2022
bb7ea31
Merge remote-tracking branch 'k2-fsa/master'
yaozengwei May 31, 2022
2a5a70e
Merge remote-tracking branch 'k2-fsa/master'
yaozengwei Jun 13, 2022
ec8646d
Merge remote-tracking branch 'k2-fsa/master'
yaozengwei Jun 13, 2022
1c067e7
init files
yaozengwei Jun 13, 2022
193b44e
use average value as memory vector for each chunk
yaozengwei Jun 13, 2022
5d877ef
change tail padding length from right_context_length to chunk_length
yaozengwei Jun 17, 2022
c27bb1c
correct the files, ln -> cp
yaozengwei Jun 17, 2022
208bbb6
fix bug in conv_emformer_transducer_stateless2/emformer.py
yaozengwei Jun 17, 2022
5b19011
fix doc in conv_emformer_transducer_stateless/emformer.py
yaozengwei Jun 21, 2022
42e3e88
refactor init states for stream
yaozengwei Jun 21, 2022
9c37c16
modify .flake8
yaozengwei Jun 22, 2022
10662c5
fix bug about memory mask when memory_size==0
yaozengwei Jul 4, 2022
dbea9a9
Merge remote-tracking branch 'k2-fsa/master' into emformer_conv_simpl…
yaozengwei Jul 5, 2022
1f6c822
add @torch.jit.export for init_states function
yaozengwei Jul 6, 2022
61794d8
update RESULTS.md
yaozengwei Jul 6, 2022
69a3ef3
minor change
yaozengwei Jul 6, 2022
f9c6014
update README.md
yaozengwei Jul 7, 2022
12c176c
modify doc
yaozengwei Jul 7, 2022
5cfdbd3
replace torch.div() with <<
yaozengwei Jul 8, 2022
2057124
fix bug, >> -> <<
yaozengwei Jul 8, 2022
e3e8b19
use i&i-1 to judge if it is a power of 2
yaozengwei Jul 8, 2022
ad68987
minor fix
yaozengwei Jul 8, 2022
1a44724
fix error in RESULTS.md
yaozengwei Jul 12, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1252,6 +1252,10 @@ def __init__(
):
super().__init__()

assert int(math.log(chunk_length, 2)) == math.log(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert (chunk_length - 1) & chunk_length == 0

chunk_length, 2
), "chunk_length should be a power of 2."

self.use_memory = memory_size > 0
self.init_memory_op = nn.AvgPool1d(
kernel_size=chunk_length,
Expand Down Expand Up @@ -1580,10 +1584,8 @@ def infer(
chunk_mask = make_pad_mask(output_lengths).to(x.device)
memory_mask = (
(
torch.div(
num_processed_frames,
self.chunk_length,
rounding_mode="floor",
(
num_processed_frames << int(math.log(self.chunk_length, 2))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please cache the result of int(math.log(self.chunk_length, 2))
in the constructor.

).view(x.size(1), 1)
<= torch.arange(self.memory_size, device=x.device).expand(
x.size(1), self.memory_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1188,6 +1188,10 @@ def __init__(
):
super().__init__()

assert int(math.log(chunk_length, 2)) == math.log(
chunk_length, 2
), "chunk_length should be a power of 2."

self.use_memory = memory_size > 0

self.emformer_layers = nn.ModuleList(
Expand Down Expand Up @@ -1488,10 +1492,8 @@ def infer(
chunk_mask = make_pad_mask(output_lengths).to(x.device)
memory_mask = (
(
torch.div(
num_processed_frames,
self.chunk_length,
rounding_mode="floor",
(
num_processed_frames << int(math.log(self.chunk_length, 2))
).view(x.size(1), 1)
<= torch.arange(self.memory_size, device=x.device).expand(
x.size(1), self.memory_size
Expand Down