diff --git a/pyproject.toml b/pyproject.toml index 5b0a12c4830..a8fbf4929b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -267,8 +267,6 @@ filterwarnings = [ "ignore:Skipping device Apple Paravirtual device that does not support Metal 2.0:UserWarning", # Unexpected warnings, worth investigating - # Lightning is having trouble inferring the batch size for ChesapeakeCVPRDataModule and CycloneDataModule for some reason - "ignore:Trying to infer the `batch_size` from an ambiguous collection:UserWarning", # https://github.com/pytest-dev/pytest/issues/11461 "ignore::pytest.PytestUnraisableExceptionWarning", ] diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index b0918d1140b..e7c12f9822f 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -242,7 +242,7 @@ def training_step( ] loss_dict = self(x, y) train_loss: Tensor = sum(loss_dict.values()) - self.log_dict(loss_dict) + self.log_dict(loss_dict, batch_size=batch_size) return train_loss def validation_step( @@ -267,7 +267,7 @@ def validation_step( # https://github.com/Lightning-AI/torchmetrics/pull/1832#issuecomment-1623890714 metrics.pop("val_classes", None) - self.log_dict(metrics) + self.log_dict(metrics, batch_size=batch_size) if ( batch_idx < 10 @@ -321,7 +321,7 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None # https://github.com/Lightning-AI/torchmetrics/pull/1832#issuecomment-1623890714 metrics.pop("test_classes", None) - self.log_dict(metrics) + self.log_dict(metrics, batch_size=batch_size) def predict_step( self, batch: Any, batch_idx: int, dataloader_idx: int = 0