Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Support exporting pretrained models to mace/ncnn and others via pnnx #527

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@
from icefall.utils import make_pad_mask, subsequent_chunk_mask


class MakePadMask(nn.Module):
def forward(self, lengths: Tensor) -> Tensor:
"""See doc for :func:`make_pad_mask`"""
return make_pad_mask(lengths)


class Conformer(EncoderInterface):
"""
Args:
Expand Down Expand Up @@ -86,6 +92,7 @@ def __init__(
short_chunk_size: int = 25,
num_left_chunks: int = -1,
causal: bool = False,
for_pnnx: bool = False,
) -> None:
super(Conformer, self).__init__()

Expand All @@ -99,7 +106,11 @@ def __init__(
# That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_features -> d_model
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.encoder_embed = Conv2dSubsampling(
num_features,
d_model,
for_pnnx=for_pnnx,
)

self.encoder_layers = num_encoder_layers
self.d_model = d_model
Expand All @@ -111,6 +122,7 @@ def __init__(
self.num_left_chunks = num_left_chunks

self.encoder_pos = RelPositionalEncoding(d_model, dropout)
self.make_pad_mask = MakePadMask()

encoder_layer = ConformerEncoderLayer(
d_model,
Expand All @@ -124,6 +136,8 @@ def __init__(
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
self._init_state: List[torch.Tensor] = [torch.empty(0)]

self.for_pnnx = for_pnnx

def forward(
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -153,12 +167,17 @@ def forward(
# lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning
#
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
lengths = (((x_lens - 1) >> 1) - 1) >> 1
if not self.for_pnnx:
lengths = (((x_lens - 1) >> 1) - 1) >> 1
else:
lengths1 = torch.floor((x_lens - 1) / 2)
lengths = torch.floor((lengths1 - 1) / 2)
lengths = lengths.to(x_lens)

if not torch.jit.is_tracing():
assert x.size(0) == lengths.max().item()

src_key_padding_mask = make_pad_mask(lengths)
src_key_padding_mask = self.make_pad_mask(lengths)

if self.dynamic_chunk_training:
assert (
Expand Down Expand Up @@ -798,7 +817,7 @@ def __init__(

self.d_model = d_model
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.pe = None
self.register_buffer("pe", None)
self.extend_pe(torch.tensor(0.0).expand(1, max_len))

def extend_pe(self, x: Tensor, left_context: int = 0) -> None:
Expand Down Expand Up @@ -1538,6 +1557,7 @@ def __init__(
layer1_channels: int = 8,
layer2_channels: int = 32,
layer3_channels: int = 128,
for_pnnx: bool = False,
) -> None:
"""
Args:
Expand Down Expand Up @@ -1592,6 +1612,10 @@ def __init__(
channel_dim=-1, min_positive=0.45, max_positive=0.55
)

# ncnn support only batch size == 1
self.for_pnnx = for_pnnx
self.conv_out_dim = self.out.weight.shape[1]

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.

Expand All @@ -1606,8 +1630,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
x = self.conv(x)
# Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if torch.jit.is_tracing() and self.for_pnnx:
x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim)
x = self.out(x)
else:
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
x = self.out_norm(x)
x = self.out_balancer(x)
Expand Down
10 changes: 5 additions & 5 deletions egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,16 @@ class BasicNorm(torch.nn.Module):
doesn't have to do this trick. We make the "eps" learnable.

Args:
num_channels: the number of channels, e.g. 512.
num_channels: the number of channels, e.g. 512.
channel_dim: the axis/dimension corresponding to the channel,
interprted as an offset from the input's ndim if negative.
shis is NOT the num_channels; it should typically be one of
interpreted as an offset from the input's ndim if negative.
This is NOT the num_channels; it should typically be one of
{-2, -1, 0, 1, 2, 3}.
eps: the initial "epsilon" that we add as ballast in:
eps: the initial "epsilon" that we add as ballast in:
scale = ((input_vec**2).mean() + epsilon)**-0.5
Note: our epsilon is actually large, but we keep the name
to indicate the connection with conventional LayerNorm.
learn_eps: if true, we learn epsilon; if false, we keep it
learn_eps: if true, we learn epsilon; if false, we keep it
at the initial value.
"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,36 @@

import torch
import torch.nn as nn
from scaling import ScaledConv1d, ScaledConv2d, ScaledEmbedding, ScaledLinear
from scaling import (
BasicNorm,
ScaledConv1d,
ScaledConv2d,
ScaledEmbedding,
ScaledLinear,
)


class NonScaledNorm(nn.Module):
"""See BasicNorm for doc"""

def __init__(
self,
num_channels: int,
eps_exp: float,
channel_dim: int = -1, # CAUTION: see documentation.
):
super().__init__()
self.num_channels = num_channels
self.channel_dim = channel_dim
self.eps_exp = eps_exp

def forward(self, x: torch.Tensor) -> torch.Tensor:
if not torch.jit.is_tracing():
assert x.shape[self.channel_dim] == self.num_channels
scales = (
torch.mean(x * x, dim=self.channel_dim, keepdim=True) + self.eps_exp
).pow(-0.5)
return x * scales


def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear:
Expand Down Expand Up @@ -164,6 +193,16 @@ def scaled_embedding_to_embedding(
return embedding


def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm:
assert isinstance(basic_norm, BasicNorm), type(BasicNorm)
norm = NonScaledNorm(
num_channels=basic_norm.num_channels,
eps_exp=basic_norm.eps.data.exp().item(),
channel_dim=basic_norm.channel_dim,
)
return norm


def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
"""Convert `ScaledLinear`, `ScaledConv1d`, and `ScaledConv2d`
in the given modle to their unscaled version `nn.Linear`, `nn.Conv1d`,
Expand Down Expand Up @@ -196,6 +235,8 @@ def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
d[name] = scaled_conv2d_to_conv2d(m)
elif isinstance(m, ScaledEmbedding):
d[name] = scaled_embedding_to_embedding(m)
elif isinstance(m, BasicNorm):
d[name] = convert_basic_norm(m)

for k, v in d.items():
if "." in k:
Expand Down
49 changes: 49 additions & 0 deletions egs/librispeech/ASR/pruned_transducer_stateless3/t2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#!/usr/bin/env python3

import math

import ncnn
import numpy as np
import torch

LOG_EPS = math.log(1e-10)


@torch.no_grad()
def main():
x = torch.rand(10, 3)
f = torch.jit.load("foo/encoder_pos.pt")

param = "foo/encoder_pos.ncnn.param"
model = "foo/encoder_pos.ncnn.bin"

with ncnn.Net() as net:
net.load_param(param)
net.load_model(model)
with net.create_extractor() as ex:
ex.input("in0", ncnn.Mat(x.numpy()).clone())
ret, ncnn_out0 = ex.extract("out0")
assert ret == 0, ret
ncnn_out0 = np.array(ncnn_out0)

ret, ncnn_out1 = ex.extract("out1")
assert ret == 0, ret
ncnn_out1 = np.array(ncnn_out1)

torch_out0, torch_out1 = f(x.unsqueeze(0))
torch_out0 = torch_out0.squeeze(0)
torch_out1 = torch_out1.squeeze(1)

ncnn_out0 = torch.from_numpy(ncnn_out0)
ncnn_out1 = torch.from_numpy(ncnn_out1)

torch.allclose(torch_out0, ncnn_out0), (
torch_out0 - ncnn_out0
).abs().max()
torch.allclose(torch_out1, ncnn_out1), (
torch_out1 - ncnn_out1
).abs().max()


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#!/usr/bin/env python3


import torch
import torch.nn as nn
from conformer import Conv2dSubsampling, MakePadMask, RelPositionalEncoding
from scaling_converter import convert_scaled_to_non_scaled


class Foo(nn.Module):
def __init__(self):
super().__init__()
num_features = 80
d_model = 512
dropout = 0.1

self.num_features = num_features
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
self.make_pad_mask = MakePadMask()

def forward(self, x: torch.Tensor, x_lens: torch.Tensor):
"""
Args:
x:
(N,T,C)
x_lens:
(N,)
"""
x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
lengths1 = torch.floor((x_lens - 1) / 2)
lengths = torch.floor((lengths1 - 1) / 2)
lengths = lengths.to(x_lens)

return x, lengths, pos_emb


def generate_pt():
f = Foo()
f.eval()
f = convert_scaled_to_non_scaled(f)
f.encoder_embed.for_pnnx = True

x = torch.rand(1, 30, 80, dtype=torch.float32) # (T, C)
x_lens = torch.tensor([30])
y, lengths, pos_emb = f(x, x_lens)
print("y.shape", y.shape)
print("lengths", lengths)
print("pos_emb.shape", pos_emb.shape)
m = torch.jit.trace(f, (x, x_lens))
m.save("foo/conformer.pt")
# print(m.graph)


def main():
generate_pt()


if __name__ == "__main__":
torch.manual_seed(20220809)
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#!/usr/bin/env python3


import torch
import torch.nn as nn
from conformer import Conv2dSubsampling
from scaling_converter import convert_scaled_to_non_scaled


class Foo(nn.Module):
def __init__(self):
super().__init__()

num_features = 80
subsampling_factor = 4
d_model = 512
self.num_features = num_features
self.subsampling_factor = subsampling_factor

self.encoder_embed = Conv2dSubsampling(
num_features,
d_model,
for_pnnx=True,
)

def forward(self, x: torch.Tensor):
"""
Args:
x:
(N, T, C)
"""
x = self.encoder_embed(x)
return x


def generate_pt():
f = Foo()
f.eval()
f = convert_scaled_to_non_scaled(f)
x = torch.rand(1, 30, 80)
y = f(x)
print("y.shape", y.shape)
m = torch.jit.trace(f, x)
m.save("foo/encoder_embed.pt")


def main():
generate_pt()


if __name__ == "__main__":
torch.manual_seed(20220809)
main()
Loading