Skip to content

Commit

Permalink
Add safety checks for 'data' key in MegatronGPTModel cfg (#8991)
Browse files Browse the repository at this point in the history
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Signed-off-by: Marc Romeyn <marcromeyn@gmail.com>
  • Loading branch information
HuiyingLi authored and marcromeyn committed Apr 22, 2024
1 parent 39a278d commit e821aac
Showing 1 changed file with 4 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
self.log_train_loss = bool(int(os.getenv("NEMO_LOG_TRAIN_LOSS", 1)))
self.log_memory_usage = bool(int(os.getenv("NEMO_LOG_MEMORY_USAGE", 0)))
self.loss_broadcast_src_rank = None
self.return_output_tensors = cfg.data.get('return_output_tensors', False)
self.validation_drop_last = cfg.data.get('validation_drop_last', True)
self.sample_weight = cfg.data.get('sample_weight', 'token')
data_cfg = cfg.get('data', {})
self.return_output_tensors = data_cfg.get('return_output_tensors', False)
self.validation_drop_last = data_cfg.get('validation_drop_last', True)
self.sample_weight = data_cfg.get('sample_weight', 'token')
self.validation_param_sync_overlap = self.cfg.get('validation_param_sync_overlap', False)

self.inference_params = None
Expand Down

0 comments on commit e821aac

Please sign in to comment.