Skip to content

Commit

Permalink
Merge pull request #251 from LabForComputationalVision/test_stop_crit…
Browse files Browse the repository at this point in the history
…_fix

fix for test_stop_criterion
  • Loading branch information
billbrod authored Feb 21, 2024
2 parents b2d4ffe + 028ad9e commit fe3dd68
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 10 deletions.
3 changes: 2 additions & 1 deletion tests/test_geodesic.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,4 +303,5 @@ def test_stop_criterion(self, einstein_small_seq, model):
moog = po.synth.Geodesic(einstein_small_seq[:1], einstein_small_seq[-1:],
model, 5)
moog.synthesize(max_iter=10, stop_criterion=.06, stop_iters_to_check=1)
assert len(moog.pixel_change_norm) == 6, "Didn't stop when hit criterion! (or optimization changed)"
assert (abs(moog.pixel_change_norm[-1:]) < .06).all(), "Didn't stop when hit criterion!"
assert (abs(moog.pixel_change_norm[:-1]) > .06).all(), "Stopped after hit criterion!"
3 changes: 2 additions & 1 deletion tests/test_mad.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,4 +207,5 @@ def test_stop_criterion(self, einstein_img):
po.tools.set_seed(0)
mad = po.synth.MADCompetition(einstein_img, po.metric.mse, dis_ssim, 'min')
mad.synthesize(max_iter=15, stop_criterion=1e-3, stop_iters_to_check=5)
assert len(mad.losses) == 12, "Didn't stop when hit criterion! (or optimization changed)"
assert abs(mad.losses[-5]-mad.losses[-1]) < 1e-3, "Didn't stop when hit criterion!"
assert abs(mad.losses[-6]-mad.losses[-2]) > 1e-3, "Stopped after hit criterion!"
11 changes: 3 additions & 8 deletions tests/test_metamers.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,11 +225,6 @@ def test_stop_criterion(self, einstein_img, model):
po.tools.set_seed(0)
met = po.synth.Metamer(einstein_img, model)
# takes different numbers of iter to converge on GPU and CPU
if DEVICE.type == 'cuda':
max_iter = 30
check_iter = 25
else:
max_iter = 10
check_iter = 8
met.synthesize(max_iter=max_iter, stop_criterion=1e-5, stop_iters_to_check=5)
assert len(met.losses) == check_iter, "Didn't stop when hit criterion! (or optimization changed)"
met.synthesize(max_iter=30, stop_criterion=1e-5, stop_iters_to_check=5)
assert abs(met.losses[-5]-met.losses[-1]) < 1e-5, "Didn't stop when hit criterion!"
assert abs(met.losses[-6]-met.losses[-2]) > 1e-5, "Stopped after hit criterion!"

0 comments on commit fe3dd68

Please sign in to comment.