Skip to content

Commit

Permalink
use batch_size when logging metrics to prevent lightning "cant infer …
Browse files Browse the repository at this point in the history
…batch_size" error
  • Loading branch information
aaprasad committed Apr 22, 2024
1 parent c351008 commit 4fed899
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions biogtr/models/gtr_runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Module containing training, validation and inference logic."""

import torch
from biogtr.inference.tracker import Tracker
from biogtr.inference import metrics
Expand Down Expand Up @@ -85,7 +86,7 @@ def training_step(
A dict containing the train loss plus any other metrics specified
"""
result = self._shared_eval_step(train_batch[0], mode="train")
self.log_metrics(result, "train")
self.log_metrics(result, len(train_batch[0]), "train")

return result

Expand All @@ -103,7 +104,7 @@ def validation_step(
A dict containing the val loss plus any other metrics specified
"""
result = self._shared_eval_step(val_batch[0], mode="val")
self.log_metrics(result, "val")
self.log_metrics(result, len(val_batch[0]), "val")

return result

Expand All @@ -119,7 +120,7 @@ def test_step(self, test_batch: list[dict], batch_idx: int) -> dict[str, float]:
A dict containing the val loss plus any other metrics specified
"""
result = self._shared_eval_step(test_batch[0], mode="test")
self.log_metrics(result, "test")
self.log_metrics(result, len(test_batch[0]), "test")

return result

Expand Down Expand Up @@ -206,13 +207,20 @@ def configure_optimizers(self) -> dict:
},
}

def log_metrics(self, result: dict, mode: str) -> None:
def log_metrics(self, result: dict, batch_size: int, mode: str) -> None:
"""Log metrics computed during evaluation.
Args:
result: A dict containing metrics to be logged.
batch_size: the size of the batch used to compute the metrics
mode: One of {'train', 'test' or 'val'}. Used as prefix while logging.
"""
if result:
for metric, val in result.items():
self.log(f"{mode}_{metric}", val, on_step=True, on_epoch=True)
self.log(
f"{mode}_{metric}",
val,
batch_size=batch_size,
on_step=True,
on_epoch=True,
)

0 comments on commit 4fed899

Please sign in to comment.