From ebb425200ec4f5bb27de076561b8d712b61966e6 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Wed, 6 Jul 2022 12:00:21 -0500 Subject: [PATCH] fix mypy errors --- torchgeo/datamodules/pastis.py | 4 ++-- torchgeo/datasets/pastis.py | 44 +++++++++++++++++----------------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/torchgeo/datamodules/pastis.py b/torchgeo/datamodules/pastis.py index d3cdf80e791..e084168ac0e 100644 --- a/torchgeo/datamodules/pastis.py +++ b/torchgeo/datamodules/pastis.py @@ -34,7 +34,7 @@ class PASTISDataModule(pl.LightningDataModule, abc.ABC): """LightningDataModule implementation for the PASTIS dataset.""" # (S1A, S1D, S2) - band_means = torch.tensor( # type: ignore[attr-defined] + band_means = torch.tensor( [ -10.930951118469238, -17.348514556884766, @@ -54,7 +54,7 @@ class PASTISDataModule(pl.LightningDataModule, abc.ABC): 1639.370361328125, ] ) - band_stds = torch.tensor( # type: ignore[attr-defined] + band_stds = torch.tensor( [ 3.285966396331787, 3.2129523754119873, diff --git a/torchgeo/datasets/pastis.py b/torchgeo/datasets/pastis.py index 86188dad35b..148d495aef8 100644 --- a/torchgeo/datasets/pastis.py +++ b/torchgeo/datasets/pastis.py @@ -6,7 +6,7 @@ import abc import glob import os -from typing import Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -170,7 +170,7 @@ def __len__(self) -> int: """ return len(self.files) - def _load_image(self, index: str) -> Tensor: + def _load_image(self, index: int) -> Tensor: """Load a single time-series. Args: @@ -187,10 +187,10 @@ def _load_image(self, index: str) -> Tensor: path = self.files[index][self.bands] array = np.load(path) - tensor = torch.from_numpy(array) # type: ignore[attr-defined] + tensor = torch.from_numpy(array) return tensor - def _load_target(self, index: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + def _load_target(self, index: int) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor]]: """Load the target mask for a single image. Args: @@ -259,8 +259,8 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]: Returns: data and label at that index """ - image = self.load_image(index) - mask = self.load_target(index) + image = self._load_image(index) + mask = self._load_target(index) sample = {"image": image, "mask": mask} if self.transforms is not None: @@ -268,7 +268,7 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]: return sample - def _load_target(self, index: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + def _load_target(self, index: int) -> Tensor: """Load the target mask for a single image. Args: @@ -278,7 +278,7 @@ def _load_target(self, index: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]: the target mask """ array = np.load(self.files[index]["semantic"]) - tensor = torch.from_numpy(array) # type: ignore[attr-defined] + tensor = torch.from_numpy(array) return tensor @@ -294,8 +294,8 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]: Returns: data and label at that index """ - image = self.load_image(index) - mask, boxes, labels = self.load_target(index) + image = self._load_image(index) + mask, boxes, labels = self._load_target(index) sample = {"image": image, "mask": mask, "boxes": boxes, "label": labels} if self.transforms is not None: @@ -315,15 +315,15 @@ def _load_target(self, index: int) -> Tuple[Tensor, Tensor, Tensor]: mask_array = np.load(self.files[index]["semantic"]) instance_array = np.load(self.files[index]["instance"]) - mask_tensor = torch.from_numpy(mask_array) # type: ignore[attr-defined] - instance_tensor = torch.from_numpy(instance_array) # type: ignore[attr-defined] + mask_tensor = torch.from_numpy(mask_array) + instance_tensor = torch.from_numpy(instance_array) # Convert from HxWxC to CxHxW mask_tensor = mask_tensor.permute((2, 0, 1)) instance_tensor = instance_tensor.permute((2, 0, 1)) # Convert instance mask of N instances to N binary instance masks - instance_ids = torch.unique(instance_tensor) # type: ignore[attr-defined] + instance_ids = torch.unique(instance_tensor) # Exclude a mask for unknown/background instance_ids = instance_ids[instance_ids != 0] instance_ids = instance_ids[:, None, None] @@ -333,21 +333,21 @@ def _load_target(self, index: int) -> Tuple[Tensor, Tensor, Tensor]: labels_list = [] for mask in masks: label = mask_tensor[mask[None, :, :]] - label = torch.unique(label)[0] # type: ignore[attr-defined] + label = torch.unique(label)[0] labels_list.append(label) # Get bounding boxes for each instance boxes_list = [] for mask in masks: - pos = torch.where(mask) # type: ignore[attr-defined] - xmin = torch.min(pos[1]) # type: ignore[attr-defined] - xmax = torch.max(pos[1]) # type: ignore[attr-defined] - ymin = torch.min(pos[0]) # type: ignore[attr-defined] - ymax = torch.max(pos[0]) # type: ignore[attr-defined] + pos = torch.where(mask) + xmin = torch.min(pos[1]) + xmax = torch.max(pos[1]) + ymin = torch.min(pos[0]) + ymax = torch.max(pos[0]) boxes_list.append([xmin, ymin, xmax, ymax]) - masks = masks.to(torch.uint8) # type: ignore[attr-defined] - boxes = torch.tensor(boxes_list).to(torch.float) # type: ignore[attr-defined] - labels = torch.tensor(labels_list).to(torch.long) # type: ignore[attr-defined] + masks = masks.to(torch.uint8) + boxes = torch.tensor(boxes_list).to(torch.float) + labels = torch.tensor(labels_list).to(torch.long) return masks, boxes, labels