Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Help is wanted] Support mace/onnx/tflite #51

Open
5 tasks
csukuangfj opened this issue Jul 6, 2022 · 12 comments
Open
5 tasks

[Help is wanted] Support mace/onnx/tflite #51

csukuangfj opened this issue Jul 6, 2022 · 12 comments

Comments

@csukuangfj
Copy link
Collaborator

csukuangfj commented Jul 6, 2022

libtorch is easy to use. However, the size of its shared libraries is large, see below.
80a

It is nice to support the following frameworks that are more lightweight compared to libtorch.
(More frameworks may be added later)

@EmreOzkose
Copy link

Hi, I am working on exporting an onnx model from rnn-t. Do you have any plan or roadmap for this issue? I exported encoder (Conformer) with these steps:

  1. change ActivationBalancer() forward:
def forward(self, x: Tensor) -> Tensor:
        return x
        """if torch.jit.is_scripting():
            return x
        else:
            return ActivationBalancerFunction.apply(
                x,
                self.channel_dim,
                self.min_positive,
                self.max_positive,
                self.max_factor,
                self.min_abs,
                self.max_abs,
            )"""
  1. change DoubleSwish() forward:
def forward(self, x: Tensor) -> Tensor:
        """Return double-swish activation function which is an approximation to Swish(Swish(x)),
        that we approximate closely with x * sigmoid(x-1).
        """
        return x * torch.sigmoid(x - torch.ones((1),  dtype=torch.float32))
        """if torch.jit.is_scripting():
            return x * torch.sigmoid(x - 1.0)
        else:
            return DoubleSwishFunction.apply(x)"""
  1. Constant values raises error during onnxruntime load. I open an Issue on Pytorch. For now, I solved this issue with a workaround. I printed the graph and saw that some values are double type. I multiplied these terms with torch.ones((1), dtype=torch.float32). So onnx model converts them to float.

    scaling.py: BasicNorm() forward:

    var_upper = torch.ones((1),  dtype=torch.float32) * -0.5
    scales = (
        torch.mean(x ** 2, dim=self.channel_dim, keepdim=True)
        + self.eps.exp()
    ) ** var_upper
    

    conformer.py.1125:

    scaling = torch.ones((1),  dtype=torch.float32) * float(head_dim) ** -0.5
    

With these changes, I am able to convert encoder to onnx version.

import os
import sys

os.environ["PYTHONPATH"] = "/path/to/k2_sherpa/icefall:" + os.environ["PYTHONPATH"]
sys.path.append("/path/to/k2_sherpa/icefall")

import torch
import torch.nn as nn

from joiner import Joiner
from decoder import Decoder
from model import Transducer
from conformer import Conformer

from decode import get_parser
from icefall.utils import AttributeDict
from asr_datamodule import AsrDataModule
from icefall.checkpoint import average_checkpoints

nn_model_filename = "/path/to/k2_sherpa/models/enUS_75-25.pt"

def get_encoder_model(params: AttributeDict) -> nn.Module:
    # TODO: We can add an option to switch between Conformer and Transformer
    encoder = Conformer(
        num_features=params.feature_dim,
        subsampling_factor=params.subsampling_factor,
        d_model=params.encoder_dim,
        nhead=params.nhead,
        dim_feedforward=params.dim_feedforward,
        num_encoder_layers=params.num_encoder_layers,
        dynamic_chunk_training=params.dynamic_chunk_training,
        short_chunk_size=params.short_chunk_size,
        num_left_chunks=params.num_left_chunks,
        causal=params.causal_convolution,
    )
    return encoder


def get_decoder_model(params: AttributeDict) -> nn.Module:
    decoder = Decoder(
        vocab_size=params.vocab_size,
        decoder_dim=params.decoder_dim,
        blank_id=params.blank_id,
        context_size=params.context_size,
    )
    return decoder


def get_joiner_model(params: AttributeDict) -> nn.Module:
    joiner = Joiner(
        encoder_dim=params.encoder_dim,
        decoder_dim=params.decoder_dim,
        joiner_dim=params.joiner_dim,
        vocab_size=params.vocab_size,
    )
    return joiner


def get_transducer_model(params: AttributeDict) -> nn.Module:
    encoder = get_encoder_model(params)
    decoder = get_decoder_model(params)
    joiner = get_joiner_model(params)

    model = Transducer(
        encoder=encoder,
        decoder=decoder,
        joiner=joiner,
        encoder_dim=params.encoder_dim,
        decoder_dim=params.decoder_dim,
        joiner_dim=params.joiner_dim,
        vocab_size=params.vocab_size,
    )
    return model

params = AttributeDict(
        {
            "best_train_loss": float("inf"),
            "best_valid_loss": float("inf"),
            "best_train_epoch": -1,
            "best_valid_epoch": -1,
            "batch_idx_train": 0,
            "log_interval": 50,
            "reset_interval": 200,
            "valid_interval": 3000,  # For the 100h subset, use 800
            # parameters for conformer
            "feature_dim": 80,
            "subsampling_factor": 4,
            "encoder_dim": 512,
            "nhead": 8,
            "dim_feedforward": 2048,
            "num_encoder_layers": 12,
            # parameters for decoder
            "decoder_dim": 512,
            # parameters for joiner
            "joiner_dim": 512,
            # parameters for Noam
            "model_warm_step": 3000,  # arg given to model, not for lrate
        }
    )

params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.unk_id = 2

params.dynamic_chunk_training = False
params.short_chunk_size = 25
params.num_left_chunks = 4
params.causal_convolution = False

params.ngram_lm_scale = 0.01
params.nbest_scale = 0.5

parser = get_parser()
AsrDataModule.add_arguments(parser)

device = "cpu"
model = get_transducer_model(params)
model.load_state_dict(average_checkpoints([nn_model_filename], device=device))

model_encoder = model.encoder
model_encoder.eval()

features = torch.rand([1, 50, 80], dtype=torch.float).to(device)

feature_lens = torch.tensor(
    [f.size(0) for f in features],
    dtype=torch.int64,
)

encoder_out, encoder_out_lens = model_encoder(
    x=features, x_lens=feature_lens
)

torch.onnx.export(model_encoder,
                  args=(features, feature_lens),
                  f="encoder_onnx_converted.onnx",
                  input_names=["x", "x_lens"],
                  output_names=["output1"],
                  opset_version = 13,
                  dynamic_axes={
                      'x': {0 : 'batch_size'},
                      'x_lens': {0 : 'batch_size'},
                      'output1': {0 : 'batch_size'},
                  },
                  do_constant_folding=True,
                  verbose=True)

import onnx
import numpy as np
import onnxruntime as ort


feature_np, feature_lens_np = features.numpy(), feature_lens.numpy()
print(feature_np.shape, feature_lens.shape)

model_onnx = onnx.load("encoder_onnx_converted_modules_oldu.onnx")
onnx.checker.check_model(model_onnx)

ort_session = ort.InferenceSession("encoder_onnx_converted_modules_oldu.onnx")
outputs = ort_session.run(
    None,
    {
        "x": feature_np.astype(np.float32),
        "x_lens": feature_lens_np
    },
)

If we check outputs are same or not:

torch.allclose(encoder_out, torch.from_numpy(outputs[0]), atol=1e-6)
# True

atol=1e-7 returns False.

I am also working on Decoder and Joiner

@csukuangfj
Copy link
Collaborator Author

csukuangfj commented Jul 21, 2022

@EmreOzkose

Great work! Thanks!

For

  1. change ActivationBalancer() forward
  2. change DoubleSwish() forward:

I recommend using

if torch.jit.is_scripting() or torch.onnx.is_in_onnx_export():
  return x

(See https://pytorch.org/docs/stable/onnx.html#torch.onnx.is_in_onnx_export)

For

  1. Constant values raises error during onnxruntime load.

Could you please make a PR to icefall with your fix?


Do you have any plan or roadmap for this issue?

Since you are starting with https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless3,
I would suggest first either modifying export.py to support exporting models to onnx format or creating a new file
onnx_export.py to do that.

I think you will get three files after exporting to onnx:

  • a file for the encoder
  • a file for the decoder
  • a file for the joiner

After exporting to onnx, I would recommend creating a file onnx_test.py to check that for the same input, onnx produces the same output with PyTorch (within some numeric tolerances).

After testing that the export goes well, I suggest creating a file like what
https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py
does, e.g., onnx_pretrained.py. You can create a file onnx_beam_search.py to implement greedy_search_batch and use it in onnx_pretrained.py for decoding. If for the same sound file, pretrained.py and onnx_pretrained.py produces the same output, we are sure that everything works as expected while doing onnx export.


Now it comes to the sherpa part. We have to select one runtime that supports onnx.
https://onnx.ai/supported-tools.html#deployModel lists lots of onnx runtimes.

I suggest that we start with https://github.com/microsoft/onnxruntime/. We can switch to other runtimes when we have more experience with microsoft/onnxruntime.


[EDITED]: We may choose to support compiling sherpa into some binaries/libraries for deployment, without depending on libtorch.

@EmreOzkose
Copy link

I notice that I forget to add a step above. I also change torch.as_strided() function with my implementation. When I try to export with torch.as_strided(), I got this error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Input In [8], in <cell line: 1>()
----> 1 torch.onnx.export(model_encoder,
      2                   args=(features, feature_lens),
      3                   f="encoder_onnx_converted_modules_oldu.onnx",
      4                   input_names=["x", "x_lens"],
      5                   output_names=["output1"],
      6                   opset_version = 13,
      7                   dynamic_axes={
      8                       'x': {0 : 'batch_size'},
      9                       'x_lens': {0 : 'batch_size'},
     10                       'output1': {0 : 'batch_size'},
     11                   },
     12                   do_constant_folding=True,
     13                   verbose=True)

File ~/anaconda3/envs/k2_sherpa/lib/python3.8/site-packages/torch/onnx/__init__.py:316, in export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)
     38 r"""
     39 Exports a model into ONNX format. If ``model`` is not a
     40 :class:`torch.jit.ScriptModule` nor a :class:`torch.jit.ScriptFunction`, this runs
   (...)
    312     model to the file ``f`` even if this is raised.
    313 """
    315 from torch.onnx import utils
--> 316 return utils.export(model, args, f, export_params, verbose, training,
    317                     input_names, output_names, operator_export_type, opset_version,
    318                     _retain_param_name, do_constant_folding, example_outputs,
    319                     strip_doc_string, dynamic_axes, keep_initializers_as_inputs,
    320                     custom_opsets, enable_onnx_checker, use_external_data_format)

File ~/anaconda3/envs/k2_sherpa/lib/python3.8/site-packages/torch/onnx/utils.py:107, in export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)
    102 if use_external_data_format is not None:
    103     warnings.warn("`use_external_data_format' is deprecated and ignored. Will be removed in next "
    104                   "PyTorch release. The code will work as it is False if models are not larger than 2GB, "
    105                   "Otherwise set to False because of size limits imposed by Protocol Buffers.")
--> 107 _export(model, args, f, export_params, verbose, training, input_names, output_names,
    108         operator_export_type=operator_export_type, opset_version=opset_version,
    109         do_constant_folding=do_constant_folding, example_outputs=example_outputs,
    110         dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs,
    111         custom_opsets=custom_opsets, use_external_data_format=use_external_data_format)

File ~/anaconda3/envs/k2_sherpa/lib/python3.8/site-packages/torch/onnx/utils.py:714, in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, use_external_data_format, onnx_shape_inference)
    710     dynamic_axes = {}
    711 _validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
    713 graph, params_dict, torch_out = \
--> 714     _model_to_graph(model, args, verbose, input_names,
    715                     output_names, operator_export_type,
    716                     example_outputs, val_do_constant_folding,
    717                     fixed_batch_size=fixed_batch_size,
    718                     training=training,
    719                     dynamic_axes=dynamic_axes)
    721 # TODO: Don't allocate a in-memory string for the protobuf
    722 defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE

File ~/anaconda3/envs/k2_sherpa/lib/python3.8/site-packages/torch/onnx/utils.py:496, in _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type, example_outputs, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size, training, dynamic_axes)
    492 graph, params, torch_out, module = _create_jit_graph(model, args)
    494 params_dict = _get_named_param_dict(graph, params)
--> 496 graph = _optimize_graph(graph, operator_export_type,
    497                         _disable_torch_constant_prop=_disable_torch_constant_prop,
    498                         fixed_batch_size=fixed_batch_size, params_dict=params_dict,
    499                         dynamic_axes=dynamic_axes, input_names=input_names,
    500                         module=module)
    501 from torch.onnx.symbolic_helper import _onnx_shape_inference
    502 if isinstance(model, torch.jit.ScriptModule) or isinstance(model, torch.jit.ScriptFunction):

File ~/anaconda3/envs/k2_sherpa/lib/python3.8/site-packages/torch/onnx/utils.py:216, in _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop, fixed_batch_size, params_dict, dynamic_axes, input_names, module)
    214     dynamic_axes = {} if dynamic_axes is None else dynamic_axes
    215     torch._C._jit_pass_onnx_set_dynamic_input_shape(graph, dynamic_axes, input_names)
--> 216 graph = torch._C._jit_pass_onnx(graph, operator_export_type)
    217 torch._C._jit_pass_lint(graph)
    219 torch._C._jit_pass_onnx_scalar_type_analysis(graph, True, _export_onnx_opset_version)

File ~/anaconda3/envs/k2_sherpa/lib/python3.8/site-packages/torch/onnx/__init__.py:373, in _run_symbolic_function(*args, **kwargs)
    371 def _run_symbolic_function(*args, **kwargs):
    372     from torch.onnx import utils
--> 373     return utils._run_symbolic_function(*args, **kwargs)

File ~/anaconda3/envs/k2_sherpa/lib/python3.8/site-packages/torch/onnx/utils.py:1019, in _run_symbolic_function(g, block, n, inputs, env, operator_export_type)
   1017             return None
   1018         attrs = {k: n[k] for k in n.attributeNames()}
-> 1019         return symbolic_fn(g, *inputs, **attrs)
   1021 elif ns == "prim":
   1022     if op_name == "Constant" and not n.mustBeNone():

File ~/anaconda3/envs/k2_sherpa/lib/python3.8/site-packages/torch/onnx/symbolic_helper.py:172, in parse_args.<locals>.decorator.<locals>.wrapper(g, *args, **kwargs)
    170 if len(kwargs) == 1:
    171     assert "_outputs" in kwargs
--> 172 return fn(g, *args, **kwargs)

File ~/anaconda3/envs/k2_sherpa/lib/python3.8/site-packages/torch/onnx/symbolic_opset9.py:3088, in as_strided(g, self, sizes, strides, offset)
   3086         ind = g.op("Add", ind, tmp_ind)
   3087 if offset:
-> 3088     ind = g.op("Add", ind, g.op("Constant", torch.tensor([offset])))
   3089 return g.op("Gather", self_1d, ind)

File ~/anaconda3/envs/k2_sherpa/lib/python3.8/site-packages/torch/onnx/utils.py:915, in _graph_op(g, opname, *raw_args, **kwargs)
    913 if _onnx_shape_inference:
    914     from torch.onnx.symbolic_helper import _export_onnx_opset_version as opset_version
--> 915     torch._C._jit_pass_onnx_node_shape_type_inference(n, _params_dict, opset_version)
    917 if outputs == 1:
    918     return n.output()

RuntimeError: required keyword attribute 'value' is undefined

Hence, I changed conformer.py:RelPositionMultiheadAttention().rel_shift():

return x.as_strided(
            (batch_size, num_heads, time1, time2),
            (batch_stride, head_stride, time1_stride - n_stride, n_stride),
            storage_offset=storage_offset,
        )

to

return self.as_strided_explicit_4d(x, 
            (batch_size, num_heads, time1, time2), 
            (batch_stride, head_stride, time1_stride - n_stride, n_stride), 
            storage_offset
        )

where as_strided_explicit_4d is defined as

def as_strided_explicit_4d(self, x: torch.Tensor, size: tuple, stride: tuple, storage_offset: int):
        # torch.as_strided() alternative
        x_flattened = x.flatten()

        out_stack = []
        for i0 in range(size[0]):
            stack0 = []
            for i1 in range(size[1]):
                stack1 = []
                for i2 in range(size[2]):
                    stack2 = []
                    for i3 in range(size[3]):
                        idx = i0 * stride[0] + i1 * stride[1] + i2 * stride[2] + i3 * stride[3] + storage_offset
                        stack2.append(x_flattened[idx])
                    stack1.append(stack2)
                stack0.append(stack1)
            out_stack.append(stack0)

        return torch.Tensor(out_stack).float().to(x.device)

@EmreOzkose
Copy link

Hi, exporting onnx models is almost done. https://github.com/EmreOzkose/icefall/tree/onnx. When I tested with a sample wav, pretrained.py and onnx_pretrained.wav files are giving the same input. However there are some issue:

  1. When I exporting encoder with a sample wav and then do inference (using onnx_pretrained.py) with the same wav, everything goes well. However encoder onnx model cannot give same output if the input wav is different from the one which is used for exporting onnx model. I think scaling (scale.py) might cause this issue, but I am not sure.
  2. During beam search, I had to export also encoder_proj and decoder_proj, since they are also used externally (e.g. here). Is it convenient way to export?

@csukuangfj
Copy link
Collaborator Author

@EmreOzkose
Please have a look at k2-fsa/icefall#501

@danpovey
Copy link
Collaborator

That looks to me like there was a bug in

I notice that I forget to add a step above. I also change torch.as_strided() function with my implementation. When I try to export with torch.as_strided(), I got this error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Input In [8], in <cell line: 1>()
----> 1 torch.onnx.export(model_encoder,
      2                   args=(features, feature_lens),
      3                   f="encoder_onnx_converted_modules_oldu.onnx",
      4                   input_names=["x", "x_lens"],
      5                   output_names=["output1"],
      6                   opset_version = 13,
      7                   dynamic_axes={
      8                       'x': {0 : 'batch_size'},
      9                       'x_lens': {0 : 'batch_size'},
     10                       'output1': {0 : 'batch_size'},
     11                   },
     12                   do_constant_folding=True,
     13                   verbose=True)

File ~/anaconda3/envs/k2_sherpa/lib/python3.8/site-packages/torch/onnx/__init__.py:316, in export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)
     38 r"""
     39 Exports a model into ONNX format. If ``model`` is not a
     40 :class:`torch.jit.ScriptModule` nor a :class:`torch.jit.ScriptFunction`, this runs
   (...)
    312     model to the file ``f`` even if this is raised.
    313 """
    315 from torch.onnx import utils
--> 316 return utils.export(model, args, f, export_params, verbose, training,
    317                     input_names, output_names, operator_export_type, opset_version,
    318                     _retain_param_name, do_constant_folding, example_outputs,
    319                     strip_doc_string, dynamic_axes, keep_initializers_as_inputs,
    320                     custom_opsets, enable_onnx_checker, use_external_data_format)

File ~/anaconda3/envs/k2_sherpa/lib/python3.8/site-packages/torch/onnx/utils.py:107, in export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)
    102 if use_external_data_format is not None:
    103     warnings.warn("`use_external_data_format' is deprecated and ignored. Will be removed in next "
    104                   "PyTorch release. The code will work as it is False if models are not larger than 2GB, "
    105                   "Otherwise set to False because of size limits imposed by Protocol Buffers.")
--> 107 _export(model, args, f, export_params, verbose, training, input_names, output_names,
    108         operator_export_type=operator_export_type, opset_version=opset_version,
    109         do_constant_folding=do_constant_folding, example_outputs=example_outputs,
    110         dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs,
    111         custom_opsets=custom_opsets, use_external_data_format=use_external_data_format)

File ~/anaconda3/envs/k2_sherpa/lib/python3.8/site-packages/torch/onnx/utils.py:714, in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, use_external_data_format, onnx_shape_inference)
    710     dynamic_axes = {}
    711 _validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
    713 graph, params_dict, torch_out = \
--> 714     _model_to_graph(model, args, verbose, input_names,
    715                     output_names, operator_export_type,
    716                     example_outputs, val_do_constant_folding,
    717                     fixed_batch_size=fixed_batch_size,
    718                     training=training,
    719                     dynamic_axes=dynamic_axes)
    721 # TODO: Don't allocate a in-memory string for the protobuf
    722 defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE

File ~/anaconda3/envs/k2_sherpa/lib/python3.8/site-packages/torch/onnx/utils.py:496, in _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type, example_outputs, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size, training, dynamic_axes)
    492 graph, params, torch_out, module = _create_jit_graph(model, args)
    494 params_dict = _get_named_param_dict(graph, params)
--> 496 graph = _optimize_graph(graph, operator_export_type,
    497                         _disable_torch_constant_prop=_disable_torch_constant_prop,
    498                         fixed_batch_size=fixed_batch_size, params_dict=params_dict,
    499                         dynamic_axes=dynamic_axes, input_names=input_names,
    500                         module=module)
    501 from torch.onnx.symbolic_helper import _onnx_shape_inference
    502 if isinstance(model, torch.jit.ScriptModule) or isinstance(model, torch.jit.ScriptFunction):

File ~/anaconda3/envs/k2_sherpa/lib/python3.8/site-packages/torch/onnx/utils.py:216, in _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop, fixed_batch_size, params_dict, dynamic_axes, input_names, module)
    214     dynamic_axes = {} if dynamic_axes is None else dynamic_axes
    215     torch._C._jit_pass_onnx_set_dynamic_input_shape(graph, dynamic_axes, input_names)
--> 216 graph = torch._C._jit_pass_onnx(graph, operator_export_type)
    217 torch._C._jit_pass_lint(graph)
    219 torch._C._jit_pass_onnx_scalar_type_analysis(graph, True, _export_onnx_opset_version)

File ~/anaconda3/envs/k2_sherpa/lib/python3.8/site-packages/torch/onnx/__init__.py:373, in _run_symbolic_function(*args, **kwargs)
    371 def _run_symbolic_function(*args, **kwargs):
    372     from torch.onnx import utils
--> 373     return utils._run_symbolic_function(*args, **kwargs)

File ~/anaconda3/envs/k2_sherpa/lib/python3.8/site-packages/torch/onnx/utils.py:1019, in _run_symbolic_function(g, block, n, inputs, env, operator_export_type)
   1017             return None
   1018         attrs = {k: n[k] for k in n.attributeNames()}
-> 1019         return symbolic_fn(g, *inputs, **attrs)
   1021 elif ns == "prim":
   1022     if op_name == "Constant" and not n.mustBeNone():

File ~/anaconda3/envs/k2_sherpa/lib/python3.8/site-packages/torch/onnx/symbolic_helper.py:172, in parse_args.<locals>.decorator.<locals>.wrapper(g, *args, **kwargs)
    170 if len(kwargs) == 1:
    171     assert "_outputs" in kwargs
--> 172 return fn(g, *args, **kwargs)

File ~/anaconda3/envs/k2_sherpa/lib/python3.8/site-packages/torch/onnx/symbolic_opset9.py:3088, in as_strided(g, self, sizes, strides, offset)
   3086         ind = g.op("Add", ind, tmp_ind)
   3087 if offset:
-> 3088     ind = g.op("Add", ind, g.op("Constant", torch.tensor([offset])))
   3089 return g.op("Gather", self_1d, ind)

File ~/anaconda3/envs/k2_sherpa/lib/python3.8/site-packages/torch/onnx/utils.py:915, in _graph_op(g, opname, *raw_args, **kwargs)
    913 if _onnx_shape_inference:
    914     from torch.onnx.symbolic_helper import _export_onnx_opset_version as opset_version
--> 915     torch._C._jit_pass_onnx_node_shape_type_inference(n, _params_dict, opset_version)
    917 if outputs == 1:
    918     return n.output()

RuntimeError: required keyword attribute 'value' is undefined

Hence, I changed conformer.py:RelPositionMultiheadAttention().rel_shift():

return x.as_strided(
            (batch_size, num_heads, time1, time2),
            (batch_stride, head_stride, time1_stride - n_stride, n_stride),
            storage_offset=storage_offset,
        )

to

return self.as_strided_explicit_4d(x, 
            (batch_size, num_heads, time1, time2), 
            (batch_stride, head_stride, time1_stride - n_stride, n_stride), 
            storage_offset
        )

where as_strided_explicit_4d is defined as

def as_strided_explicit_4d(self, x: torch.Tensor, size: tuple, stride: tuple, storage_offset: int):
        # torch.as_strided() alternative
        x_flattened = x.flatten()

        out_stack = []
        for i0 in range(size[0]):
            stack0 = []
            for i1 in range(size[1]):
                stack1 = []
                for i2 in range(size[2]):
                    stack2 = []
                    for i3 in range(size[3]):
                        idx = i0 * stride[0] + i1 * stride[1] + i2 * stride[2] + i3 * stride[3] + storage_offset
                        stack2.append(x_flattened[idx])
                    stack1.append(stack2)
                stack0.append(stack1)
            out_stack.append(stack0)

        return torch.Tensor(out_stack).float().to(x.device)

This looks to me like there was a bug in the onnx export code that specific version of Torch you were using. The current code looks nothing like the code in your error message, so I am assuming the issue was fixed long ago.

@csukuangfj
Copy link
Collaborator Author

However encoder onnx model cannot give same output if the input wav is different from the one which is used for exporting onnx model. I think scaling (scale.py) might cause this issue, but I am not sure.

The reason is that you are passing a nn.Module to torch.onnx.export, which will use torch.jit.trace.

However, in your re-implementation of as_strided function, you are using for loops, which won't work well
with torch.jit.trace if you want to use a dynamic input shape.

@danpovey
Copy link
Collaborator

Incidentally, the use of as_strided() in the implementation may not be 100% necessary. You could look at the history of that code, at some point I implemented it that way as I felt it was either easier to understand or more efficient, but it should be equivalent to the previous implementation.
But you could also try another torch version to see if that bug has been solved.

@csukuangfj
Copy link
Collaborator Author

        if torch.jit.is_tracing():
            rows = torch.arange(start=time1 - 1, end=-1, step=-1)
            cols = torch.arange(time1)
            rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
            indexes = rows + cols

            x = x.reshape(-1, n)
            x = torch.gather(x, dim=1, index=indexes)
            x = x.reshape(batch_size, num_heads, time1, time1)
            return x
        else:
            # Note: TorchScript requires explicit arg for stride()
            batch_stride = x.stride(0)
            head_stride = x.stride(1)
            time1_stride = x.stride(2)
            n_stride = x.stride(3)
            return x.as_strided(
                (batch_size, num_heads, time1, time2),
                (batch_stride, head_stride, time1_stride - n_stride, n_stride),
                storage_offset=n_stride * (time1 - 1),
            )

I have written a version using torch.gather that supports ONNX.

@EmreOzkose
Copy link

Hi, after Icefall, I also integrated onnx models to Sherpa and pushed them to my fork. There are 2 issues:

  1. Greedy search is implemented in Python. I am working on C++ version.
  2. Since I cannot call projection layers like joiner.encoder_proj during greedy search, I had to export Joiner layers separately (joiner-proj-input-false/joiner-encoder-proj/joiner-decoder-proj). After that, onnx-all-in-one model contains these models separately. Is it a proper way for you?

I also had to change egs/librispeech/ASR/pruned_transducer_stateless3/export.py, added exporting layers of Joiner.

In my experiments, rtf is decreased by approximately 1/3 on CPU.
example file: icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav

jit model: 0.0937
onnx model: 0.0342

@csukuangfj
Copy link
Collaborator Author

@EmreOzkose
Great work!

Supporting more than one inference framework in sherpa makes the code difficult to maintain.

Would you mind porting your onnx related c++ code to https://github.com/k2-fsa/sherpa-onnx

Thanks!

@csukuangfj
Copy link
Collaborator Author

Is it a proper way for you?

Yes, that looks good to me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants