-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #386 from UCL-CCS/grid_search
Grid search
- Loading branch information
Showing
54 changed files
with
2,261 additions
and
269 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
import os | ||
#from string import Template | ||
# from string import Template | ||
from jinja2 import Template | ||
import logging | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
"""A grid sampler | ||
Useful for e.g. hyperparameter search. The "vary" dict contains the values | ||
that must be considered per (hyper)parameter, for instance: | ||
vary = {"x1": [0.0, 0.5, 0.1], | ||
"x2 = [1, 3], | ||
"x3" = [True, False]} | ||
The sampler will create a tensor grid using all specified 1D parameter | ||
values. | ||
""" | ||
|
||
__author__ = "Wouter Edeling" | ||
__copyright__ = """ | ||
Copyright 2018 Robin A. Richardson, David W. Wright | ||
This file is part of EasyVVUQ | ||
EasyVVUQ is free software: you can redistribute it and/or modify | ||
it under the terms of the Lesser GNU General Public License as published by | ||
the Free Software Foundation, either version 3 of the License, or | ||
(at your option) any later version. | ||
EasyVVUQ is distributed in the hope that it will be useful, | ||
but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
Lesser GNU General Public License for more details. | ||
You should have received a copy of the Lesser GNU General Public License | ||
along with this program. If not, see <https://www.gnu.org/licenses/>. | ||
""" | ||
__license__ = "LGPL" | ||
|
||
from itertools import product | ||
import numpy as np | ||
from .base import BaseSamplingElement # , Vary | ||
|
||
|
||
class Grid_Sampler(BaseSamplingElement, sampler_name="grid_sampler"): | ||
|
||
def __init__(self, vary, count=0): | ||
""" | ||
Initialize the grid sampler. | ||
Parameters | ||
---------- | ||
vary : dict, or list of dicts | ||
A dictionary containing all 1D values for each parameter. For instance | ||
vary = {"x1": [0.0, 0.5. 1.0], "x2": [True, False]}. This will | ||
create a 2D tensor product of all (x1, x2) parameter combinations. | ||
If a list of vary dicts is specified, each vary dict will be treated | ||
independently to generate points. These dicts do not have to contain | ||
the same parameters. The tensor product points are stored in the | ||
'points' list, with one tensor product per vary dict. | ||
count : int, optional | ||
Internal counter used to count the number of samples that have | ||
been executed. The default is 0. | ||
Returns | ||
------- | ||
None. | ||
""" | ||
# allways add vary to list, even if only a single dict is specified | ||
if not isinstance(vary, list): | ||
vary = [vary] | ||
|
||
self.vary = vary | ||
self.count = count | ||
self.points = [] | ||
|
||
# make sure all parameters are stored in a list or array, even | ||
# if they have only a single value | ||
for _vary in vary: | ||
for param in _vary.keys(): | ||
if not isinstance(_vary[param], list) and not isinstance(_vary[param], np.ndarray): | ||
vary[param] = [vary[param]] | ||
|
||
# use dtype=object to allow for multiple different type (float, boolean etc) | ||
self.points.append(np.array(list(product(*list(_vary.values()))), dtype=object)) | ||
|
||
# the cumulative sizes of all ensembles generated by the vary dicts | ||
self.cumul_sizes = np.cumsum([points.shape[0] for points in self.points]) | ||
# add a zero to the beginning (necessary in __next__ subroutine) | ||
self.cumul_sizes = np.insert(self.cumul_sizes, 0, 0) | ||
|
||
def is_finite(self): | ||
return True | ||
|
||
def n_samples(self): | ||
"""Returns the number of samples in this sampler. | ||
""" | ||
# return self.points.shape[0] | ||
return self.cumul_sizes[-1] | ||
|
||
def get_param_names(self): | ||
""" | ||
Get the names of all parameters that were varied. | ||
Returns | ||
------- | ||
param_names : list | ||
List of parameter names. | ||
""" | ||
param_names = [] | ||
for _vary in self.vary: | ||
for name in _vary.keys(): | ||
if not name in param_names: | ||
param_names.append(name) | ||
return param_names | ||
|
||
def __next__(self): | ||
""" | ||
Return the next sample from the input distributions. | ||
Raises | ||
------ | ||
StopIteration | ||
Stop iteration when count >= n_samples. | ||
Returns | ||
------- | ||
run_dict : dict | ||
A dictionary with the random input samples, e.g. | ||
{'x1': 0.5, 'x2': False}. | ||
""" | ||
if self.count < self.n_samples(): | ||
vary_idx = np.where(self.count < self.cumul_sizes[1:])[0][0] | ||
run_dict = {} | ||
i_par = 0 | ||
for param_name in self.vary[vary_idx].keys(): | ||
sample_idx = self.count - self.cumul_sizes[vary_idx] | ||
run_dict[param_name] = self.points[vary_idx][sample_idx][i_par] | ||
i_par += 1 | ||
self.count += 1 | ||
return run_dict | ||
else: | ||
raise StopIteration |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.