diff --git a/tests/test_tools.py b/tests/test_tools.py index 3fe452a7..bbbe9103 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -528,31 +528,40 @@ def test_validate_ctf_pass(self): model, image_shape=(1, 1, *model.image_shape), device=DEVICE ) + # Metric validation tests def test_validate_metric_inputs(self): - metric = lambda x: x + def identity_metric(x): + return x + with pytest.raises(TypeError, match="metric should be callable and accept two"): - po.tools.validate.validate_metric(metric, device=DEVICE) + po.tools.validate.validate_metric(identity_metric, device=DEVICE) def test_validate_metric_output_shape(self): - metric = lambda x, y: x - y + def difference_metric(x, y): + return x - y + with pytest.raises( ValueError, match="metric should return a scalar value but output" ): - po.tools.validate.validate_metric(metric, device=DEVICE) + po.tools.validate.validate_metric(difference_metric, device=DEVICE) def test_validate_metric_identical(self): - metric = lambda x, y: (x + y).mean() + def mean_metric(x, y): + return (x + y).mean() + with pytest.raises( ValueError, match="metric should return <= 5e-7 on two identical" ): - po.tools.validate.validate_metric(metric, device=DEVICE) + po.tools.validate.validate_metric(mean_metric, device=DEVICE) def test_validate_metric_nonnegative(self): - metric = lambda x, y: (x - y).sum() + def sum_metric(x, y): + return (x - y).sum() + with pytest.raises( ValueError, match="metric should always return non-negative" ): - po.tools.validate.validate_metric(metric, device=DEVICE) + po.tools.validate.validate_metric(sum_metric, device=DEVICE) @pytest.mark.parametrize("model", ["frontend.OnOff.nograd"], indirect=True) def test_remove_grad(self, model):