Skip to content

Commit

Permalink
Not doing CastToFloat by default (#6524) (#6563)
Browse files Browse the repository at this point in the history
* Not doing CastToFloat by default



* Added docustring



* Dummy commit



---------

Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
Co-authored-by: Boris Fomitchev <borisfom@users.noreply.github.com>
Co-authored-by: Eric Harper <complex451@gmail.com>
  • Loading branch information
3 people authored May 10, 2023
1 parent c21f299 commit e6ee331
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 13 deletions.
18 changes: 12 additions & 6 deletions nemo/utils/cast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,11 @@ def __init__(self, mod):
self.mod = mod

def forward(self, x):
with torch.cuda.amp.autocast(enabled=False):
ret = self.mod.forward(x.to(torch.float32)).to(x.dtype)
if torch.is_autocast_enabled() and x.dtype != torch.float32:
with torch.cuda.amp.autocast(enabled=False):
ret = self.mod.forward(x.to(torch.float32)).to(x.dtype)
else:
ret = self.mod.forward(x)
return ret


Expand All @@ -81,7 +84,10 @@ def __init__(self, mod):
self.mod = mod

def forward(self, *args):
from_dtype = args[0].dtype
with torch.cuda.amp.autocast(enabled=False):
ret = self.mod.forward(*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32))
return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype)
if torch.is_autocast_enabled():
from_dtype = args[0].dtype
with torch.cuda.amp.autocast(enabled=False):
ret = self.mod.forward(*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32))
return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype)
else:
return self.mod.forward(*args)
24 changes: 17 additions & 7 deletions nemo/utils/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,26 +440,36 @@ def script_module(m: nn.Module):

def replace_for_export(model: nn.Module) -> nn.Module:
"""
Top-level function to replace default set of modules in model
Top-level function to replace 'default set' of modules in model, called from _prepare_for_export.
NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.
Args:
model : top level module
replace_1D_2D : include 1D -> 2D replacements
Returns:
model, possibly modified in-place
"""
from nemo.collections.tts.modules.submodules import MaskedInstanceNorm1d

default_replacements = {
"BatchNorm1d": wrap_module(nn.BatchNorm1d, CastToFloat),
"BatchNorm2d": wrap_module(nn.BatchNorm2d, CastToFloat),
"LayerNorm": wrap_module(nn.LayerNorm, CastToFloat),
"InstanceNorm1d": wrap_module(nn.InstanceNorm1d, CastToFloat),
"MaskedInstanceNorm1d": wrap_module(MaskedInstanceNorm1d, CastToFloatAll),
"MatchedScaleMaskSoftmax": wrap_module(None, replace_MatchedScaleMaskSoftmax),
}

replace_modules(model, default_Apex_replacements)
replace_modules(model, default_replacements)
# This one has to be the last
replace_modules(model, script_replacements)


def add_casts_around_norms(model: nn.Module):
"""
Function to put additional to/from float32 casts around operations known to require full precision.
It was used with an extra post-parse script to have TRT preserve extra precision when --fp16 needed.
Should not be needed with TRT 8.6.1 or later.
"""
default_cast_replacements = {
"BatchNorm1d": wrap_module(nn.BatchNorm1d, CastToFloat),
"BatchNorm2d": wrap_module(nn.BatchNorm2d, CastToFloat),
"LayerNorm": wrap_module(nn.LayerNorm, CastToFloat),
"InstanceNorm1d": wrap_module(nn.InstanceNorm1d, CastToFloat),
"MaskedInstanceNorm1d": wrap_module(MaskedInstanceNorm1d, CastToFloatAll),
}
replace_modules(model, default_cast_replacements)

0 comments on commit e6ee331

Please sign in to comment.