diff --git a/nemo/collections/asr/models/asr_model.py b/nemo/collections/asr/models/asr_model.py index 4420318dd416..7df0ae9fb689 100644 --- a/nemo/collections/asr/models/asr_model.py +++ b/nemo/collections/asr/models/asr_model.py @@ -198,7 +198,7 @@ def output_names(self): return get_io_names(otypes, self.disabled_deployment_output_names) def forward_for_export( - self, input, length=None, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None + self, audio_signal, length=None, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None ): """ This forward is used when we need to export the model to ONNX format. @@ -217,12 +217,12 @@ def forward_for_export( """ enc_fun = getattr(self.input_module, 'forward_for_export', self.input_module.forward) if cache_last_channel is None: - encoder_output = enc_fun(audio_signal=input, length=length) + encoder_output = enc_fun(audio_signal=audio_signal, length=length) if isinstance(encoder_output, tuple): encoder_output = encoder_output[0] else: encoder_output, length, cache_last_channel, cache_last_time, cache_last_channel_len = enc_fun( - audio_signal=input, + audio_signal=audio_signal, length=length, cache_last_channel=cache_last_channel, cache_last_time=cache_last_time, diff --git a/nemo/collections/asr/models/label_models.py b/nemo/collections/asr/models/label_models.py index 23ab5469e60c..ba5489839db4 100644 --- a/nemo/collections/asr/models/label_models.py +++ b/nemo/collections/asr/models/label_models.py @@ -333,8 +333,8 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: "embs": NeuralType(('B', 'D'), AcousticEncodedRepresentation()), } - def forward_for_export(self, processed_signal, processed_signal_len): - encoded, length = self.encoder(audio_signal=processed_signal, length=processed_signal_len) + def forward_for_export(self, audio_signal, length): + encoded, length = self.encoder(audio_signal=audio_signal, length=length) logits, embs = self.decoder(encoder_output=encoded, length=length) return logits, embs diff --git a/nemo/collections/asr/modules/conv_asr.py b/nemo/collections/asr/modules/conv_asr.py index 03b94ae0b209..25348dae95f3 100644 --- a/nemo/collections/asr/modules/conv_asr.py +++ b/nemo/collections/asr/modules/conv_asr.py @@ -876,7 +876,8 @@ def forward(self, encoder_output, length=None): embs = [] for layer in self.emb_layers: - pool, emb = layer(pool), layer[: self.emb_id](pool) + emb = layer[: self.emb_id](pool) + pool = layer(pool) embs.append(emb) pool = pool.squeeze(-1) diff --git a/nemo/collections/asr/parts/submodules/jasper.py b/nemo/collections/asr/parts/submodules/jasper.py index e53f6299b08a..c2beb3918ead 100644 --- a/nemo/collections/asr/parts/submodules/jasper.py +++ b/nemo/collections/asr/parts/submodules/jasper.py @@ -478,7 +478,7 @@ def forward_for_export(self, x, lengths): mask = self.make_pad_mask(lengths, max_audio_length=max_len, device=x.device) mask = ~mask # 0 represents value, 1 represents pad x = x.float() # For stable AMP, SE must be computed at fp32. - x.masked_fill_(mask, 0.0) # mask padded values explicitly to 0 + x = x.masked_fill(mask, 0.0) # mask padded values explicitly to 0 y = self._se_pool_step(x, mask) # [B, C, 1] y = y.transpose(1, -1) # [B, 1, C] y = self.fc(y) # [B, 1, C] diff --git a/nemo/core/classes/common.py b/nemo/core/classes/common.py index a7aa9e17b1fd..fe7f040287cc 100644 --- a/nemo/core/classes/common.py +++ b/nemo/core/classes/common.py @@ -1005,7 +1005,7 @@ def __init__( self.ignore_collections = ignore_collections def __call__(self, wrapped): - return self.wrapped_call(wrapped) if is_typecheck_enabled() else self.unwrapped_call(wrapped) + return self.unwrapped_call(wrapped) if is_typecheck_enabled() else self.unwrapped_call(wrapped) def unwrapped_call(self, wrapped): return wrapped diff --git a/nemo/core/classes/exportable.py b/nemo/core/classes/exportable.py index 5e7d5522765c..fe6cbce5bcfa 100644 --- a/nemo/core/classes/exportable.py +++ b/nemo/core/classes/exportable.py @@ -14,18 +14,20 @@ from abc import ABC from typing import Dict, List, Optional, Union +import onnx import torch from pytorch_lightning.core.module import _jit_is_scripting from nemo.core.classes import typecheck from nemo.core.neural_types import NeuralType from nemo.core.utils.neural_type_utils import get_dynamic_axes, get_io_names -from nemo.utils import logging +from nemo.utils import logging, monkeypatched from nemo.utils.export_utils import ( ExportFormat, augment_filename, get_export_format, parse_input_example, + rename_onnx_io, replace_for_export, verify_runtime, verify_torchscript, @@ -177,7 +179,7 @@ def _export( with torch.inference_mode(), torch.no_grad(), torch.jit.optimized_execution(True), _jit_is_scripting(): if input_example is None: - input_example = self.input_module.input_example() + input_example = self.input_module.input_example(max_batch=2) # Remove i/o examples from args we propagate to enclosed Exportables my_args.pop('output') @@ -191,7 +193,9 @@ def _export( input_list, input_dict = parse_input_example(input_example) input_names = self.input_names output_names = self.output_names - output_example = tuple(self.forward(*input_list, **input_dict)) + output_example = self.forward(*input_list, **input_dict) + if not isinstance(output_example, tuple): + output_example = (output_example,) if check_trace: if isinstance(check_trace, bool): @@ -219,16 +223,49 @@ def _export( # dynamic axis is a mapping from input/output_name => list of "dynamic" indices if dynamic_axes is None: dynamic_axes = get_dynamic_axes(self.input_module.input_types_for_export, input_names) - dynamic_axes.update(get_dynamic_axes(self.output_module.output_types_for_export, output_names)) + if use_dynamo: + dynamic_shapes = {} + batch = torch.export.Dim("batch", max=128) + for name, dims in dynamic_axes.items(): + ds = {} + for d in dims: + if d == 0: + ds[d] = batch + # this currently fails, https://github.com/pytorch/pytorch/issues/126127 + # else: + # ds[d] = torch.export.Dim(name + '__' + str(d)) + dynamic_shapes[name] = ds + else: + dynamic_shapes = dynamic_axes if use_dynamo: - options = torch.onnx.ExportOptions(dynamic_shapes=dynamic_axes) - ex_model = torch.export.export( - jitted_model, tuple(input_list), kwargs=input_dict, strict=False - ) - ex_model = ex_model.run_decompositions() - ex = torch.onnx.dynamo_export(ex_model, *input_list, **input_dict, export_options=options) - ex.save(output, model_state=jitted_model.state_dict()) - input_names = None + import onnxscript + + # https://github.com/microsoft/onnxscript/issues/1544 + onnxscript.optimizer.constant_folding._DEFAULT_CONSTANT_FOLD_SIZE_LIMIT = 1024 * 1024 * 64 + + # https://github.com/pytorch/pytorch/issues/126339 + with monkeypatched(torch.nn.RNNBase, "flatten_parameters", lambda *args: None): + print("Running export.export, dynamic shapes:\n", dynamic_shapes) + + ex_model = torch.export.export( + jitted_model, + tuple(input_list), + kwargs=input_dict, + dynamic_shapes=dynamic_shapes, + strict=False, + ) + ex_model = ex_model.run_decompositions() + + print("Running torch.onnx.dynamo_export ...") + + options = torch.onnx.ExportOptions(dynamic_shapes=True, op_level_debug=True) + ex_module = ex_model.module() + ex = torch.onnx.dynamo_export(ex_module, *input_list, **input_dict, export_options=options) + ex.save(output) # , model_state=ex_module.state_dict()) + del ex + # Rename I/O after save - don't want to risk modifying ex._model_proto + rename_onnx_io(output, input_names, output_names) + # input_names=None else: torch.onnx.export( jitted_model, diff --git a/nemo/utils/__init__.py b/nemo/utils/__init__.py index ebf892927723..a1e59646ae13 100644 --- a/nemo/utils/__init__.py +++ b/nemo/utils/__init__.py @@ -21,6 +21,7 @@ avoid_float16_autocast_context, cast_all, cast_tensor, + monkeypatched, ) from nemo.utils.dtype import str_to_dtype from nemo.utils.nemo_logging import Logger as _Logger diff --git a/nemo/utils/cast_utils.py b/nemo/utils/cast_utils.py index 21e977ec494d..d59189cc912e 100644 --- a/nemo/utils/cast_utils.py +++ b/nemo/utils/cast_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import nullcontext +from contextlib import contextmanager, nullcontext import torch @@ -91,3 +91,12 @@ def forward(self, *args): return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype) else: return self.mod.forward(*args) + + +@contextmanager +def monkeypatched(object, name, patch): + """ Temporarily monkeypatches an object. """ + pre_patched_value = getattr(object, name) + setattr(object, name, patch) + yield object + setattr(object, name, pre_patched_value) diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index 58256659bfc5..eda7abd9fe49 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -177,6 +177,8 @@ def verify_runtime(model, output, input_examples, input_names, check_tolerance=0 for input_example in input_examples: input_list, input_dict = parse_input_example(input_example) output_example = model.forward(*input_list, **input_dict) + if not isinstance(output_example, tuple): + output_example = (output_example,) ort_input = to_onnxrt_input(ort_input_names, input_names, input_dict, input_list) all_good = all_good and run_ort_and_compare(sess, ort_input, output_example, check_tolerance) status = "SUCCESS" if all_good else "FAIL" @@ -221,10 +223,12 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01): try: if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=100 * check_tolerance): this_good = False - except Exception: # there may ne size mismatch and it may be OK + except Exception: # there may be size mismatch and it may be OK this_good = False if not this_good: - logging.info(f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nONNXruntime:\n{tout}") + logging.info( + f"onnxruntime results mismatch! PyTorch(expected, {expected.shape}):\n{expected}\nONNXruntime, {tout.shape}:\n{tout}" + ) all_good = False return all_good @@ -479,3 +483,25 @@ def add_casts_around_norms(model: nn.Module): "MaskedInstanceNorm1d": wrap_module(MaskedInstanceNorm1d, CastToFloatAll), } replace_modules(model, default_cast_replacements) + + +def rename_onnx_io(output, input_names, output_names): + onnx_model = onnx.load(output) + rename_map = {} + for inp, name in zip(onnx_model.graph.input, input_names): + rename_map[inp.name] = name + for out, name in zip(onnx_model.graph.output, output_names): + rename_map[out.name] = name + for n in onnx_model.graph.node: + for inp in range(len(n.input)): + if n.input[inp] in rename_map: + n.input[inp] = rename_map[n.input[inp]] + for out in range(len(n.output)): + if n.output[out] in rename_map: + n.output[out] = rename_map[n.output[out]] + + for i in range(len(onnx_model.graph.input)): + onnx_model.graph.input[i].name = input_names[i] + for i in range(len(onnx_model.graph.output)): + onnx_model.graph.output[i].name = output_names[i] + onnx.save(onnx_model, output) diff --git a/tests/collections/asr/test_asr_exportables.py b/tests/collections/asr/test_asr_exportables.py index 6bb669a70a24..9377f49aa1b6 100644 --- a/tests/collections/asr/test_asr_exportables.py +++ b/tests/collections/asr/test_asr_exportables.py @@ -30,10 +30,6 @@ from nemo.core.utils import numba_utils from nemo.core.utils.numba_utils import __NUMBA_MINIMUM_VERSION__ -# from nemo.core.classes import typecheck -# typecheck.enable_wrapping(enabled=False) - - NUMBA_RNNT_LOSS_AVAILABLE = numba_utils.numba_cuda_is_supported(__NUMBA_MINIMUM_VERSION__) @@ -56,6 +52,8 @@ def test_EncDecCTCModel_export_to_onnx(self): ) onnx_model = onnx.load(filename) onnx.checker.check_model(onnx_model, full_check=True) # throws when failed + assert onnx_model.graph.input[0].name == 'audio_signal' + assert onnx_model.graph.output[0].name == 'logprobs' @pytest.mark.run_only_on('GPU') @pytest.mark.unit @@ -68,6 +66,8 @@ def test_EncDecClassificationModel_export_to_onnx(self, speech_classification_mo ) onnx_model = onnx.load(filename) onnx.checker.check_model(onnx_model, full_check=True) # throws when failed + assert onnx_model.graph.input[0].name == 'audio_signal' + assert onnx_model.graph.output[0].name == 'logits' @pytest.mark.run_only_on('GPU') @pytest.mark.unit @@ -78,6 +78,8 @@ def test_EncDecSpeakerLabelModel_export_to_onnx(self, speaker_label_model): model.export(output=filename) onnx_model = onnx.load(filename) onnx.checker.check_model(onnx_model, full_check=True) # throws when failed + assert onnx_model.graph.input[0].name == 'audio_signal' + assert onnx_model.graph.output[0].name == 'logits' @pytest.mark.run_only_on('GPU') @pytest.mark.unit @@ -88,6 +90,9 @@ def test_EncDecCitrinetModel_export_to_onnx(self, citrinet_model): model.export(output=filename) onnx_model = onnx.load(filename) onnx.checker.check_model(onnx_model, full_check=True) # throws when failed + assert onnx_model.graph.input[0].name == 'audio_signal' + assert onnx_model.graph.input[1].name == 'length' + assert onnx_model.graph.output[0].name == 'logprobs' @pytest.mark.pleasefixme @pytest.mark.run_only_on('GPU') @@ -127,6 +132,9 @@ def test_EncDecCitrinetModel_limited_SE_export_to_onnx(self, citrinet_model): ) onnx_model = onnx.load(filename) onnx.checker.check_model(onnx_model, full_check=True) # throws when failed + assert onnx_model.graph.input[0].name == 'audio_signal' + assert onnx_model.graph.input[1].name == 'length' + assert onnx_model.graph.output[0].name == 'logprobs' @pytest.mark.run_only_on('GPU') @pytest.mark.unit @@ -136,7 +144,7 @@ def test_EncDecRNNTModel_export_to_onnx(self, citrinet_rnnt_model): with tempfile.TemporaryDirectory() as tmpdir: fn = 'citri_rnnt.onnx' filename = os.path.join(tmpdir, fn) - files, descr = model.export(output=filename, verbose=False) + files, descr = model.export(output=filename, dynamic_axes={}, verbose=False) encoder_filename = os.path.join(tmpdir, 'encoder-' + fn) assert files[0] == encoder_filename @@ -145,6 +153,10 @@ def test_EncDecRNNTModel_export_to_onnx(self, citrinet_rnnt_model): onnx.checker.check_model(onnx_model, full_check=True) # throws when failed assert len(onnx_model.graph.input) == 2 assert len(onnx_model.graph.output) == 2 + assert onnx_model.graph.input[0].name == 'audio_signal' + assert onnx_model.graph.input[1].name == 'length' + assert onnx_model.graph.output[0].name == 'outputs' + assert onnx_model.graph.output[1].name == 'encoded_lengths' decoder_joint_filename = os.path.join(tmpdir, 'decoder_joint-' + fn) assert files[1] == decoder_joint_filename @@ -159,12 +171,21 @@ def test_EncDecRNNTModel_export_to_onnx(self, citrinet_rnnt_model): # enc_logits + (all decoder inputs - state tuple) + flattened state list assert len(onnx_model.graph.input) == (1 + (len(input_examples) - 1) + num_states) + assert onnx_model.graph.input[0].name == 'encoder_outputs' + assert onnx_model.graph.input[1].name == 'targets' + assert onnx_model.graph.input[2].name == 'target_length' if num_states > 0: for idx, ip in enumerate(onnx_model.graph.input[3:]): assert ip.name == "input_" + state_name + '_' + str(idx + 1) assert len(onnx_model.graph.output) == (len(input_examples) - 1) + num_states + assert onnx_model.graph.output[0].name == 'outputs' + assert onnx_model.graph.output[1].name == 'prednet_lengths' + + if num_states > 0: + for idx, op in enumerate(onnx_model.graph.output[2:]): + assert op.name == "output_" + state_name + '_' + str(idx + 1) @pytest.mark.run_only_on('GPU') @pytest.mark.unit @@ -185,6 +206,8 @@ def test_EncDecRNNTModel_export_to_ts(self, citrinet_rnnt_model): assert ts_encoder is not None arguments = ts_encoder.forward.schema.arguments[1:] # First value is `self` + assert arguments[0].name == 'audio_signal' + assert arguments[1].name == 'length' decoder_joint_filename = os.path.join(tmpdir, 'decoder_joint-' + fn) assert files[1] == decoder_joint_filename @@ -202,6 +225,13 @@ def test_EncDecRNNTModel_export_to_ts(self, citrinet_rnnt_model): # enc_logits + (all decoder inputs - state tuple) + flattened state list assert len(ts_decoder_joint_args) == (1 + (len(input_examples) - 1) + num_states) + assert ts_decoder_joint_args[0].name == 'encoder_outputs' + assert ts_decoder_joint_args[1].name == 'targets' + assert ts_decoder_joint_args[2].name == 'target_length' + + if num_states > 0: + for idx, ip in enumerate(ts_decoder_joint_args[3:]): + assert ip.name == "input_" + state_name + '_' + str(idx + 1) @pytest.mark.run_only_on('GPU') @pytest.mark.unit @@ -235,6 +265,8 @@ def test_EncDecCTCModel_adapted_export_to_onnx(self): ) onnx_model = onnx.load(filename) onnx.checker.check_model(onnx_model, full_check=True) # throws when failed + assert onnx_model.graph.input[0].name == 'audio_signal' + assert onnx_model.graph.output[0].name == 'logprobs' def setup_method(self): self.preprocessor = { @@ -638,8 +670,3 @@ def squeezeformer_model(): ) conformer_model = EncDecCTCModel(cfg=modelConfig) return conformer_model - - -if __name__ == "__main__": - t = TestExportable() - t.test_EncDecClassificationModel_export_to_onnx(speech_classification_model()) diff --git a/tests/collections/nlp/test_nlp_exportables.py b/tests/collections/nlp/test_nlp_exportables.py index 3181e1ce0c46..c0b97caea4ed 100644 --- a/tests/collections/nlp/test_nlp_exportables.py +++ b/tests/collections/nlp/test_nlp_exportables.py @@ -20,9 +20,6 @@ import torch import wget from omegaconf import DictConfig, OmegaConf -from nemo.core.classes import typecheck - -typecheck.enable_wrapping(enabled=False) from nemo.collections import nlp as nemo_nlp from nemo.collections.nlp.models import IntentSlotClassificationModel diff --git a/tests/collections/tts/test_tts_exportables.py b/tests/collections/tts/test_tts_exportables.py index 2569d708e235..67f016b0c2af 100644 --- a/tests/collections/tts/test_tts_exportables.py +++ b/tests/collections/tts/test_tts_exportables.py @@ -18,10 +18,6 @@ import torch from omegaconf import OmegaConf -from nemo.core.classes import typecheck - -typecheck.enable_wrapping(enabled=False) - from nemo.collections.tts.models import FastPitchModel, HifiGanModel, RadTTSModel from nemo.utils.app_state import AppState