Skip to content

Commit

Permalink
update test NDCG
Browse files Browse the repository at this point in the history
  • Loading branch information
tqtg committed Aug 11, 2023
1 parent 8d99cc6 commit 6c5f332
Showing 1 changed file with 9 additions and 13 deletions.
22 changes: 9 additions & 13 deletions tests/cornac/metrics/test_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,25 +50,21 @@ def test_ndcg(self):

self.assertEqual(
1,
ndcg.compute(
gt_pos=np.asarray([0]), gt_neg=np.asarray([]), pd_rank=np.asarray([0])
),
ndcg.compute(gt_pos=np.asarray([0]), pd_rank=np.asarray([0])),
)

gt_pos = np.asarray([0, 2]) # [1, 3]
gt_neg = np.asarray([1]) # [2]
pd_rank = np.asarray([0, 2, 1]) # [1, 3, 2]
self.assertEqual(1, ndcg.compute(gt_pos, gt_neg, pd_rank))
self.assertEqual(1, ndcg.compute(gt_pos, pd_rank))

ndcg_2 = NDCG(k=2)
self.assertEqual(ndcg_2.k, 2)

gt_pos = np.asarray([2]) # [3]
gt_neg = np.asarray([0, 1]) # [1, 2]
pd_rank = np.asarray([1, 2, 0]) # [2, 3, 1]
self.assertEqual(
0.63,
float("{:.2f}".format(ndcg_2.compute(gt_pos, gt_neg, pd_rank))),
float("{:.2f}".format(ndcg_2.compute(gt_pos, pd_rank))),
)

def test_ncrr(self):
Expand Down Expand Up @@ -229,18 +225,18 @@ def test_auc(self):
self.assertEqual(auc.name, "AUC")

item_indices = np.arange(4)
gt_pos = np.array([2, 3]) # [0, 0, 1, 1]
gt_pos = np.array([2, 3]) # [0, 0, 1, 1]
pd_scores = np.array([0.1, 0.4, 0.35, 0.8])
auc_score = auc.compute(item_indices, pd_scores, gt_pos)
self.assertEqual(0.75, auc_score)

item_indices = np.arange(4)
gt_pos = np.array([1, 3]) # [0, 1, 0, 1]
gt_pos = np.array([1, 3]) # [0, 1, 0, 1]
pd_scores = np.array([0.1, 0.4, 0.35, 0.8])
auc_score = auc.compute(item_indices, pd_scores, gt_pos)
self.assertEqual(1.0, auc_score)

gt_pos = np.array([2]) # [0, 0, 1, 0]
gt_pos = np.array([2]) # [0, 0, 1, 0]
gt_neg = np.array([1, 1, 0, 0])
pd_scores = np.array([0.1, 0.4, 0.35, 0.8])
auc_score = auc.compute(item_indices, pd_scores, gt_pos, gt_neg)
Expand All @@ -253,17 +249,17 @@ def test_map(self):
self.assertEqual(mAP.name, "MAP")

item_indices = np.arange(3)
gt_pos = np.array([0]) # [1, 0, 0]
gt_pos = np.array([0]) # [1, 0, 0]
pd_scores = np.array([0.75, 0.5, 1])
self.assertEqual(0.5, mAP.compute(item_indices, pd_scores, gt_pos))

item_indices = np.arange(3)
gt_pos = np.array([2]) # [0, 0, 1]
gt_pos = np.array([2]) # [0, 0, 1]
pd_scores = np.array([1, 0.2, 0.1])
self.assertEqual(1 / 3, mAP.compute(item_indices, pd_scores, gt_pos))

item_indices = np.arange(10)
gt_pos = np.array([1, 3, 5]) # [0, 1, 0, 1, 0, 1, 0, 0, 0, 0]
gt_pos = np.array([1, 3, 5]) # [0, 1, 0, 1, 0, 1, 0, 0, 0, 0]
pd_scores = np.linspace(0.0, 1.0, len(item_indices))[::-1]
self.assertEqual(0.5, mAP.compute(item_indices, pd_scores, gt_pos))

Expand Down

0 comments on commit 6c5f332

Please sign in to comment.