Skip to content

Commit

Permalink
add reference for the best practice of batch dot product
Browse files Browse the repository at this point in the history
  • Loading branch information
donglihe-hub committed Jan 3, 2024
1 parent cb35327 commit c96657c
Showing 1 changed file with 4 additions and 7 deletions.
11 changes: 4 additions & 7 deletions libmultilabel/nn/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,21 +77,18 @@ def update(self, preds, target):
self.num_sample += preds.shape[0]

def compute(self):
score = self.score / self.num_sample
return score
return self.score / self.num_sample

def _dcg(self, preds, target, discount):
_, sorted_top_k_idx = torch.topk(preds, k=self.top_k)
gains = target.take_along_dim(sorted_top_k_idx, dim=1)
dcg = (gains * discount).sum(dim=1)
return dcg
# best practice for batch dot product: https://discuss.pytorch.org/t/dot-product-batch-wise/9746/11
return (gains * discount).sum(dim=1)

def _idcg(self, target, discount):
"""optimized idcg for multilabel classification"""
cum_discount = discount.cumsum(dim=0)
idx = target.sum(dim=1).clamp(max=self.top_k) - 1
idcg = cum_discount[idx]
return idcg
return cum_discount[idx]


class RPrecision(Metric):
Expand Down

0 comments on commit c96657c

Please sign in to comment.