Skip to content

Commit 32c06fa

Browse files
arendupre-commit-ci[bot]
authored andcommitted
remove auto generated examples (NVIDIA#7510)
* explicitly remove autogenerated examples for data parallel evaluation Signed-off-by: arendu <adithyare@nvidia.com> * mark autogenrated and remove it for test Signed-off-by: arendu <adithyare@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: arendu <adithyare@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister <sasha.meister.work@gmail.com>
1 parent 1be2b40 commit 32c06fa

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,13 @@ def __getitem__(self, idx):
171171
# idx may < 0 because we pad_samples_to_global_batch_size, e.g. id = -1
172172
if idx < 0:
173173
idx = len(self) + idx
174+
auto_gen_idx = True
175+
else:
176+
auto_gen_idx = False
174177
try:
175178
example = self.indexed_dataset[idx]
179+
if auto_gen_idx:
180+
example['__AUTOGENERATED__'] = True
176181
except Exception as e:
177182
logging.error(f"Error while loading example {idx} from dataset {self.file_path}")
178183
raise e

nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,6 @@ def inference_epoch_end(self, outputs, mode, data_cfg):
492492
)
493493

494494
# Remove duplicate examples due to distributed sampler.
495-
inp_label_set = set()
496495
deduplicated_outputs = {
497496
'preds': [],
498497
'labels': [],
@@ -505,14 +504,16 @@ def inference_epoch_end(self, outputs, mode, data_cfg):
505504
for pred, label, input, metadata in zip(
506505
batch['preds'], batch['labels'], batch['inputs'], batch['metadata']
507506
):
508-
key = input + label
509507
total_size += 1
510-
if key not in inp_label_set:
511-
inp_label_set.add(key)
508+
if not metadata.get("__AUTOGENERATED__", False):
512509
deduplicated_outputs['preds'].append(pred)
513510
deduplicated_outputs['labels'].append(label)
514511
deduplicated_outputs['inputs'].append(input)
515512
deduplicated_outputs['metadata'].append(metadata)
513+
else:
514+
logging.info(
515+
f"skipping autogenerated example example {input} prediction {pred} label {label}"
516+
)
516517

517518
# Compute metric score
518519
metric_name = self.val_metric_name if mode == 'validation' else self.test_metric_name

0 commit comments

Comments
 (0)