Skip to content

Commit

Permalink
mypy fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
isaaccorley committed Nov 11, 2021
1 parent 2268299 commit 6702954
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
8 changes: 4 additions & 4 deletions torchgeo/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def draw_semantic_segmentation_masks(
mask: Tensor,
alpha: float = 0.5,
colors: Optional[ColorMap] = None,
) -> np.ndarray:
) -> np.ndarray: # type: ignore[type-arg]
"""Overlay a semantic segmentation mask onto an image.
Args:
Expand All @@ -444,11 +444,11 @@ def draw_semantic_segmentation_masks(
Returns:
a list of the subset datasets. Either [train, val] or [train, val, test]
"""
classes = torch.unique(mask)
classes = torch.unique(mask) # type: ignore[attr-defined]
classes = classes[1:]
class_masks = mask == classes[:, None, None]
img = draw_segmentation_masks(
image=image, masks=class_masks, alpha=alpha, colors=colors
)
img = img.permute((1, 2, 0))
return img.numpy()
img = img.permute((1, 2, 0)).numpy()
return img # type: ignore[no-any-return]
10 changes: 8 additions & 2 deletions torchgeo/datasets/xview.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,16 @@ def plot(self, index: int, alpha: float = 0.5) -> plt.Figure:
"""
sample = self[index]
image1 = draw_semantic_segmentation_masks(
sample["image"][0], sample["mask"][0], alpha=alpha, colors=self.colormap
sample["image"][0],
sample["mask"][0],
alpha=alpha,
colors=self.colormap # type: ignore[arg-type]
)
image2 = draw_semantic_segmentation_masks(
sample["image"][1], sample["mask"][1], alpha=alpha, colors=self.colormap
sample["image"][1],
sample["mask"][1],
alpha=alpha,
colors=self.colormap # type: ignore[arg-type]
)
fig, (ax1, ax2) = plt.subplots(ncols=2)
fig.set_size_inches((25, 25))
Expand Down

0 comments on commit 6702954

Please sign in to comment.