Skip to content

Commit

Permalink
move floating
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki committed Mar 21, 2022
1 parent a049748 commit bc2ce5a
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions torchmetrics/functional/classification/ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,10 @@ def _label_ranking_average_precision_update(
n_preds, n_labels = neg_preds.shape
for i in range(n_preds):
relevant = target[i] == 1
ranking = _rank_data(neg_preds[i][relevant])
ranking = _rank_data(neg_preds[i][relevant]).float()
if len(ranking) > 0 and len(ranking) < n_labels:
rank = _rank_data(neg_preds[i])[relevant]
score_idx = (ranking / rank).float().mean()
rank = _rank_data(neg_preds[i])[relevant].float()
score_idx = (ranking / rank).mean()
else:
score_idx = 1.0

Expand Down

0 comments on commit bc2ce5a

Please sign in to comment.