Skip to content

Commit

Permalink
fix mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
isaaccorley committed Apr 21, 2023
1 parent e9935c3 commit ebb4252
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 24 deletions.
4 changes: 2 additions & 2 deletions torchgeo/datamodules/pastis.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class PASTISDataModule(pl.LightningDataModule, abc.ABC):
"""LightningDataModule implementation for the PASTIS dataset."""

# (S1A, S1D, S2)
band_means = torch.tensor( # type: ignore[attr-defined]
band_means = torch.tensor(
[
-10.930951118469238,
-17.348514556884766,
Expand All @@ -54,7 +54,7 @@ class PASTISDataModule(pl.LightningDataModule, abc.ABC):
1639.370361328125,
]
)
band_stds = torch.tensor( # type: ignore[attr-defined]
band_stds = torch.tensor(
[
3.285966396331787,
3.2129523754119873,
Expand Down
44 changes: 22 additions & 22 deletions torchgeo/datasets/pastis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import abc
import glob
import os
from typing import Callable, Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -170,7 +170,7 @@ def __len__(self) -> int:
"""
return len(self.files)

def _load_image(self, index: str) -> Tensor:
def _load_image(self, index: int) -> Tensor:
"""Load a single time-series.
Args:
Expand All @@ -187,10 +187,10 @@ def _load_image(self, index: str) -> Tensor:
path = self.files[index][self.bands]
array = np.load(path)

tensor = torch.from_numpy(array) # type: ignore[attr-defined]
tensor = torch.from_numpy(array)
return tensor

def _load_target(self, index: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
def _load_target(self, index: int) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor]]:
"""Load the target mask for a single image.
Args:
Expand Down Expand Up @@ -259,16 +259,16 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]:
Returns:
data and label at that index
"""
image = self.load_image(index)
mask = self.load_target(index)
image = self._load_image(index)
mask = self._load_target(index)
sample = {"image": image, "mask": mask}

if self.transforms is not None:
sample = self.transforms(sample)

return sample

def _load_target(self, index: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
def _load_target(self, index: int) -> Tensor:
"""Load the target mask for a single image.
Args:
Expand All @@ -278,7 +278,7 @@ def _load_target(self, index: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
the target mask
"""
array = np.load(self.files[index]["semantic"])
tensor = torch.from_numpy(array) # type: ignore[attr-defined]
tensor = torch.from_numpy(array)
return tensor


Expand All @@ -294,8 +294,8 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]:
Returns:
data and label at that index
"""
image = self.load_image(index)
mask, boxes, labels = self.load_target(index)
image = self._load_image(index)
mask, boxes, labels = self._load_target(index)
sample = {"image": image, "mask": mask, "boxes": boxes, "label": labels}

if self.transforms is not None:
Expand All @@ -315,15 +315,15 @@ def _load_target(self, index: int) -> Tuple[Tensor, Tensor, Tensor]:
mask_array = np.load(self.files[index]["semantic"])
instance_array = np.load(self.files[index]["instance"])

mask_tensor = torch.from_numpy(mask_array) # type: ignore[attr-defined]
instance_tensor = torch.from_numpy(instance_array) # type: ignore[attr-defined]
mask_tensor = torch.from_numpy(mask_array)
instance_tensor = torch.from_numpy(instance_array)

# Convert from HxWxC to CxHxW
mask_tensor = mask_tensor.permute((2, 0, 1))
instance_tensor = instance_tensor.permute((2, 0, 1))

# Convert instance mask of N instances to N binary instance masks
instance_ids = torch.unique(instance_tensor) # type: ignore[attr-defined]
instance_ids = torch.unique(instance_tensor)
# Exclude a mask for unknown/background
instance_ids = instance_ids[instance_ids != 0]
instance_ids = instance_ids[:, None, None]
Expand All @@ -333,21 +333,21 @@ def _load_target(self, index: int) -> Tuple[Tensor, Tensor, Tensor]:
labels_list = []
for mask in masks:
label = mask_tensor[mask[None, :, :]]
label = torch.unique(label)[0] # type: ignore[attr-defined]
label = torch.unique(label)[0]
labels_list.append(label)

# Get bounding boxes for each instance
boxes_list = []
for mask in masks:
pos = torch.where(mask) # type: ignore[attr-defined]
xmin = torch.min(pos[1]) # type: ignore[attr-defined]
xmax = torch.max(pos[1]) # type: ignore[attr-defined]
ymin = torch.min(pos[0]) # type: ignore[attr-defined]
ymax = torch.max(pos[0]) # type: ignore[attr-defined]
pos = torch.where(mask)
xmin = torch.min(pos[1])
xmax = torch.max(pos[1])
ymin = torch.min(pos[0])
ymax = torch.max(pos[0])
boxes_list.append([xmin, ymin, xmax, ymax])

masks = masks.to(torch.uint8) # type: ignore[attr-defined]
boxes = torch.tensor(boxes_list).to(torch.float) # type: ignore[attr-defined]
labels = torch.tensor(labels_list).to(torch.long) # type: ignore[attr-defined]
masks = masks.to(torch.uint8)
boxes = torch.tensor(boxes_list).to(torch.float)
labels = torch.tensor(labels_list).to(torch.long)

return masks, boxes, labels

0 comments on commit ebb4252

Please sign in to comment.