diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py index c4cccc75789c..4df6a1bb577b 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py @@ -227,7 +227,7 @@ def _build_loss_mask(self, processed_example): input_ids = processed_example['input_ids'] answer_start_idx = processed_example['answer_start_idx'] if self.answer_only_loss: - loss_mask = [float(idx > answer_start_idx) for idx in range(len(input_ids))] + loss_mask = [float(idx >= answer_start_idx) for idx in range(len(input_ids))] else: loss_mask = [1.0] * len(input_ids)