Skip to content

Commit 8b4b583

Browse files
authored
Merge pull request #188 from python-adaptive/cycle_strategy
BalancingLearner: add a "cycle" strategy, sampling the learners one by one
2 parents e73df74 + d25f3d1 commit 8b4b583

File tree

2 files changed

+31
-10
lines changed

2 files changed

+31
-10
lines changed

adaptive/learner/balancing_learner.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from collections.abc import Iterable
55
from contextlib import suppress
66
from functools import partial
7+
import itertools
78
from operator import itemgetter
89

910
import numpy as np
@@ -50,10 +51,11 @@ class BalancingLearner(BaseLearner):
5051
function : callable
5152
A function that calls the functions of the underlying learners.
5253
Its signature is ``function(learner_index, point)``.
53-
strategy : 'loss_improvements' (default), 'loss', or 'npoints'
54+
strategy : 'loss_improvements' (default), 'loss', 'npoints', or 'cycle'.
5455
The points that the `BalancingLearner` choses can be either based on:
5556
the best 'loss_improvements', the smallest total 'loss' of the
56-
child learners, or the number of points per learner, using 'npoints'.
57+
child learners, the number of points per learner, using 'npoints',
58+
or by cycling through the learners one by one using 'cycle'.
5759
One can dynamically change the strategy while the simulation is
5860
running by changing the ``learner.strategy`` attribute.
5961
@@ -90,10 +92,11 @@ def __init__(self, learners, *, cdims=None, strategy="loss_improvements"):
9092

9193
@property
9294
def strategy(self):
93-
"""Can be either 'loss_improvements' (default), 'loss', or 'npoints'
94-
The points that the `BalancingLearner` choses can be either based on:
95-
the best 'loss_improvements', the smallest total 'loss' of the
96-
child learners, or the number of points per learner, using 'npoints'.
95+
"""Can be either 'loss_improvements' (default), 'loss', 'npoints', or
96+
'cycle'. The points that the `BalancingLearner` choses can be either
97+
based on: the best 'loss_improvements', the smallest total 'loss' of
98+
the child learners, the number of points per learner, using 'npoints',
99+
or by going through all learners one by one using 'cycle'.
97100
One can dynamically change the strategy while the simulation is
98101
running by changing the ``learner.strategy`` attribute."""
99102
return self._strategy
@@ -107,10 +110,13 @@ def strategy(self, strategy):
107110
self._ask_and_tell = self._ask_and_tell_based_on_loss
108111
elif strategy == "npoints":
109112
self._ask_and_tell = self._ask_and_tell_based_on_npoints
113+
elif strategy == "cycle":
114+
self._ask_and_tell = self._ask_and_tell_based_on_cycle
115+
self._cycle = itertools.cycle(range(len(self.learners)))
110116
else:
111117
raise ValueError(
112-
'Only strategy="loss_improvements", strategy="loss", or'
113-
' strategy="npoints" is implemented.'
118+
'Only strategy="loss_improvements", strategy="loss",'
119+
' strategy="npoints", or strategy="cycle" is implemented.'
114120
)
115121

116122
def _ask_and_tell_based_on_loss_improvements(self, n):
@@ -173,6 +179,17 @@ def _ask_and_tell_based_on_npoints(self, n):
173179
points, loss_improvements = map(list, zip(*selected))
174180
return points, loss_improvements
175181

182+
def _ask_and_tell_based_on_cycle(self, n):
183+
points, loss_improvements = [], []
184+
for _ in range(n):
185+
index = next(self._cycle)
186+
point, loss_improvement = self.learners[index].ask(n=1)
187+
points.append((index, point[0]))
188+
loss_improvements.append(loss_improvement[0])
189+
self.tell_pending((index, point[0]))
190+
191+
return points, loss_improvements
192+
176193
def ask(self, n, tell_pending=True):
177194
"""Chose points for learners."""
178195
if n == 0:

adaptive/tests/test_balancing_learner.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
from adaptive.runner import simple
77

88

9+
strategies = ["loss", "loss_improvements", "npoints", "cycle"]
10+
11+
912
def test_balancing_learner_loss_cache():
1013
learner = Learner1D(lambda x: x, bounds=(-1, 1))
1114
learner.tell(-1, -1)
@@ -26,7 +29,7 @@ def test_balancing_learner_loss_cache():
2629
assert bl.loss(real=True) == real_loss
2730

2831

29-
@pytest.mark.parametrize("strategy", ["loss", "loss_improvements", "npoints"])
32+
@pytest.mark.parametrize("strategy", strategies)
3033
def test_distribute_first_points_over_learners(strategy):
3134
for initial_points in [0, 3]:
3235
learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)]
@@ -41,7 +44,7 @@ def test_distribute_first_points_over_learners(strategy):
4144
assert len(set(i_learner)) == len(learners)
4245

4346

44-
@pytest.mark.parametrize("strategy", ["loss", "loss_improvements", "npoints"])
47+
@pytest.mark.parametrize("strategy", strategies)
4548
def test_ask_0(strategy):
4649
learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)]
4750
learner = BalancingLearner(learners, strategy=strategy)
@@ -55,6 +58,7 @@ def test_ask_0(strategy):
5558
("loss", lambda l: l.loss() < 0.1),
5659
("loss_improvements", lambda l: l.loss() < 0.1),
5760
("npoints", lambda bl: all(l.npoints > 10 for l in bl.learners)),
61+
("cycle", lambda l: l.loss() < 0.1),
5862
],
5963
)
6064
def test_strategies(strategy, goal):

0 commit comments

Comments
 (0)