From a3c298db34f3286191c3e79f98672c46491affb0 Mon Sep 17 00:00:00 2001 From: Huiying Date: Sat, 20 Apr 2024 10:06:36 -0700 Subject: [PATCH] Add safety checks for 'data' key in MegatronGPTModel cfg (#8991) Signed-off-by: HuiyingLi --- .../nlp/models/language_modeling/megatron_gpt_model.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index a660af46f13d2..e5e48cdc10da8 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -367,9 +367,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): self.log_train_loss = bool(int(os.getenv("NEMO_LOG_TRAIN_LOSS", 1))) self.log_memory_usage = bool(int(os.getenv("NEMO_LOG_MEMORY_USAGE", 0))) self.loss_broadcast_src_rank = None - self.return_output_tensors = cfg.data.get('return_output_tensors', False) - self.validation_drop_last = cfg.data.get('validation_drop_last', True) - self.sample_weight = cfg.data.get('sample_weight', 'token') + data_cfg = cfg.get('data', {}) + self.return_output_tensors = data_cfg.get('return_output_tensors', False) + self.validation_drop_last = data_cfg.get('validation_drop_last', True) + self.sample_weight = data_cfg.get('sample_weight', 'token') self.validation_param_sync_overlap = self.cfg.get('validation_param_sync_overlap', False) self.inference_params = None