From 9cd977fe8fa987a27acca68d8932802882e55717 Mon Sep 17 00:00:00 2001 From: Abhishree Date: Mon, 14 Aug 2023 13:39:37 -0700 Subject: [PATCH] Add bf16-mixed and 16-mixed in module.py Signed-off-by: Abhishree --- nemo/collections/nlp/modules/common/megatron/module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/module.py b/nemo/collections/nlp/modules/common/megatron/module.py index caa424cc01b3..42bda16df221 100644 --- a/nemo/collections/nlp/modules/common/megatron/module.py +++ b/nemo/collections/nlp/modules/common/megatron/module.py @@ -263,13 +263,13 @@ def __init__(self, config: ModelParallelConfig, module, precision, share_token_e super().__init__(config=config, share_token_embeddings=share_token_embeddings) self.precision = precision - if precision == 'bf16': + if precision in ['bf16', 'bf16-mixed']: self.add_module('module', module.bfloat16()) def float16_converter(val): return val.bfloat16() - elif int(precision) == 16: + elif precision in [16, '16', '16-mixed']: self.add_module('module', module.half()) def float16_converter(val):