Skip to content

Commit

Permalink
Eliminated test cross-contamination
Browse files Browse the repository at this point in the history
Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
  • Loading branch information
borisfom committed Jun 27, 2024
1 parent a762cc2 commit f804d26
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
14 changes: 10 additions & 4 deletions nemo/core/classes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,11 +1015,14 @@ def __init__(

self.ignore_collections = ignore_collections

def __call__(self, wrapped):
return self.wrapped_call(wrapped)

def unwrapped_call(self, wrapped):
return wrapped

@wrapt.decorator(enabled=is_typecheck_enabled)
def __call__(self, wrapped, instance: Typing, args, kwargs):
def wrapped_call(self, wrapped, instance: Typing, args, kwargs):
"""
Wrapper method that can be used on any function of a class that implements :class:`~nemo.core.Typing`.
By default, it will utilize the `input_types` and `output_types` properties of the class inheriting Typing.
Expand Down Expand Up @@ -1130,6 +1133,9 @@ def disable_semantic_checks():
typecheck.set_semantic_check_enabled(enabled=True)

@staticmethod
def disable_wrapping():
typecheck.set_typecheck_enabled(False)
typecheck.__call__ = nemo.core.classes.common.typecheck.unwrapped_call
def enable_wrapping(enabled: bool = True):
typecheck.set_typecheck_enabled(enabled)
if enabled:
typecheck.__call__ = nemo.core.classes.common.typecheck.wrapped_call
else:
typecheck.__call__ = nemo.core.classes.common.typecheck.unwrapped_call
2 changes: 2 additions & 0 deletions nemo/core/classes/exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def _export(
if dynamic_axes is None:
dynamic_axes = self.dynamic_shapes_for_export(use_dynamo)
if use_dynamo:
typecheck.enable_wrapping(enabled=False)
# https://github.com/pytorch/pytorch/issues/126339
with monkeypatched(torch.nn.RNNBase, "flatten_parameters", lambda *args: None):
logging.info(f"Running export.export, dynamic shapes:{dynamic_axes}\n")
Expand Down Expand Up @@ -279,6 +280,7 @@ def _export(
else:
raise ValueError(f'Encountered unknown export format {format}.')
finally:
typecheck.enable_wrapping(enabled=True)
typecheck.set_typecheck_enabled(enabled=True)
if forward_method:
type(self).forward = old_forward_method
Expand Down
2 changes: 1 addition & 1 deletion tests/collections/nlp/test_nlp_exportables.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# Has to be applied before first import of NeMo
from nemo.core.classes import typecheck

typecheck.disable_wrapping()
typecheck.enable_wrapping(enabled=False)

from nemo.collections import nlp as nemo_nlp
from nemo.collections.nlp.models import IntentSlotClassificationModel
Expand Down

0 comments on commit f804d26

Please sign in to comment.