Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@ ignore_errors = True

ignore_errors=True

[mypy-torchvision.models.detection.anchor_utils]

ignore_errors = True

[mypy-torchvision.models.detection.backbone_utils]

ignore_errors = True
Expand Down
33 changes: 17 additions & 16 deletions torchvision/models/detection/anchor_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import List, Optional
from typing import List, Optional, Tuple

import torch
from torch import nn, Tensor
Expand Down Expand Up @@ -34,16 +34,17 @@ class AnchorGenerator(nn.Module):

def __init__(
self,
sizes=((128, 256, 512),),
aspect_ratios=((0.5, 1.0, 2.0),),
):
sizes: Tuple[Tuple[int, ...]] = ((128, 256, 512),),
aspect_ratios: Tuple[Tuple[float, ...]] = ((0.5, 1.0, 2.0),),
) -> None:

super(AnchorGenerator, self).__init__()

if not isinstance(sizes[0], (list, tuple)):
# TODO change this
sizes = tuple((s,) for s in sizes)
sizes = tuple((s,) for s in sizes) # type: ignore[assignment]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't do anything much here, we are checking if size is list, which isn't the case, and mypy gets confused.

if not isinstance(aspect_ratios[0], (list, tuple)):
aspect_ratios = (aspect_ratios,) * len(sizes)
aspect_ratios = (aspect_ratios,) * len(sizes) # type: ignore[assignment]

assert len(sizes) == len(aspect_ratios)

Expand All @@ -59,11 +60,11 @@ def __init__(
# This method assumes aspect ratio = height / width for an anchor.
def generate_anchors(
self,
scales: List[int],
aspect_ratios: List[float],
scales: Tuple[int, ...],
aspect_ratios: Tuple[float, ...],
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
):
) -> Tensor:
scales = torch.as_tensor(scales, dtype=dtype, device=device)
aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
h_ratios = torch.sqrt(aspect_ratios)
Expand All @@ -75,10 +76,10 @@ def generate_anchors(
base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2
return base_anchors.round()

def set_cell_anchors(self, dtype: torch.dtype, device: torch.device):
def set_cell_anchors(self, dtype: torch.dtype, device: torch.device) -> None:
self.cell_anchors = [cell_anchor.to(dtype=dtype, device=device) for cell_anchor in self.cell_anchors]

def num_anchors_per_location(self):
def num_anchors_per_location(self) -> List[int]:
return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]

# For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
Expand Down Expand Up @@ -116,7 +117,7 @@ def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]])
return anchors

def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]:
grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
grid_sizes = [list(feature_map.shape[-2:]) for feature_map in feature_maps]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.shape[-2:] gives us a torch.Size() object and we need List[int].

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly:

grid_size = [torch.Size([200, 200]), torch.Size([100, 100]), torch.Size([50, 50]), torch.Size([25, 25]), torch.Size([13, 13])]

image_size = image_list.tensors.shape[-2:]
dtype, device = feature_maps[0].dtype, feature_maps[0].device
strides = [
Expand Down Expand Up @@ -162,7 +163,7 @@ def __init__(
scales: Optional[List[float]] = None,
steps: Optional[List[int]] = None,
clip: bool = True,
):
) -> None:
super().__init__()
if steps is not None:
assert len(aspect_ratios) == len(steps)
Expand Down Expand Up @@ -204,7 +205,7 @@ def _generate_wh_pairs(
_wh_pairs.append(torch.as_tensor(wh_pairs, dtype=dtype, device=device))
return _wh_pairs

def num_anchors_per_location(self):
def num_anchors_per_location(self) -> List[int]:
# Estimate num of anchors based on aspect ratios: 2 default boxes + 2 * ratios of feaure map.
return [2 + 2 * len(r) for r in self.aspect_ratios]

Expand Down Expand Up @@ -247,8 +248,8 @@ def __repr__(self) -> str:
return s.format(**self.__dict__)

def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]:
grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
image_size = image_list.tensors.shape[-2:]
grid_sizes = [list(feature_map.shape[-2:]) for feature_map in feature_maps]
image_size = list(image_list.tensors.shape[-2:])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, We need to change torch.Size() object to List[int].

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The debugger shows that:

grid_sizes = [torch.Size([38, 38]), torch.Size([19, 19]), torch.Size([10, 10]), torch.Size([5, 5]), torch.Size([3, 3]), torch.Size([1, 1])]
image_size = torch.Size([300, 300])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but we have annotated grid_sizes as List[List[int[[ and image_size as List[int].
So a workaround was the above.

I'm not sure, should we change the original annotation of grid_sizes and image_size ?

Sorry I'm not very clear with what should be done. 😕

dtype, device = feature_maps[0].dtype, feature_maps[0].device
default_boxes = self._grid_default_boxes(grid_sizes, image_size, dtype=dtype)
default_boxes = default_boxes.to(device)
Expand Down