Skip to content

Commit

Permalink
lambda functions assigned to variables are replaced by acutal functio…
Browse files Browse the repository at this point in the history
…n definitions in test validate metrics. Resolved E731 linting error
  • Loading branch information
hmd101 committed Oct 16, 2024
1 parent 5a0751f commit 38ac6c3
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,31 +528,40 @@ def test_validate_ctf_pass(self):
model, image_shape=(1, 1, *model.image_shape), device=DEVICE
)

# Metric validation tests
def test_validate_metric_inputs(self):
metric = lambda x: x
def identity_metric(x):
return x

with pytest.raises(TypeError, match="metric should be callable and accept two"):
po.tools.validate.validate_metric(metric, device=DEVICE)
po.tools.validate.validate_metric(identity_metric, device=DEVICE)

def test_validate_metric_output_shape(self):
metric = lambda x, y: x - y
def difference_metric(x, y):
return x - y

with pytest.raises(
ValueError, match="metric should return a scalar value but output"
):
po.tools.validate.validate_metric(metric, device=DEVICE)
po.tools.validate.validate_metric(difference_metric, device=DEVICE)

def test_validate_metric_identical(self):
metric = lambda x, y: (x + y).mean()
def mean_metric(x, y):
return (x + y).mean()

with pytest.raises(
ValueError, match="metric should return <= 5e-7 on two identical"
):
po.tools.validate.validate_metric(metric, device=DEVICE)
po.tools.validate.validate_metric(mean_metric, device=DEVICE)

def test_validate_metric_nonnegative(self):
metric = lambda x, y: (x - y).sum()
def sum_metric(x, y):
return (x - y).sum()

with pytest.raises(
ValueError, match="metric should always return non-negative"
):
po.tools.validate.validate_metric(metric, device=DEVICE)
po.tools.validate.validate_metric(sum_metric, device=DEVICE)

@pytest.mark.parametrize("model", ["frontend.OnOff.nograd"], indirect=True)
def test_remove_grad(self, model):
Expand Down

0 comments on commit 38ac6c3

Please sign in to comment.