Skip to content

Commit

Permalink
Apply minor fixes to Emformer implementation (#2252)
Browse files Browse the repository at this point in the history
Summary:
Noticed some items to clean up in `Emformer`.
- Make `segment_length` a required argument in `_EmformerLayer`.
- Remove unused variables from `_unpack_state` and `_gen_attention_mask`.

These don't affect `Emformer`'s functionality or public API.

Pull Request resolved: #2252

Reviewed By: carolineechen, mthrok

Differential Revision: D34321430

Pulled By: hwangjeff

fbshipit-source-id: 38a5046f633a3e625352c476ef71c78380ccc597
  • Loading branch information
hwangjeff authored and facebook-github-bot committed Feb 18, 2022
1 parent 3184aeb commit cbf1b83
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions torchaudio/models/emformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,11 +321,11 @@ class _EmformerLayer(torch.nn.Module):
input_dim (int): input dimension.
num_heads (int): number of attention heads.
ffn_dim: (int): hidden layer dimension of feedforward network.
segment_length (int): length of each input segment.
dropout (float, optional): dropout probability. (Default: 0.0)
activation (str, optional): activation function to use in feedforward network.
Must be one of ("relu", "gelu", "silu"). (Default: "relu")
left_context_length (int, optional): length of left context. (Default: 0)
segment_length (int, optional): length of each input segment. (Default: 128)
max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
weight_init_gain (float or None, optional): scale factor to apply when initializing
attention module parameters. (Default: ``None``)
Expand All @@ -338,10 +338,10 @@ def __init__(
input_dim: int,
num_heads: int,
ffn_dim: int,
segment_length: int,
dropout: float = 0.0,
activation: str = "relu",
left_context_length: int = 0,
segment_length: int = 128,
max_memory_size: int = 0,
weight_init_gain: Optional[float] = None,
tanh_on_mem: bool = False,
Expand Down Expand Up @@ -386,9 +386,7 @@ def _init_state(self, batch_size: int, device: Optional[torch.device]) -> List[t
past_length = torch.zeros(1, batch_size, dtype=torch.int32, device=device)
return [empty_memory, left_context_key, left_context_val, past_length]

def _unpack_state(
self, utterance: torch.Tensor, mems: torch.Tensor, state: List[torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def _unpack_state(self, state: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
past_length = state[3][0][0].item()
past_left_context_length = min(self.left_context_length, past_length)
past_mem_length = min(self.max_memory_size, math.ceil(past_length / self.segment_length))
Expand Down Expand Up @@ -474,7 +472,7 @@ def _apply_attention_infer(
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
if state is None:
state = self._init_state(utterance.size(1), device=utterance.device)
pre_mems, lc_key, lc_val = self._unpack_state(utterance, mems, state)
pre_mems, lc_key, lc_val = self._unpack_state(state)
if self.use_mem:
summary = self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
summary = summary[:1]
Expand Down Expand Up @@ -652,10 +650,10 @@ def __init__(
input_dim,
num_heads,
ffn_dim,
segment_length,
dropout=dropout,
activation=activation,
left_context_length=left_context_length,
segment_length=segment_length,
max_memory_size=max_memory_size,
weight_init_gain=weight_init_gains[layer_idx],
tanh_on_mem=tanh_on_mem,
Expand Down Expand Up @@ -718,7 +716,7 @@ def _gen_attention_mask_col_widths(self, seg_idx: int, utterance_length: int) ->
return col_widths

def _gen_attention_mask(self, input: torch.Tensor) -> torch.Tensor:
utterance_length, batch_size, _ = input.shape
utterance_length = input.size(0)
num_segs = math.ceil(utterance_length / self.segment_length)

rc_mask = []
Expand Down

0 comments on commit cbf1b83

Please sign in to comment.