Skip to content

Commit f4339da

Browse files
authored
Merge pull request #997 from rhayes777/feature/parallel_initializer
feature/parallel initializer
2 parents cea43bd + d296a74 commit f4339da

File tree

8 files changed

+236
-177
lines changed

8 files changed

+236
-177
lines changed

autofit/non_linear/initializer.py

+47-22
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
import os
44
import random
55
from abc import ABC, abstractmethod
6-
from typing import Dict, Tuple, List
6+
from typing import Dict, Tuple, List, Optional
77

88
import numpy as np
99

1010
from autofit import exc
11+
from autofit.non_linear.paths.abstract import AbstractPaths
1112
from autofit.mapper.prior.abstract import Prior
1213
from autofit.mapper.prior_model.abstract import AbstractPriorModel
14+
from autofit.non_linear.parallel import SneakyPool
1315

1416
logger = logging.getLogger(__name__)
1517

@@ -23,13 +25,28 @@ class AbstractInitializer(ABC):
2325
def _generate_unit_parameter_list(self, model):
2426
pass
2527

28+
@staticmethod
29+
def figure_of_metric(args) -> Optional[float]:
30+
fitness, parameter_list = args
31+
try:
32+
figure_of_merit = fitness(parameters=parameter_list)
33+
34+
if np.isnan(figure_of_merit) or figure_of_merit < -1e98:
35+
return None
36+
37+
return figure_of_merit
38+
except exc.FitException:
39+
return None
40+
2641
def samples_from_model(
2742
self,
2843
total_points: int,
2944
model: AbstractPriorModel,
3045
fitness,
46+
paths: AbstractPaths,
3147
use_prior_medians: bool = False,
3248
test_mode_samples: bool = True,
49+
n_cores: int = 1,
3350
):
3451
"""
3552
Generate the initial points of the non-linear search, by randomly drawing unit values from a uniform
@@ -55,31 +72,39 @@ def samples_from_model(
5572
parameter_lists = []
5673
figures_of_merit_list = []
5774

58-
point_index = 0
59-
60-
while point_index < total_points:
61-
if not use_prior_medians:
62-
unit_parameter_list = self._generate_unit_parameter_list(model)
63-
else:
64-
unit_parameter_list = [0.5] * model.prior_count
65-
66-
parameter_list = model.vector_from_unit_vector(
67-
unit_vector=unit_parameter_list
68-
)
75+
sneaky_pool = SneakyPool(n_cores, fitness, paths)
6976

70-
try:
71-
figure_of_merit = fitness(parameters=parameter_list)
77+
while len(figures_of_merit_list) < total_points:
78+
remaining_points = total_points - len(figures_of_merit_list)
79+
batch_size = min(remaining_points, n_cores)
80+
parameter_lists_ = []
81+
unit_parameter_lists_ = []
7282

73-
if np.isnan(figure_of_merit) or figure_of_merit < -1e98:
74-
raise exc.FitException
83+
for _ in range(batch_size):
84+
if not use_prior_medians:
85+
unit_parameter_list = self._generate_unit_parameter_list(model)
86+
else:
87+
unit_parameter_list = [0.5] * model.prior_count
7588

76-
unit_parameter_lists.append(unit_parameter_list)
77-
parameter_lists.append(parameter_list)
78-
figures_of_merit_list.append(figure_of_merit)
79-
point_index += 1
80-
except exc.FitException:
81-
pass
89+
parameter_list = model.vector_from_unit_vector(
90+
unit_vector=unit_parameter_list
91+
)
8292

93+
parameter_lists_.append(parameter_list)
94+
unit_parameter_lists_.append(unit_parameter_list)
95+
96+
for figure_of_merit, unit_parameter_list, parameter_list in zip(
97+
sneaky_pool.map(
98+
self.figure_of_metric,
99+
[(fitness, parameter_list) for parameter_list in parameter_lists_],
100+
),
101+
unit_parameter_lists_,
102+
parameter_lists_,
103+
):
104+
if figure_of_merit is not None:
105+
unit_parameter_lists.append(unit_parameter_list)
106+
parameter_lists.append(parameter_list)
107+
figures_of_merit_list.append(figure_of_merit)
83108

84109
if total_points > 1 and np.allclose(
85110
a=figures_of_merit_list[0], b=figures_of_merit_list[1:]

autofit/non_linear/search/mcmc/emcee/search.py

+17-13
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def _fit(self, model: AbstractPriorModel, analysis):
105105
analysis=analysis,
106106
paths=self.paths,
107107
fom_is_log_likelihood=False,
108-
resample_figure_of_merit=-np.inf
108+
resample_figure_of_merit=-np.inf,
109109
)
110110

111111
pool = self.make_sneaky_pool(fitness)
@@ -147,6 +147,8 @@ def _fit(self, model: AbstractPriorModel, analysis):
147147
total_points=search_internal.nwalkers,
148148
model=model,
149149
fitness=fitness,
150+
paths=self.paths,
151+
n_cores=self.number_of_cores,
150152
)
151153

152154
state = np.zeros(shape=(search_internal.nwalkers, model.prior_count))
@@ -184,17 +186,19 @@ def _fit(self, model: AbstractPriorModel, analysis):
184186
samples = self.samples_from(model=model, search_internal=search_internal)
185187

186188
if self.auto_correlation_settings.check_for_convergence:
187-
if search_internal.iteration > self.auto_correlation_settings.check_size:
189+
if (
190+
search_internal.iteration
191+
> self.auto_correlation_settings.check_size
192+
):
188193
if samples.converged:
189194
iterations_remaining = 0
190195

191196
if iterations_remaining > 0:
192-
193197
self.perform_update(
194198
model=model,
195199
analysis=analysis,
196200
search_internal=search_internal,
197-
during_analysis=True
201+
during_analysis=True,
198202
)
199203

200204
return search_internal
@@ -214,7 +218,6 @@ def output_search_internal(self, search_internal):
214218
pass
215219

216220
def samples_info_from(self, search_internal=None):
217-
218221
search_internal = search_internal or self.backend
219222

220223
auto_correlations = self.auto_correlations_from(search_internal=search_internal)
@@ -225,7 +228,7 @@ def samples_info_from(self, search_internal=None):
225228
"change_threshold": auto_correlations.change_threshold,
226229
"total_walkers": len(search_internal.get_chain()[0, :, 0]),
227230
"total_steps": len(search_internal.get_log_prob()),
228-
"time": self.timer.time if self.timer else None
231+
"time": self.timer.time if self.timer else None,
229232
}
230233

231234
def samples_via_internal_from(self, model, search_internal=None):
@@ -247,14 +250,14 @@ def samples_via_internal_from(self, model, search_internal=None):
247250
search_internal = search_internal or self.backend
248251

249252
if os.environ.get("PYAUTOFIT_TEST_MODE") == "1":
250-
251253
samples_after_burn_in = search_internal.get_chain(
252-
discard=5, thin=5, flat=True
253-
)
254+
discard=5, thin=5, flat=True
255+
)
254256

255257
else:
256-
257-
auto_correlations = self.auto_correlations_from(search_internal=search_internal)
258+
auto_correlations = self.auto_correlations_from(
259+
search_internal=search_internal
260+
)
258261

259262
discard = int(3.0 * np.max(auto_correlations.times))
260263
thin = int(np.max(auto_correlations.times) / 2.0)
@@ -292,11 +295,12 @@ def samples_via_internal_from(self, model, search_internal=None):
292295
sample_list=sample_list,
293296
samples_info=self.samples_info_from(search_internal=search_internal),
294297
auto_correlation_settings=self.auto_correlation_settings,
295-
auto_correlations=self.auto_correlations_from(search_internal=search_internal),
298+
auto_correlations=self.auto_correlations_from(
299+
search_internal=search_internal
300+
),
296301
)
297302

298303
def auto_correlations_from(self, search_internal=None):
299-
300304
search_internal = search_internal or self.backend
301305

302306
times = search_internal.get_autocorr_time(tol=0)

0 commit comments

Comments
 (0)