diff --git a/torchgeo/datamodules/potsdam.py b/torchgeo/datamodules/potsdam.py index 832f4cc4d6a..b7f407ec3ef 100644 --- a/torchgeo/datamodules/potsdam.py +++ b/torchgeo/datamodules/potsdam.py @@ -13,10 +13,9 @@ from torch.utils.data import DataLoader, Dataset from torchvision.transforms import Compose -from torchgeo.datasets.utils import collate_patches_per_tile -from torchgeo.samplers.utils import _to_tuple - from ..datasets import Potsdam2D +from ..datasets.utils import collate_patches_per_tile +from ..samplers.utils import _to_tuple from .utils import dataset_split diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 0acaa561269..19d85b74508 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -99,6 +99,7 @@ from .usavars import USAVars from .utils import ( BoundingBox, + collate_patches_per_tile, concat_samples, merge_samples, stack_samples, @@ -210,4 +211,5 @@ "merge_samples", "stack_samples", "unbind_samples", + "collate_patches_per_tile", ) diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index ebefe3e7d59..6e6e78046ee 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -229,7 +229,8 @@ def collate_patches_per_tile(batch: List[Dict[str, Any]]) -> Dict[str, Any]: 'train_batch_size' * 'num_patches_per_tile' """ r_batch: Dict[str, Any] = default_collate(batch) # type: ignore[no-untyped-call] - r_batch["image"] = rearrange(r_batch["image"], "b t c h w -> (b t) c h w") + if "mask" in r_batch: + r_batch["image"] = rearrange(r_batch["image"], "b t c h w -> (b t) c h w") r_batch["mask"] = rearrange(r_batch["mask"], "b t h w -> (b t) h w") return r_batch