Skip to content

Type annotations for torchvision.ops #2331

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions torchvision/ops/_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import torch
from torch import Tensor
from torch.jit.annotations import List
from torch.jit.annotations import List, Tuple


def _cat(tensors, dim=0):
# type: (List[Tensor], int) -> Tensor
def _cat(tensors: List[Tensor], dim: int = 0) -> Tensor:
"""
Efficient version of torch.cat that avoids a copy if there is only a single element in a list
"""
Expand All @@ -15,8 +14,7 @@ def _cat(tensors, dim=0):
return torch.cat(tensors, dim)


def convert_boxes_to_roi_format(boxes):
# type: (List[Tensor]) -> Tensor
def convert_boxes_to_roi_format(boxes: List[Tensor]) -> Tensor:
concat_boxes = _cat([b for b in boxes], dim=0)
temp = []
for i, b in enumerate(boxes):
Expand All @@ -26,7 +24,7 @@ def convert_boxes_to_roi_format(boxes):
return rois


def check_roi_boxes_shape(boxes):
def check_roi_boxes_shape(boxes: Tensor):
if isinstance(boxes, (list, tuple)):
for _tensor in boxes:
assert _tensor.size(1) == 4, \
Expand Down
21 changes: 11 additions & 10 deletions torchvision/ops/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import torchvision


def nms(boxes, scores, iou_threshold):
# type: (Tensor, Tensor, float) -> Tensor
def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
"""
Performs non-maximum suppression (NMS) on the boxes according
to their intersection-over-union (IoU).
Expand Down Expand Up @@ -41,8 +40,12 @@ def nms(boxes, scores, iou_threshold):


@torch.jit._script_if_tracing
def batched_nms(boxes, scores, idxs, iou_threshold):
# type: (Tensor, Tensor, Tensor, float) -> Tensor
def batched_nms(
boxes: Tensor,
scores: Tensor,
idxs: Tensor,
iou_threshold: float,
) -> Tensor:
"""
Performs non-maximum suppression in a batched fashion.

Expand Down Expand Up @@ -83,8 +86,7 @@ def batched_nms(boxes, scores, idxs, iou_threshold):
return keep


def remove_small_boxes(boxes, min_size):
# type: (Tensor, float) -> Tensor
def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor:
"""
Remove boxes which contains at least one side smaller than min_size.

Expand All @@ -102,8 +104,7 @@ def remove_small_boxes(boxes, min_size):
return keep


def clip_boxes_to_image(boxes, size):
# type: (Tensor, Tuple[int, int]) -> Tensor
def clip_boxes_to_image(boxes: Tensor, size: Tuple[int, int]) -> Tensor:
"""
Clip boxes so that they lie inside an image of size `size`.

Expand Down Expand Up @@ -132,7 +133,7 @@ def clip_boxes_to_image(boxes, size):
return clipped_boxes.reshape(boxes.shape)


def box_area(boxes):
def box_area(boxes: Tensor) -> Tensor:
"""
Computes the area of a set of bounding boxes, which are specified by its
(x1, y1, x2, y2) coordinates.
Expand All @@ -149,7 +150,7 @@ def box_area(boxes):

# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
# with slight modifications
def box_iou(boxes1, boxes2):
def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
"""
Return intersection-over-union (Jaccard index) of boxes.

Expand Down
30 changes: 23 additions & 7 deletions torchvision/ops/deform_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,15 @@
from torch.jit.annotations import Optional, Tuple


def deform_conv2d(input, offset, weight, bias=None, stride=(1, 1), padding=(0, 0), dilation=(1, 1)):
# type: (Tensor, Tensor, Tensor, Optional[Tensor], Tuple[int, int], Tuple[int, int], Tuple[int, int]) -> Tensor
def deform_conv2d(
input: Tensor,
offset: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0),
dilation: Tuple[int, int] = (1, 1),
) -> Tensor:
"""
Performs Deformable Convolution, described in Deformable Convolutional Networks

Expand Down Expand Up @@ -80,8 +87,17 @@ class DeformConv2d(nn.Module):
"""
See deform_conv2d
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
dilation=1, groups=1, bias=True):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
):
super(DeformConv2d, self).__init__()

if in_channels % groups != 0:
Expand All @@ -107,14 +123,14 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,

self.reset_parameters()

def reset_parameters(self):
def reset_parameters(self) -> None:
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)

def forward(self, input, offset):
def forward(self, input: Tensor, offset: Tensor) -> Tensor:
"""
Arguments:
input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor
Expand All @@ -125,7 +141,7 @@ def forward(self, input, offset):
return deform_conv2d(input, offset, self.weight, self.bias, stride=self.stride,
padding=self.padding, dilation=self.dilation)

def __repr__(self):
def __repr__(self) -> str:
s = self.__class__.__name__ + '('
s += '{in_channels}'
s += ', {out_channels}'
Expand Down
78 changes: 47 additions & 31 deletions torchvision/ops/feature_pyramid_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,31 @@
import torch.nn.functional as F
from torch import nn, Tensor

from torch.jit.annotations import Tuple, List, Dict
from torch.jit.annotations import Tuple, List, Dict, Optional


class ExtraFPNBlock(nn.Module):
"""
Base class for the extra block in the FPN.

Arguments:
results (List[Tensor]): the result of the FPN
x (List[Tensor]): the original feature maps
names (List[str]): the names for each one of the
original feature maps

Returns:
results (List[Tensor]): the extended set of results
of the FPN
names (List[str]): the extended set of names for the results
"""
def forward(
self,
results: List[Tensor],
x: List[Tensor],
names: List[str],
) -> Tuple[List[Tensor], List[str]]:
pass


class FeaturePyramidNetwork(nn.Module):
Expand Down Expand Up @@ -44,7 +68,12 @@ class FeaturePyramidNetwork(nn.Module):
>>> ('feat3', torch.Size([1, 5, 8, 8]))]

"""
def __init__(self, in_channels_list, out_channels, extra_blocks=None):
def __init__(
self,
in_channels_list: List[int],
out_channels: int,
extra_blocks: Optional[ExtraFPNBlock] = None,
):
super(FeaturePyramidNetwork, self).__init__()
self.inner_blocks = nn.ModuleList()
self.layer_blocks = nn.ModuleList()
Expand All @@ -66,8 +95,7 @@ def __init__(self, in_channels_list, out_channels, extra_blocks=None):
assert isinstance(extra_blocks, ExtraFPNBlock)
self.extra_blocks = extra_blocks

def get_result_from_inner_blocks(self, x, idx):
# type: (Tensor, int) -> Tensor
def get_result_from_inner_blocks(self, x: Tensor, idx: int) -> Tensor:
"""
This is equivalent to self.inner_blocks[idx](x),
but torchscript doesn't support this yet
Expand All @@ -85,8 +113,7 @@ def get_result_from_inner_blocks(self, x, idx):
i += 1
return out

def get_result_from_layer_blocks(self, x, idx):
# type: (Tensor, int) -> Tensor
def get_result_from_layer_blocks(self, x: Tensor, idx: int) -> Tensor:
"""
This is equivalent to self.layer_blocks[idx](x),
but torchscript doesn't support this yet
Expand All @@ -104,8 +131,7 @@ def get_result_from_layer_blocks(self, x, idx):
i += 1
return out

def forward(self, x):
# type: (Dict[str, Tensor]) -> Dict[str, Tensor]
def forward(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""
Computes the FPN for a set of feature maps.

Expand Down Expand Up @@ -140,31 +166,16 @@ def forward(self, x):
return out


class ExtraFPNBlock(nn.Module):
"""
Base class for the extra block in the FPN.

Arguments:
results (List[Tensor]): the result of the FPN
x (List[Tensor]): the original feature maps
names (List[str]): the names for each one of the
original feature maps

Returns:
results (List[Tensor]): the extended set of results
of the FPN
names (List[str]): the extended set of names for the results
"""
def forward(self, results, x, names):
pass


class LastLevelMaxPool(ExtraFPNBlock):
"""
Applies a max_pool2d on top of the last feature map
"""
def forward(self, x, y, names):
# type: (List[Tensor], List[Tensor], List[str]) -> Tuple[List[Tensor], List[str]]
def forward(
self,
x: List[Tensor],
y: List[Tensor],
names: List[str],
) -> Tuple[List[Tensor], List[str]]:
names.append("pool")
x.append(F.max_pool2d(x[-1], 1, 2, 0))
return x, names
Expand All @@ -174,7 +185,7 @@ class LastLevelP6P7(ExtraFPNBlock):
"""
This module is used in RetinaNet to generate extra layers, P6 and P7.
"""
def __init__(self, in_channels, out_channels):
def __init__(self, in_channels: int, out_channels: int):
super(LastLevelP6P7, self).__init__()
self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)
self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1)
Expand All @@ -183,7 +194,12 @@ def __init__(self, in_channels, out_channels):
nn.init.constant_(module.bias, 0)
self.use_P5 = in_channels == out_channels

def forward(self, p, c, names):
def forward(
self,
p: List[Tensor],
c: List[Tensor],
names: List[str],
) -> Tuple[List[Tensor], List[str]]:
p5, c5 = p[-1], c[-1]
x = p5 if self.use_P5 else c5
p6 = self.p6(x)
Expand Down
25 changes: 20 additions & 5 deletions torchvision/ops/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import warnings
import torch
from torch import Tensor, Size
from torch.jit.annotations import List, Optional, Tuple


class Conv2d(torch.nn.Conv2d):
Expand Down Expand Up @@ -46,7 +48,12 @@ class FrozenBatchNorm2d(torch.nn.Module):
are fixed
"""

def __init__(self, num_features, eps=0., n=None):
def __init__(
self,
num_features: Tuple[int, ...],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I would have expected num_features to be an int, not a tuple with a varying number of elements.

eps: float = 0.,
n: Optional[Tuple[int, ...]] = None,
):
# n=None for backward-compatibility
if n is not None:
warnings.warn("`n` argument is deprecated and has been renamed `num_features`",
Expand All @@ -59,8 +66,16 @@ def __init__(self, num_features, eps=0., n=None):
self.register_buffer("running_mean", torch.zeros(num_features))
self.register_buffer("running_var", torch.ones(num_features))

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
def _load_from_state_dict(
self,
state_dict: dict,
prefix: str,
local_metadata: dict,
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
):
num_batches_tracked_key = prefix + 'num_batches_tracked'
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
Expand All @@ -69,7 +84,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
Expand All @@ -80,5 +95,5 @@ def forward(self, x):
bias = b - rm * scale
return x * scale + bias

def __repr__(self):
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.weight.shape[0]})"
3 changes: 1 addition & 2 deletions torchvision/ops/new_empty_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
from torch import Tensor


def _new_empty_tensor(x, shape):
# type: (Tensor, List[int]) -> Tensor
def _new_empty_tensor(x: Tensor, shape: List[int]) -> Tensor:
"""
Arguments:
input (Tensor): input tensor
Expand Down
Loading