diff --git a/nemo/utils/cast_utils.py b/nemo/utils/cast_utils.py index eeb48f35ffa7..21e977ec494d 100644 --- a/nemo/utils/cast_utils.py +++ b/nemo/utils/cast_utils.py @@ -70,8 +70,11 @@ def __init__(self, mod): self.mod = mod def forward(self, x): - with torch.cuda.amp.autocast(enabled=False): - ret = self.mod.forward(x.to(torch.float32)).to(x.dtype) + if torch.is_autocast_enabled() and x.dtype != torch.float32: + with torch.cuda.amp.autocast(enabled=False): + ret = self.mod.forward(x.to(torch.float32)).to(x.dtype) + else: + ret = self.mod.forward(x) return ret @@ -81,7 +84,10 @@ def __init__(self, mod): self.mod = mod def forward(self, *args): - from_dtype = args[0].dtype - with torch.cuda.amp.autocast(enabled=False): - ret = self.mod.forward(*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32)) - return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype) + if torch.is_autocast_enabled(): + from_dtype = args[0].dtype + with torch.cuda.amp.autocast(enabled=False): + ret = self.mod.forward(*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32)) + return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype) + else: + return self.mod.forward(*args) diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index cc0ce744a9a6..9fa2bc239eb8 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -440,22 +440,16 @@ def script_module(m: nn.Module): def replace_for_export(model: nn.Module) -> nn.Module: """ - Top-level function to replace default set of modules in model + Top-level function to replace 'default set' of modules in model, called from _prepare_for_export. NOTE: This occurs in place, if you want to preserve model then make sure to copy it first. Args: model : top level module - replace_1D_2D : include 1D -> 2D replacements Returns: model, possibly modified in-place """ from nemo.collections.tts.modules.submodules import MaskedInstanceNorm1d default_replacements = { - "BatchNorm1d": wrap_module(nn.BatchNorm1d, CastToFloat), - "BatchNorm2d": wrap_module(nn.BatchNorm2d, CastToFloat), - "LayerNorm": wrap_module(nn.LayerNorm, CastToFloat), - "InstanceNorm1d": wrap_module(nn.InstanceNorm1d, CastToFloat), - "MaskedInstanceNorm1d": wrap_module(MaskedInstanceNorm1d, CastToFloatAll), "MatchedScaleMaskSoftmax": wrap_module(None, replace_MatchedScaleMaskSoftmax), } @@ -463,3 +457,19 @@ def replace_for_export(model: nn.Module) -> nn.Module: replace_modules(model, default_replacements) # This one has to be the last replace_modules(model, script_replacements) + + +def add_casts_around_norms(model: nn.Module): + """ + Function to put additional to/from float32 casts around operations known to require full precision. + It was used with an extra post-parse script to have TRT preserve extra precision when --fp16 needed. + Should not be needed with TRT 8.6.1 or later. + """ + default_cast_replacements = { + "BatchNorm1d": wrap_module(nn.BatchNorm1d, CastToFloat), + "BatchNorm2d": wrap_module(nn.BatchNorm2d, CastToFloat), + "LayerNorm": wrap_module(nn.LayerNorm, CastToFloat), + "InstanceNorm1d": wrap_module(nn.InstanceNorm1d, CastToFloat), + "MaskedInstanceNorm1d": wrap_module(MaskedInstanceNorm1d, CastToFloatAll), + } + replace_modules(model, default_cast_replacements)