Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki committed Jul 24, 2021
1 parent 42d8778 commit e12c0c8
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions tests/classification/test_binned_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ class TestBinnedRecallAtPrecision(MetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("min_precision", [0.05, 0.1, 0.3, 0.5, 0.8, 0.95])
def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step, min_precision):
def test_binned_recall_at_precision(
self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step, min_precision
):
# rounding will simulate binning for both implementations
preds = Tensor(np.round(preds.numpy(), 2)) + 1e-6

Expand Down Expand Up @@ -112,8 +114,8 @@ class TestBinnedAveragePrecision(MetricTester):

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("thresholds", (10, 301, None, torch.linspace(0.0, 1.0, 101)))
def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step, thresholds):
@pytest.mark.parametrize("thresholds", (301, torch.linspace(0.0, 1.0, 101)))
def test_binned_average_precision(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step, thresholds):
# rounding will simulate binning for both implementations
preds = Tensor(np.round(preds.numpy(), 2)) + 1e-6

Expand Down

0 comments on commit e12c0c8

Please sign in to comment.