-
Notifications
You must be signed in to change notification settings - Fork 25.8k
Description
🐛 Describe the bug
It is now possible to include torchaudio.transforms.Spectrogram in a model and successfully export the model to an onnx program. However, when loading the model I cannot use a batch size besides the one used in the saved model. I've tried several approaches. Here is an example based on #113067 (comment)
create and save onnx model
import torchvision
import torchaudio
import torch
# define a pytorch model
class SpecMaker(torch.nn.Module):
def __init__(self):
super().__init__()
self.transforms = torchvision.transforms.Compose(
[
torchaudio.transforms.Spectrogram(
n_fft=512,
win_length=512,
hop_length=256,
),
torchaudio.transforms.AmplitudeToDB(top_db=100),
]
)
def forward(self, x):
return self.transforms(x)
specmodel = SpecMaker()
input = torch.rand(32000 * 10)
spec = specmodel(input)
input_batch = torch.stack([input, input])
spec_batch = specmodel(input_batch) # just testing pytorch model works as expected
assert spec_batch.shape== torch.Size([2, 257, 1251])
onnx_program = torch.onnx.export(
specmodel,
(input_batch,),
dynamic_shapes=[{0: "dim_x"}],
report=True,
dynamo=True,
)
onnx_program.save("specmodel2.onnx")load onnx model and attempt to run with different batch size
import onnx, onnxruntime
import torch
onnx_model = onnx.load("specmodel2.onnx")
onnx.checker.check_model(onnx_model)
input = torch.rand(32000 * 10)
input = torch.tensor((opso.birds).trim(0, 10).samples)
# what if its batched?
input_batched = torch.stack([input, input, input]) #works if batch has 2 samples, fails with 3 samples
EP_list = ["CUDAExecutionProvider", "CPUExecutionProvider"]
ort_session = onnxruntime.InferenceSession("specmodel2.onnx", providers=EP_list)
def to_numpy(tensor):
return (
tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
)
# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(input_batched)}
ort_outs = ort_session.run(None, ort_inputs)Error:
---------------------------------------------------------------------------
RuntimeException Traceback (most recent call last)
Cell In[8], [line 14](vscode-notebook-cell:?execution_count=8&line=14)
[12](vscode-notebook-cell:?execution_count=8&line=12) # compute ONNX Runtime output prediction
[13](vscode-notebook-cell:?execution_count=8&line=13) ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(input_batched)}
---> [14](vscode-notebook-cell:?execution_count=8&line=14) ort_outs = ort_session.run(None, ort_inputs)
File ~/miniconda3/envs/bmz_dev/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:266, in Session.run(self, output_names, input_feed, run_options)
[264](https://file+.vscode-resource.vscode-cdn.net/Users/SML161/nb_opso/ml/save_model/~/miniconda3/envs/bmz_dev/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:264) output_names = [output.name for output in self._outputs_meta]
[265](https://file+.vscode-resource.vscode-cdn.net/Users/SML161/nb_opso/ml/save_model/~/miniconda3/envs/bmz_dev/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:265) try:
--> [266](https://file+.vscode-resource.vscode-cdn.net/Users/SML161/nb_opso/ml/save_model/~/miniconda3/envs/bmz_dev/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:266) return self._sess.run(output_names, input_feed, run_options)
[267](https://file+.vscode-resource.vscode-cdn.net/Users/SML161/nb_opso/ml/save_model/~/miniconda3/envs/bmz_dev/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:267) except C.EPFail as err:
[268](https://file+.vscode-resource.vscode-cdn.net/Users/SML161/nb_opso/ml/save_model/~/miniconda3/envs/bmz_dev/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:268) if self._enable_fallback:
RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Reshape node. Name:'node_Reshape_5' Status Message: /Users/runner/work/1/s/onnxruntime/core/providers/cpu/tensor/reshape_helper.h:47 onnxruntime::ReshapeHelper::ReshapeHelper(const onnxruntime::TensorShape &, onnxruntime::TensorShapeVector &, bool) input_shape_size == size was false. The input tensor cannot be reshaped to the requested shape. Input shape:{3,320000}, requested shape:{1,2,320000}
Am I loading and using the onnx model incorrectly, or is this an issue with dynamic shape exporting when the model contains stft? (There are several related issues including #113067, #139246, and this PR #145080)
Versions
collect_env.py.1 100%[================================================================================================================================>] 23.78K --.-KB/s in 0.004s
2025-03-06 12:54:42 (6.31 MB/s) - ‘collect_env.py.1’ saved [24353/24353]
zsh: command not found: #
Collecting environment information...
PyTorch version: 2.7.0.dev20250301
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 15.1 (arm64)
GCC version: Could not collect
Clang version: 16.0.0 (clang-1600.0.26.4)
CMake version: Could not collect
Libc version: N/A
Python version: 3.10.14 (main, May 6 2024, 14:42:37) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-15.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M1 Pro
Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] onnx==1.17.0
[pip3] onnxruntime==1.20.1
[pip3] onnxscript==0.2.1
[pip3] optree==0.14.0
[pip3] pytorch-lightning==2.4.0
[pip3] torch==2.7.0.dev20250301
[pip3] torch-audiomentations==0.11.0
[pip3] torch_pitch_shift==1.2.5
[pip3] torchaudio==2.6.0.dev20250301
[pip3] torchmetrics==1.2.0
[pip3] torchview==0.2.6
[pip3] torchvision==0.22.0.dev20250301
[conda] numpy 1.26.4 pypi_0 pypi
[conda] optree 0.14.0 pypi_0 pypi
[conda] pytorch-lightning 2.4.0 pypi_0 pypi
[conda] torch 2.7.0.dev20250301 pypi_0 pypi
[conda] torch-audiomentations 0.11.0 pypi_0 pypi
[conda] torch-pitch-shift 1.2.5 pypi_0 pypi
[conda] torchaudio 2.6.0.dev20250301 pypi_0 pypi
[conda] torchmetrics 1.2.0 pypi_0 pypi
[conda] torchview 0.2.6 pypi_0 pypi
[conda] torchvision 0.22.0.dev20250301 pypi_0 pypi