Skip to content

Commit

Permalink
Update docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
stevetorr committed Oct 14, 2019
1 parent 81f9ea5 commit f9d03cc
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions flare/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,16 @@ def force_list_to_np(forces: list):
return forces_np

def train(self, output=None, opt_param_override: dict = None):
"""Train Gaussian Process model on training data. Tunes the \
"""
Train Gaussian Process model on training data. Tunes the \
hyperparameters to maximize the likelihood, then computes L and alpha \
(related to the covariance matrix of the training set)."""
(related to the covariance matrix of the training set).
:param output: Output to write to
:param opt_param_override: Dictionary of parameters to override
instace's optimzation parameters.
:return:
"""
x_0 = self.hyps

args = (self.training_data, self.training_labels_np,
Expand All @@ -141,8 +147,8 @@ def train(self, output=None, opt_param_override: dict = None):

grad_tol = opt_param_temp.get('grad_tol', 1e-4)
x_tol = opt_param_temp.get('x_tol', 1e-5)
line_steps = opt_param_temp.get('maxls', 20)
max_iter = opt_param_temp.get('maxiter', self.maxiter)
line_steps = opt_param_temp.get('max_ls', 20)
max_iter = opt_param_temp.get('max_iter', self.maxiter)
algo = opt_param_temp.get('algorithm', self.algo)
disp = opt_param_temp.get('disp', False)

Expand All @@ -164,7 +170,6 @@ def train(self, output=None, opt_param_override: dict = None):
print("Warning! Algorithm for L-BFGS-B failed. Changing to "
"BFGS for remainder of run.")
self.opt_params['algorithm'] = 'BFGS'
algo = 'BFGS'

if opt_param_temp.get('custom_bounds', None) is not None:
res = minimize(get_neg_like_grad, x_0, args,
Expand Down

0 comments on commit f9d03cc

Please sign in to comment.