4
4
from collections .abc import Iterable
5
5
from contextlib import suppress
6
6
from functools import partial
7
+ import itertools
7
8
from operator import itemgetter
8
9
9
10
import numpy as np
@@ -50,10 +51,11 @@ class BalancingLearner(BaseLearner):
50
51
function : callable
51
52
A function that calls the functions of the underlying learners.
52
53
Its signature is ``function(learner_index, point)``.
53
- strategy : 'loss_improvements' (default), 'loss', or 'npoints'
54
+ strategy : 'loss_improvements' (default), 'loss', 'npoints', or 'cycle'.
54
55
The points that the `BalancingLearner` choses can be either based on:
55
56
the best 'loss_improvements', the smallest total 'loss' of the
56
- child learners, or the number of points per learner, using 'npoints'.
57
+ child learners, the number of points per learner, using 'npoints',
58
+ or by cycling through the learners one by one using 'cycle'.
57
59
One can dynamically change the strategy while the simulation is
58
60
running by changing the ``learner.strategy`` attribute.
59
61
@@ -90,10 +92,11 @@ def __init__(self, learners, *, cdims=None, strategy="loss_improvements"):
90
92
91
93
@property
92
94
def strategy (self ):
93
- """Can be either 'loss_improvements' (default), 'loss', or 'npoints'
94
- The points that the `BalancingLearner` choses can be either based on:
95
- the best 'loss_improvements', the smallest total 'loss' of the
96
- child learners, or the number of points per learner, using 'npoints'.
95
+ """Can be either 'loss_improvements' (default), 'loss', 'npoints', or
96
+ 'cycle'. The points that the `BalancingLearner` choses can be either
97
+ based on: the best 'loss_improvements', the smallest total 'loss' of
98
+ the child learners, the number of points per learner, using 'npoints',
99
+ or by going through all learners one by one using 'cycle'.
97
100
One can dynamically change the strategy while the simulation is
98
101
running by changing the ``learner.strategy`` attribute."""
99
102
return self ._strategy
@@ -107,10 +110,13 @@ def strategy(self, strategy):
107
110
self ._ask_and_tell = self ._ask_and_tell_based_on_loss
108
111
elif strategy == "npoints" :
109
112
self ._ask_and_tell = self ._ask_and_tell_based_on_npoints
113
+ elif strategy == "cycle" :
114
+ self ._ask_and_tell = self ._ask_and_tell_based_on_cycle
115
+ self ._cycle = itertools .cycle (range (len (self .learners )))
110
116
else :
111
117
raise ValueError (
112
- 'Only strategy="loss_improvements", strategy="loss", or '
113
- ' strategy="npoints" is implemented.'
118
+ 'Only strategy="loss_improvements", strategy="loss",'
119
+ ' strategy="npoints", or strategy="cycle" is implemented.'
114
120
)
115
121
116
122
def _ask_and_tell_based_on_loss_improvements (self , n ):
@@ -173,6 +179,17 @@ def _ask_and_tell_based_on_npoints(self, n):
173
179
points , loss_improvements = map (list , zip (* selected ))
174
180
return points , loss_improvements
175
181
182
+ def _ask_and_tell_based_on_cycle (self , n ):
183
+ points , loss_improvements = [], []
184
+ for _ in range (n ):
185
+ index = next (self ._cycle )
186
+ point , loss_improvement = self .learners [index ].ask (n = 1 )
187
+ points .append ((index , point [0 ]))
188
+ loss_improvements .append (loss_improvement [0 ])
189
+ self .tell_pending ((index , point [0 ]))
190
+
191
+ return points , loss_improvements
192
+
176
193
def ask (self , n , tell_pending = True ):
177
194
"""Chose points for learners."""
178
195
if n == 0 :
0 commit comments