From 8047aa185b4ba8b13112bdec63418a46781988ac Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Tue, 19 Nov 2024 16:22:36 -0500 Subject: [PATCH 01/20] wip: topk ic generation --- botorch/optim/initializers.py | 83 ++++++++++++++++++++++++++++++++--- botorch/utils/sampling.py | 8 ++-- 2 files changed, 79 insertions(+), 12 deletions(-) diff --git a/botorch/optim/initializers.py b/botorch/optim/initializers.py index af0f918f4a..ad074a4e85 100644 --- a/botorch/optim/initializers.py +++ b/botorch/optim/initializers.py @@ -328,14 +328,24 @@ def gen_batch_initial_conditions( init_kwargs = {} device = bounds.device bounds_cpu = bounds.cpu() - if "eta" in options: - init_kwargs["eta"] = options.get("eta") - if options.get("nonnegative") or is_nonnegative(acq_function): + + if options.get("topk"): + init_func = initialize_q_batch_topk + init_func_opts = ["sorted", "largest"] + elif options.get("nonnegative") or is_nonnegative(acq_function): init_func = initialize_q_batch_nonneg - if "alpha" in options: - init_kwargs["alpha"] = options.get("alpha") + init_func_opts = ["alpha", "eta"] else: init_func = initialize_q_batch + init_func_opts = ["eta"] + + for opt in init_func_opts: + # default value of "largest" to "acq_function.maximize" if it exists + if opt == "largest" and hasattr(acq_function, "maximize"): + init_kwargs[opt] = acq_function.maximize + + if opt in options: + init_kwargs[opt] = options.get(opt) q = 1 if q is None else q # the dimension the samples are drawn from @@ -363,7 +373,7 @@ def gen_batch_initial_conditions( X_rnd_nlzd = torch.rand( n, q, bounds_cpu.shape[-1], dtype=bounds.dtype ) - X_rnd = bounds_cpu[0] + (bounds_cpu[1] - bounds_cpu[0]) * X_rnd_nlzd + X_rnd = unnormalize(X_rnd_nlzd, bounds_cpu) else: X_rnd = sample_q_batches_from_polytope( n=n, @@ -375,7 +385,8 @@ def gen_batch_initial_conditions( equality_constraints=equality_constraints, inequality_constraints=inequality_constraints, ) - # sample points around best + + # sample additional points around best if sample_around_best: X_best_rnd = sample_points_around_best( acq_function=acq_function, @@ -395,6 +406,8 @@ def gen_batch_initial_conditions( ) # Keep X on CPU for consistency & to limit GPU memory usage. X_rnd = fix_features(X_rnd, fixed_features=fixed_features).cpu() + + # Append the fixed fantasies to the randomly generated points if fixed_X_fantasies is not None: if (d_f := fixed_X_fantasies.shape[-1]) != (d_r := X_rnd.shape[-1]): raise BotorchTensorDimensionError( @@ -411,6 +424,9 @@ def gen_batch_initial_conditions( ], dim=-2, ) + + # Evaluate the acquisition function on `X_rnd` using `batch_limit` + # sized chunks. with torch.no_grad(): if batch_limit is None: batch_limit = X_rnd.shape[0] @@ -423,16 +439,22 @@ def gen_batch_initial_conditions( ], dim=0, ) + + # Downselect the initial conditions based on the acquisition function values batch_initial_conditions, _ = init_func( X=X_rnd, acq_vals=acq_vals, n=num_restarts, **init_kwargs ) batch_initial_conditions = batch_initial_conditions.to(device=device) + + # Return the initial conditions if no warnings were raised if not any(issubclass(w.category, BadInitialCandidatesWarning) for w in ws): return batch_initial_conditions + if factor < max_factor: factor += 1 if seed is not None: seed += 1 # make sure to sample different X_rnd + warnings.warn( "Unable to find non-zero acquisition function values - initial conditions " "are being selected randomly.", @@ -1057,6 +1079,53 @@ def initialize_q_batch_nonneg( return X[idcs], acq_vals[idcs] +def initialize_q_batch_topk( + X: Tensor, acq_vals: Tensor, n: int, largest: bool = True, sorted: bool = True +) -> tuple[Tensor, Tensor]: + r"""Take the top `n` initial conditions for candidate generation. + + Args: + X: A `b x q x d` tensor of `b` samples of `q`-batches from a `d`-dim. + feature space. Typically, these are generated using qMC. + acq_vals: A tensor of `b` outcomes associated with the samples. Typically, this + is the value of the batch acquisition function to be maximized. + n: The number of initial condition to be generated. Must be less than `b`. + + Returns: + - An `n x q x d` tensor of `n` `q`-batch initial conditions. + - An `n` tensor of the corresponding acquisition values. + + Example: + >>> # To get `n=10` starting points of q-batch size `q=3` + >>> # for model with `d=6`: + >>> qUCB = qUpperConfidenceBound(model, beta=0.1) + >>> X_rnd = torch.rand(500, 3, 6) + >>> X_init, acq_init = initialize_q_batch_topk(X=X_rnd, acq_vals=qUCB(X_rnd), n=10) + """ + n_samples = X.shape[0] + if n > n_samples: + raise RuntimeError( + f"n ({n}) cannot be larger than the number of " + f"provided samples ({n_samples})" + ) + elif n == n_samples: + return X, acq_vals + + Ystd = acq_vals.std(dim=0) + if torch.any(Ystd == 0): + warnings.warn( + "All acquisition values for raw samples points are the same for " + "at least one batch. Choosing initial conditions at random.", + BadInitialCandidatesWarning, + stacklevel=3, + ) + idcs = torch.randperm(n=n_samples, device=X.device)[:n] + return X[idcs], acq_vals[idcs] + + idcs = acq_vals.topk(n, largest=largest, sorted=sorted).indices + return X[idcs], acq_vals[idcs] + + def sample_points_around_best( acq_function: AcquisitionFunction, n_discrete_points: int, diff --git a/botorch/utils/sampling.py b/botorch/utils/sampling.py index 52fe54fbb2..a508320299 100644 --- a/botorch/utils/sampling.py +++ b/botorch/utils/sampling.py @@ -98,14 +98,12 @@ def draw_sobol_samples( batch_shape = batch_shape or torch.Size() batch_size = int(torch.prod(torch.tensor(batch_shape))) d = bounds.shape[-1] - lower = bounds[0] - rng = bounds[1] - bounds[0] sobol_engine = SobolEngine(q * d, scramble=True, seed=seed) - samples_raw = sobol_engine.draw(batch_size * n, dtype=lower.dtype) - samples_raw = samples_raw.view(*batch_shape, n, q, d).to(device=lower.device) + samples_raw = sobol_engine.draw(batch_size * n, dtype=bounds.dtype) + samples_raw = samples_raw.view(*batch_shape, n, q, d).to(device=bounds.device) if batch_shape != torch.Size(): samples_raw = samples_raw.permute(-3, *range(len(batch_shape)), -2, -1) - return lower + rng * samples_raw + return unnormalize(samples_raw, bounds) def draw_sobol_normal_samples( From f5a8d64f279dc9e7e752c93ea2589cc93cf966f7 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 20 Nov 2024 11:41:29 -0500 Subject: [PATCH 02/20] tests: add tests --- botorch/optim/__init__.py | 7 ++++++- test/optim/test_initializers.py | 36 ++++++++++++++++++++++++++++++++- 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/botorch/optim/__init__.py b/botorch/optim/__init__.py index f4abe3fd87..6bb32b6658 100644 --- a/botorch/optim/__init__.py +++ b/botorch/optim/__init__.py @@ -22,7 +22,11 @@ LinearHomotopySchedule, LogLinearHomotopySchedule, ) -from botorch.optim.initializers import initialize_q_batch, initialize_q_batch_nonneg +from botorch.optim.initializers import ( + initialize_q_batch, + initialize_q_batch_nonneg, + initialize_q_batch_topk, +) from botorch.optim.optimize import ( gen_batch_initial_conditions, optimize_acqf, @@ -43,6 +47,7 @@ "gen_batch_initial_conditions", "initialize_q_batch", "initialize_q_batch_nonneg", + "initialize_q_batch_topk", "OptimizationResult", "OptimizationStatus", "optimize_acqf", diff --git a/test/optim/test_initializers.py b/test/optim/test_initializers.py index 09be6f2326..7cf7621eca 100644 --- a/test/optim/test_initializers.py +++ b/test/optim/test_initializers.py @@ -30,8 +30,10 @@ from botorch.exceptions.warnings import BotorchWarning from botorch.models import SingleTaskGP from botorch.models.model_list_gp_regression import ModelListGP -from botorch.optim import initialize_q_batch, initialize_q_batch_nonneg from botorch.optim.initializers import ( + initialize_q_batch, + initialize_q_batch_nonneg, + initialize_q_batch_topk, gen_batch_initial_conditions, gen_one_shot_hvkg_initial_conditions, gen_one_shot_kg_initial_conditions, @@ -155,6 +157,38 @@ def test_initialize_q_batch(self): with self.assertRaises(RuntimeError): initialize_q_batch(X=X, acq_vals=acq_vals, n=10) + def test_initialize_q_batch_topk(self): + for dtype in (torch.float, torch.double): + # basic test + X = torch.rand(5, 3, 4, device=self.device, dtype=dtype) + acq_vals = torch.rand(5, device=self.device, dtype=dtype) + ics_X, ics_acq_vals = initialize_q_batch_topk(X=X, acq_vals=acq_vals, n=2) + self.assertEqual(ics_X.shape, torch.Size([2, 3, 4])) + self.assertEqual(ics_X.device, X.device) + self.assertEqual(ics_X.dtype, X.dtype) + self.assertEqual(ics_acq_vals.shape, torch.Size([2])) + self.assertEqual(ics_acq_vals.device, acq_vals.device) + self.assertEqual(ics_acq_vals.dtype, acq_vals.dtype) + # ensure nothing happens if we want all samples + ics_X, ics_acq_vals = initialize_q_batch_topk(X=X, acq_vals=acq_vals, n=5) + self.assertTrue(torch.equal(X, ics_X)) + self.assertTrue(torch.equal(acq_vals, ics_acq_vals)) + # make sure things work with constant inputs + acq_vals = torch.ones(5, device=self.device, dtype=dtype) + ics, _ = initialize_q_batch_topk(X=X, acq_vals=acq_vals, n=2) + self.assertEqual(ics.shape, torch.Size([2, 3, 4])) + self.assertEqual(ics.device, X.device) + self.assertEqual(ics.dtype, X.dtype) + # ensure raises correct warning + acq_vals = torch.zeros(5, device=self.device, dtype=dtype) + with warnings.catch_warnings(record=True) as w: + ics, _ = initialize_q_batch_topk(X=X, acq_vals=acq_vals, n=2) + self.assertEqual(len(w), 1) + self.assertTrue(issubclass(w[-1].category, BadInitialCandidatesWarning)) + self.assertEqual(ics.shape, torch.Size([2, 3, 4])) + with self.assertRaises(RuntimeError): + initialize_q_batch_topk(X=X, acq_vals=acq_vals, n=10) + def test_initialize_q_batch_largeZ(self): for dtype in (torch.float, torch.double): # testing large eta*Z From a022462d84a8f63a977e539498c75467887cc98c Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sat, 23 Nov 2024 19:13:15 -0500 Subject: [PATCH 03/20] fix: micro-optimization suggestion from review --- botorch/optim/initializers.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/botorch/optim/initializers.py b/botorch/optim/initializers.py index ad074a4e85..520818fdcf 100644 --- a/botorch/optim/initializers.py +++ b/botorch/optim/initializers.py @@ -1100,7 +1100,10 @@ def initialize_q_batch_topk( >>> # for model with `d=6`: >>> qUCB = qUpperConfidenceBound(model, beta=0.1) >>> X_rnd = torch.rand(500, 3, 6) - >>> X_init, acq_init = initialize_q_batch_topk(X=X_rnd, acq_vals=qUCB(X_rnd), n=10) + >>> X_init, acq_init = initialize_q_batch_topk( + ... X=X_rnd, acq_vals=qUCB(X_rnd), n=10 + ... ) + """ n_samples = X.shape[0] if n > n_samples: @@ -1122,8 +1125,8 @@ def initialize_q_batch_topk( idcs = torch.randperm(n=n_samples, device=X.device)[:n] return X[idcs], acq_vals[idcs] - idcs = acq_vals.topk(n, largest=largest, sorted=sorted).indices - return X[idcs], acq_vals[idcs] + topk_out, topk_idcs = acq_vals.topk(n, largest=largest, sorted=sorted) + return X[topk_idcs], topk_out def sample_points_around_best( From e75239d939606f710fc3671a58385615601fa538 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Mon, 25 Nov 2024 10:33:48 -0500 Subject: [PATCH 04/20] fix: don't use unnormalize due to unexpected behaviour with constant bounds --- botorch/optim/initializers.py | 2 +- botorch/utils/sampling.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/botorch/optim/initializers.py b/botorch/optim/initializers.py index 520818fdcf..0908c78f39 100644 --- a/botorch/optim/initializers.py +++ b/botorch/optim/initializers.py @@ -373,7 +373,7 @@ def gen_batch_initial_conditions( X_rnd_nlzd = torch.rand( n, q, bounds_cpu.shape[-1], dtype=bounds.dtype ) - X_rnd = unnormalize(X_rnd_nlzd, bounds_cpu) + X_rnd = X_rnd_nlzd * (bounds_cpu[1] - bounds_cpu[0]) + bounds_cpu[0] else: X_rnd = sample_q_batches_from_polytope( n=n, diff --git a/botorch/utils/sampling.py b/botorch/utils/sampling.py index a508320299..9ca48a5668 100644 --- a/botorch/utils/sampling.py +++ b/botorch/utils/sampling.py @@ -103,7 +103,7 @@ def draw_sobol_samples( samples_raw = samples_raw.view(*batch_shape, n, q, d).to(device=bounds.device) if batch_shape != torch.Size(): samples_raw = samples_raw.permute(-3, *range(len(batch_shape)), -2, -1) - return unnormalize(samples_raw, bounds) + return bounds[0] + (bounds[1] - bounds[0]) * samples_raw def draw_sobol_normal_samples( From 8e274227e51adec300fe1bbd9da7ecee5d47ccf9 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Mon, 25 Nov 2024 10:39:19 -0500 Subject: [PATCH 05/20] doc: initialize_q_batch_topk -> initialize_q_batch_topn --- botorch/optim/__init__.py | 4 ++-- botorch/optim/initializers.py | 8 ++++---- test/optim/test_initializers.py | 18 +++++++++--------- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/botorch/optim/__init__.py b/botorch/optim/__init__.py index 6bb32b6658..5156bba684 100644 --- a/botorch/optim/__init__.py +++ b/botorch/optim/__init__.py @@ -25,7 +25,7 @@ from botorch.optim.initializers import ( initialize_q_batch, initialize_q_batch_nonneg, - initialize_q_batch_topk, + initialize_q_batch_topn, ) from botorch.optim.optimize import ( gen_batch_initial_conditions, @@ -47,7 +47,7 @@ "gen_batch_initial_conditions", "initialize_q_batch", "initialize_q_batch_nonneg", - "initialize_q_batch_topk", + "initialize_q_batch_topn", "OptimizationResult", "OptimizationStatus", "optimize_acqf", diff --git a/botorch/optim/initializers.py b/botorch/optim/initializers.py index 0908c78f39..0d91d08bea 100644 --- a/botorch/optim/initializers.py +++ b/botorch/optim/initializers.py @@ -329,8 +329,8 @@ def gen_batch_initial_conditions( device = bounds.device bounds_cpu = bounds.cpu() - if options.get("topk"): - init_func = initialize_q_batch_topk + if options.get("topn"): + init_func = initialize_q_batch_topn init_func_opts = ["sorted", "largest"] elif options.get("nonnegative") or is_nonnegative(acq_function): init_func = initialize_q_batch_nonneg @@ -1079,7 +1079,7 @@ def initialize_q_batch_nonneg( return X[idcs], acq_vals[idcs] -def initialize_q_batch_topk( +def initialize_q_batch_topn( X: Tensor, acq_vals: Tensor, n: int, largest: bool = True, sorted: bool = True ) -> tuple[Tensor, Tensor]: r"""Take the top `n` initial conditions for candidate generation. @@ -1100,7 +1100,7 @@ def initialize_q_batch_topk( >>> # for model with `d=6`: >>> qUCB = qUpperConfidenceBound(model, beta=0.1) >>> X_rnd = torch.rand(500, 3, 6) - >>> X_init, acq_init = initialize_q_batch_topk( + >>> X_init, acq_init = initialize_q_batch_topn( ... X=X_rnd, acq_vals=qUCB(X_rnd), n=10 ... ) diff --git a/test/optim/test_initializers.py b/test/optim/test_initializers.py index 7cf7621eca..155e333cf9 100644 --- a/test/optim/test_initializers.py +++ b/test/optim/test_initializers.py @@ -31,13 +31,13 @@ from botorch.models import SingleTaskGP from botorch.models.model_list_gp_regression import ModelListGP from botorch.optim.initializers import ( - initialize_q_batch, - initialize_q_batch_nonneg, - initialize_q_batch_topk, gen_batch_initial_conditions, gen_one_shot_hvkg_initial_conditions, gen_one_shot_kg_initial_conditions, gen_value_function_initial_conditions, + initialize_q_batch, + initialize_q_batch_nonneg, + initialize_q_batch_topn, sample_perturbed_subset_dims, sample_points_around_best, sample_q_batches_from_polytope, @@ -157,12 +157,12 @@ def test_initialize_q_batch(self): with self.assertRaises(RuntimeError): initialize_q_batch(X=X, acq_vals=acq_vals, n=10) - def test_initialize_q_batch_topk(self): + def test_initialize_q_batch_topn(self): for dtype in (torch.float, torch.double): # basic test X = torch.rand(5, 3, 4, device=self.device, dtype=dtype) acq_vals = torch.rand(5, device=self.device, dtype=dtype) - ics_X, ics_acq_vals = initialize_q_batch_topk(X=X, acq_vals=acq_vals, n=2) + ics_X, ics_acq_vals = initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=2) self.assertEqual(ics_X.shape, torch.Size([2, 3, 4])) self.assertEqual(ics_X.device, X.device) self.assertEqual(ics_X.dtype, X.dtype) @@ -170,24 +170,24 @@ def test_initialize_q_batch_topk(self): self.assertEqual(ics_acq_vals.device, acq_vals.device) self.assertEqual(ics_acq_vals.dtype, acq_vals.dtype) # ensure nothing happens if we want all samples - ics_X, ics_acq_vals = initialize_q_batch_topk(X=X, acq_vals=acq_vals, n=5) + ics_X, ics_acq_vals = initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=5) self.assertTrue(torch.equal(X, ics_X)) self.assertTrue(torch.equal(acq_vals, ics_acq_vals)) # make sure things work with constant inputs acq_vals = torch.ones(5, device=self.device, dtype=dtype) - ics, _ = initialize_q_batch_topk(X=X, acq_vals=acq_vals, n=2) + ics, _ = initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=2) self.assertEqual(ics.shape, torch.Size([2, 3, 4])) self.assertEqual(ics.device, X.device) self.assertEqual(ics.dtype, X.dtype) # ensure raises correct warning acq_vals = torch.zeros(5, device=self.device, dtype=dtype) with warnings.catch_warnings(record=True) as w: - ics, _ = initialize_q_batch_topk(X=X, acq_vals=acq_vals, n=2) + ics, _ = initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=2) self.assertEqual(len(w), 1) self.assertTrue(issubclass(w[-1].category, BadInitialCandidatesWarning)) self.assertEqual(ics.shape, torch.Size([2, 3, 4])) with self.assertRaises(RuntimeError): - initialize_q_batch_topk(X=X, acq_vals=acq_vals, n=10) + initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=10) def test_initialize_q_batch_largeZ(self): for dtype in (torch.float, torch.double): From 662caf132d7e24cd825e84b27dfb9d9ff8e1f9f8 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Tue, 26 Nov 2024 11:38:20 -0500 Subject: [PATCH 06/20] tests: achive full coverage --- test/optim/test_initializers.py | 86 +++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/test/optim/test_initializers.py b/test/optim/test_initializers.py index 155e333cf9..d8b571ad91 100644 --- a/test/optim/test_initializers.py +++ b/test/optim/test_initializers.py @@ -280,6 +280,86 @@ def test_gen_batch_initial_conditions(self): torch.all(batch_initial_conditions[..., idx] == val) ) + def test_gen_batch_initial_conditions_topn(self): + bounds = torch.stack([torch.zeros(2), torch.ones(2)]) + mock_acqf = MockAcquisitionFunction() + mock_acqf.objective = lambda y: y.squeeze(-1) + mock_acqf.maximize = True # Add maximize attribute + for dtype in (torch.float, torch.double): + bounds = bounds.to(device=self.device, dtype=dtype) + mock_acqf.X_baseline = bounds # for testing sample_around_best + mock_acqf.model = MockModel(MockPosterior(mean=bounds[:, :1])) + for ( + topn, + largest, + is_sorted, + seed, + init_batch_limit, + ffs, + sample_around_best, + ) in product( + [True, False], + [True, False, None], + [True, False], + [None, 1234], + [None, 1], + [None, {0: 0.5}], + [True, False], + ): + with mock.patch.object( + MockAcquisitionFunction, + "__call__", + wraps=mock_acqf.__call__, + ) as mock_acqf_call, warnings.catch_warnings(): + warnings.simplefilter( + "ignore", category=BadInitialCandidatesWarning + ) + options = { + "topn": topn, + "sorted": is_sorted, + "seed": seed, + "init_batch_limit": init_batch_limit, + "sample_around_best": sample_around_best, + } + if largest is not None: + options["largest"] = largest + batch_initial_conditions = gen_batch_initial_conditions( + acq_function=mock_acqf, + bounds=bounds, + q=1, + num_restarts=2, + raw_samples=10, + fixed_features=ffs, + options=options, + ) + expected_shape = torch.Size([2, 1, 2]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) + self.assertLess( + _get_max_violation_of_bounds(batch_initial_conditions, bounds), + 1e-6, + ) + batch_shape = ( + torch.Size([]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + raw_samps = mock_acqf_call.call_args[0][0] + batch_shape = ( + torch.Size([20 if sample_around_best else 10]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) + self.assertEqual(raw_samps.shape, expected_raw_samps_shape) + + if ffs is not None: + for idx, val in ffs.items(): + self.assertTrue( + torch.all(batch_initial_conditions[..., idx] == val) + ) + def test_gen_batch_initial_conditions_highdim(self): d = 2200 # 2200 * 10 (q) > 21201 (sobol max dim) bounds = torch.stack([torch.zeros(d), torch.ones(d)]) @@ -1471,3 +1551,9 @@ def test_sample_points_around_best(self): self.assertTrue( ((X_rnd.unsqueeze(0) == X_train.unsqueeze(1)).all(dim=-1)).sum() == 0 ) + + +if __name__ == "__main__": + import pytest + + pytest.main([__file__]) From 75eea37bd75776294fc276cfc27f97fdb1d0af7a Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Tue, 26 Nov 2024 11:45:21 -0500 Subject: [PATCH 07/20] clean: remote debug snippet --- test/optim/test_initializers.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/test/optim/test_initializers.py b/test/optim/test_initializers.py index d8b571ad91..65cde7e183 100644 --- a/test/optim/test_initializers.py +++ b/test/optim/test_initializers.py @@ -1551,9 +1551,3 @@ def test_sample_points_around_best(self): self.assertTrue( ((X_rnd.unsqueeze(0) == X_train.unsqueeze(1)).all(dim=-1)).sum() == 0 ) - - -if __name__ == "__main__": - import pytest - - pytest.main([__file__]) From ec4d7f80292a46b824b0aa49cdf623b6bd877264 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 27 Nov 2024 17:23:46 -0500 Subject: [PATCH 08/20] fea: add InfeasibleTranforms from vizier --- botorch/models/transforms/outcome.py | 120 ++++++++++++++++- test/models/transforms/test_outcome.py | 171 +++++++++++++++++++++++++ 2 files changed, 290 insertions(+), 1 deletion(-) diff --git a/botorch/models/transforms/outcome.py b/botorch/models/transforms/outcome.py index 6f93c668a4..d8ca29e6c9 100644 --- a/botorch/models/transforms/outcome.py +++ b/botorch/models/transforms/outcome.py @@ -276,20 +276,22 @@ def forward( "the `batch_shape` argument to `Standardize`, but got " f"Y.shape[:-2]={Y.shape[:-2]}." ) + if Y.size(-1) != self._m: raise RuntimeError( f"Wrong output dimension. Y.size(-1) is {Y.size(-1)}; expected " f"{self._m}." ) + if Y.shape[-2] < 1: raise ValueError(f"Can't standardize with no observations. {Y.shape=}.") - elif Y.shape[-2] == 1: stdvs = torch.ones( (*Y.shape[:-2], 1, Y.shape[-1]), dtype=Y.dtype, device=Y.device ) else: stdvs = Y.std(dim=-2, keepdim=True) + stdvs = stdvs.where(stdvs >= self._min_stdv, torch.full_like(stdvs, 1.0)) means = Y.mean(dim=-2, keepdim=True) if self._outputs is not None: @@ -823,3 +825,119 @@ def untransform_posterior(self, posterior: Posterior) -> TransformedPosterior: posterior=posterior, sample_transform=lambda x: x.sign() * x.abs().expm1(), ) + + +def _nanmax( + tensor: Tensor, dim: int | None = None, keepdim: bool = False +) -> Tensor | tuple[Tensor, Tensor]: + min_value = torch.finfo(tensor.dtype).min + if dim is None: + return tensor.nan_to_num(min_value).max() + return tensor.nan_to_num(min_value).max(dim=dim, keepdim=keepdim) + + +def _nanmin( + tensor: Tensor, dim: int | None = None, keepdim: bool = False +) -> Tensor | tuple[Tensor, Tensor]: + max_value = torch.finfo(tensor.dtype).max + if dim is None: + return tensor.nan_to_num(max_value).min() + return tensor.nan_to_num(max_value).min(dim=dim, keepdim=keepdim) + + +class InfeasibleTransform(OutcomeTransform): + """Transforms infeasible (NaN) values to feasible values.""" + + def __init__(self, batch_shape: torch.Size | None = None) -> None: + """Transforms infeasible (NaN) values to feasible values. + + Args: + batch_shape: The batch shape of the outcomes. + """ + super().__init__() + self._is_trained = False + self.register_buffer("_shift", None) + self.register_buffer("warped_bad_value", torch.tensor(float("nan"))) + + self._batch_shape = batch_shape + + def forward( + self, Y: Tensor, Yvar: Tensor | None = None + ) -> tuple[Tensor, Tensor | None]: + """Transform the outcomes by handling NaN values. + + Args: + Y: A `batch_shape x n x m`-dim tensor of training targets. + Yvar: A `batch_shape x n x m`-dim tensor of observation noises + associated with the training targets (if applicable). + + Returns: + A two-tuple with the transformed outcomes: + - The transformed outcome observations. + - The transformed observation noise (if applicable). + """ + if self.training: + if Y.shape[:-2] != self._batch_shape: + raise RuntimeError( + f"Expected Y.shape[:-2] to be {self._batch_shape}, matching " + "the `batch_shape` argument to `Standardize`, but got " + f"Y.shape[:-2]={Y.shape[:-2]}." + ) + + if Y.shape[-2] < 1: + raise ValueError(f"Can't standardize with no observations. {Y.shape=}.") + + if torch.isnan(Y).all(dim=-2).any(): + raise RuntimeError("For at least one batch, all outcomes are NaN") + + labels_range = _nanmax(Y, dim=-2).values - _nanmin(Y, dim=-2).values + warped_bad_value = _nanmin(Y, dim=-2).values - (0.5 * labels_range + 1) + num_feasible = Y.shape[-2] - torch.isnan(Y).sum(dim=-2) + + # Estimate the relative frequency of feasible points + p_feasible = (0.5 + num_feasible) / (1 + Y.numel()) + + self.warped_bad_value = warped_bad_value + self._shift = -torch.nanmean(Y, dim=-2) * p_feasible - warped_bad_value * ( + 1 - p_feasible + ) + + self._is_trained = torch.tensor(True) + + # Expand warped_bad_value to match Y's shape + expanded_bad_value = self.warped_bad_value.unsqueeze(-2).expand( + *Y.shape[:-2], Y.shape[-2], -1 + ) + expanded_shift = self._shift.unsqueeze(-2).expand( + *Y.shape[:-2], Y.shape[-2], -1 + ) + Y = torch.where(torch.isnan(Y), expanded_bad_value, Y) + Y = torch.where(~torch.isnan(Y), Y + expanded_shift, Y) + return Y, Yvar + + def untransform( + self, Y: Tensor, Yvar: Tensor | None = None + ) -> tuple[Tensor, Tensor | None]: + """Un-transform the outcomes. + + Args: + Y: A `batch_shape x n x m`-dim tensor of transformed targets. + Yvar: A `batch_shape x n x m`-dim tensor of transformed observation + noises associated with the targets (if applicable). + + Returns: + A two-tuple with the un-transformed outcomes: + - The un-transformed outcome observations. + - The un-transformed observation noise (if applicable). + """ + if not self._is_trained: + raise RuntimeError( + "forward() needs to be called before untransform() is called." + ) + + # Expand shift to match Y's shape + expanded_shift = self._shift.unsqueeze(-2).expand( + *Y.shape[:-2], Y.shape[-2], -1 + ) + Y -= expanded_shift + return Y, Yvar diff --git a/test/models/transforms/test_outcome.py b/test/models/transforms/test_outcome.py index 49fa23862f..4960dab4e5 100644 --- a/test/models/transforms/test_outcome.py +++ b/test/models/transforms/test_outcome.py @@ -9,8 +9,11 @@ import torch from botorch.models.transforms.outcome import ( + _nanmax, + _nanmin, Bilog, ChainedOutcomeTransform, + InfeasibleTransform, Log, OutcomeTransform, Power, @@ -51,6 +54,70 @@ def forward(self, Y, Yvar): pass +class TestNanMax(BotorchTestCase): + def test_nanmax_basic(self): + tensor = torch.tensor([1.0, float("nan"), 3.0, 2.0]) + result = _nanmax(tensor) + expected = torch.tensor(3.0) + self.assertEqual(result, expected) + + def test_nanmax_with_dim(self): + tensor = torch.tensor([[1.0, float("nan")], [3.0, 2.0]]) + result = _nanmax(tensor, dim=1) + expected = torch.tensor([1.0, 3.0]) + self.assertTrue(torch.equal(result.values, expected)) + + def test_nanmax_with_keepdim(self): + tensor = torch.tensor([[1.0, float("nan")], [3.0, 2.0]]) + result = _nanmax(tensor, dim=1, keepdim=True) + expected = torch.tensor([[1.0], [3.0]]) + self.assertTrue(torch.equal(result.values, expected)) + + def test_nanmax_all_nan(self): + tensor = torch.tensor([float("nan"), float("nan")]) + result = _nanmax(tensor) + expected = torch.tensor(torch.finfo(tensor.dtype).min) + self.assertEqual(result, expected) + + def test_nanmax_no_nan(self): + tensor = torch.tensor([1.0, 2.0, 3.0]) + result = _nanmax(tensor) + expected = torch.tensor(3.0) + self.assertEqual(result, expected) + + +class TestNanMin(BotorchTestCase): + def test_nanmin_basic(self): + tensor = torch.tensor([1.0, float("nan"), 3.0, 2.0]) + result = _nanmin(tensor) + expected = torch.tensor(1.0) + self.assertEqual(result, expected) + + def test_nanmin_with_dim(self): + tensor = torch.tensor([[1.0, float("nan")], [3.0, 2.0]]) + result = _nanmin(tensor, dim=1) + expected = torch.tensor([1.0, 2.0]) + self.assertTrue(torch.equal(result.values, expected)) + + def test_nanmin_with_keepdim(self): + tensor = torch.tensor([[1.0, float("nan")], [3.0, 2.0]]) + result = _nanmin(tensor, dim=1, keepdim=True) + expected = torch.tensor([[1.0], [2.0]]) + self.assertTrue(torch.equal(result.values, expected)) + + def test_nanmin_all_nan(self): + tensor = torch.tensor([float("nan"), float("nan")]) + result = _nanmin(tensor) + expected = torch.tensor(torch.finfo(tensor.dtype).max) + self.assertEqual(result, expected) + + def test_nanmin_no_nan(self): + tensor = torch.tensor([1.0, 2.0, 3.0]) + result = _nanmin(tensor) + expected = torch.tensor(1.0) + self.assertEqual(result, expected) + + class TestOutcomeTransforms(BotorchTestCase): def test_abstract_base_outcome_transform(self): with self.assertRaises(TypeError): @@ -817,3 +884,107 @@ def test_bilog(self, seed=0): Y_tf_subset, Yvar_tf_subset = tf_subset(Y[..., [0]], None) self.assertTrue(torch.equal(Y_tf_subset, Y_tf[..., [0]])) self.assertIsNone(Yvar_tf_subset) + + +class TestInfeasibleTransform(BotorchTestCase): + def test_infeasible_transform_init(self): + """Test initialization of InfeasibleTransform.""" + batch_shape = torch.Size([2, 3]) + transform = InfeasibleTransform(batch_shape=batch_shape) + assert transform._batch_shape == batch_shape + assert not transform._is_trained + assert transform._shift is None + assert torch.isnan(transform.warped_bad_value) + + def test_infeasible_transform_forward(self): + """Test forward transformation with NaN values.""" + batch_shape = torch.Size([2]) + transform = InfeasibleTransform(batch_shape=batch_shape) + + # Create test data with NaN values + Y = torch.randn(*batch_shape, 3, 2) + Y[..., 0, 0] = float("nan") + Y_orig = Y.clone() + + # Test forward pass in training mode + transform.train() + Y_tf, _ = transform.forward(Y, None) + + # Check that transform is now trained + assert transform._is_trained + assert transform._shift is not None + assert not torch.isnan(transform.warped_bad_value).all() + + # Check that NaN values are replaced with warped_bad_value + assert not torch.isnan(Y_tf).any() + + # Test forward pass in eval mode + transform.eval() + Y_tf_eval, _ = transform.forward(Y_orig, None) + + # Check that NaN values are replaced consistently + assert not torch.isnan(Y_tf_eval).any() + + def test_infeasible_transform_untransform(self): + """Test untransform functionality.""" + transform = InfeasibleTransform(batch_shape=torch.Size([])) + + # Should raise error if not trained + with self.assertRaises(RuntimeError): + transform.untransform(torch.tensor([1.0, 2.0]), None) + + # Train the transform first + batch_shape = torch.Size([2]) + transform = InfeasibleTransform(batch_shape=batch_shape) + Y = torch.randn(*batch_shape, 3, 2) + Y[..., 0, 0] = float("nan") + + transform.train() + Y_tf, _ = transform.forward(Y, None) + + # Test untransform + Y_untf, _ = transform.untransform(Y_tf, None) + + # Check that values are properly untransformed + assert torch.allclose(Y_untf[:, 1:], Y[:, 1:], rtol=1e-4) + + # test the unwarped_bad_value + assert torch.allclose(transform.warped_bad_value[:, 0], Y_untf[..., 0, 0]) + + def test_infeasible_transform_batch_shape_validation(self): + """Test batch shape validation.""" + transform = InfeasibleTransform(batch_shape=torch.Size([2])) + + # Wrong batch shape should raise error + with self.assertRaises(RuntimeError): + transform.forward(torch.randn(3, 4, 2), None) + + def test_infeasible_transform_empty_input(self): + """Test handling of empty input.""" + transform = InfeasibleTransform(batch_shape=torch.Size([])) + + # Empty input should raise error + with self.assertRaises(ValueError): + transform.forward(torch.tensor([]).reshape(0, 1), None) + + def test_infeasible_transform_all_nan(self): + """Test handling of all-NaN input.""" + transform = InfeasibleTransform(batch_shape=torch.Size([])) + + Y = torch.tensor([[float("nan"), float("nan")]]) + transform.train() + with self.assertRaises(RuntimeError): + transform.forward(Y, None) + + def test_infeasible_transform_no_nan(self): + """Test handling of input with no NaN values.""" + transform = InfeasibleTransform(batch_shape=torch.Size([])) + + Y = torch.tensor([[1.0, 2.0, 3.0]]) + transform.train() + Y_tf, _ = transform.forward(Y, None) + + # Check that transformation preserves finite values + assert not torch.isnan(Y_tf).any() + Y_untf, _ = transform.untransform(Y_tf, None) + assert torch.allclose(Y_untf, Y, rtol=1e-4) From 351c3f87c861ae870833f2a4ba9c83bc1adc4bc2 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 27 Nov 2024 18:54:46 -0500 Subject: [PATCH 09/20] fea: add the logwarp transform from vizier --- botorch/models/transforms/outcome.py | 112 +++++++++++++++++++++++++ test/models/transforms/test_outcome.py | 89 ++++++++++++++++++++ 2 files changed, 201 insertions(+) diff --git a/botorch/models/transforms/outcome.py b/botorch/models/transforms/outcome.py index d8ca29e6c9..32af5ec5b7 100644 --- a/botorch/models/transforms/outcome.py +++ b/botorch/models/transforms/outcome.py @@ -913,6 +913,7 @@ def forward( ) Y = torch.where(torch.isnan(Y), expanded_bad_value, Y) Y = torch.where(~torch.isnan(Y), Y + expanded_shift, Y) + # TODO: Handle Yvar return Y, Yvar def untransform( @@ -940,4 +941,115 @@ def untransform( *Y.shape[:-2], Y.shape[-2], -1 ) Y -= expanded_shift + # TODO: Handle Yvar return Y, Yvar + + +class LogWarperTransform(OutcomeTransform): + """Warps an array of labels to highlight the difference between good values. + + Note that this warping is performed on finite values of the array and NaNs are + untouched. + """ + + def __init__( + self, batch_shape: torch.Size | None = None, offset: float = 1.5 + ) -> None: + """Initialize transform. + + Args: + offset: Offset parameter for the log transformation. Must be > 0. + """ + super().__init__() + if offset <= 0: + raise ValueError("offset must be positive") + self._is_trained = False + self._batch_shape = batch_shape + self.register_buffer("offset", torch.tensor(offset)) + self.register_buffer("_labels_min", torch.tensor(float("nan"))) + self.register_buffer("_labels_max", torch.tensor(float("nan"))) + + def forward( + self, Y: Tensor, Yvar: Tensor | None = None + ) -> tuple[Tensor, Tensor | None]: + """Transform the outcomes. + + Args: + Y: A `batch_shape x n x m`-dim tensor of training targets. + Yvar: A `batch_shape x n x m`-dim tensor of observation noises + associated with the training targets (if applicable). + + Returns: + A two-tuple with the transformed outcomes: + - The transformed outcome observations. + - The transformed observation noise (if applicable). + """ + if self.training: + if Y.shape[:-2] != self._batch_shape: + raise RuntimeError( + f"Expected Y.shape[:-2] to be {self._batch_shape}, matching " + "the `batch_shape` argument to `Standardize`, but got " + f"Y.shape[:-2]={Y.shape[:-2]}." + ) + + if Y.shape[-2] < 1: + raise ValueError(f"Can't standardize with no observations. {Y.shape=}.") + + if torch.isnan(Y).all(dim=-2).any(): + raise RuntimeError("For at least one batch, all outcomes are NaN") + + self._labels_min = _nanmin(Y, dim=-2).values + self._labels_max = _nanmax(Y, dim=-2).values + + self._is_trained = torch.tensor(True) + + expanded_labels_min = self._labels_min.unsqueeze(-2).expand( + *Y.shape[:-2], Y.shape[-2], -1 + ) + expanded_labels_max = self._labels_max.unsqueeze(-2).expand( + *Y.shape[:-2], Y.shape[-2], -1 + ) + + # Calculate normalized difference + norm_diff = (expanded_labels_max - Y) / ( + expanded_labels_max - expanded_labels_min + ) + Y_transformed = 0.5 - ( + torch.log1p(norm_diff * (self.offset - 1)) / torch.log(self.offset) + ) + + # TODO: Handle Yvar + return Y_transformed, Yvar + + def untransform( + self, Y: Tensor, Yvar: Tensor | None = None + ) -> tuple[Tensor, Tensor | None]: + """Un-transform the outcomes. + + Args: + Y: A `batch_shape x n x m`-dim tensor of transformed targets. + Yvar: A `batch_shape x n x m`-dim tensor of transformed observation + noises associated with the targets (if applicable). + + Returns: + A two-tuple with the un-transformed outcomes: + - The un-transformed outcome observations. + - The un-transformed observation noise (if applicable). + """ + if not self._is_trained: + raise RuntimeError("forward() needs to be called before untransform()") + + expanded_labels_min = self._labels_min.unsqueeze(-2).expand( + *Y.shape[:-2], Y.shape[-2], -1 + ) + expanded_labels_max = self._labels_max.unsqueeze(-2).expand( + *Y.shape[:-2], Y.shape[-2], -1 + ) + + Y_untransformed = expanded_labels_max - ( + (torch.exp(torch.log(self.offset) * (0.5 - Y)) - 1) + * (expanded_labels_max - expanded_labels_min) + / (self.offset - 1) + ) + + return Y_untransformed, Yvar diff --git a/test/models/transforms/test_outcome.py b/test/models/transforms/test_outcome.py index 4960dab4e5..bcc8906ca9 100644 --- a/test/models/transforms/test_outcome.py +++ b/test/models/transforms/test_outcome.py @@ -15,6 +15,7 @@ ChainedOutcomeTransform, InfeasibleTransform, Log, + LogWarperTransform, OutcomeTransform, Power, Standardize, @@ -988,3 +989,91 @@ def test_infeasible_transform_no_nan(self): assert not torch.isnan(Y_tf).any() Y_untf, _ = transform.untransform(Y_tf, None) assert torch.allclose(Y_untf, Y, rtol=1e-4) + + +class TestLogWarperTransform(BotorchTestCase): + def test_log_warper_transform_init(self): + """Test initialization of LogWarperTransform.""" + batch_shape = torch.Size([2, 3]) + transform = LogWarperTransform(offset=2.0, batch_shape=batch_shape) + self.assertEqual(transform._batch_shape, batch_shape) + self.assertEqual(transform.offset.item(), 2.0) + + # Test invalid offset + with self.assertRaisesRegex(ValueError, "offset must be positive"): + LogWarperTransform(offset=0.0) + with self.assertRaisesRegex(ValueError, "offset must be positive"): + LogWarperTransform(offset=-1.0) + + def test_log_warper_transform_forward(self): + """Test forward transformation.""" + batch_shape = torch.Size([2]) + transform = LogWarperTransform(offset=2.0, batch_shape=batch_shape) + + # Create test data with NaN values + Y = torch.randn(*batch_shape, 3, 2) + Y[..., 0, 0] = float("nan") + Y_orig = Y.clone() + + # Test forward pass in training mode + transform.train() + Y_tf, _ = transform.forward(Y, None) + + # Check that transform is now trained + labels_min = transform._labels_min.clone() + labels_max = transform._labels_max.clone() + + assert transform._is_trained + assert torch.isfinite(labels_min).all() + assert torch.isfinite(labels_max).all() + assert (torch.isnan(Y_tf) == torch.isnan(Y_orig)).all() + + # Test forward pass in eval mode + transform.eval() + Y_tf_eval, _ = transform.forward(Y_tf, None) + + # Check that NaN values are replaced consistently + assert (torch.isnan(Y_tf_eval) == torch.isnan(Y_tf)).all() + assert torch.allclose(labels_min, transform._labels_min) + assert torch.allclose(labels_max, transform._labels_max) + + def test_log_warper_transform_untransform(self): + """Test untransform functionality.""" + batch_shape = torch.Size([2]) + transform = LogWarperTransform(offset=2.0, batch_shape=batch_shape) + + # Should raise error if not trained + with self.assertRaises(RuntimeError): + transform.untransform(torch.tensor([1.0, 2.0]), None) + + # Train the transform first + Y = torch.randn(*batch_shape, 3, 2) + Y[..., 0, 0] = float("nan") + + transform.train() + Y_tf, _ = transform.forward(Y, None) + + # Test untransform + Y_untf, _ = transform.untransform(Y_tf, None) + + # Check that values are properly untransformed + assert torch.allclose(Y_untf[:, 1:], Y[:, 1:], rtol=1e-4) + + # test the nan values don't change + assert torch.isnan(Y_untf[..., 0, 0]).all() + + def test_log_warper_transform_batch_shape_validation(self): + """Test batch shape validation.""" + transform = LogWarperTransform(offset=2.0, batch_shape=torch.Size([2])) + + # Wrong batch shape should raise error + with self.assertRaises(RuntimeError): + transform.forward(torch.randn(3, 4, 2), None) + + def test_log_warper_transform_empty_input(self): + """Test handling of empty input.""" + transform = LogWarperTransform(offset=2.0, batch_shape=torch.Size([])) + + # Empty input should raise error + with self.assertRaises(ValueError): + transform.forward(torch.tensor([]).reshape(0, 1), None) From a656685f1f6eb26fff738594ca0985096190cb15 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Fri, 29 Nov 2024 16:37:08 -0500 Subject: [PATCH 10/20] wip half rank --- botorch/models/transforms/outcome.py | 145 +++++++++++++++++++++++---- 1 file changed, 124 insertions(+), 21 deletions(-) diff --git a/botorch/models/transforms/outcome.py b/botorch/models/transforms/outcome.py index 32af5ec5b7..f6a838fd9f 100644 --- a/botorch/models/transforms/outcome.py +++ b/botorch/models/transforms/outcome.py @@ -24,6 +24,7 @@ from abc import ABC, abstractmethod from collections import OrderedDict +from itertools import product import torch from botorch.models.transforms.utils import ( @@ -845,6 +846,19 @@ def _nanmin( return tensor.nan_to_num(max_value).min(dim=dim, keepdim=keepdim) +def _check_batched_output(Y: Tensor, batch_shape: Tensor) -> None: + """Utility for common output transform checks.""" + if Y.shape[:-2] != batch_shape: + raise RuntimeError( + f"Expected Y.shape[:-2] to be {batch_shape}, matching " + "the `batch_shape` argument to `Standardize`, but got " + f"Y.shape[:-2]={Y.shape[:-2]}." + ) + + if Y.shape[-2] < 1: + raise ValueError(f"Can't transform with no observations. {Y.shape=}.") + + class InfeasibleTransform(OutcomeTransform): """Transforms infeasible (NaN) values to feasible values.""" @@ -876,17 +890,9 @@ def forward( - The transformed outcome observations. - The transformed observation noise (if applicable). """ - if self.training: - if Y.shape[:-2] != self._batch_shape: - raise RuntimeError( - f"Expected Y.shape[:-2] to be {self._batch_shape}, matching " - "the `batch_shape` argument to `Standardize`, but got " - f"Y.shape[:-2]={Y.shape[:-2]}." - ) - - if Y.shape[-2] < 1: - raise ValueError(f"Can't standardize with no observations. {Y.shape=}.") + _check_batched_output(Y, self._batch_shape) + if self.training: if torch.isnan(Y).all(dim=-2).any(): raise RuntimeError("For at least one batch, all outcomes are NaN") @@ -984,23 +990,14 @@ def forward( - The transformed outcome observations. - The transformed observation noise (if applicable). """ - if self.training: - if Y.shape[:-2] != self._batch_shape: - raise RuntimeError( - f"Expected Y.shape[:-2] to be {self._batch_shape}, matching " - "the `batch_shape` argument to `Standardize`, but got " - f"Y.shape[:-2]={Y.shape[:-2]}." - ) - - if Y.shape[-2] < 1: - raise ValueError(f"Can't standardize with no observations. {Y.shape=}.") + _check_batched_output(Y, self._batch_shape) + if self.training: if torch.isnan(Y).all(dim=-2).any(): raise RuntimeError("For at least one batch, all outcomes are NaN") self._labels_min = _nanmin(Y, dim=-2).values self._labels_max = _nanmax(Y, dim=-2).values - self._is_trained = torch.tensor(True) expanded_labels_min = self._labels_min.unsqueeze(-2).expand( @@ -1053,3 +1050,109 @@ def untransform( ) return Y_untransformed, Yvar + + +class HalfRankTransform(OutcomeTransform): + """Warps half of the outcomes to fit into a Gaussian distribution. + + This transform warps values below the median to follow a Gaussian distribution while + leaving values above the median unchanged. NaN values are preserved. + """ + + def __init__(self, batch_shape: torch.Size | None = None) -> None: + """Initialize transform. + + Args: + outputs: Which of the outputs to transform. If omitted, all outputs + will be transformed. + """ + super().__init__() + self._batch_shape = batch_shape + self._is_trained = False + self.register_buffer("_original_labels", torch.tensor([])) + self.register_buffer("_warped_labels", torch.tensor([])) + self.register_buffer("_original_label_median", torch.tensor(float("nan"))) + + def _get_std_above_median(self, unique_y: Tensor, y_median: Tensor) -> Tensor: + # Estimate std of good half + good_half = unique_y[unique_y >= y_median] + std = torch.sqrt(((good_half - y_median) ** 2).mean()) + + if std == 0: + std = torch.sqrt(((unique_y - y_median) ** 2).mean()) + + if torch.isnan(std): + std = torch.abs(unique_y - y_median).mean() + + return std + + def forward( + self, Y: Tensor, Yvar: Tensor | None = None + ) -> tuple[Tensor, Tensor | None]: + """Transform the outcomes. + + Args: + Y: A `batch_shape x n x m`-dim tensor of training targets. + Yvar: A `batch_shape x n x m`-dim tensor of observation noises + associated with the training targets (if applicable). + + Returns: + A two-tuple with the transformed outcomes: + - The transformed outcome observations. + - The transformed observation noise (if applicable). + """ + _check_batched_output(Y, self._batch_shape) + + if self.training: + if torch.isnan(Y).all(dim=-2).any(): + raise RuntimeError("For at least one batch, all outcomes are NaN") + + Y_transformed = Y.clone() + + # Compute median for each batch + Y_medians = torch.nanmedian(Y, dim=-2) + + for dim in range(Y.shape[-1]): + for batch_idx in product((range(n) for n in self._batch_shape)): + y_median = Y_medians[dim] + y = Y[*batch_idx, :, dim] + + # Get finite values and their ranks for each batch + is_finite_mask = ~torch.isnan(y) + ranks = torch.zeros_like(y) + + unique_y, unique_idx = torch.unique( + y[is_finite_mask], return_index=True + ) + + for i, val in enumerate(unique_y): + ranks[y == val] = i + 1 + + ranks = torch.where(is_finite_mask, ranks, len(unique_y) + 1) + + # Transform values below median + below_median_mask = y < y_median + + # Calculate rank quantiles + dedup_median_index = torch.searchsorted(unique_y, y_median) + denominator = dedup_median_index + 0.5 * ( + unique_y[dedup_median_index] == y_median + ) + rank_quantile = 0.5 * (ranks[below_median_mask] - 0.5) / denominator + + y_above_median_std = self._get_std_above_median(unique_y, y_median) + + # Apply transformation + rank_ppf = ( + torch.erfinv(2 * rank_quantile - 1) + * y_above_median_std + * torch.sqrt(torch.tensor(2.0)) + ) + Y_transformed[*batch_idx, below_median_mask, dim] = ( + rank_ppf + y_median + ) + + # TODO: what do I need to save? + + self._is_trained = torch.tensor(True) + return Y_transformed, Yvar From 88a2e5d25b1da2f7e0834c1127d613104a6d1f2a Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Mon, 2 Dec 2024 09:51:01 -0500 Subject: [PATCH 11/20] fea: use unnormalize in more places but add flag to turn off the constant bound adjustment --- botorch/optim/initializers.py | 4 +- botorch/utils/feasible_volume.py | 5 +- botorch/utils/sampling.py | 2 +- botorch/utils/transforms.py | 14 +- test/optim/test_initializers.py | 262 ++++++++++++++++--------------- 5 files changed, 153 insertions(+), 134 deletions(-) diff --git a/botorch/optim/initializers.py b/botorch/optim/initializers.py index 0d91d08bea..fbf975cedc 100644 --- a/botorch/optim/initializers.py +++ b/botorch/optim/initializers.py @@ -373,7 +373,9 @@ def gen_batch_initial_conditions( X_rnd_nlzd = torch.rand( n, q, bounds_cpu.shape[-1], dtype=bounds.dtype ) - X_rnd = X_rnd_nlzd * (bounds_cpu[1] - bounds_cpu[0]) + bounds_cpu[0] + X_rnd = unnormalize( + X_rnd_nlzd, bounds, update_constant_bounds=False + ) else: X_rnd = sample_q_batches_from_polytope( n=n, diff --git a/botorch/utils/feasible_volume.py b/botorch/utils/feasible_volume.py index f3b8d2fb76..2608c03c2a 100644 --- a/botorch/utils/feasible_volume.py +++ b/botorch/utils/feasible_volume.py @@ -11,7 +11,7 @@ import botorch.models.model as model import torch from botorch.logging import _get_logger -from botorch.utils.sampling import manual_seed +from botorch.utils.sampling import manual_seed, unnormalize from torch import Tensor @@ -164,9 +164,10 @@ def estimate_feasible_volume( seed = seed if seed is not None else torch.randint(0, 1000000, (1,)).item() with manual_seed(seed=seed): - box_samples = bounds[0] + (bounds[1] - bounds[0]) * torch.rand( + samples_nlzd = torch.rand( (nsample_feature, bounds.size(1)), dtype=dtype, device=device ) + box_samples = unnormalize(samples_nlzd, bounds, update_constant_bounds=False) features, p_feature = get_feasible_samples( samples=box_samples, inequality_constraints=inequality_constraints diff --git a/botorch/utils/sampling.py b/botorch/utils/sampling.py index 9ca48a5668..f914dea24d 100644 --- a/botorch/utils/sampling.py +++ b/botorch/utils/sampling.py @@ -103,7 +103,7 @@ def draw_sobol_samples( samples_raw = samples_raw.view(*batch_shape, n, q, d).to(device=bounds.device) if batch_shape != torch.Size(): samples_raw = samples_raw.permute(-3, *range(len(batch_shape)), -2, -1) - return bounds[0] + (bounds[1] - bounds[0]) * samples_raw + return unnormalize(samples_raw, bounds, update_constant_bounds=False) def draw_sobol_normal_samples( diff --git a/botorch/utils/transforms.py b/botorch/utils/transforms.py index 01f34c0da4..5b60ec4ff1 100644 --- a/botorch/utils/transforms.py +++ b/botorch/utils/transforms.py @@ -66,7 +66,7 @@ def _update_constant_bounds(bounds: Tensor) -> Tensor: return bounds -def normalize(X: Tensor, bounds: Tensor) -> Tensor: +def normalize(X: Tensor, bounds: Tensor, update_constant_bounds: bool = True) -> Tensor: r"""Min-max normalize X w.r.t. the provided bounds. NOTE: If the upper and lower bounds are identical for a dimension, that dimension @@ -89,11 +89,15 @@ def normalize(X: Tensor, bounds: Tensor) -> Tensor: >>> bounds = torch.stack([torch.zeros(3), 0.5 * torch.ones(3)]) >>> X_normalized = normalize(X, bounds) """ - bounds = _update_constant_bounds(bounds=bounds) + bounds = ( + _update_constant_bounds(bounds=bounds) if update_constant_bounds else bounds + ) return (X - bounds[0]) / (bounds[1] - bounds[0]) -def unnormalize(X: Tensor, bounds: Tensor) -> Tensor: +def unnormalize( + X: Tensor, bounds: Tensor, update_constant_bounds: bool = True +) -> Tensor: r"""Un-normalizes X w.r.t. the provided bounds. NOTE: If the upper and lower bounds are identical for a dimension, that dimension @@ -116,7 +120,9 @@ def unnormalize(X: Tensor, bounds: Tensor) -> Tensor: >>> bounds = torch.stack([torch.zeros(3), 0.5 * torch.ones(3)]) >>> X = unnormalize(X_normalized, bounds) """ - bounds = _update_constant_bounds(bounds=bounds) + bounds = ( + _update_constant_bounds(bounds=bounds) if update_constant_bounds else bounds + ) return X * (bounds[1] - bounds[0]) + bounds[0] diff --git a/test/optim/test_initializers.py b/test/optim/test_initializers.py index 65cde7e183..e9145eb59f 100644 --- a/test/optim/test_initializers.py +++ b/test/optim/test_initializers.py @@ -47,7 +47,7 @@ transform_intra_point_constraint, ) from botorch.sampling.normal import IIDNormalSampler -from botorch.utils.sampling import draw_sobol_samples, manual_seed +from botorch.utils.sampling import draw_sobol_samples, manual_seed, unnormalize from botorch.utils.testing import ( _get_max_violation_of_bounds, _get_max_violation_of_constraints, @@ -221,144 +221,152 @@ def test_gen_batch_initial_conditions(self): bounds = torch.stack([torch.zeros(2), torch.ones(2)]) mock_acqf = MockAcquisitionFunction() mock_acqf.objective = lambda y: y.squeeze(-1) - for dtype in (torch.float, torch.double): + for ( + dtype, + nonnegative, + seed, + init_batch_limit, + ffs, + sample_around_best, + ) in product( + (torch.float, torch.double), + [True, False], + [None, 1234], + [None, 1], + [None, {0: 0.5}], + [True, False], + ): bounds = bounds.to(device=self.device, dtype=dtype) mock_acqf.X_baseline = bounds # for testing sample_around_best mock_acqf.model = MockModel(MockPosterior(mean=bounds[:, :1])) - for nonnegative, seed, init_batch_limit, ffs, sample_around_best in product( - [True, False], [None, 1234], [None, 1], [None, {0: 0.5}], [True, False] - ): - with mock.patch.object( - MockAcquisitionFunction, - "__call__", - wraps=mock_acqf.__call__, - ) as mock_acqf_call, warnings.catch_warnings(): - warnings.simplefilter( - "ignore", category=BadInitialCandidatesWarning - ) - batch_initial_conditions = gen_batch_initial_conditions( - acq_function=mock_acqf, - bounds=bounds, - q=1, - num_restarts=2, - raw_samples=10, - fixed_features=ffs, - options={ - "nonnegative": nonnegative, - "eta": 0.01, - "alpha": 0.1, - "seed": seed, - "init_batch_limit": init_batch_limit, - "sample_around_best": sample_around_best, - }, - ) - expected_shape = torch.Size([2, 1, 2]) - self.assertEqual(batch_initial_conditions.shape, expected_shape) - self.assertEqual(batch_initial_conditions.device, bounds.device) - self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) - self.assertLess( - _get_max_violation_of_bounds(batch_initial_conditions, bounds), - 1e-6, - ) - batch_shape = ( - torch.Size([]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - raw_samps = mock_acqf_call.call_args[0][0] - batch_shape = ( - torch.Size([20 if sample_around_best else 10]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) - self.assertEqual(raw_samps.shape, expected_raw_samps_shape) + with mock.patch.object( + MockAcquisitionFunction, + "__call__", + wraps=mock_acqf.__call__, + ) as mock_acqf_call, warnings.catch_warnings(): + warnings.simplefilter("ignore", category=BadInitialCandidatesWarning) + batch_initial_conditions = gen_batch_initial_conditions( + acq_function=mock_acqf, + bounds=bounds, + q=1, + num_restarts=2, + raw_samples=10, + fixed_features=ffs, + options={ + "nonnegative": nonnegative, + "eta": 0.01, + "alpha": 0.1, + "seed": seed, + "init_batch_limit": init_batch_limit, + "sample_around_best": sample_around_best, + }, + ) + expected_shape = torch.Size([2, 1, 2]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) + self.assertLess( + _get_max_violation_of_bounds(batch_initial_conditions, bounds), + 1e-6, + ) + batch_shape = ( + torch.Size([]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + raw_samps = mock_acqf_call.call_args[0][0] + batch_shape = ( + torch.Size([20 if sample_around_best else 10]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) + self.assertEqual(raw_samps.shape, expected_raw_samps_shape) - if ffs is not None: - for idx, val in ffs.items(): - self.assertTrue( - torch.all(batch_initial_conditions[..., idx] == val) - ) + if ffs is not None: + for idx, val in ffs.items(): + self.assertTrue( + torch.all(batch_initial_conditions[..., idx] == val) + ) def test_gen_batch_initial_conditions_topn(self): bounds = torch.stack([torch.zeros(2), torch.ones(2)]) mock_acqf = MockAcquisitionFunction() mock_acqf.objective = lambda y: y.squeeze(-1) mock_acqf.maximize = True # Add maximize attribute - for dtype in (torch.float, torch.double): + for ( + dtype, + topn, + largest, + is_sorted, + seed, + init_batch_limit, + ffs, + sample_around_best, + ) in product( + [torch.float, torch.double], + [True, False], + [True, False, None], + [True, False], + [None, 1234], + [None, 1], + [None, {0: 0.5}], + [True, False], + ): bounds = bounds.to(device=self.device, dtype=dtype) mock_acqf.X_baseline = bounds # for testing sample_around_best mock_acqf.model = MockModel(MockPosterior(mean=bounds[:, :1])) - for ( - topn, - largest, - is_sorted, - seed, - init_batch_limit, - ffs, - sample_around_best, - ) in product( - [True, False], - [True, False, None], - [True, False], - [None, 1234], - [None, 1], - [None, {0: 0.5}], - [True, False], - ): - with mock.patch.object( - MockAcquisitionFunction, - "__call__", - wraps=mock_acqf.__call__, - ) as mock_acqf_call, warnings.catch_warnings(): - warnings.simplefilter( - "ignore", category=BadInitialCandidatesWarning - ) - options = { - "topn": topn, - "sorted": is_sorted, - "seed": seed, - "init_batch_limit": init_batch_limit, - "sample_around_best": sample_around_best, - } - if largest is not None: - options["largest"] = largest - batch_initial_conditions = gen_batch_initial_conditions( - acq_function=mock_acqf, - bounds=bounds, - q=1, - num_restarts=2, - raw_samples=10, - fixed_features=ffs, - options=options, - ) - expected_shape = torch.Size([2, 1, 2]) - self.assertEqual(batch_initial_conditions.shape, expected_shape) - self.assertEqual(batch_initial_conditions.device, bounds.device) - self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) - self.assertLess( - _get_max_violation_of_bounds(batch_initial_conditions, bounds), - 1e-6, - ) - batch_shape = ( - torch.Size([]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - raw_samps = mock_acqf_call.call_args[0][0] - batch_shape = ( - torch.Size([20 if sample_around_best else 10]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) - self.assertEqual(raw_samps.shape, expected_raw_samps_shape) + with mock.patch.object( + MockAcquisitionFunction, + "__call__", + wraps=mock_acqf.__call__, + ) as mock_acqf_call, warnings.catch_warnings(): + warnings.simplefilter("ignore", category=BadInitialCandidatesWarning) + options = { + "topn": topn, + "sorted": is_sorted, + "seed": seed, + "init_batch_limit": init_batch_limit, + "sample_around_best": sample_around_best, + } + if largest is not None: + options["largest"] = largest + batch_initial_conditions = gen_batch_initial_conditions( + acq_function=mock_acqf, + bounds=bounds, + q=1, + num_restarts=2, + raw_samples=10, + fixed_features=ffs, + options=options, + ) + expected_shape = torch.Size([2, 1, 2]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) + self.assertLess( + _get_max_violation_of_bounds(batch_initial_conditions, bounds), + 1e-6, + ) + batch_shape = ( + torch.Size([]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + raw_samps = mock_acqf_call.call_args[0][0] + batch_shape = ( + torch.Size([20 if sample_around_best else 10]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) + self.assertEqual(raw_samps.shape, expected_raw_samps_shape) - if ffs is not None: - for idx, val in ffs.items(): - self.assertTrue( - torch.all(batch_initial_conditions[..., idx] == val) - ) + if ffs is not None: + for idx, val in ffs.items(): + self.assertTrue( + torch.all(batch_initial_conditions[..., idx] == val) + ) def test_gen_batch_initial_conditions_highdim(self): d = 2200 # 2200 * 10 (q) > 21201 (sobol max dim) @@ -841,7 +849,9 @@ def generator(n: int, q: int, seed: int | None): dtype=bounds.dtype, device=self.device, ) - X_rnd = bounds[0] + (bounds[1] - bounds[0]) * X_rnd_nlzd + X_rnd = unnormalize( + X_rnd_nlzd, bounds, update_constant_bounds=False + ) X_rnd[..., -1] = 0.42 return X_rnd From e0202e2b4d1a1e6623b01955f44788aca5827344 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Mon, 2 Dec 2024 09:56:04 -0500 Subject: [PATCH 12/20] doc: add docstring for the new update_constant_bounds argument --- botorch/utils/transforms.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/botorch/utils/transforms.py b/botorch/utils/transforms.py index 5b60ec4ff1..b354821cfb 100644 --- a/botorch/utils/transforms.py +++ b/botorch/utils/transforms.py @@ -69,14 +69,15 @@ def _update_constant_bounds(bounds: Tensor) -> Tensor: def normalize(X: Tensor, bounds: Tensor, update_constant_bounds: bool = True) -> Tensor: r"""Min-max normalize X w.r.t. the provided bounds. - NOTE: If the upper and lower bounds are identical for a dimension, that dimension - will not be scaled. Such dimensions will only be shifted as - `new_X[..., i] = X[..., i] - bounds[0, i]`. This avoids division by zero issues. - Args: X: `... x d` tensor of data bounds: `2 x d` tensor of lower and upper bounds for each of the X's d columns. + update_constant_bounds: If `True`, update the constant bounds in order to + avoid division by zero issues. When the upper and lower bounds are + identical for a dimension, that dimension will not be scaled. Such + dimensions will only be shifted as + `new_X[..., i] = X[..., i] - bounds[0, i]`. Returns: A `... x d`-dim tensor of normalized data, given by @@ -100,14 +101,16 @@ def unnormalize( ) -> Tensor: r"""Un-normalizes X w.r.t. the provided bounds. - NOTE: If the upper and lower bounds are identical for a dimension, that dimension - will not be scaled. Such dimensions will only be shifted as - `new_X[..., i] = X[..., i] + bounds[0, i]`, matching the behavior of `normalize`. - Args: X: `... x d` tensor of data bounds: `2 x d` tensor of lower and upper bounds for each of the X's d columns. + update_constant_bounds: If `True`, update the constant bounds in order to + avoid division by zero issues. When the upper and lower bounds are + identical for a dimension, that dimension will not be scaled. Such + dimensions will only be shifted as + `new_X[..., i] = X[..., i] + bounds[0, i]`. This is the inverse of + the behavior of `normalize` when `update_constant_bounds=True`. Returns: A `... x d`-dim tensor of unnormalized data, given by From 44225d8a93a2fa31ad9193d547f695abed3ae728 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Mon, 2 Dec 2024 14:01:49 -0500 Subject: [PATCH 13/20] Merge remote-tracking branch 'upstream/main' into vizier-output-transforms --- requirements.txt | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/requirements.txt b/requirements.txt index 10596af227..61559fe624 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,8 @@ -multipledispatch -scipy -mpmath>=0.19,<=1.3 -torch>=2.0.1 -pyro-ppl>=1.8.4 typing_extensions pyre_extensions gpytorch==1.13 linear_operator==0.5.3 +torch>=2.0.1 +pyro-ppl>=1.8.4 +scipy +multipledispatch From 3a87cc694b92706cd179bdfa38262a6571df2e19 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Mon, 2 Dec 2024 18:55:26 -0500 Subject: [PATCH 14/20] wip: untransform still doesn't work --- botorch/models/transforms/outcome.py | 150 +++++++++++++++++++++++---- 1 file changed, 131 insertions(+), 19 deletions(-) diff --git a/botorch/models/transforms/outcome.py b/botorch/models/transforms/outcome.py index f6a838fd9f..5927c7b7d6 100644 --- a/botorch/models/transforms/outcome.py +++ b/botorch/models/transforms/outcome.py @@ -892,6 +892,11 @@ def forward( """ _check_batched_output(Y, self._batch_shape) + if Yvar is not None: + raise NotImplementedError( + "InfeasibleTransform does not support transforming observation noise" + ) + if self.training: if torch.isnan(Y).all(dim=-2).any(): raise RuntimeError("For at least one batch, all outcomes are NaN") @@ -919,7 +924,7 @@ def forward( ) Y = torch.where(torch.isnan(Y), expanded_bad_value, Y) Y = torch.where(~torch.isnan(Y), Y + expanded_shift, Y) - # TODO: Handle Yvar + return Y, Yvar def untransform( @@ -942,6 +947,11 @@ def untransform( "forward() needs to be called before untransform() is called." ) + if Yvar is not None: + raise NotImplementedError( + "InfeasibleTransform does not support untransforming observation noise" + ) + # Expand shift to match Y's shape expanded_shift = self._shift.unsqueeze(-2).expand( *Y.shape[:-2], Y.shape[-2], -1 @@ -992,6 +1002,11 @@ def forward( """ _check_batched_output(Y, self._batch_shape) + if Yvar is not None: + raise NotImplementedError( + "LogWarperTransform does not support transforming observation noise" + ) + if self.training: if torch.isnan(Y).all(dim=-2).any(): raise RuntimeError("For at least one batch, all outcomes are NaN") @@ -1015,7 +1030,6 @@ def forward( torch.log1p(norm_diff * (self.offset - 1)) / torch.log(self.offset) ) - # TODO: Handle Yvar return Y_transformed, Yvar def untransform( @@ -1036,6 +1050,11 @@ def untransform( if not self._is_trained: raise RuntimeError("forward() needs to be called before untransform()") + if Yvar is not None: + raise NotImplementedError( + "LogWarperTransform does not support untransforming observation noise" + ) + expanded_labels_min = self._labels_min.unsqueeze(-2).expand( *Y.shape[:-2], Y.shape[-2], -1 ) @@ -1069,9 +1088,9 @@ def __init__(self, batch_shape: torch.Size | None = None) -> None: super().__init__() self._batch_shape = batch_shape self._is_trained = False - self.register_buffer("_original_labels", torch.tensor([])) - self.register_buffer("_warped_labels", torch.tensor([])) - self.register_buffer("_original_label_median", torch.tensor(float("nan"))) + self._unique_labels = {} + self._warped_labels = {} + self.register_buffer("_original_label_medians", torch.tensor([])) def _get_std_above_median(self, unique_y: Tensor, y_median: Tensor) -> Tensor: # Estimate std of good half @@ -1101,6 +1120,11 @@ def forward( - The transformed outcome observations. - The transformed observation noise (if applicable). """ + if Yvar is not None: + raise NotImplementedError( + "HalfRankTransform does not support transforming observation noise" + ) + _check_batched_output(Y, self._batch_shape) if self.training: @@ -1110,35 +1134,41 @@ def forward( Y_transformed = Y.clone() # Compute median for each batch - Y_medians = torch.nanmedian(Y, dim=-2) + Y_medians = torch.nanmedian(Y, dim=-2).values + + self._original_label_medians.resize_( + torch.Size((*self._batch_shape, Y.shape[-1])) + ) for dim in range(Y.shape[-1]): - for batch_idx in product((range(n) for n in self._batch_shape)): - y_median = Y_medians[dim] + batch_indices = ( + product(*([m for m in range(n)] for n in self._batch_shape)) + if len(self._batch_shape) > 0 + else [ # this allows it to work with no batch dim + ..., + ] + ) + for batch_idx in batch_indices: + y_median = Y_medians[*batch_idx, dim] y = Y[*batch_idx, :, dim] # Get finite values and their ranks for each batch is_finite_mask = ~torch.isnan(y) ranks = torch.zeros_like(y) - unique_y, unique_idx = torch.unique( - y[is_finite_mask], return_index=True - ) + unique_y = torch.unique(y[is_finite_mask]) for i, val in enumerate(unique_y): ranks[y == val] = i + 1 ranks = torch.where(is_finite_mask, ranks, len(unique_y) + 1) - # Transform values below median - below_median_mask = y < y_median - # Calculate rank quantiles dedup_median_index = torch.searchsorted(unique_y, y_median) denominator = dedup_median_index + 0.5 * ( unique_y[dedup_median_index] == y_median ) - rank_quantile = 0.5 * (ranks[below_median_mask] - 0.5) / denominator + rank_quantile = 0.5 * (ranks - 0.5) / denominator y_above_median_std = self._get_std_above_median(unique_y, y_median) @@ -1148,11 +1178,93 @@ def forward( * y_above_median_std * torch.sqrt(torch.tensor(2.0)) ) - Y_transformed[*batch_idx, below_median_mask, dim] = ( - rank_ppf + y_median + Y_transformed[*batch_idx, :, dim] = torch.where( + y < y_median, + rank_ppf + y_median, + Y_transformed[*batch_idx, :, dim], ) - # TODO: what do I need to save? + # save intermediate values for untransform + self._original_label_medians[*batch_idx, dim] = y_median + self._unique_labels[(*batch_idx, dim)] = unique_y + self._warped_labels[(*batch_idx, dim)] = unique_y self._is_trained = torch.tensor(True) - return Y_transformed, Yvar + + return Y_transformed, Yvar + + def untransform( + self, Y: Tensor, Yvar: Tensor | None = None + ) -> tuple[Tensor, Tensor | None]: + """Un-transform the outcomes. + + Args: + Y: A `batch_shape x n x m`-dim tensor of transformed targets. + Yvar: A `batch_shape x n x m`-dim tensor of transformed observation + noises associated with the targets (if applicable). + + Returns: + A two-tuple with the un-transformed outcomes: + - The un-transformed outcome observations. + - The un-transformed observation noise (if applicable). + """ + if not self._is_trained: + raise RuntimeError("forward() needs to be called before untransform()") + + if Yvar is not None: + raise NotImplementedError( + "HalfRankTransform does not support untransforming observation noise" + ) + + Y_utf = Y.clone() + + for dim in range(Y.shape[-1]): + batch_indices = ( + product(*(range(n) for n in self._batch_shape)) + if len(self._batch_shape) > 0 + else [ # this allows it to work with no batch dim + ..., + ] + ) + for batch_idx in batch_indices: + y = Y[*batch_idx, :, dim] + unique_labels = self._unique_labels[(*batch_idx, dim)] + warped_labels = self._warped_labels[(*batch_idx, dim)] + + # Process values below median + below_median = y < self._original_label_medians[*batch_idx, dim] + if below_median.any(): + # Find nearest warped values and interpolate + warped_idx = torch.searchsorted(warped_labels, y[below_median]) + + # Handle edge cases and interpolation + for i, (val, idx) in enumerate(zip(y[below_median], warped_idx)): + if idx == 0: + # Extrapolate below minimum + scale = (val - warped_labels[0]) / ( + warped_labels[-1] - warped_labels[0] + ) + Y_utf[below_median][i] = unique_labels[0] - scale * ( + unique_labels[-1] - unique_labels[0] + ) + else: + # Interpolate between points + lower_idx = idx - 1 + upper_idx = min(idx, len(warped_labels) - 1) + + original_gap = ( + unique_labels[upper_idx] - unique_labels[lower_idx] + ) + warped_gap = ( + warped_labels[upper_idx] - warped_labels[lower_idx] + ) + + if warped_gap > 0: + scale = (val - warped_labels[lower_idx]) / warped_gap + Y_utf[below_median][i] = ( + unique_labels[lower_idx] + scale * original_gap + ) + else: + Y_utf[below_median][i] = unique_labels[lower_idx] + + return Y_utf, Yvar From 926d9e2f177428b7515a5ec92cd57b1d605b5057 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Mon, 2 Dec 2024 22:05:22 -0500 Subject: [PATCH 15/20] fea: add half rank transform --- botorch/models/transforms/outcome.py | 139 ++++++++++++++++++++------- 1 file changed, 106 insertions(+), 33 deletions(-) diff --git a/botorch/models/transforms/outcome.py b/botorch/models/transforms/outcome.py index 5927c7b7d6..6afa8136fe 100644 --- a/botorch/models/transforms/outcome.py +++ b/botorch/models/transforms/outcome.py @@ -26,6 +26,8 @@ from collections import OrderedDict from itertools import product +import numpy as np + import torch from botorch.models.transforms.utils import ( norm_to_lognorm_mean, @@ -1150,13 +1152,19 @@ def forward( ) for batch_idx in batch_indices: y_median = Y_medians[*batch_idx, dim] - y = Y[*batch_idx, :, dim] + y = Y_transformed[*batch_idx, :, dim] # Get finite values and their ranks for each batch is_finite_mask = ~torch.isnan(y) ranks = torch.zeros_like(y) - unique_y = torch.unique(y[is_finite_mask]) + # TODO: this is annoying but torch.unique doesn't support + # returning indices + np_unique_y, np_unique_indices = np.unique( + y[is_finite_mask].numpy(), return_index=True + ) + unique_y = torch.from_numpy(np_unique_y) + unique_indices = torch.from_numpy(np_unique_indices) for i, val in enumerate(unique_y): ranks[y == val] = i + 1 @@ -1187,7 +1195,9 @@ def forward( # save intermediate values for untransform self._original_label_medians[*batch_idx, dim] = y_median self._unique_labels[(*batch_idx, dim)] = unique_y - self._warped_labels[(*batch_idx, dim)] = unique_y + self._warped_labels[(*batch_idx, dim)] = (rank_ppf + y_median)[ + is_finite_mask + ][unique_indices] self._is_trained = torch.tensor(True) @@ -1227,7 +1237,7 @@ def untransform( ] ) for batch_idx in batch_indices: - y = Y[*batch_idx, :, dim] + y = Y_utf[*batch_idx, :, dim].clone() unique_labels = self._unique_labels[(*batch_idx, dim)] warped_labels = self._warped_labels[(*batch_idx, dim)] @@ -1237,34 +1247,97 @@ def untransform( # Find nearest warped values and interpolate warped_idx = torch.searchsorted(warped_labels, y[below_median]) - # Handle edge cases and interpolation - for i, (val, idx) in enumerate(zip(y[below_median], warped_idx)): - if idx == 0: - # Extrapolate below minimum - scale = (val - warped_labels[0]) / ( - warped_labels[-1] - warped_labels[0] - ) - Y_utf[below_median][i] = unique_labels[0] - scale * ( - unique_labels[-1] - unique_labels[0] - ) - else: - # Interpolate between points - lower_idx = idx - 1 - upper_idx = min(idx, len(warped_labels) - 1) - - original_gap = ( - unique_labels[upper_idx] - unique_labels[lower_idx] - ) - warped_gap = ( - warped_labels[upper_idx] - warped_labels[lower_idx] - ) - - if warped_gap > 0: - scale = (val - warped_labels[lower_idx]) / warped_gap - Y_utf[below_median][i] = ( - unique_labels[lower_idx] + scale * original_gap - ) - else: - Y_utf[below_median][i] = unique_labels[lower_idx] + # Create indices for neighboring values + left_idx = torch.clamp(warped_idx - 1, min=0) + right_idx = torch.clamp(warped_idx + 1, max=len(warped_labels)) + + # Gather neighboring values + candidates = torch.stack( + [ + warped_labels[left_idx], + warped_labels[warped_idx], + warped_labels[right_idx], + ], + dim=-1, + ) + + best_idx = torch.argmin( + torch.abs(candidates - y[below_median].unsqueeze(-1)), dim=-1 + ) + lookup_mask = torch.isclose( + candidates[torch.arange(len(best_idx)), best_idx], + y[below_median], + ) + full_lookup_mask = torch.full_like(below_median, False) + below_median_indices = torch.where(below_median)[0] + lookup_indices = below_median_indices[lookup_mask] + full_lookup_mask[lookup_indices] = True + full_lookup_values = torch.zeros_like(Y_utf[*batch_idx, :, dim]) + full_lookup_values[full_lookup_mask] = unique_labels[ + warped_idx[lookup_mask] + ] + Y_utf[*batch_idx, :, dim] = torch.where( + full_lookup_mask, full_lookup_values, Y_utf[*batch_idx, :, dim] + ) + + # if the value is below the warped minimum, we need to + # extrapolate outside the range + extrapolate_mask = y < warped_labels[0] + extrapolated_values = unique_labels[0] - ( + y[extrapolate_mask] - warped_labels[0] + ).abs() / (warped_labels[-1] - warped_labels[0]) * ( + unique_labels[-1] - unique_labels[0] + ) + full_extrapolated_values = torch.zeros_like( + Y_utf[*batch_idx, :, dim] + ) + full_extrapolated_values[extrapolate_mask] = extrapolated_values + Y_utf[*batch_idx, :, dim] = torch.where( + extrapolate_mask, + full_extrapolated_values, + Y_utf[*batch_idx, :, dim], + ) + + # otherwise, interpolate + neither_extrapolate_nor_lookup = ~( + (y[below_median] < warped_labels[0]) | lookup_mask + ) + y_neither_extrapolate_nor_lookup = y[below_median][ + neither_extrapolate_nor_lookup + ] + warped_idx_neither_extrapolate_nor_lookup = warped_idx[ + neither_extrapolate_nor_lookup + ] + + lower_idx = (warped_idx_neither_extrapolate_nor_lookup - 1,) + upper_idx = (warped_idx_neither_extrapolate_nor_lookup,) + + original_gap = unique_labels[upper_idx] - unique_labels[lower_idx] + warped_gap = warped_labels[upper_idx] - warped_labels[lower_idx] + + full_interpolated_mask = torch.full_like(below_median, False) + below_median_indices = torch.where(below_median)[0] + interpolated_indices = below_median_indices[ + neither_extrapolate_nor_lookup + ] + full_interpolated_mask[interpolated_indices] = True + + full_interpolated_values = torch.zeros_like( + Y_utf[*batch_idx, :, dim] + ) + full_interpolated_values[full_interpolated_mask] = torch.where( + warped_gap > 0, + unique_labels[lower_idx] + + (y_neither_extrapolate_nor_lookup - warped_labels[lower_idx]) + / warped_gap + * original_gap, + unique_labels[lower_idx], + ) + + Y_utf[*batch_idx, :, dim] = torch.where( + full_interpolated_mask, + full_interpolated_values, + Y_utf[*batch_idx, :, dim], + ) return Y_utf, Yvar From 301fa5c47a7e4caf70c0eff83efb0fc179e7a612 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Tue, 3 Dec 2024 09:51:50 -0500 Subject: [PATCH 16/20] test: add tests for half-rank --- botorch/models/transforms/outcome.py | 15 ++- test/models/transforms/test_outcome.py | 136 +++++++++++++++++++++++++ 2 files changed, 147 insertions(+), 4 deletions(-) diff --git a/botorch/models/transforms/outcome.py b/botorch/models/transforms/outcome.py index 6afa8136fe..56fc68697d 100644 --- a/botorch/models/transforms/outcome.py +++ b/botorch/models/transforms/outcome.py @@ -862,7 +862,10 @@ def _check_batched_output(Y: Tensor, batch_shape: Tensor) -> None: class InfeasibleTransform(OutcomeTransform): - """Transforms infeasible (NaN) values to feasible values.""" + """Transforms infeasible (NaN) values to feasible values. + + Inspired by output-space transformations in Vizier: https://arxiv.org/abs/2408.11527 + """ def __init__(self, batch_shape: torch.Size | None = None) -> None: """Transforms infeasible (NaN) values to feasible values. @@ -968,6 +971,8 @@ class LogWarperTransform(OutcomeTransform): Note that this warping is performed on finite values of the array and NaNs are untouched. + + Inspired by output-space transformations in Vizier: https://arxiv.org/abs/2408.11527 """ def __init__( @@ -1078,6 +1083,8 @@ class HalfRankTransform(OutcomeTransform): This transform warps values below the median to follow a Gaussian distribution while leaving values above the median unchanged. NaN values are preserved. + + Inspired by output-space transformations in Vizier: https://arxiv.org/abs/2408.11527 """ def __init__(self, batch_shape: torch.Size | None = None) -> None: @@ -1088,7 +1095,7 @@ def __init__(self, batch_shape: torch.Size | None = None) -> None: will be transformed. """ super().__init__() - self._batch_shape = batch_shape + self._batch_shape = batch_shape if batch_shape is not None else torch.Size([]) self._is_trained = False self._unique_labels = {} self._warped_labels = {} @@ -1147,7 +1154,7 @@ def forward( product(*([m for m in range(n)] for n in self._batch_shape)) if len(self._batch_shape) > 0 else [ # this allows it to work with no batch dim - ..., + (...,), ] ) for batch_idx in batch_indices: @@ -1233,7 +1240,7 @@ def untransform( product(*(range(n) for n in self._batch_shape)) if len(self._batch_shape) > 0 else [ # this allows it to work with no batch dim - ..., + (...,), ] ) for batch_idx in batch_indices: diff --git a/test/models/transforms/test_outcome.py b/test/models/transforms/test_outcome.py index bcc8906ca9..695f2556b3 100644 --- a/test/models/transforms/test_outcome.py +++ b/test/models/transforms/test_outcome.py @@ -13,6 +13,7 @@ _nanmin, Bilog, ChainedOutcomeTransform, + HalfRankTransform, InfeasibleTransform, Log, LogWarperTransform, @@ -1077,3 +1078,138 @@ def test_log_warper_transform_empty_input(self): # Empty input should raise error with self.assertRaises(ValueError): transform.forward(torch.tensor([]).reshape(0, 1), None) + + +class TestHalfRankTransform(BotorchTestCase): + def test_init(self): + # Test initialization + transform = HalfRankTransform() + self.assertIsNone(transform._batch_shape) + self.assertFalse(transform._is_trained) + self.assertEqual(transform._unique_labels, {}) + self.assertEqual(transform._warped_labels, {}) + + # Test with batch shape + batch_shape = torch.Size([2, 3]) + transform = HalfRankTransform(batch_shape=batch_shape) + self.assertEqual(transform._batch_shape, batch_shape) + + def test_transform_simple_case(self): + # Test with simple 1D tensor + transform = HalfRankTransform() + Y = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]).reshape(-1, 1) + Y_transformed, _ = transform.forward(Y) + + # Values above median should remain unchanged + self.assertTrue( + torch.allclose(Y_transformed[Y.squeeze() > 3.0], Y[Y.squeeze() > 3.0]) + ) + + # Check if transform is trained + self.assertTrue(transform._is_trained) + + # Test untransform + Y_untransformed, _ = transform.untransform(Y_transformed) + self.assertTrue(torch.allclose(Y_untransformed, Y, rtol=1e-4)) + + def test_transform_with_nans(self): + transform = HalfRankTransform() + Y = torch.tensor([1.0, float("nan"), 3.0, 4.0, 5.0]).reshape(-1, 1) + Y_transformed, _ = transform.forward(Y) + + # NaN values should remain NaN + self.assertTrue(torch.isnan(Y_transformed[torch.isnan(Y)]).all()) + + # Non-NaN values above median should remain unchanged + valid_mask = ~torch.isnan(Y.squeeze()) + median = torch.nanmedian(Y) + self.assertTrue( + torch.allclose( + Y_transformed[valid_mask & (Y.squeeze() > median)], + Y[valid_mask & (Y.squeeze() > median)], + ) + ) + + def test_transform_batch(self): + batch_shape = torch.Size([2]) + transform = HalfRankTransform(batch_shape=batch_shape) + Y = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]).reshape(2, 3, 1) + Y_transformed, _ = transform.forward(Y) + + # Shape should be preserved + self.assertEqual(Y_transformed.shape, Y.shape) + + # Test untransform + Y_untransformed, _ = transform.untransform(Y_transformed) + self.assertTrue(torch.allclose(Y_untransformed, Y, rtol=1e-4)) + + def test_transform_multi_output(self): + transform = HalfRankTransform() + Y = torch.tensor([[1.0, 10.0], [2.0, 20.0], [3.0, 30.0], [4.0, 40.0]]) + Y_transformed, _ = transform.forward(Y) + + # Each output dimension should be transformed independently + self.assertEqual(Y_transformed.shape, Y.shape) + + # Test untransform + Y_untransformed, _ = transform.untransform(Y_transformed) + self.assertTrue(torch.allclose(Y_untransformed, Y, rtol=1e-4)) + + def test_error_cases(self): + transform = HalfRankTransform() + + # Test all NaN case + Y = torch.tensor([[float("nan")], [float("nan")]]) + with self.assertRaisesRegex( + RuntimeError, "For at least one batch, all outcomes are NaN" + ): + transform.forward(Y) + + # Test untransform before training + Y = torch.tensor([[1.0], [2.0]]) + with self.assertRaisesRegex( + RuntimeError, "needs to be called before untransform" + ): + transform.untransform(Y) + + # Test with observation noise + Y = torch.tensor([[1.0], [2.0]]) + Yvar = torch.tensor([[0.1], [0.1]]) + with self.assertRaisesRegex( + NotImplementedError, + "HalfRankTransform does not support transforming observation noise", + ): + transform.forward(Y, Yvar) + + def test_batch_shape_mismatch(self): + batch_shape = torch.Size([2]) + transform = HalfRankTransform(batch_shape=batch_shape) + Y = torch.tensor([[1.0], [2.0], [3.0]]) # Wrong batch shape + with self.assertRaises(RuntimeError): + transform.forward(Y) + + def test_extrapolation(self): + transform = HalfRankTransform() + Y = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]).reshape(-1, 1) + Y_transformed, _ = transform.forward(Y) + + # Test extrapolation below minimum + Y_test = torch.tensor([0.0]).reshape(-1, 1) + Y_test_transformed, _ = transform.forward(Y_test) + Y_test_untransformed, _ = transform.untransform(Y_test_transformed) + + # The untransformed value should be close to but below the minimum + self.assertLess(Y_test_untransformed.item(), Y.min()) + + def test_interpolation(self): + transform = HalfRankTransform() + Y = torch.tensor([1.0, 3.0, 5.0]).reshape(-1, 1) + Y_transformed, _ = transform.forward(Y) + + # Test interpolation between values + Y_test = torch.tensor([2.0]).reshape(-1, 1) + Y_test_transformed, _ = transform.forward(Y_test) + Y_test_untransformed, _ = transform.untransform(Y_test_transformed) + + # The untransformed value should be close to the original + self.assertTrue(torch.allclose(Y_test_untransformed, Y_test, rtol=1e-4)) From 21c14b281dc858908b3156ce8db1bf42efc0fcbe Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Tue, 3 Dec 2024 11:19:52 -0500 Subject: [PATCH 17/20] test: reduce the number of tests run whilst ensuring coverage --- test/optim/test_initializers.py | 156 ++++++++++++++++---------------- 1 file changed, 78 insertions(+), 78 deletions(-) diff --git a/test/optim/test_initializers.py b/test/optim/test_initializers.py index e9145eb59f..902ecfc449 100644 --- a/test/optim/test_initializers.py +++ b/test/optim/test_initializers.py @@ -131,31 +131,36 @@ def test_initialize_q_batch_nonneg(self): self.assertEqual(ics.dtype, X.dtype) def test_initialize_q_batch(self): - for dtype in (torch.float, torch.double): - for batch_shape in (torch.Size(), [3, 2], (2,), torch.Size([2, 3, 4]), []): - # basic test - X = torch.rand(5, *batch_shape, 3, 4, device=self.device, dtype=dtype) - acq_vals = torch.rand(5, *batch_shape, device=self.device, dtype=dtype) - ics_X, ics_acq_vals = initialize_q_batch(X=X, acq_vals=acq_vals, n=2) - self.assertEqual(ics_X.shape, torch.Size([2, *batch_shape, 3, 4])) - self.assertEqual(ics_X.device, X.device) - self.assertEqual(ics_X.dtype, X.dtype) - self.assertEqual(ics_acq_vals.shape, torch.Size([2, *batch_shape])) - self.assertEqual(ics_acq_vals.device, acq_vals.device) - self.assertEqual(ics_acq_vals.dtype, acq_vals.dtype) - # ensure nothing happens if we want all samples - ics_X, ics_acq_vals = initialize_q_batch(X=X, acq_vals=acq_vals, n=5) - self.assertTrue(torch.equal(X, ics_X)) - self.assertTrue(torch.equal(acq_vals, ics_acq_vals)) - # ensure raises correct warning - acq_vals = torch.zeros(5, device=self.device, dtype=dtype) - with warnings.catch_warnings(record=True) as w: - ics, _ = initialize_q_batch(X=X, acq_vals=acq_vals, n=2) - self.assertEqual(len(w), 1) - self.assertTrue(issubclass(w[-1].category, BadInitialCandidatesWarning)) - self.assertEqual(ics.shape, torch.Size([2, *batch_shape, 3, 4])) - with self.assertRaises(RuntimeError): - initialize_q_batch(X=X, acq_vals=acq_vals, n=10) + for dtype, batch_shape in ( + (torch.float, torch.Size()), + (torch.double, [3, 2]), + (torch.float, (2,)), + (torch.double, torch.Size([2, 3, 4])), + (torch.float, []), + ): + # basic test + X = torch.rand(5, *batch_shape, 3, 4, device=self.device, dtype=dtype) + acq_vals = torch.rand(5, *batch_shape, device=self.device, dtype=dtype) + ics_X, ics_acq_vals = initialize_q_batch(X=X, acq_vals=acq_vals, n=2) + self.assertEqual(ics_X.shape, torch.Size([2, *batch_shape, 3, 4])) + self.assertEqual(ics_X.device, X.device) + self.assertEqual(ics_X.dtype, X.dtype) + self.assertEqual(ics_acq_vals.shape, torch.Size([2, *batch_shape])) + self.assertEqual(ics_acq_vals.device, acq_vals.device) + self.assertEqual(ics_acq_vals.dtype, acq_vals.dtype) + # ensure nothing happens if we want all samples + ics_X, ics_acq_vals = initialize_q_batch(X=X, acq_vals=acq_vals, n=5) + self.assertTrue(torch.equal(X, ics_X)) + self.assertTrue(torch.equal(acq_vals, ics_acq_vals)) + # ensure raises correct warning + acq_vals = torch.zeros(5, device=self.device, dtype=dtype) + with warnings.catch_warnings(record=True) as w: + ics, _ = initialize_q_batch(X=X, acq_vals=acq_vals, n=2) + self.assertEqual(len(w), 1) + self.assertTrue(issubclass(w[-1].category, BadInitialCandidatesWarning)) + self.assertEqual(ics.shape, torch.Size([2, *batch_shape, 3, 4])) + with self.assertRaises(RuntimeError): + initialize_q_batch(X=X, acq_vals=acq_vals, n=10) def test_initialize_q_batch_topn(self): for dtype in (torch.float, torch.double): @@ -228,13 +233,10 @@ def test_gen_batch_initial_conditions(self): init_batch_limit, ffs, sample_around_best, - ) in product( - (torch.float, torch.double), - [True, False], - [None, 1234], - [None, 1], - [None, {0: 0.5}], - [True, False], + ) in ( + (torch.float, True, None, None, None, True), + (torch.double, False, 1234, 1, {0: 0.5}, False), + (torch.double, True, 1234, None, {0: 0.5}, True), ): bounds = bounds.to(device=self.device, dtype=dtype) mock_acqf.X_baseline = bounds # for testing sample_around_best @@ -303,15 +305,15 @@ def test_gen_batch_initial_conditions_topn(self): init_batch_limit, ffs, sample_around_best, - ) in product( - [torch.float, torch.double], - [True, False], - [True, False, None], - [True, False], - [None, 1234], - [None, 1], - [None, {0: 0.5}], - [True, False], + ) in ( + (torch.float, True, True, True, None, None, None, True), + (torch.double, False, False, False, 1234, 1, {0: 0.5}, False), + (torch.float, True, None, True, 1234, None, None, False), + (torch.double, False, True, False, None, 1, {0: 0.5}, True), + (torch.float, True, False, False, 1234, None, {0: 0.5}, True), + (torch.double, False, None, True, None, 1, None, False), + (torch.float, True, True, False, 1234, 1, {0: 0.5}, True), + (torch.double, False, False, True, None, None, None, False), ): bounds = bounds.to(device=self.device, dtype=dtype) mock_acqf.X_baseline = bounds # for testing sample_around_best @@ -374,48 +376,46 @@ def test_gen_batch_initial_conditions_highdim(self): ffs_map = {i: random() for i in range(0, d, 2)} mock_acqf = MockAcquisitionFunction() mock_acqf.objective = lambda y: y.squeeze(-1) - for dtype in (torch.float, torch.double): + for dtype, nonnegative, seed, ffs, sample_around_best in ( + (torch.float, True, None, None, True), + (torch.double, False, 1234, ffs_map, False), + (torch.double, True, 1234, ffs_map, True), + ): bounds = bounds.to(device=self.device, dtype=dtype) mock_acqf.X_baseline = bounds # for testing sample_around_best mock_acqf.model = MockModel(MockPosterior(mean=bounds[:, :1])) - - for nonnegative, seed, ffs, sample_around_best in product( - [True, False], [None, 1234], [None, ffs_map], [True, False] - ): - with warnings.catch_warnings(record=True) as ws: - warnings.simplefilter( - "ignore", category=BadInitialCandidatesWarning - ) - batch_initial_conditions = gen_batch_initial_conditions( - acq_function=MockAcquisitionFunction(), - bounds=bounds, - q=10, - num_restarts=1, - raw_samples=2, - fixed_features=ffs, - options={ - "nonnegative": nonnegative, - "eta": 0.01, - "alpha": 0.1, - "seed": seed, - "sample_around_best": sample_around_best, - }, - ) + with warnings.catch_warnings(record=True) as ws: + warnings.simplefilter("ignore", category=BadInitialCandidatesWarning) + batch_initial_conditions = gen_batch_initial_conditions( + acq_function=MockAcquisitionFunction(), + bounds=bounds, + q=10, + num_restarts=1, + raw_samples=2, + fixed_features=ffs, + options={ + "nonnegative": nonnegative, + "eta": 0.01, + "alpha": 0.1, + "seed": seed, + "sample_around_best": sample_around_best, + }, + ) + self.assertTrue( + any(issubclass(w.category, SamplingWarning) for w in ws) + ) + expected_shape = torch.Size([1, 10, d]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) + self.assertLess( + _get_max_violation_of_bounds(batch_initial_conditions, bounds), 1e-6 + ) + if ffs is not None: + for idx, val in ffs.items(): self.assertTrue( - any(issubclass(w.category, SamplingWarning) for w in ws) + torch.all(batch_initial_conditions[..., idx] == val) ) - expected_shape = torch.Size([1, 10, d]) - self.assertEqual(batch_initial_conditions.shape, expected_shape) - self.assertEqual(batch_initial_conditions.device, bounds.device) - self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) - self.assertLess( - _get_max_violation_of_bounds(batch_initial_conditions, bounds), 1e-6 - ) - if ffs is not None: - for idx, val in ffs.items(): - self.assertTrue( - torch.all(batch_initial_conditions[..., idx] == val) - ) def test_gen_batch_initial_conditions_warning(self) -> None: for dtype in (torch.float, torch.double): From c63c5716a8af1bef64d36d0ed6213c52f6b59be0 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Tue, 3 Dec 2024 17:32:21 -0500 Subject: [PATCH 18/20] fix: fix some review comments --- botorch/models/transforms/outcome.py | 120 +++++++++++++------------ test/models/transforms/test_outcome.py | 47 +++++----- 2 files changed, 90 insertions(+), 77 deletions(-) diff --git a/botorch/models/transforms/outcome.py b/botorch/models/transforms/outcome.py index 56fc68697d..c3fcaa8b76 100644 --- a/botorch/models/transforms/outcome.py +++ b/botorch/models/transforms/outcome.py @@ -18,6 +18,11 @@ International Conference on Artificial Intelligence and Statistics. PMLR, 2021, http://proceedings.mlr.press/v130/eriksson21a.html +.. [song2024vizier] + Song, Xingyou and others. The vizier gaussian process bandit algorithm + arXiv preprint arXiv:2408.11527. + https://arxiv.org/abs/2408.11527 + """ from __future__ import annotations @@ -833,6 +838,7 @@ def untransform_posterior(self, posterior: Posterior) -> TransformedPosterior: def _nanmax( tensor: Tensor, dim: int | None = None, keepdim: bool = False ) -> Tensor | tuple[Tensor, Tensor]: + """Compute the maximum of a tensor, ignoring NaNs.""" min_value = torch.finfo(tensor.dtype).min if dim is None: return tensor.nan_to_num(min_value).max() @@ -842,6 +848,7 @@ def _nanmax( def _nanmin( tensor: Tensor, dim: int | None = None, keepdim: bool = False ) -> Tensor | tuple[Tensor, Tensor]: + """Compute the minimum of a tensor, ignoring NaNs.""" max_value = torch.finfo(tensor.dtype).max if dim is None: return tensor.nan_to_num(max_value).min() @@ -853,7 +860,7 @@ def _check_batched_output(Y: Tensor, batch_shape: Tensor) -> None: if Y.shape[:-2] != batch_shape: raise RuntimeError( f"Expected Y.shape[:-2] to be {batch_shape}, matching " - "the `batch_shape` argument to `Standardize`, but got " + "the `batch_shape` argument to the `OutcomeTransform`, but got " f"Y.shape[:-2]={Y.shape[:-2]}." ) @@ -864,19 +871,19 @@ def _check_batched_output(Y: Tensor, batch_shape: Tensor) -> None: class InfeasibleTransform(OutcomeTransform): """Transforms infeasible (NaN) values to feasible values. - Inspired by output-space transformations in Vizier: https://arxiv.org/abs/2408.11527 + Inspired by output-space transformations in Vizier [song2024vizier]_. """ - def __init__(self, batch_shape: torch.Size | None = None) -> None: + def __init__(self, batch_shape: torch.Size = torch.Size()) -> None: """Transforms infeasible (NaN) values to feasible values. Args: batch_shape: The batch shape of the outcomes. """ super().__init__() - self._is_trained = False self.register_buffer("_shift", None) self.register_buffer("warped_bad_value", torch.tensor(float("nan"))) + self.register_buffer("_is_trained", torch.tensor(False)) self._batch_shape = batch_shape @@ -897,11 +904,6 @@ def forward( """ _check_batched_output(Y, self._batch_shape) - if Yvar is not None: - raise NotImplementedError( - "InfeasibleTransform does not support transforming observation noise" - ) - if self.training: if torch.isnan(Y).all(dim=-2).any(): raise RuntimeError("For at least one batch, all outcomes are NaN") @@ -911,7 +913,7 @@ def forward( num_feasible = Y.shape[-2] - torch.isnan(Y).sum(dim=-2) # Estimate the relative frequency of feasible points - p_feasible = (0.5 + num_feasible) / (1 + Y.numel()) + p_feasible = (0.5 + num_feasible) / (1 + Y.shape[-2]) self.warped_bad_value = warped_bad_value self._shift = -torch.nanmean(Y, dim=-2) * p_feasible - warped_bad_value * ( @@ -921,14 +923,12 @@ def forward( self._is_trained = torch.tensor(True) # Expand warped_bad_value to match Y's shape - expanded_bad_value = self.warped_bad_value.unsqueeze(-2).expand( - *Y.shape[:-2], Y.shape[-2], -1 - ) - expanded_shift = self._shift.unsqueeze(-2).expand( - *Y.shape[:-2], Y.shape[-2], -1 - ) - Y = torch.where(torch.isnan(Y), expanded_bad_value, Y) - Y = torch.where(~torch.isnan(Y), Y + expanded_shift, Y) + expanded_bad_value = self.warped_bad_value.unsqueeze(-2).expand_as(Y) + expanded_shift = self._shift.unsqueeze(-2).expand_as(Y) + Y = torch.where(torch.isnan(Y), expanded_bad_value, Y + expanded_shift) + + if Yvar is not None: + Yvar = torch.where(torch.isnan(Y), torch.tensor(0.0), Yvar) return Y, Yvar @@ -952,45 +952,60 @@ def untransform( "forward() needs to be called before untransform() is called." ) - if Yvar is not None: - raise NotImplementedError( - "InfeasibleTransform does not support untransforming observation noise" - ) - # Expand shift to match Y's shape - expanded_shift = self._shift.unsqueeze(-2).expand( - *Y.shape[:-2], Y.shape[-2], -1 - ) + expanded_shift = self._shift.unsqueeze(-2).expand_as(Y) Y -= expanded_shift - # TODO: Handle Yvar return Y, Yvar class LogWarperTransform(OutcomeTransform): - """Warps an array of labels to highlight the difference between good values. + r"""Warps an array of labels to highlight the difference between good values. - Note that this warping is performed on finite values of the array and NaNs are + NOTE that this warping is performed on finite values of the array and NaNs are untouched. - Inspired by output-space transformations in Vizier: https://arxiv.org/abs/2408.11527 + Inspired by output-space transformations in Vizier [song2024vizier]_. + + The log warping process consists of two transformations: + + 1. Normalization: + + .. math:: + + \hat{y} = \frac{y_{\max} - y}{y_{\max} - y_{\min}} + + 2. Log Warping: + + .. math:: + + \hat{y}_{\text{warped}} = 0.5 - \frac{\log(1 + (s - 1) \cdot \hat{y})}{\log(s)} + + Where: + - :math:`y` is the input value + - :math:`y_{\min}` is the minimum value in the dataset + - :math:`y_{\max}` is the maximum value in the dataset + - :math:`s` is a free parameter (default 1.5) + """ def __init__( - self, batch_shape: torch.Size | None = None, offset: float = 1.5 + self, batch_shape: torch.Size = torch.Size(), offset: float = 1.5 ) -> None: """Initialize transform. Args: - offset: Offset parameter for the log transformation. Must be > 0. + offset: Offset parameter for the log transformation. Larger values + of the offset parameter will lead to greater spreading of good + values. Must be > 1. """ super().__init__() if offset <= 0: raise ValueError("offset must be positive") - self._is_trained = False self._batch_shape = batch_shape self.register_buffer("offset", torch.tensor(offset)) self.register_buffer("_labels_min", torch.tensor(float("nan"))) self.register_buffer("_labels_max", torch.tensor(float("nan"))) + self.register_buffer("_is_trained", torch.tensor(False)) def forward( self, Y: Tensor, Yvar: Tensor | None = None @@ -1022,12 +1037,8 @@ def forward( self._labels_max = _nanmax(Y, dim=-2).values self._is_trained = torch.tensor(True) - expanded_labels_min = self._labels_min.unsqueeze(-2).expand( - *Y.shape[:-2], Y.shape[-2], -1 - ) - expanded_labels_max = self._labels_max.unsqueeze(-2).expand( - *Y.shape[:-2], Y.shape[-2], -1 - ) + expanded_labels_min = self._labels_min.unsqueeze(-2).expand_as(Y) + expanded_labels_max = self._labels_max.unsqueeze(-2).expand_as(Y) # Calculate normalized difference norm_diff = (expanded_labels_max - Y) / ( @@ -1062,12 +1073,8 @@ def untransform( "LogWarperTransform does not support untransforming observation noise" ) - expanded_labels_min = self._labels_min.unsqueeze(-2).expand( - *Y.shape[:-2], Y.shape[-2], -1 - ) - expanded_labels_max = self._labels_max.unsqueeze(-2).expand( - *Y.shape[:-2], Y.shape[-2], -1 - ) + expanded_labels_min = self._labels_min.unsqueeze(-2).expand_as(Y) + expanded_labels_max = self._labels_max.unsqueeze(-2).expand_as(Y) Y_untransformed = expanded_labels_max - ( (torch.exp(torch.log(self.offset) * (0.5 - Y)) - 1) @@ -1084,10 +1091,10 @@ class HalfRankTransform(OutcomeTransform): This transform warps values below the median to follow a Gaussian distribution while leaving values above the median unchanged. NaN values are preserved. - Inspired by output-space transformations in Vizier: https://arxiv.org/abs/2408.11527 + Inspired by output-space transformations in Vizier [song2024vizier]_. """ - def __init__(self, batch_shape: torch.Size | None = None) -> None: + def __init__(self, batch_shape: torch.Size = torch.Size()) -> None: """Initialize transform. Args: @@ -1095,11 +1102,11 @@ def __init__(self, batch_shape: torch.Size | None = None) -> None: will be transformed. """ super().__init__() - self._batch_shape = batch_shape if batch_shape is not None else torch.Size([]) - self._is_trained = False + self._batch_shape = batch_shape self._unique_labels = {} self._warped_labels = {} self.register_buffer("_original_label_medians", torch.tensor([])) + self.register_buffer("_is_trained", torch.tensor(False)) def _get_std_above_median(self, unique_y: Tensor, y_median: Tensor) -> Tensor: # Estimate std of good half @@ -1145,14 +1152,14 @@ def forward( # Compute median for each batch Y_medians = torch.nanmedian(Y, dim=-2).values - self._original_label_medians.resize_( - torch.Size((*self._batch_shape, Y.shape[-1])) + self._original_label_medians = torch.empty( + (*self._batch_shape, Y.shape[-1]), device=Y.device ) for dim in range(Y.shape[-1]): batch_indices = ( product(*([m for m in range(n)] for n in self._batch_shape)) - if len(self._batch_shape) > 0 + if self._batch_shape is not None and len(self._batch_shape) > 0 else [ # this allows it to work with no batch dim (...,), ] @@ -1170,8 +1177,8 @@ def forward( np_unique_y, np_unique_indices = np.unique( y[is_finite_mask].numpy(), return_index=True ) - unique_y = torch.from_numpy(np_unique_y) - unique_indices = torch.from_numpy(np_unique_indices) + unique_y = torch.from_numpy(np_unique_y).to(y.device) + unique_indices = torch.from_numpy(np_unique_indices).to(y.device) for i, val in enumerate(unique_y): ranks[y == val] = i + 1 @@ -1207,8 +1214,9 @@ def forward( ][unique_indices] self._is_trained = torch.tensor(True) + return Y_transformed, Yvar - return Y_transformed, Yvar + return Y, Yvar def untransform( self, Y: Tensor, Yvar: Tensor | None = None @@ -1238,7 +1246,7 @@ def untransform( for dim in range(Y.shape[-1]): batch_indices = ( product(*(range(n) for n in self._batch_shape)) - if len(self._batch_shape) > 0 + if self._batch_shape is not None and len(self._batch_shape) > 0 else [ # this allows it to work with no batch dim (...,), ] diff --git a/test/models/transforms/test_outcome.py b/test/models/transforms/test_outcome.py index 695f2556b3..f98face90f 100644 --- a/test/models/transforms/test_outcome.py +++ b/test/models/transforms/test_outcome.py @@ -893,10 +893,10 @@ def test_infeasible_transform_init(self): """Test initialization of InfeasibleTransform.""" batch_shape = torch.Size([2, 3]) transform = InfeasibleTransform(batch_shape=batch_shape) - assert transform._batch_shape == batch_shape - assert not transform._is_trained - assert transform._shift is None - assert torch.isnan(transform.warped_bad_value) + self.assertEqual(transform._batch_shape, batch_shape) + self.assertFalse(transform._is_trained) + self.assertIsNone(transform._shift) + self.assertTrue(torch.isnan(transform.warped_bad_value)) def test_infeasible_transform_forward(self): """Test forward transformation with NaN values.""" @@ -913,19 +913,19 @@ def test_infeasible_transform_forward(self): Y_tf, _ = transform.forward(Y, None) # Check that transform is now trained - assert transform._is_trained - assert transform._shift is not None - assert not torch.isnan(transform.warped_bad_value).all() + self.assertTrue(transform._is_trained) + self.assertIsNotNone(transform._shift) + self.assertFalse(torch.isnan(transform.warped_bad_value).all()) # Check that NaN values are replaced with warped_bad_value - assert not torch.isnan(Y_tf).any() + self.assertFalse(torch.isnan(Y_tf).any()) # Test forward pass in eval mode transform.eval() Y_tf_eval, _ = transform.forward(Y_orig, None) # Check that NaN values are replaced consistently - assert not torch.isnan(Y_tf_eval).any() + self.assertFalse(torch.isnan(Y_tf_eval).any()) def test_infeasible_transform_untransform(self): """Test untransform functionality.""" @@ -948,10 +948,15 @@ def test_infeasible_transform_untransform(self): Y_untf, _ = transform.untransform(Y_tf, None) # Check that values are properly untransformed - assert torch.allclose(Y_untf[:, 1:], Y[:, 1:], rtol=1e-4) + self.assertTrue(torch.allclose(Y_untf[:, 1:], Y[:, 1:], rtol=1e-4)) # test the unwarped_bad_value - assert torch.allclose(transform.warped_bad_value[:, 0], Y_untf[..., 0, 0]) + self.assertTrue( + torch.allclose( + transform.warped_bad_value[:, 0] - transform._shift[:, 0], + Y_untf[..., 0, 0], + ) + ) def test_infeasible_transform_batch_shape_validation(self): """Test batch shape validation.""" @@ -1024,19 +1029,19 @@ def test_log_warper_transform_forward(self): labels_min = transform._labels_min.clone() labels_max = transform._labels_max.clone() - assert transform._is_trained - assert torch.isfinite(labels_min).all() - assert torch.isfinite(labels_max).all() - assert (torch.isnan(Y_tf) == torch.isnan(Y_orig)).all() + self.assertTrue(transform._is_trained) + self.assertTrue(torch.isfinite(labels_min).all()) + self.assertTrue(torch.isfinite(labels_max).all()) + self.assertTrue((torch.isnan(Y_tf) == torch.isnan(Y_orig)).all()) # Test forward pass in eval mode transform.eval() Y_tf_eval, _ = transform.forward(Y_tf, None) # Check that NaN values are replaced consistently - assert (torch.isnan(Y_tf_eval) == torch.isnan(Y_tf)).all() - assert torch.allclose(labels_min, transform._labels_min) - assert torch.allclose(labels_max, transform._labels_max) + self.assertTrue((torch.isnan(Y_tf_eval) == torch.isnan(Y_tf)).all()) + self.assertTrue(torch.allclose(labels_min, transform._labels_min)) + self.assertTrue(torch.allclose(labels_max, transform._labels_max)) def test_log_warper_transform_untransform(self): """Test untransform functionality.""" @@ -1058,10 +1063,10 @@ def test_log_warper_transform_untransform(self): Y_untf, _ = transform.untransform(Y_tf, None) # Check that values are properly untransformed - assert torch.allclose(Y_untf[:, 1:], Y[:, 1:], rtol=1e-4) + self.assertTrue(torch.allclose(Y_untf[:, 1:], Y[:, 1:], rtol=1e-4)) # test the nan values don't change - assert torch.isnan(Y_untf[..., 0, 0]).all() + self.assertTrue(torch.isnan(Y_untf[..., 0, 0]).all()) def test_log_warper_transform_batch_shape_validation(self): """Test batch shape validation.""" @@ -1084,7 +1089,7 @@ class TestHalfRankTransform(BotorchTestCase): def test_init(self): # Test initialization transform = HalfRankTransform() - self.assertIsNone(transform._batch_shape) + self.assertEqual(transform._batch_shape, torch.Size([])) self.assertFalse(transform._is_trained) self.assertEqual(transform._unique_labels, {}) self.assertEqual(transform._warped_labels, {}) From cbee6d18be22debb52624752de3af3067a0ac60a Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 4 Dec 2024 15:01:45 -0500 Subject: [PATCH 19/20] fea: add forward when not in train --- botorch/models/transforms/outcome.py | 192 ++++++++++++++++++++----- test/models/transforms/test_outcome.py | 7 +- 2 files changed, 159 insertions(+), 40 deletions(-) diff --git a/botorch/models/transforms/outcome.py b/botorch/models/transforms/outcome.py index c3fcaa8b76..3a250bc483 100644 --- a/botorch/models/transforms/outcome.py +++ b/botorch/models/transforms/outcome.py @@ -32,6 +32,7 @@ from itertools import product import numpy as np +import scipy.stats as stats import torch from botorch.models.transforms.utils import ( @@ -855,7 +856,7 @@ def _nanmin( return tensor.nan_to_num(max_value).min(dim=dim, keepdim=keepdim) -def _check_batched_output(Y: Tensor, batch_shape: Tensor) -> None: +def _check_batched_output(Y: Tensor, batch_shape: Tensor, m: int) -> None: """Utility for common output transform checks.""" if Y.shape[:-2] != batch_shape: raise RuntimeError( @@ -864,6 +865,9 @@ def _check_batched_output(Y: Tensor, batch_shape: Tensor) -> None: f"Y.shape[:-2]={Y.shape[:-2]}." ) + if Y.shape[-1] != m: + raise ValueError(f"Expected Y.shape[-1] to be {m}, but got {Y.shape[-1]}.") + if Y.shape[-2] < 1: raise ValueError(f"Can't transform with no observations. {Y.shape=}.") @@ -874,18 +878,18 @@ class InfeasibleTransform(OutcomeTransform): Inspired by output-space transformations in Vizier [song2024vizier]_. """ - def __init__(self, batch_shape: torch.Size = torch.Size()) -> None: + def __init__(self, m: int, batch_shape: torch.Size = torch.Size()) -> None: """Transforms infeasible (NaN) values to feasible values. Args: batch_shape: The batch shape of the outcomes. """ super().__init__() - self.register_buffer("_shift", None) - self.register_buffer("warped_bad_value", torch.tensor(float("nan"))) - self.register_buffer("_is_trained", torch.tensor(False)) - + self._m = m self._batch_shape = batch_shape + self.register_buffer("_shift", torch.zeros([*batch_shape, m])) + self.register_buffer("warped_bad_value", torch.zeros([*batch_shape, m])) + self.register_buffer("_is_trained", torch.tensor(False)) def forward( self, Y: Tensor, Yvar: Tensor | None = None @@ -989,11 +993,13 @@ class LogWarperTransform(OutcomeTransform): """ def __init__( - self, batch_shape: torch.Size = torch.Size(), offset: float = 1.5 + self, m: int, batch_shape: torch.Size = torch.Size(), offset: float = 1.5 ) -> None: """Initialize transform. Args: + m: The output dimension. + batch_shape: The batch_shape of the training targets. offset: Offset parameter for the log transformation. Larger values of the offset parameter will lead to greater spreading of good values. Must be > 1. @@ -1001,10 +1007,11 @@ def __init__( super().__init__() if offset <= 0: raise ValueError("offset must be positive") + self._m = m self._batch_shape = batch_shape self.register_buffer("offset", torch.tensor(offset)) - self.register_buffer("_labels_min", torch.tensor(float("nan"))) - self.register_buffer("_labels_max", torch.tensor(float("nan"))) + self.register_buffer("_labels_min", torch.zeros([*batch_shape, m])) + self.register_buffer("_labels_max", torch.zeros([*batch_shape, m])) self.register_buffer("_is_trained", torch.tensor(False)) def forward( @@ -1022,7 +1029,7 @@ def forward( - The transformed outcome observations. - The transformed observation noise (if applicable). """ - _check_batched_output(Y, self._batch_shape) + _check_batched_output(Y, self._batch_shape, self._m) if Yvar is not None: raise NotImplementedError( @@ -1094,19 +1101,23 @@ class HalfRankTransform(OutcomeTransform): Inspired by output-space transformations in Vizier [song2024vizier]_. """ - def __init__(self, batch_shape: torch.Size = torch.Size()) -> None: + def __init__(self, m: int, batch_shape: torch.Size = torch.Size()) -> None: """Initialize transform. Args: - outputs: Which of the outputs to transform. If omitted, all outputs - will be transformed. + m: The output dimension. + batch_shape: The batch_shape of the training targets. """ super().__init__() + self._m = m self._batch_shape = batch_shape + self.register_buffer("_original_label_medians", torch.zeros([*batch_shape, m])) + self.register_buffer("_is_trained", torch.tensor(False)) + + # TODO these are ragged tensors, we should use a better data structure such + # that they are saved to the state_dict self._unique_labels = {} self._warped_labels = {} - self.register_buffer("_original_label_medians", torch.tensor([])) - self.register_buffer("_is_trained", torch.tensor(False)) def _get_std_above_median(self, unique_y: Tensor, y_median: Tensor) -> Tensor: # Estimate std of good half @@ -1141,21 +1152,16 @@ def forward( "HalfRankTransform does not support transforming observation noise" ) - _check_batched_output(Y, self._batch_shape) + _check_batched_output(Y, self._batch_shape, self._m) + Y_transformed = Y.clone() if self.training: if torch.isnan(Y).all(dim=-2).any(): raise RuntimeError("For at least one batch, all outcomes are NaN") - Y_transformed = Y.clone() - # Compute median for each batch Y_medians = torch.nanmedian(Y, dim=-2).values - self._original_label_medians = torch.empty( - (*self._batch_shape, Y.shape[-1]), device=Y.device - ) - for dim in range(Y.shape[-1]): batch_indices = ( product(*([m for m in range(n)] for n in self._batch_shape)) @@ -1170,27 +1176,24 @@ def forward( # Get finite values and their ranks for each batch is_finite_mask = ~torch.isnan(y) - ranks = torch.zeros_like(y) # TODO: this is annoying but torch.unique doesn't support # returning indices np_unique_y, np_unique_indices = np.unique( y[is_finite_mask].numpy(), return_index=True ) + ranks = stats.rankdata(y.numpy(), method="dense") + unique_y = torch.from_numpy(np_unique_y).to(y.device) unique_indices = torch.from_numpy(np_unique_indices).to(y.device) - - for i, val in enumerate(unique_y): - ranks[y == val] = i + 1 - - ranks = torch.where(is_finite_mask, ranks, len(unique_y) + 1) + ranks = torch.from_numpy(ranks).to(y.device) # Calculate rank quantiles dedup_median_index = torch.searchsorted(unique_y, y_median) - denominator = dedup_median_index + 0.5 * ( + denominator = 2 * dedup_median_index + ( unique_y[dedup_median_index] == y_median ) - rank_quantile = 0.5 * (ranks - 0.5) / denominator + rank_quantile = (ranks - 0.5) / denominator y_above_median_std = self._get_std_above_median(unique_y, y_median) @@ -1209,14 +1212,129 @@ def forward( # save intermediate values for untransform self._original_label_medians[*batch_idx, dim] = y_median self._unique_labels[(*batch_idx, dim)] = unique_y - self._warped_labels[(*batch_idx, dim)] = (rank_ppf + y_median)[ - is_finite_mask - ][unique_indices] + self._warped_labels[(*batch_idx, dim)] = Y_transformed[ + *batch_idx, :, dim + ][is_finite_mask][unique_indices] self._is_trained = torch.tensor(True) return Y_transformed, Yvar - return Y, Yvar + for dim in range(Y.shape[-1]): + batch_indices = ( + product(*([m for m in range(n)] for n in self._batch_shape)) + if self._batch_shape is not None and len(self._batch_shape) > 0 + else [ # this allows it to work with no batch dim + (...,), + ] + ) + for batch_idx in batch_indices: + y_median = self._original_label_medians[*batch_idx, dim] + y = Y[*batch_idx, :, dim] + warped_labels: torch.Tensor = self._warped_labels[(*batch_idx, dim)] + unique_labels: torch.Tensor = self._unique_labels[(*batch_idx, dim)] + + # Process values below median + below_median = y < self._original_label_medians[*batch_idx, dim] + if below_median.any(): + # Find nearest original values and perform lookup + original_idx = torch.searchsorted(unique_labels, y[below_median]) + + # Create indices for neighboring values + left_idx = torch.clamp(original_idx - 1, min=0) + right_idx = torch.clamp(original_idx + 1, max=len(unique_labels)) + + # Gather neighboring values + candidates = torch.stack( + [ + unique_labels[left_idx], + unique_labels[original_idx], + unique_labels[right_idx], + ], + dim=-1, + ) + + # Find nearest original values and perform lookup + best_idx = torch.argmin( + torch.abs(candidates - y[below_median].unsqueeze(-1)), dim=-1 + ) + + lookup_mask = torch.isclose( + candidates[torch.arange(len(best_idx)), best_idx], + y[below_median], + ) + full_lookup_mask = torch.full_like(below_median, False) + below_median_indices = torch.where(below_median)[0] + lookup_indices = below_median_indices[lookup_mask] + full_lookup_mask[lookup_indices] = True + full_lookup_values = torch.zeros_like(Y[*batch_idx, :, dim]) + full_lookup_values[full_lookup_mask] = warped_labels[ + original_idx[lookup_mask] + ] + Y_transformed[*batch_idx, :, dim] = torch.where( + full_lookup_mask, + full_lookup_values, + Y_transformed[*batch_idx, :, dim], + ) + + # if the value is below the original minimum, we need to + # extrapolate outside the range + extrapolate_mask = y < unique_labels[0] + extrapolated_values = warped_labels[0] - ( + y[extrapolate_mask] - unique_labels[0] + ).abs() / (unique_labels.max() - unique_labels.min()) * ( + warped_labels.max() - warped_labels.min() + ) + full_extrapolated_values = torch.zeros_like(Y[*batch_idx, :, dim]) + full_extrapolated_values[extrapolate_mask] = extrapolated_values + Y_transformed[*batch_idx, :, dim] = torch.where( + extrapolate_mask, + full_extrapolated_values, + Y_transformed[*batch_idx, :, dim], + ) + + # otherwise, interpolate + neither_extrapolate_nor_lookup = ~( + (y[below_median] < unique_labels[0]) | lookup_mask + ) + y_neither_extrapolate_nor_lookup = y[below_median][ + neither_extrapolate_nor_lookup + ] + warped_idx_neither_extrapolate_nor_lookup = original_idx[ + neither_extrapolate_nor_lookup + ] + + lower_idx = (warped_idx_neither_extrapolate_nor_lookup - 1,) + upper_idx = (warped_idx_neither_extrapolate_nor_lookup,) + + original_gap = unique_labels[upper_idx] - unique_labels[lower_idx] + warped_gap = warped_labels[upper_idx] - warped_labels[lower_idx] + + full_interpolated_mask = torch.full_like(below_median, False) + below_median_indices = torch.where(below_median)[0] + interpolated_indices = below_median_indices[ + neither_extrapolate_nor_lookup + ] + full_interpolated_mask[interpolated_indices] = True + + full_interpolated_values = torch.zeros_like( + Y_transformed[*batch_idx, :, dim] + ) + full_interpolated_values[full_interpolated_mask] = torch.where( + original_gap > 0, + warped_labels[lower_idx] + + (y_neither_extrapolate_nor_lookup - unique_labels[lower_idx]) + / original_gap + * warped_gap, + warped_labels[lower_idx], + ) + + Y_transformed[*batch_idx, :, dim] = torch.where( + full_interpolated_mask, + full_interpolated_values, + Y_transformed[*batch_idx, :, dim], + ) + + return Y_transformed, Yvar def untransform( self, Y: Tensor, Yvar: Tensor | None = None @@ -1259,7 +1377,7 @@ def untransform( # Process values below median below_median = y < self._original_label_medians[*batch_idx, dim] if below_median.any(): - # Find nearest warped values and interpolate + # Find nearest warped values and perform lookup warped_idx = torch.searchsorted(warped_labels, y[below_median]) # Create indices for neighboring values @@ -1300,8 +1418,8 @@ def untransform( extrapolate_mask = y < warped_labels[0] extrapolated_values = unique_labels[0] - ( y[extrapolate_mask] - warped_labels[0] - ).abs() / (warped_labels[-1] - warped_labels[0]) * ( - unique_labels[-1] - unique_labels[0] + ).abs() / (warped_labels.max() - warped_labels.min()) * ( + unique_labels.max() - unique_labels.min() ) full_extrapolated_values = torch.zeros_like( Y_utf[*batch_idx, :, dim] diff --git a/test/models/transforms/test_outcome.py b/test/models/transforms/test_outcome.py index f98face90f..714c85173b 100644 --- a/test/models/transforms/test_outcome.py +++ b/test/models/transforms/test_outcome.py @@ -942,14 +942,15 @@ def test_infeasible_transform_untransform(self): Y[..., 0, 0] = float("nan") transform.train() - Y_tf, _ = transform.forward(Y, None) + Y_tf, Yvar_tf = transform.forward(Y, Y + 2) + self.assertTrue(torch.allclose(Yvar_tf[:, 1:], Y[:, 1:] + 2)) # Test untransform - Y_untf, _ = transform.untransform(Y_tf, None) + Y_untf, Yvar_untf = transform.untransform(Y_tf, Yvar_tf) # Check that values are properly untransformed self.assertTrue(torch.allclose(Y_untf[:, 1:], Y[:, 1:], rtol=1e-4)) - + self.assertTrue(torch.allclose(Yvar_untf[:, 1:], Yvar_tf[:, 1:], rtol=1e-4)) # test the unwarped_bad_value self.assertTrue( torch.allclose( From b8311aee9d1bf75dcea11ac6a8b4c61d1a4811fc Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 4 Dec 2024 15:05:06 -0500 Subject: [PATCH 20/20] fix: force=True to enforce device safety --- botorch/models/transforms/outcome.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/botorch/models/transforms/outcome.py b/botorch/models/transforms/outcome.py index 3a250bc483..c11eaf1a7e 100644 --- a/botorch/models/transforms/outcome.py +++ b/botorch/models/transforms/outcome.py @@ -1180,9 +1180,9 @@ def forward( # TODO: this is annoying but torch.unique doesn't support # returning indices np_unique_y, np_unique_indices = np.unique( - y[is_finite_mask].numpy(), return_index=True + y[is_finite_mask].numpy(force=True), return_index=True ) - ranks = stats.rankdata(y.numpy(), method="dense") + ranks = stats.rankdata(y.numpy(force=True), method="dense") unique_y = torch.from_numpy(np_unique_y).to(y.device) unique_indices = torch.from_numpy(np_unique_indices).to(y.device)