Skip to content

Commit

Permalink
Fixed the missing type hint.
Browse files Browse the repository at this point in the history
  • Loading branch information
WINDSKY45 committed May 30, 2022
1 parent 245a3de commit 1cf5dde
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions mmcv/ops/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def __init__(self,
self.register_buffer('weight', weight)
self.reduction = reduction

def forward(self, input: torch.Tensor, target: torch.LongTensor):
def forward(self, input: torch.Tensor,
target: torch.LongTensor) -> torch.Tensor:
return sigmoid_focal_loss(input, target, self.gamma, self.alpha,
self.weight, self.reduction)

Expand Down Expand Up @@ -215,7 +216,8 @@ def __init__(self,
self.register_buffer('weight', weight)
self.reduction = reduction

def forward(self, input: torch.Tensor, target: torch.LongTensor):
def forward(self, input: torch.Tensor,
target: torch.LongTensor) -> torch.Tensor:
return softmax_focal_loss(input, target, self.gamma, self.alpha,
self.weight, self.reduction)

Expand Down

0 comments on commit 1cf5dde

Please sign in to comment.