From 96476af875ce50af43902cbb4aa8f7dbd4b8ff92 Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Mon, 23 Sep 2024 08:21:57 -0700 Subject: [PATCH] Robustify prune_inferior_points tests against sorting order (#2548) Summary: Our nightly CI started failing, likely due to a sorting order change introduced in https://github.com/pytorch/pytorch/pull/127936 This change robustifies the tests against the point order (and also fixes a torch deprecation warning) Pull Request resolved: https://github.com/pytorch/botorch/pull/2548 Reviewed By: sdaulton Differential Revision: D63260870 Pulled By: Balandat --- botorch/acquisition/multi_objective/utils.py | 2 +- botorch/acquisition/utils.py | 7 ++++--- .../acquisition/multi_objective/test_utils.py | 13 +++++++------ test/acquisition/test_utils.py | 19 +++++++++---------- 4 files changed, 21 insertions(+), 20 deletions(-) diff --git a/botorch/acquisition/multi_objective/utils.py b/botorch/acquisition/multi_objective/utils.py index 369c0e6a5c..30448b587b 100644 --- a/botorch/acquisition/multi_objective/utils.py +++ b/botorch/acquisition/multi_objective/utils.py @@ -154,7 +154,7 @@ def prune_inferior_points_multi_objective( probs = pareto_mask.to(dtype=X.dtype).mean(dim=0) idcs = probs.nonzero().view(-1) if idcs.shape[0] > max_points: - counts, order_idcs = torch.sort(probs, descending=True) + counts, order_idcs = torch.sort(probs, stable=True, descending=True) idcs = order_idcs[:max_points] effective_n_w = obj_vals.shape[-2] // X.shape[-2] idcs = (idcs / effective_n_w).long().unique() diff --git a/botorch/acquisition/utils.py b/botorch/acquisition/utils.py index 198228409a..ae4f054321 100644 --- a/botorch/acquisition/utils.py +++ b/botorch/acquisition/utils.py @@ -335,15 +335,16 @@ def prune_inferior_points( marginalize_dim=marginalize_dim, ) if infeas.any(): - # set infeasible points to worse than worst objective - # across all samples + # set infeasible points to worse than worst objective across all samples + # Use clone() here to avoid deprecated `index_put_` on an expanded tensor + obj_vals = obj_vals.clone() obj_vals[infeas] = obj_vals.min() - 1 is_best = torch.argmax(obj_vals, dim=-1) idcs, counts = torch.unique(is_best, return_counts=True) if len(idcs) > max_points: - counts, order_idcs = torch.sort(counts, descending=True) + counts, order_idcs = torch.sort(counts, stable=True, descending=True) idcs = order_idcs[:max_points] return X[idcs] diff --git a/test/acquisition/multi_objective/test_utils.py b/test/acquisition/multi_objective/test_utils.py index acdfddbc95..786c72ad9c 100644 --- a/test/acquisition/multi_objective/test_utils.py +++ b/test/acquisition/multi_objective/test_utils.py @@ -130,13 +130,14 @@ def test_prune_inferior_points_multi_objective(self): X_pruned = prune_inferior_points_multi_objective( model=mm, X=X, ref_point=ref_point, max_frac=2 / 3 ) - if self.device.type == "cuda": - # sorting has different order on cuda - self.assertTrue( - torch.equal(X_pruned, X[[2, 1]]) or torch.equal(X_pruned, X[[1, 2]]) + # sorting has different order on cuda + X_expected = X[1:3] if self.device.type == "cuda" else X[:2] + self.assertTrue( + torch.equal( + torch.sort(X_pruned, stable=True).values, + torch.sort(X_expected, stable=True).values, ) - else: - self.assertTrue(torch.equal(X_pruned, X[:2])) + ) # test that zero-probability is in fact pruned samples[2, 0, 0] = 10 with mock.patch.object(MockPosterior, "rsample", return_value=samples): diff --git a/test/acquisition/test_utils.py b/test/acquisition/test_utils.py index c8a6484cca..c9552da886 100644 --- a/test/acquisition/test_utils.py +++ b/test/acquisition/test_utils.py @@ -270,11 +270,14 @@ def test_prune_inferior_points(self): with mock.patch.object(MockPosterior, "rsample", return_value=samples): mm = MockModel(MockPosterior(samples=samples)) X_pruned = prune_inferior_points(model=mm, X=X, max_frac=2 / 3) - if self.device.type == "cuda": - # sorting has different order on cuda - self.assertTrue(torch.equal(X_pruned, torch.stack([X[2], X[1]], dim=0))) - else: - self.assertTrue(torch.equal(X_pruned, X[:2])) + # sorting has different order on cuda + X_expected = X[1:3] if self.device.type == "cuda" else X[:2] + self.assertTrue( + torch.equal( + torch.sort(X_pruned, stable=True).values, + torch.sort(X_expected, stable=True).values, + ) + ) # test that zero-probability is in fact pruned samples[2, 0, 0] = 10 with mock.patch.object(MockPosterior, "rsample", return_value=samples): @@ -289,11 +292,7 @@ def test_prune_inferior_points(self): device=self.device, dtype=dtype, ) - mm = MockModel( - MockPosterior( - samples=samples, - ) - ) + mm = MockModel(MockPosterior(samples=samples)) X_pruned = prune_inferior_points( model=mm, X=X,