Skip to content

Commit

Permalink
Removed some unnecessary type hint.
Browse files Browse the repository at this point in the history
  • Loading branch information
WINDSKY45 committed May 31, 2022
1 parent 1cf5dde commit 52111dc
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 5 deletions.
4 changes: 2 additions & 2 deletions mmcv/ops/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(self,
gamma: float,
alpha: float,
weight: Optional[torch.Tensor] = None,
reduction: str = 'mean') -> None:
reduction: str = 'mean'):
super().__init__()
self.gamma = gamma
self.alpha = alpha
Expand Down Expand Up @@ -209,7 +209,7 @@ def __init__(self,
gamma: float,
alpha: float,
weight: Optional[torch.Tensor] = None,
reduction: str = 'mean') -> None:
reduction: str = 'mean'):
super().__init__()
self.gamma = gamma
self.alpha = alpha
Expand Down
4 changes: 3 additions & 1 deletion mmcv/ops/knn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import torch
from torch.autograd import Function

Expand All @@ -19,7 +21,7 @@ class KNN(Function):
def forward(ctx,
k: int,
xyz: torch.Tensor,
center_xyz: torch.Tensor = None,
center_xyz: Optional[torch.Tensor] = None,
transposed: bool = False) -> torch.Tensor:
"""
Args:
Expand Down
12 changes: 10 additions & 2 deletions mmcv/ops/masked_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
class MaskedConv2dFunction(Function):

@staticmethod
def symbolic(g, features, mask, weight, bias, padding, stride):
def symbolic(g, features: torch.Tensor, mask: torch.Tensor,
weight: torch.Tensor, bias: torch.Tensor, padding: int,
stride: int):
return g.op(
'mmcv::MMCVMaskedConv2d',
features,
Expand All @@ -27,7 +29,13 @@ def symbolic(g, features, mask, weight, bias, padding, stride):
stride_i=stride)

@staticmethod
def forward(ctx, features, mask, weight, bias, padding=0, stride=1):
def forward(ctx,
features: torch.Tensor,
mask: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
padding: int = 0,
stride: int = 1) -> torch.Tensor:
assert mask.dim() == 3 and mask.size(0) == 1
assert features.dim() == 4 and features.size(0) == 1
assert features.size()[2:] == mask.size()[1:]
Expand Down

0 comments on commit 52111dc

Please sign in to comment.