diff --git a/torchmetrics/functional/classification/ranking.py b/torchmetrics/functional/classification/ranking.py index e6b67f2fc74..65834726a85 100644 --- a/torchmetrics/functional/classification/ranking.py +++ b/torchmetrics/functional/classification/ranking.py @@ -117,10 +117,10 @@ def _label_ranking_average_precision_update( n_preds, n_labels = neg_preds.shape for i in range(n_preds): relevant = target[i] == 1 - ranking = _rank_data(neg_preds[i][relevant]) + ranking = _rank_data(neg_preds[i][relevant]).float() if len(ranking) > 0 and len(ranking) < n_labels: - rank = _rank_data(neg_preds[i])[relevant] - score_idx = (ranking / rank).float().mean() + rank = _rank_data(neg_preds[i])[relevant].float() + score_idx = (ranking / rank).mean() else: score_idx = 1.0