Skip to content

Commit

Permalink
skip non working cpu + half tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki committed Aug 20, 2022
1 parent 961cf9f commit ab1d4ed
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
6 changes: 3 additions & 3 deletions tests/unittests/classification/test_average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_binary_average_precision_differentiability(self, input):
def test_binary_average_precision_dtype_cpu(self, input, dtype):
preds, target = input
if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_8:
pytest.xfail(reason="torch.flip not support before pytorch v1.8")
pytest.xfail(reason="torch.flip not support before pytorch v1.8 for cpu + half precision")
if (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision")
self.run_precision_test_cpu(
Expand Down Expand Up @@ -214,7 +214,7 @@ def test_multiclass_average_precision_differentiability(self, input):
def test_multiclass_average_precision_dtype_cpu(self, input, dtype):
preds, target = input
if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_8:
pytest.xfail(reason="torch.flip not support before pytorch v1.8")
pytest.xfail(reason="torch.flip not support before pytorch v1.8 for cpu + half precision")
if dtype == torch.half and not ((0 < preds) & (preds < 1)).all():
pytest.xfail(reason="half support for torch.softmax on cpu not implemented")
self.run_precision_test_cpu(
Expand Down Expand Up @@ -326,7 +326,7 @@ def test_multiclass_average_precision_differentiability(self, input):
def test_multilabel_average_precision_dtype_cpu(self, input, dtype):
preds, target = input
if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_8:
pytest.xfail(reason="torch.flip not support before pytorch v1.8")
pytest.xfail(reason="torch.flip not support before pytorch v1.8 for cpu + half precision")
if dtype == torch.half and not ((0 < preds) & (preds < 1)).all():
pytest.xfail(reason="half support for torch.softmax on cpu not implemented")
self.run_precision_test_cpu(
Expand Down
14 changes: 7 additions & 7 deletions tests/unittests/classification/test_recall_at_fixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
multiclass_recall_at_fixed_precision,
multilabel_recall_at_fixed_precision,
)
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_8
from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases
from unittests.helpers import seed_all
from unittests.helpers.testers import NUM_CLASSES, MetricTester, inject_ignore_index, remove_ignore_index
Expand Down Expand Up @@ -118,8 +118,8 @@ def test_binary_recall_at_fixed_precision_differentiability(self, input):
@pytest.mark.parametrize("dtype", [torch.half, torch.double])
def test_binary_recall_at_fixed_precision_dtype_cpu(self, input, dtype):
preds, target = input
if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6:
pytest.xfail(reason="half support of core ops not support before pytorch v1.6")
if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_8:
pytest.xfail(reason="torch.flip not support before pytorch v1.8 for cpu + half precision")
if (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision")
self.run_precision_test_cpu(
Expand Down Expand Up @@ -235,8 +235,8 @@ def test_multiclass_recall_at_fixed_precision_differentiability(self, input):
@pytest.mark.parametrize("dtype", [torch.half, torch.double])
def test_multiclass_recall_at_fixed_precision_dtype_cpu(self, input, dtype):
preds, target = input
if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6:
pytest.xfail(reason="half support of core ops not support before pytorch v1.6")
if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_8:
pytest.xfail(reason="torch.flip not support before pytorch v1.8 for cpu + half precision")
if dtype == torch.half and not ((0 < preds) & (preds < 1)).all():
pytest.xfail(reason="half support for torch.softmax on cpu not implemented")
self.run_precision_test_cpu(
Expand Down Expand Up @@ -347,8 +347,8 @@ def test_multiclass_recall_at_fixed_precision_differentiability(self, input):
@pytest.mark.parametrize("dtype", [torch.half, torch.double])
def test_multilabel_recall_at_fixed_precision_dtype_cpu(self, input, dtype):
preds, target = input
if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6:
pytest.xfail(reason="half support of core ops not support before pytorch v1.6")
if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_8:
pytest.xfail(reason="torch.flip not support before pytorch v1.8 for cpu + half precision")
if dtype == torch.half and not ((0 < preds) & (preds < 1)).all():
pytest.xfail(reason="half support for torch.softmax on cpu not implemented")
self.run_precision_test_cpu(
Expand Down

0 comments on commit ab1d4ed

Please sign in to comment.