Skip to content

Commit

Permalink
Append output of val step to self.validation_step_outputs (NVIDIA#7530)…
Browse files Browse the repository at this point in the history
… (NVIDIA#7532)

Signed-off-by: Abhishree <abhishreetm@gmail.com>
Co-authored-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com>
Signed-off-by: Sasha Meister <ameister@nvidia.com>
  • Loading branch information
2 people authored and sashameister committed Oct 2, 2023
1 parent 0b75d82 commit 2fee8e4
Showing 1 changed file with 6 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -173,16 +173,18 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0):
model_output = torch.argmax(model_output, 1)

eval_tensors = {'preds': model_output, 'labels': labels}
return {'val_loss': val_loss, 'eval_tensors': eval_tensors}
output = {'val_loss': val_loss, 'eval_tensors': eval_tensors}
self.validation_step_outputs.append(output)
return output

def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0):
"""
Called at the end of validation to aggregate outputs.
outputs: list of individual outputs of each validation step.
"""
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
preds = torch.cat([x['eval_tensors']['preds'] for x in outputs])
labels = torch.cat([x['eval_tensors']['labels'] for x in outputs])
avg_loss = torch.stack([x['val_loss'] for x in self.validation_step_outputs]).mean()
preds = torch.cat([x['eval_tensors']['preds'] for x in self.validation_step_outputs])
labels = torch.cat([x['eval_tensors']['labels'] for x in self.validation_step_outputs])

all_preds = []
all_labels = []
Expand Down

0 comments on commit 2fee8e4

Please sign in to comment.