Skip to content

Commit 888f729

Browse files
committed
ONNX converter and onnxruntime based transcriber
1 parent 2e04ed3 commit 888f729

File tree

2 files changed

+124
-0
lines changed

2 files changed

+124
-0
lines changed

moonshine/tools/convert_to_onnx.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import sys
2+
import keras
3+
import moonshine
4+
from pathlib import Path
5+
6+
7+
def convert_and_store(model, input_signature, output_file):
8+
from tf2onnx.convert import from_keras
9+
import onnx
10+
11+
onnx_model, external_storage_dict = from_keras(
12+
model, input_signature=input_signature
13+
)
14+
assert external_storage_dict is None, f"External storage for onnx not supported"
15+
onnx.save_model(onnx_model, output_file)
16+
17+
18+
def main():
19+
assert (
20+
len(sys.argv) == 3
21+
), "Usage: convert_to_onnx.py <moonshine model name> <output directory name>"
22+
assert (
23+
keras.config.backend() == "tensorflow"
24+
), "Should be run with the tensorflow backend"
25+
26+
import tensorflow as tf
27+
28+
model_name = sys.argv[1]
29+
model = moonshine.load_model(model_name)
30+
output_dir = sys.argv[2]
31+
Path(output_dir).mkdir(parents=True, exist_ok=True)
32+
33+
convert_and_store(
34+
model.preprocessor.preprocess,
35+
input_signature=[tf.TensorSpec([None, None], dtype=tf.float32)],
36+
output_file=f"{output_dir}/preprocess.onnx",
37+
)
38+
39+
seq_len_spec = tf.TensorSpec([1], dtype=tf.int32)
40+
41+
convert_and_store(
42+
model.encoder.encoder,
43+
input_signature=[
44+
tf.TensorSpec([None, None, model.dim], dtype=tf.float32),
45+
seq_len_spec,
46+
],
47+
output_file=f"{output_dir}/encode.onnx",
48+
)
49+
50+
input_spec = tf.TensorSpec([None, None], dtype=tf.int32)
51+
context_spec = tf.TensorSpec([None, None, model.dim], dtype=tf.float32)
52+
cache_spec = [
53+
tf.TensorSpec(
54+
[None, None, model.n_head, model.inner_dim // model.n_head],
55+
dtype=tf.float32,
56+
)
57+
for _ in range(model.dec_n_layers * 4)
58+
]
59+
60+
convert_and_store(
61+
model.decoder.uncached_call,
62+
input_signature=[input_spec, context_spec, seq_len_spec],
63+
output_file=f"{output_dir}/uncached_decode.onnx",
64+
)
65+
66+
convert_and_store(
67+
model.decoder.cached_call,
68+
input_signature=[input_spec, context_spec, seq_len_spec] + cache_spec,
69+
output_file=f"{output_dir}/cached_decode.onnx",
70+
)
71+
72+
73+
if __name__ == "__main__":
74+
main()

moonshine/tools/onnx_model.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import onnxruntime
2+
import moonshine
3+
4+
5+
class MoonshineOnnxModel(object):
6+
def __init__(self, models_dir):
7+
self.preprocess = onnxruntime.InferenceSession(f"{models_dir}/preprocess.onnx")
8+
self.encode = onnxruntime.InferenceSession(f"{models_dir}/encode.onnx")
9+
self.uncached_decode = onnxruntime.InferenceSession(
10+
f"{models_dir}/uncached_decode.onnx"
11+
)
12+
self.cached_decode = onnxruntime.InferenceSession(
13+
f"{models_dir}/cached_decode.onnx"
14+
)
15+
self.tokenizer = moonshine.load_tokenizer()
16+
17+
def generate(self, audio, max_len=None):
18+
audio = moonshine.load_audio(audio, return_numpy=True)
19+
if max_len is None:
20+
# max 6 tokens per second of audio
21+
max_len = int((audio.shape[-1] / 16_000) * 6)
22+
preprocessed = self.preprocess.run([], dict(args_0=audio))[0]
23+
seq_len = [preprocessed.shape[-2]]
24+
25+
context = self.encode.run([], dict(args_0=preprocessed, args_1=seq_len))[0]
26+
inputs = [[1]]
27+
seq_len = [1]
28+
29+
tokens = [1]
30+
logits, *cache = self.uncached_decode.run(
31+
[], dict(args_0=inputs, args_1=context, args_2=seq_len)
32+
)
33+
for i in range(max_len):
34+
next_token = logits.squeeze().argmax()
35+
tokens.extend([next_token])
36+
if next_token == 2:
37+
break
38+
39+
seq_len[0] += 1
40+
inputs = [[next_token]]
41+
logits, *cache = self.cached_decode.run(
42+
[],
43+
dict(
44+
args_0=inputs,
45+
args_1=context,
46+
args_2=seq_len,
47+
**{f"args_{i+3}": x for i, x in enumerate(cache)},
48+
),
49+
)
50+
return [tokens]

0 commit comments

Comments
 (0)