Skip to content

torch.onnx.export with torchaudio Spectrogram doesn't support dynamic batch size #148687

@sammlapp

Description

@sammlapp

🐛 Describe the bug

It is now possible to include torchaudio.transforms.Spectrogram in a model and successfully export the model to an onnx program. However, when loading the model I cannot use a batch size besides the one used in the saved model. I've tried several approaches. Here is an example based on #113067 (comment)

create and save onnx model

import torchvision
import torchaudio
import torch

# define a pytorch model
class SpecMaker(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.transforms = torchvision.transforms.Compose(
            [
                torchaudio.transforms.Spectrogram(
                    n_fft=512,
                    win_length=512,
                    hop_length=256,
                ),
                torchaudio.transforms.AmplitudeToDB(top_db=100),
            ]
        )

    def forward(self, x):
        return self.transforms(x)


specmodel = SpecMaker()
input = torch.rand(32000 * 10)
spec = specmodel(input)
input_batch = torch.stack([input, input])
spec_batch = specmodel(input_batch) # just testing pytorch model works as expected

assert spec_batch.shape== torch.Size([2, 257, 1251])

onnx_program = torch.onnx.export(
    specmodel,
    (input_batch,),
    dynamic_shapes=[{0: "dim_x"}],
    report=True,
    dynamo=True,
)

onnx_program.save("specmodel2.onnx")

load onnx model and attempt to run with different batch size

import onnx, onnxruntime
import torch

onnx_model = onnx.load("specmodel2.onnx")
onnx.checker.check_model(onnx_model)
input = torch.rand(32000 * 10)
input = torch.tensor((opso.birds).trim(0, 10).samples)

# what if its batched?
input_batched = torch.stack([input, input, input]) #works if batch has 2 samples, fails with 3 samples

EP_list = ["CUDAExecutionProvider", "CPUExecutionProvider"]

ort_session = onnxruntime.InferenceSession("specmodel2.onnx", providers=EP_list)


def to_numpy(tensor):
    return (
        tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
    )


# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(input_batched)}
ort_outs = ort_session.run(None, ort_inputs)

Error:

---------------------------------------------------------------------------
RuntimeException                          Traceback (most recent call last)
Cell In[8], [line 14](vscode-notebook-cell:?execution_count=8&line=14)
     [12](vscode-notebook-cell:?execution_count=8&line=12) # compute ONNX Runtime output prediction
     [13](vscode-notebook-cell:?execution_count=8&line=13) ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(input_batched)}
---> [14](vscode-notebook-cell:?execution_count=8&line=14) ort_outs = ort_session.run(None, ort_inputs)

File ~/miniconda3/envs/bmz_dev/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:266, in Session.run(self, output_names, input_feed, run_options)
    [264](https://file+.vscode-resource.vscode-cdn.net/Users/SML161/nb_opso/ml/save_model/~/miniconda3/envs/bmz_dev/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:264)     output_names = [output.name for output in self._outputs_meta]
    [265](https://file+.vscode-resource.vscode-cdn.net/Users/SML161/nb_opso/ml/save_model/~/miniconda3/envs/bmz_dev/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:265) try:
--> [266](https://file+.vscode-resource.vscode-cdn.net/Users/SML161/nb_opso/ml/save_model/~/miniconda3/envs/bmz_dev/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:266)     return self._sess.run(output_names, input_feed, run_options)
    [267](https://file+.vscode-resource.vscode-cdn.net/Users/SML161/nb_opso/ml/save_model/~/miniconda3/envs/bmz_dev/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:267) except C.EPFail as err:
    [268](https://file+.vscode-resource.vscode-cdn.net/Users/SML161/nb_opso/ml/save_model/~/miniconda3/envs/bmz_dev/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:268)     if self._enable_fallback:

RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Reshape node. Name:'node_Reshape_5' Status Message: /Users/runner/work/1/s/onnxruntime/core/providers/cpu/tensor/reshape_helper.h:47 onnxruntime::ReshapeHelper::ReshapeHelper(const onnxruntime::TensorShape &, onnxruntime::TensorShapeVector &, bool) input_shape_size == size was false. The input tensor cannot be reshaped to the requested shape. Input shape:{3,320000}, requested shape:{1,2,320000}

Am I loading and using the onnx model incorrectly, or is this an issue with dynamic shape exporting when the model contains stft? (There are several related issues including #113067, #139246, and this PR #145080)

Versions

collect_env.py.1 100%[================================================================================================================================>] 23.78K --.-KB/s in 0.004s

2025-03-06 12:54:42 (6.31 MB/s) - ‘collect_env.py.1’ saved [24353/24353]

zsh: command not found: #
Collecting environment information...
PyTorch version: 2.7.0.dev20250301
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 15.1 (arm64)
GCC version: Could not collect
Clang version: 16.0.0 (clang-1600.0.26.4)
CMake version: Could not collect
Libc version: N/A

Python version: 3.10.14 (main, May 6 2024, 14:42:37) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-15.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1 Pro

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] onnx==1.17.0
[pip3] onnxruntime==1.20.1
[pip3] onnxscript==0.2.1
[pip3] optree==0.14.0
[pip3] pytorch-lightning==2.4.0
[pip3] torch==2.7.0.dev20250301
[pip3] torch-audiomentations==0.11.0
[pip3] torch_pitch_shift==1.2.5
[pip3] torchaudio==2.6.0.dev20250301
[pip3] torchmetrics==1.2.0
[pip3] torchview==0.2.6
[pip3] torchvision==0.22.0.dev20250301
[conda] numpy 1.26.4 pypi_0 pypi
[conda] optree 0.14.0 pypi_0 pypi
[conda] pytorch-lightning 2.4.0 pypi_0 pypi
[conda] torch 2.7.0.dev20250301 pypi_0 pypi
[conda] torch-audiomentations 0.11.0 pypi_0 pypi
[conda] torch-pitch-shift 1.2.5 pypi_0 pypi
[conda] torchaudio 2.6.0.dev20250301 pypi_0 pypi
[conda] torchmetrics 1.2.0 pypi_0 pypi
[conda] torchview 0.2.6 pypi_0 pypi
[conda] torchvision 0.22.0.dev20250301 pypi_0 pypi

Metadata

Metadata

Labels

module: onnxRelated to torch.onnxtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions