Skip to content

Commit

Permalink
Incorporate infeasible trials into Eagle designer
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 613371148
  • Loading branch information
belenkil authored and copybara-github committed Mar 7, 2024
1 parent ffcecf7 commit 37a6768
Show file tree
Hide file tree
Showing 7 changed files with 325 additions and 112 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ def __init__(
self._config = config or FireflyAlgorithmConfig()
self._utils = EagleStrategyUtils(self._problem, self._config, self._rng)
self._firefly_pool = FireflyPool(
utils=self._utils, capacity=self._utils.compute_pool_capacity())
utils=self._utils, capacity=self._utils.compute_pool_capacity()
)

if initial_designer_factory is None:
initial_designer_factory = quasi_random.QuasiRandomDesigner.from_problem
Expand Down Expand Up @@ -287,6 +288,8 @@ def _mutate_fly(self, moving_fly: Firefly) -> None:
# Apply the pulls from 'other_fly' on the moving fly's parameters.
for param_config in self._problem.search_space.parameters:
pull_weight = pull_weights[param_config.type]
if other_fly.trial.infeasible:
pull_weight *= self._config.infeasible_force_factor
# Accentuate 'other_fly' pull using 'exploration_rate'.
if pull_weight > 0.5:
explore_pull_weight = (
Expand Down Expand Up @@ -357,6 +360,11 @@ def update(

def _update_one(self, trial: vz.Trial) -> None:
"""Update the pool using a single trial."""
if trial.infeasible and self._config.infeasible_force_factor > 0:
# Add the infeasible firefly to the pool.
infeasible_firefly_id = self._firefly_pool.generate_new_fly_id()
self._firefly_pool.create_or_update_fly(trial, infeasible_firefly_id)

parent_fly_id = int(trial.metadata.ns('eagle').get('parent_fly_id'))
parent_fly = self._firefly_pool.find_parent_fly(parent_fly_id)
if parent_fly is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@
from vizier import algorithms as vza
from vizier import pyvizier as vz
from vizier._src.algorithms.designers.eagle_strategy import eagle_strategy
from vizier._src.algorithms.designers.eagle_strategy import eagle_strategy_utils
from vizier._src.algorithms.designers.eagle_strategy import testing

from absl.testing import absltest
from absl.testing import parameterized


EagleStrategyDesigner = eagle_strategy.EagleStrategyDesigner
FireflyAlgorithmConfig = eagle_strategy_utils.FireflyAlgorithmConfig


class EagleStrategyTest(parameterized.TestCase):
Expand Down Expand Up @@ -242,6 +244,90 @@ def test_suggest_update(self, batch_size):
tid += 1
eagle_designer.update(vza.CompletedTrials(completed), vza.ActiveTrials())

@parameterized.named_parameters(
dict(
testcase_name='Less suggestions than pool capacity',
num_feasible_suggestions=3,
num_infeasible_suggestions=2,
),
dict(
testcase_name='More suggestions than pool capacity',
num_feasible_suggestions=50,
num_infeasible_suggestions=5,
),
)
def test_infeasible_trials(
self, num_feasible_suggestions, num_infeasible_suggestions
):
"""Tests that Eagle works with infeasible trials."""
problem = vz.ProblemStatement()
problem.search_space.select_root().add_float_param(
'float1', 1e-2, 1e3, scale_type=vz.ScaleType.LOG
)
problem.search_space.select_root().add_float_param(
'float2', -2.0, 5.0, scale_type=vz.ScaleType.LINEAR
)
problem.search_space.select_root().add_int_param(
'int', min_value=0, max_value=10
)
problem.search_space.select_root().add_discrete_param(
'discrete', feasible_values=[0.0, 0.6]
)
problem.search_space.select_root().add_categorical_param(
'categorical', feasible_values=['a', 'b', 'c']
)
problem.metric_information.append(
vz.MetricInformation(goal=vz.ObjectiveMetricGoal.MINIMIZE, name='')
)
config = FireflyAlgorithmConfig(infeasible_force_factor=0.1)
eagle_designer = EagleStrategyDesigner(problem, config=config)

def _suggest_and_update(
eagle_designer: EagleStrategyDesigner, tid: int, infeasible: bool
):
suggestion = eagle_designer.suggest(count=1)[0]
completed = suggestion.to_trial(tid).complete(
vz.Measurement(metrics={'': np.random.uniform()}),
infeasibility_reason='infeasible' if infeasible else None,
)
eagle_designer.update(
vza.CompletedTrials([completed]), vza.ActiveTrials()
)

# Suggest trials and update designer for less than pool capacity.
tid = 1
for _ in range(num_feasible_suggestions):
_suggest_and_update(eagle_designer, tid, infeasible=False)
tid += 1

# Suggest another trial and return it as infeasible.
for _ in range(num_infeasible_suggestions):
_suggest_and_update(eagle_designer, tid, infeasible=True)
tid += 1

# Test that the pool size is not affected by infeasible trials..
self.assertEqual(
eagle_designer._firefly_pool.size,
min(eagle_designer._firefly_pool.capacity, num_feasible_suggestions),
)
# Test that the pool contains infeasible trials.
self.assertEqual(
eagle_designer._firefly_pool._infeasible_count,
num_infeasible_suggestions,
)
self.assertEqual(
sum([
1
for firefly in eagle_designer._firefly_pool._pool.values()
if firefly.trial.infeasible
]),
num_infeasible_suggestions,
)
# Suggest more trials while having infeasible trials in the pool.
for _ in range(3):
_suggest_and_update(eagle_designer, tid, infeasible=False)
tid += 1


if __name__ == '__main__':
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class FireflyAlgorithmConfig:
pool_size_factor: float = 1.2
# Exploration rate (value > 1.0 encourages more exploration)
explore_rate: float = 1.0
# The factor to apply on infeasible trial repel force.
infeasible_force_factor: float = 0.0


@attr.define
Expand All @@ -70,10 +72,10 @@ class Firefly:
suggested from the firefly.
trial: The best trial associated with the firefly.
"""
id_: int = attr.field(validator=attr.validators.instance_of(int))
perturbation: float = attr.field(validator=attr.validators.instance_of(float))
generation: int = attr.field(validator=attr.validators.instance_of(int))
trial: vz.Trial = attr.field(validator=attr.validators.instance_of(vz.Trial))
id_: int
perturbation: float
generation: int
trial: vz.Trial


@attr.define
Expand Down Expand Up @@ -337,7 +339,7 @@ def is_better_than(
trial1: vz.Trial,
trial2: vz.Trial,
) -> bool:
"""Checks whether the current trial is better than another trial.
"""Checks whether the 'trial1' is better than 'trial2'.
The comparison is based on the value of final measurement and whether it
goal is MAXIMIZATION or MINIMIZATON.
Expand All @@ -354,8 +356,10 @@ def is_better_than(
"""
if not trial1.is_completed or not trial2.is_completed:
return False
if trial1.infeasible or trial2.infeasible:
if trial1.infeasible and not trial2.infeasible:
return False
if not trial1.infeasible and trial2.infeasible:
return True
if trial1.final_measurement is None or trial2.final_measurement is None:
return False

Expand Down Expand Up @@ -404,31 +408,35 @@ class FireflyPool:
Attributes:
utils: Eagle Strategy utils class.
capacity: The maximum number of flies that the pool could store.
size: The current number of flies in the pool.
capacity: The maximum number of non-feasible fireflies in the pool.
size: The current number of non-feasible fireflies in the pool.
_pool: A dictionary of Firefly objects organized by firefly id.
_last_id: The last firefly id used to generate a suggestion. It's persistent
across calls to ensure we don't use the same fly repeatedly.
_max_fly_id: The maximum value of any fly id ever created. It's persistent
persistent accross calls to ensure unique ids even if trails were deleted.
_infeasible_count: The number of infeasible fireflies in the pool.
"""
utils: EagleStrategyUtils = attr.field(
validator=attr.validators.instance_of(EagleStrategyUtils))

capacity: int = attr.field(validator=attr.validators.instance_of(int))

_utils: EagleStrategyUtils
_capacity: int
_pool: Dict[int, Firefly] = attr.field(init=False, default=attr.Factory(dict))

_last_id: int = attr.field(init=False, default=0)

_max_fly_id: int = attr.field(init=False, default=0)
_infeasible_count: int = attr.field(init=False, default=0)

@property
def capacity(self) -> int:
return self._capacity

@property
def size(self) -> int:
return len(self._pool)
"""Returns the number of feasible fireflies in the pool."""
return len(self._pool) - self._infeasible_count

def remove_fly(self, fly: Firefly):
"""Removes a fly from the pool."""
if fly.trial.infeasible:
raise ValueError('Infeasible firefly should not be removed from pool.')
del self._pool[fly.id_]

def get_shuffled_flies(self, rng: np.random.Generator) -> list[Firefly]:
Expand All @@ -455,23 +463,24 @@ def get_next_moving_fly_copy(self) -> Firefly:
Returns:
A copy of the next moving fly.
"""
current_fly_id = self._last_id + 1
while current_fly_id != self._last_id:
if current_fly_id > self._max_fly_id:
curr_id = self._last_id + 1
while curr_id != self._last_id:
if curr_id > self._max_fly_id:
# Passed the maximum id. Start from the first one as ids are monotonic.
current_fly_id = next(iter(self._pool))
if current_fly_id in self._pool:
self._last_id = current_fly_id
return copy.deepcopy(self._pool[current_fly_id])
current_fly_id += 1
curr_id = next(iter(self._pool))
if curr_id in self._pool and not self._pool[curr_id].trial.infeasible:
self._last_id = curr_id
return copy.deepcopy(self._pool[curr_id])
curr_id += 1

return copy.deepcopy(self._pool[self._last_id])

def is_best_fly(self, fly: Firefly) -> bool:
"""Checks if the 'fly' has the best final measurement in the pool."""
for other_fly_id, other_fly in self._pool.items():
if other_fly_id != fly.id_ and self.utils.is_better_than(
other_fly.trial, fly.trial):
if other_fly_id != fly.id_ and self._utils.is_better_than(
other_fly.trial, fly.trial
):
return False
return True

Expand All @@ -497,8 +506,11 @@ def find_closest_parent(self, trial: vz.Trial) -> Firefly:

min_dist, closest_parent = float('inf'), next(iter(self._pool.values()))
for other_fly in self._pool.values():
curr_dist = self.utils.compute_cononical_distance(
other_fly.trial.parameters, trial.parameters)
if other_fly.trial.infeasible:
continue
curr_dist = self._utils.compute_cononical_distance(
other_fly.trial.parameters, trial.parameters
)
if curr_dist < min_dist:
min_dist = curr_dist
closest_parent = other_fly
Expand Down Expand Up @@ -530,12 +542,14 @@ def create_or_update_fly(self, trial: vz.Trial, parent_fly_id: int) -> None:
# Create a new Firefly in pool.
new_fly = Firefly(
id_=parent_fly_id,
perturbation=self.utils.config.perturbation,
perturbation=self._utils.config.perturbation,
generation=1,
trial=trial,
)
self._pool[parent_fly_id] = new_fly
if trial.infeasible:
self._infeasible_count += 1
else:
# Parent fly id already in pool. Update trial if there was improvement.
if self.utils.is_better_than(trial, self._pool[parent_fly_id].trial):
if self._utils.is_better_than(trial, self._pool[parent_fly_id].trial):
self._pool[parent_fly_id].trial = trial
Loading

0 comments on commit 37a6768

Please sign in to comment.