Skip to content

Commit

Permalink
suggested changes
Browse files Browse the repository at this point in the history
  • Loading branch information
nilsleh authored and adamjstewart committed Dec 30, 2022
1 parent dd30f0a commit a5a007c
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
5 changes: 2 additions & 3 deletions torchgeo/datamodules/potsdam.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose

from torchgeo.datasets.utils import collate_patches_per_tile
from torchgeo.samplers.utils import _to_tuple

from ..datasets import Potsdam2D
from ..datasets.utils import collate_patches_per_tile
from ..samplers.utils import _to_tuple
from .utils import dataset_split


Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
from .usavars import USAVars
from .utils import (
BoundingBox,
collate_patches_per_tile,
concat_samples,
merge_samples,
stack_samples,
Expand Down Expand Up @@ -210,4 +211,5 @@
"merge_samples",
"stack_samples",
"unbind_samples",
"collate_patches_per_tile",
)
3 changes: 2 additions & 1 deletion torchgeo/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,8 @@ def collate_patches_per_tile(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
'train_batch_size' * 'num_patches_per_tile'
"""
r_batch: Dict[str, Any] = default_collate(batch) # type: ignore[no-untyped-call]
r_batch["image"] = rearrange(r_batch["image"], "b t c h w -> (b t) c h w")
if "mask" in r_batch:
r_batch["image"] = rearrange(r_batch["image"], "b t c h w -> (b t) c h w")
r_batch["mask"] = rearrange(r_batch["mask"], "b t h w -> (b t) h w")
return r_batch

Expand Down

0 comments on commit a5a007c

Please sign in to comment.