Skip to content

Commit

Permalink
encoder interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Mar 20, 2022
1 parent 83286fd commit 69cd23c
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 1 deletion.
2 changes: 2 additions & 0 deletions nn/encoder/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""
Code to create encoders (for hybrid, the encoder of encoder-decoder-attention, or also transducer).
"""

from .base import *
40 changes: 39 additions & 1 deletion nn/encoder/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,55 @@
you only care about some encoded vector of type :class:`Tensor`.
"""

from typing import Tuple
from abc import ABC
from ... import nn


class IEncoder(nn.Module):
class IEncoder(nn.Module, ABC):
"""
Generic encoder interface
The encoder is a function x -> y.
The input can potentially be sparse or dense.
The output is dense with feature dim `out_dim`.
"""

out_dim: nn.Dim

@nn.scoped
def __call__(self, source: nn.Tensor) -> nn.Tensor:
"""
Encode the input
"""
raise NotImplementedError


class ISeqFramewiseEncoder(nn.Module, ABC):
"""
This specializes IEncoder that it operates on a sequence.
The output sequence length here is the same as the input.
"""

out_dim: nn.Dim

@nn.scoped
def __call__(self, source: nn.Tensor, *, spatial_dim: nn.Dim) -> nn.Tensor:
raise NotImplementedError


class ISeqDownsamplingEncoder(nn.Module, ABC):
"""
This is more specific than IEncoder in that it operates on a sequence.
The output sequence length here is shorter than the input.
This is a common scenario for speech recognition
where the input might be on 10ms/frame
and the output might cover 30ms/frame or 60ms/frame or so.
"""

out_dim: nn.Dim

@nn.scoped
def __call__(self, source: nn.Tensor, *, in_spatial_dim: nn.Dim) -> Tuple[nn.Tensor, nn.Dim]:
raise NotImplementedError

0 comments on commit 69cd23c

Please sign in to comment.