Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support exporting to ncnn format via PNNX #571

Merged
merged 4 commits into from
Sep 20, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ log
*.bak
*-bak
*bak.py
*.param
*.bin
41 changes: 35 additions & 6 deletions egs/librispeech/ASR/lstm_transducer_stateless/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ class RNN(EncoderInterface):
Period of auxiliary layers used for random combiner during training.
If set to 0, will not use the random combiner (Default).
You can set a positive integer to use the random combiner, e.g., 3.
is_pnnx:
True to make this class exportable via PNNX.
"""

def __init__(
Expand All @@ -129,6 +131,7 @@ def __init__(
dropout: float = 0.1,
layer_dropout: float = 0.075,
aux_layer_period: int = 0,
is_pnnx: bool = False,
) -> None:
super(RNN, self).__init__()

Expand All @@ -142,7 +145,13 @@ def __init__(
# That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_features -> d_model
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.encoder_embed = Conv2dSubsampling(
num_features,
d_model,
is_pnnx=is_pnnx,
)

self.is_pnnx = is_pnnx

self.num_encoder_layers = num_encoder_layers
self.d_model = d_model
Expand Down Expand Up @@ -209,7 +218,13 @@ def forward(
# lengths = ((x_lens - 3) // 2 - 1) // 2 # issue an warning
#
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
lengths = (((x_lens - 3) >> 1) - 1) >> 1
if not self.is_pnnx:
lengths = (((x_lens - 3) >> 1) - 1) >> 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so // (integer division) not supported either in ncnn?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, ncnn's binary op only supports floating point operation.
Also, >> the right shift operator is not supported either in ncnn.
(Note: we can extend ncnn to support >> if there is a need.)

else:
lengths1 = torch.floor((x_lens - 3) / 2)
lengths = torch.floor((lengths1 - 1) / 2)
lengths = lengths.to(x_lens)

if not torch.jit.is_tracing():
assert x.size(0) == lengths.max().item()

Expand Down Expand Up @@ -359,7 +374,7 @@ def forward(
# for cell state
assert states[1].shape == (1, src.size(1), self.rnn_hidden_size)
src_lstm, new_states = self.lstm(src, states)
src = src + self.dropout(src_lstm)
src = self.dropout(src_lstm) + src

# feed forward module
src = src + self.dropout(self.feed_forward(src))
Expand Down Expand Up @@ -505,6 +520,7 @@ def __init__(
layer1_channels: int = 8,
layer2_channels: int = 32,
layer3_channels: int = 128,
is_pnnx: bool = False,
) -> None:
"""
Args:
Expand All @@ -517,6 +533,9 @@ def __init__(
Number of channels in layer1
layer1_channels:
Number of channels in layer2
is_pnnx:
True if we are converting the model to PNNX format.
False otherwise.
"""
assert in_channels >= 9
super().__init__()
Expand Down Expand Up @@ -559,6 +578,10 @@ def __init__(
channel_dim=-1, min_positive=0.45, max_positive=0.55
)

# ncnn support only batch size == 1
self.is_pnnx = is_pnnx
self.conv_out_dim = self.out.weight.shape[1]

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.

Expand All @@ -572,9 +595,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
x = self.conv(x)
# Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))

if torch.jit.is_tracing() and self.is_pnnx:
x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim)
x = self.out(x)
else:
# Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))

# Now x is of shape (N, ((T-3)//2-1))//2, odim)
x = self.out_norm(x)
x = self.out_balancer(x)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import torch
import torch.nn as nn
from scaling import (
BasicNorm,
ScaledConv1d,
ScaledConv2d,
ScaledEmbedding,
Expand All @@ -38,6 +39,29 @@
)


class NonScaledNorm(nn.Module):
"""See BasicNorm for doc"""

def __init__(
self,
num_channels: int,
eps_exp: float,
channel_dim: int = -1, # CAUTION: see documentation.
):
super().__init__()
self.num_channels = num_channels
self.channel_dim = channel_dim
self.eps_exp = eps_exp

def forward(self, x: torch.Tensor) -> torch.Tensor:
if not torch.jit.is_tracing():
assert x.shape[self.channel_dim] == self.num_channels
scales = (
torch.mean(x * x, dim=self.channel_dim, keepdim=True) + self.eps_exp
).pow(-0.5)
return x * scales


def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear:
"""Convert an instance of ScaledLinear to nn.Linear.

Expand Down Expand Up @@ -174,6 +198,16 @@ def scaled_embedding_to_embedding(
return embedding


def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm:
assert isinstance(basic_norm, BasicNorm), type(BasicNorm)
norm = NonScaledNorm(
num_channels=basic_norm.num_channels,
eps_exp=basic_norm.eps.data.exp().item(),
channel_dim=basic_norm.channel_dim,
)
return norm


def scaled_lstm_to_lstm(scaled_lstm: ScaledLSTM) -> nn.LSTM:
"""Convert an instance of ScaledLSTM to nn.LSTM.

Expand Down Expand Up @@ -256,6 +290,8 @@ def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
d[name] = scaled_conv2d_to_conv2d(m)
elif isinstance(m, ScaledEmbedding):
d[name] = scaled_embedding_to_embedding(m)
elif isinstance(m, BasicNorm):
d[name] = convert_basic_norm(m)
elif isinstance(m, ScaledLSTM):
d[name] = scaled_lstm_to_lstm(m)

Expand Down