Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
  • Loading branch information
oke-aditya and SkafteNicki authored Sep 24, 2021
1 parent 10a4fde commit 857a19f
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def training_step(self, batch, batch_idx):
loss_dict = self.model(images, targets)
loss = sum(loss for loss in loss_dict.values())
self.log("loss", loss, prog_bar=True)
return loss

def validation_step(self, batch, batch_idx):
images, targets = batch
Expand Down
2 changes: 2 additions & 0 deletions pl_bolts/models/detection/retinanet/retainanet_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,15 @@ def training_step(self, batch, batch_idx):
loss_dict = self.model(images, targets)
loss = sum(loss for loss in loss_dict.values())
self.log("loss", loss, prog_bar=True)
return loss

def validation_step(self, batch, batch_idx):
images, targets = batch
# fasterrcnn takes only images for eval() mode
outs = self.model(images)
iou = torch.stack([_evaluate_iou(t, o) for t, o in zip(targets, outs)]).mean()
self.log("val_iou", iou, prog_bar=True)
return {"val_iou": iou}

def validation_epoch_end(self, outs):
avg_iou = torch.stack([o["val_iou"] for o in outs]).mean()
Expand Down

0 comments on commit 857a19f

Please sign in to comment.