From e4c89efecca5cb355cdc7a8d7aa866d9c29559c2 Mon Sep 17 00:00:00 2001 From: Jan Lasek Date: Fri, 24 Nov 2023 12:59:50 +0100 Subject: [PATCH] Resolve dtype with utils_funcs.py Signed-off-by: Jan Lasek --- nemo/collections/nlp/parts/utils_funcs.py | 8 +++++++- .../nlp_language_modeling/convert_hf_llama_to_nemo.py | 11 ++--------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/nemo/collections/nlp/parts/utils_funcs.py b/nemo/collections/nlp/parts/utils_funcs.py index 5185c6cf9b5a..2ec77faf91f5 100644 --- a/nemo/collections/nlp/parts/utils_funcs.py +++ b/nemo/collections/nlp/parts/utils_funcs.py @@ -12,7 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ['list2str', 'tensor2list', 'plot_confusion_matrix', 'get_classification_report'] +__all__ = [ + 'torch_dtype_from_precision', + 'list2str', + 'tensor2list', + 'plot_confusion_matrix', + 'get_classification_report', +] import os import time diff --git a/scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py b/scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py index d1453aeee972..380a7bab2eeb 100644 --- a/scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py +++ b/scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py @@ -41,6 +41,7 @@ NLPSaveRestoreConnector, PipelineMixedPrecisionPlugin, ) +from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision from nemo.utils import logging @@ -170,15 +171,6 @@ def convert(args): else: plugins.append(PipelineMixedPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler)) - if precision == 32: - dtype = torch.float32 - elif precision in [16, "16", "16-mixed"]: - dtype = torch.float16 - elif precision in ["bf16", "bf16-mixed"]: - dtype = torch.bfloat16 - else: - dtype = torch.float32 # fallback - nemo_config.precision = precision print(f"nemo_config: {nemo_config}") @@ -315,6 +307,7 @@ def convert(args): model._save_restore_connector = NLPSaveRestoreConnector() # cast to target precision and disable cpu init + dtype = torch_dtype_from_precision(precision) model = model.to(dtype=dtype) model.cfg.use_cpu_initialization = False