Skip to content

Commit

Permalink
Merge pull request ASUS-AICS#351 from ntumlgroup/fix-recall
Browse files Browse the repository at this point in the history
Fixed recall for no label instances
  • Loading branch information
Eleven1Liu authored Jan 18, 2024
2 parents 55f0f49 + 46fd233 commit d1ef6b9
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
10 changes: 7 additions & 3 deletions libmultilabel/linear/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def update_argsort(self, argsort_preds: np.ndarray, target: np.ndarray):
dcg = _DCG_argsort(argsort_preds, target, self.top_k)
idcg = _IDCG(target, self.top_k)
ndcg_score = dcg / idcg
# by convention, ndcg is 0 for zero label instances
self.score += np.nan_to_num(ndcg_score, nan=0.0).sum()
self.num_sample += argsort_preds.shape[0]

Expand Down Expand Up @@ -95,6 +96,7 @@ def update(self, preds: np.ndarray, target: np.ndarray):
def update_argsort(self, argsort_preds: np.ndarray, target: np.ndarray):
top_k_idx = argsort_preds[:, -self.top_k :]
num_relevant = np.take_along_axis(target, top_k_idx, axis=-1).sum(axis=-1) # (batch_size, )
# by convention, rprecision is 0 for zero label instances
self.score += np.nan_to_num(num_relevant / np.minimum(self.top_k, target.sum(axis=-1)), nan=0.0).sum()
self.num_sample += argsort_preds.shape[0]

Expand Down Expand Up @@ -167,7 +169,8 @@ def update(self, preds: np.ndarray, target: np.ndarray):
def update_argsort(self, argsort_preds: np.ndarray, target: np.ndarray):
top_k_idx = argsort_preds[:, -self.top_k :]
num_relevant = np.take_along_axis(target, top_k_idx, -1).sum(axis=-1)
self.score += np.nan_to_num(num_relevant / target.sum(axis=-1), nan=1.0).sum()
# by convention, recall is 0 for zero label instances
self.score += np.nan_to_num(num_relevant / target.sum(axis=-1), nan=0.0).sum()
self.num_sample += argsort_preds.shape[0]

def compute(self) -> float:
Expand Down Expand Up @@ -210,14 +213,15 @@ def update(self, preds: np.ndarray, target: np.ndarray):
def compute(self) -> float:
prev_settings = np.seterr("ignore")

# F1 is 0 for the cases where there are no positive instances
if self.average == "macro":
score = np.nansum(2 * self.tp / (2 * self.tp + self.fp + self.fn)) / self.num_classes
elif self.average == "micro":
score = np.nan_to_num(2 * np.sum(self.tp) / np.sum(2 * self.tp + self.fp + self.fn))
score = np.nan_to_num(2 * np.sum(self.tp) / np.sum(2 * self.tp + self.fp + self.fn), nan=0.0)
elif self.average == "another-macro":
macro_prec = np.nansum(self.tp / (self.tp + self.fp)) / self.num_classes
macro_recall = np.nansum(self.tp / (self.tp + self.fn)) / self.num_classes
score = np.nan_to_num(2 * macro_prec * macro_recall / (macro_prec + macro_recall))
score = np.nan_to_num(2 * macro_prec * macro_recall / (macro_prec + macro_recall), nan=0.0)

np.seterr(**prev_settings)
return score
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = libmultilabel
version = 0.5.0
version = 0.5.1
author = LibMultiLabel Team
license = MIT License
license_file = LICENSE
Expand Down

0 comments on commit d1ef6b9

Please sign in to comment.