diff --git a/nemo/core/classes/common.py b/nemo/core/classes/common.py index 81f8e6b3e14c..60f842dbfb68 100644 --- a/nemo/core/classes/common.py +++ b/nemo/core/classes/common.py @@ -1015,11 +1015,14 @@ def __init__( self.ignore_collections = ignore_collections + def __call__(self, wrapped): + return self.wrapped_call(wrapped) + def unwrapped_call(self, wrapped): return wrapped @wrapt.decorator(enabled=is_typecheck_enabled) - def __call__(self, wrapped, instance: Typing, args, kwargs): + def wrapped_call(self, wrapped, instance: Typing, args, kwargs): """ Wrapper method that can be used on any function of a class that implements :class:`~nemo.core.Typing`. By default, it will utilize the `input_types` and `output_types` properties of the class inheriting Typing. @@ -1130,6 +1133,9 @@ def disable_semantic_checks(): typecheck.set_semantic_check_enabled(enabled=True) @staticmethod - def disable_wrapping(): - typecheck.set_typecheck_enabled(False) - typecheck.__call__ = nemo.core.classes.common.typecheck.unwrapped_call + def enable_wrapping(enabled: bool = True): + typecheck.set_typecheck_enabled(enabled) + if enabled: + typecheck.__call__ = nemo.core.classes.common.typecheck.wrapped_call + else: + typecheck.__call__ = nemo.core.classes.common.typecheck.unwrapped_call diff --git a/nemo/core/classes/exportable.py b/nemo/core/classes/exportable.py index d1773cedbaa3..aab09d42d907 100644 --- a/nemo/core/classes/exportable.py +++ b/nemo/core/classes/exportable.py @@ -225,6 +225,7 @@ def _export( if dynamic_axes is None: dynamic_axes = self.dynamic_shapes_for_export(use_dynamo) if use_dynamo: + typecheck.enable_wrapping(enabled=False) # https://github.com/pytorch/pytorch/issues/126339 with monkeypatched(torch.nn.RNNBase, "flatten_parameters", lambda *args: None): logging.info(f"Running export.export, dynamic shapes:{dynamic_axes}\n") @@ -279,6 +280,7 @@ def _export( else: raise ValueError(f'Encountered unknown export format {format}.') finally: + typecheck.enable_wrapping(enabled=True) typecheck.set_typecheck_enabled(enabled=True) if forward_method: type(self).forward = old_forward_method diff --git a/tests/collections/nlp/test_nlp_exportables.py b/tests/collections/nlp/test_nlp_exportables.py index 119093b703de..dbd5b3ac4427 100644 --- a/tests/collections/nlp/test_nlp_exportables.py +++ b/tests/collections/nlp/test_nlp_exportables.py @@ -25,7 +25,7 @@ # Has to be applied before first import of NeMo from nemo.core.classes import typecheck -typecheck.disable_wrapping() +typecheck.enable_wrapping(enabled=False) from nemo.collections import nlp as nemo_nlp from nemo.collections.nlp.models import IntentSlotClassificationModel