Skip to content

Commit

Permalink
minor adjustments to the VITS recipes for onnx runtime (#1405)
Browse files Browse the repository at this point in the history
  • Loading branch information
JinZr authored Dec 7, 2023
1 parent b87ed26 commit bda72f8
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 9 deletions.
4 changes: 2 additions & 2 deletions egs/ljspeech/TTS/vits/export-onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,16 +176,16 @@ def export_model_onnx(

torch.onnx.export(
model,
(tokens, tokens_lens, noise_scale, noise_scale_dur, alpha),
(tokens, tokens_lens, noise_scale, alpha, noise_scale_dur),
model_filename,
verbose=False,
opset_version=opset_version,
input_names=[
"tokens",
"tokens_lens",
"noise_scale",
"noise_scale_dur",
"alpha",
"noise_scale_dur",
],
output_names=["audio"],
dynamic_axes={
Expand Down
4 changes: 2 additions & 2 deletions egs/ljspeech/TTS/vits/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Ten
self.model.get_inputs()[0].name: tokens.numpy(),
self.model.get_inputs()[1].name: tokens_lens.numpy(),
self.model.get_inputs()[2].name: noise_scale.numpy(),
self.model.get_inputs()[3].name: noise_scale_dur.numpy(),
self.model.get_inputs()[4].name: alpha.numpy(),
self.model.get_inputs()[3].name: alpha.numpy(),
self.model.get_inputs()[4].name: noise_scale_dur.numpy(),
},
)[0]
return torch.from_numpy(out)
Expand Down
4 changes: 2 additions & 2 deletions egs/vctk/TTS/vits/export-onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,17 +187,17 @@ def export_model_onnx(

torch.onnx.export(
model,
(tokens, tokens_lens, noise_scale, noise_scale_dur, speaker, alpha),
(tokens, tokens_lens, noise_scale, alpha, noise_scale_dur, speaker),
model_filename,
verbose=False,
opset_version=opset_version,
input_names=[
"tokens",
"tokens_lens",
"noise_scale",
"alpha",
"noise_scale_dur",
"speaker",
"alpha",
],
output_names=["audio"],
dynamic_axes={
Expand Down
6 changes: 3 additions & 3 deletions egs/vctk/TTS/vits/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ def __call__(
self.model.get_inputs()[0].name: tokens.numpy(),
self.model.get_inputs()[1].name: tokens_lens.numpy(),
self.model.get_inputs()[2].name: noise_scale.numpy(),
self.model.get_inputs()[3].name: noise_scale_dur.numpy(),
self.model.get_inputs()[4].name: speaker.numpy(),
self.model.get_inputs()[5].name: alpha.numpy(),
self.model.get_inputs()[3].name: alpha.numpy(),
self.model.get_inputs()[4].name: noise_scale_dur.numpy(),
self.model.get_inputs()[5].name: speaker.numpy(),
},
)[0]
return torch.from_numpy(out)
Expand Down

0 comments on commit bda72f8

Please sign in to comment.