Skip to content

Commit

Permalink
Fix plotting in datamodules when dataset is a Subset (#2003)
Browse files Browse the repository at this point in the history
* Fix plotting in datamodules when dataset is a Subset

* Guessing at how to fix isort
  • Loading branch information
calebrob6 authored Apr 17, 2024
1 parent 8b24ca1 commit 2f46c73
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion torchgeo/datamodules/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from lightning.pytorch import LightningDataModule
from matplotlib.figure import Figure
from torch import Tensor
from torch.utils.data import DataLoader, Dataset, default_collate
from torch.utils.data import DataLoader, Dataset, Subset, default_collate

from ..datasets import GeoDataset, NonGeoDataset, stack_samples
from ..samplers import (
Expand Down Expand Up @@ -157,6 +157,8 @@ def plot(self, *args: Any, **kwargs: Any) -> Figure | None:
"""
fig: Figure | None = None
dataset = self.dataset or self.val_dataset
if isinstance(dataset, Subset):
dataset = dataset.dataset
if dataset is not None:
if hasattr(dataset, "plot"):
fig = dataset.plot(*args, **kwargs)
Expand Down

0 comments on commit 2f46c73

Please sign in to comment.