Skip to content

Commit

Permalink
Add type hints in ops/assign_score_withk.py (#2023)
Browse files Browse the repository at this point in the history
  • Loading branch information
shawnh2 authored May 29, 2022
1 parent de90c7a commit c70fafe
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions mmcv/ops/assign_score_withk.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import Tuple

import torch
from torch.autograd import Function

from ..utils import ext_loader
Expand Down Expand Up @@ -27,11 +30,11 @@ class AssignScoreWithK(Function):

@staticmethod
def forward(ctx,
scores,
point_features,
center_features,
knn_idx,
aggregate='sum'):
scores: torch.Tensor,
point_features: torch.Tensor,
center_features: torch.Tensor,
knn_idx: torch.Tensor,
aggregate: str = 'sum') -> torch.Tensor:
"""
Args:
scores (torch.Tensor): (B, npoint, K, M), predicted scores to
Expand Down Expand Up @@ -78,7 +81,9 @@ def forward(ctx,
return output

@staticmethod
def backward(ctx, grad_out):
def backward(
ctx, grad_out: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None]:
"""
Args:
grad_out (torch.Tensor): (B, out_dim, npoint, K)
Expand Down

0 comments on commit c70fafe

Please sign in to comment.