Skip to content

Commit

Permalink
[Enhance]Add type hint in focal_loss.py.
Browse files Browse the repository at this point in the history
  • Loading branch information
WINDSKY45 committed May 26, 2022
1 parent a22843c commit b07bff6
Showing 1 changed file with 31 additions and 17 deletions.
48 changes: 31 additions & 17 deletions mmcv/ops/focal_loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Union

import torch
import torch.nn as nn
from torch.autograd import Function
Expand All @@ -15,7 +17,8 @@
class SigmoidFocalLossFunction(Function):

@staticmethod
def symbolic(g, input, target, gamma, alpha, weight, reduction):
def symbolic(g, input: torch.Tensor, target: torch.Tensor, gamma: float,
alpha: float, weight: torch.Tensor, reduction: str):
return g.op(
'mmcv::MMCVSigmoidFocalLoss',
input,
Expand All @@ -27,12 +30,13 @@ def symbolic(g, input, target, gamma, alpha, weight, reduction):

@staticmethod
def forward(ctx,
input,
target,
gamma=2.0,
alpha=0.25,
weight=None,
reduction='mean'):
input: torch.Tensor,
target: Union[torch.Tensor, torch.LongTensor,
torch.cuda.LongTensor],
gamma: float = 2.0,
alpha: float = 0.25,
weight: Optional[torch.Tensor] = None,
reduction: str = 'mean') -> torch.Tensor:

assert isinstance(
target, (torch.Tensor, torch.LongTensor, torch.cuda.LongTensor))
Expand Down Expand Up @@ -64,7 +68,7 @@ def forward(ctx,

@staticmethod
@once_differentiable
def backward(ctx, grad_output):
def backward(ctx, grad_output: torch.Tensor) -> tuple:
input, target, weight = ctx.saved_tensors

grad_input = input.new_zeros(input.size())
Expand All @@ -88,14 +92,18 @@ def backward(ctx, grad_output):

class SigmoidFocalLoss(nn.Module):

def __init__(self, gamma, alpha, weight=None, reduction='mean'):
def __init__(self,
gamma: float,
alpha: float,
weight: Optional[torch.Tensor] = None,
reduction: str = 'mean') -> None:
super().__init__()
self.gamma = gamma
self.alpha = alpha
self.register_buffer('weight', weight)
self.reduction = reduction

def forward(self, input, target):
def forward(self, input: torch.Tensor, target: torch.Tensor):
return sigmoid_focal_loss(input, target, self.gamma, self.alpha,
self.weight, self.reduction)

Expand All @@ -110,7 +118,8 @@ def __repr__(self):
class SoftmaxFocalLossFunction(Function):

@staticmethod
def symbolic(g, input, target, gamma, alpha, weight, reduction):
def symbolic(g, input: torch.Tensor, target: torch.Tensor, gamma: float,
alpha: float, weight: torch.Tensor, reduction: str):
return g.op(
'mmcv::MMCVSoftmaxFocalLoss',
input,
Expand All @@ -122,12 +131,13 @@ def symbolic(g, input, target, gamma, alpha, weight, reduction):

@staticmethod
def forward(ctx,
input,
target,
input: torch.Tensor,
target: Union[torch.Tensor, torch.LongTensor,
torch.cuda.LongTensor],
gamma=2.0,
alpha=0.25,
weight=None,
reduction='mean'):
reduction='mean') -> torch.Tensor:

assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor))
assert input.dim() == 2
Expand Down Expand Up @@ -169,7 +179,7 @@ def forward(ctx,
return output

@staticmethod
def backward(ctx, grad_output):
def backward(ctx, grad_output: torch.Tensor) -> tuple:
input_softmax, target, weight = ctx.saved_tensors
buff = input_softmax.new_zeros(input_softmax.size(0))
grad_input = input_softmax.new_zeros(input_softmax.size())
Expand All @@ -194,14 +204,18 @@ def backward(ctx, grad_output):

class SoftmaxFocalLoss(nn.Module):

def __init__(self, gamma, alpha, weight=None, reduction='mean'):
def __init__(self,
gamma: float,
alpha: float,
weight: Optional[torch.Tensor] = None,
reduction: str = 'mean') -> None:
super().__init__()
self.gamma = gamma
self.alpha = alpha
self.register_buffer('weight', weight)
self.reduction = reduction

def forward(self, input, target):
def forward(self, input: torch.Tensor, target: torch.Tensor):
return softmax_focal_loss(input, target, self.gamma, self.alpha,
self.weight, self.reduction)

Expand Down

0 comments on commit b07bff6

Please sign in to comment.