Skip to content

Commit

Permalink
Merge pull request #136 from mir-group/ase_interface
Browse files Browse the repository at this point in the history
Ase interface - close #133
  • Loading branch information
nw13slx authored Feb 13, 2020
2 parents 4685126 + 6615cfa commit 6f0dd93
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 44 deletions.
22 changes: 12 additions & 10 deletions flare/ase/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,30 +28,32 @@ def __init__(self, gp_model, mgp_model=None, par=False, use_mapping=False):
self.par = par
self.results = {}

def get_property(self, atoms, property_name):
if property_name not in self.results.keys():
def get_property(self, name, atoms=None, allow_calculation=True):
if name not in self.results.keys():
if not allow_calculation:
return None
self.calculate(atoms)
return self.results[property_name]
return self.results[name]


def get_potential_energy(self, atoms=None, force_consistent=False):
if self.use_mapping:
print('MGP energy mapping not implemented, temporarily set to 0')
return self.get_property(atoms, 'energy')


def get_forces(self, atoms):
return self.get_property(atoms, 'forces')
return self.get_property('energy', atoms)
def get_forces(self, atoms):
return self.get_property('forces', atoms)


def get_stress(self, atoms):
if not self.use_mapping:
raise NotImplementedError("Stress is only supported in MGP")
return self.get_property(atoms, 'stress')
return self.get_property('stress', atoms)


def get_uncertainties(self, atoms):
return self.get_property(atoms, 'stds')
return self.get_property('stds', atoms)


def calculate(self, atoms):
Expand Down
2 changes: 1 addition & 1 deletion flare/ase/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def write_header_info(self):
self.logfile.write('\nhyperparameter optimization algorithm: ' +
gp_model.algo)
self.logfile.write('\nuncertainty tolerance: {} times noise'.format(
str(self.dyn.std_tolerance)))
str(self.dyn.std_tolerance_factor)))
self.logfile.write('\ntimestep (ps): {}'.format(self.dyn.dt/1000))
self.logfile.write('\nnumber of frames: {}'.format(0))
self.logfile.write('\nnumber of atoms: {}'.format(
Expand Down
46 changes: 21 additions & 25 deletions flare/ase/otf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
'''
import os
import sys
import inspect
from copy import deepcopy

from flare.struc import Structure
Expand Down Expand Up @@ -60,21 +61,14 @@ def __init__(self,
use_mapping: bool=False, non_mapping_steps: list=[],
l_bound: float=None, two_d: bool=False):

self.dft_calc = dft_calc
# get all arguments as attributes
arg_dict = inspect.getargvalues(inspect.currentframe())[3]
del arg_dict['self']
self.__dict__.update(arg_dict)

if dft_count is None:
self.dft_count = 0
else:
self.dft_count = dft_count
self.std_tolerance = std_tolerance_factor
self.noa = len(self.atoms.positions)
self.max_atoms_added = max_atoms_added
self.freeze_hyps = freeze_hyps

# params for mapped force field
self.use_mapping = use_mapping
self.non_mapping_steps = non_mapping_steps
self.l_bound = l_bound
self.two_d = two_d

# initialize local energies
if calculate_energy:
Expand All @@ -85,11 +79,6 @@ def __init__(self,
# initialize otf
if init_atoms is None:
self.init_atoms = [int(n) for n in range(self.noa)]
else:
self.init_atoms = init_atoms

# restart mode
self.restart_from = restart_from

def otf_run(self, steps, rescale_temp=[], rescale_steps=[]):
"""
Expand All @@ -109,6 +98,12 @@ def otf_run(self, steps, rescale_temp=[], rescale_steps=[]):
rescale_steps = [100, 200]
"""

# observers
for i, obs in enumerate(self.observers):
if obs[0].__class__.__name__ == "OTFLogger":
self.logger_ind = i
break

# restart from previous OTF training
if self.restart_from is not None:
self.restart()
Expand All @@ -132,9 +127,10 @@ def otf_run(self, steps, rescale_temp=[], rescale_steps=[]):
# train calculator
for atom in self.init_atoms:
# the observers[0][0] is the logger
self.observers[0][0].add_atom_info(atom, self.stds[atom])
self.observers[self.logger_ind][0].add_atom_info(atom,
self.stds[atom])
self.train()
self.observers[0][0].write_wall_time()
self.observers[self.logger_ind][0].write_wall_time()

if self.md_engine == 'NPT':
if not self.initialized:
Expand Down Expand Up @@ -173,7 +169,7 @@ def otf_run(self, steps, rescale_temp=[], rescale_steps=[]):
curr_struc.stds = np.copy(self.stds)
noise = self.atoms.calc.gp_model.hyps[-1]
self.std_in_bound, self.target_atoms = is_std_in_bound(\
noise, self.std_tolerance, curr_struc, self.max_atoms_added)
noise, self.std_tolerance_factor, curr_struc, self.max_atoms_added)

print('std in bound:', self.std_in_bound, self.target_atoms)
#self.is_std_in_bound([])
Expand All @@ -187,7 +183,7 @@ def otf_run(self, steps, rescale_temp=[], rescale_steps=[]):
print('updating gp')
self.update_GP(dft_forces)

self.observers[0][0].run_complete()
self.observers[self.logger_ind][0].run_complete()


def call_DFT(self):
Expand Down Expand Up @@ -229,22 +225,22 @@ def update_GP(self, dft_forces):

# write added atom to the log file,
# refer to ase.optimize.optimize.Dynamics
self.observers[0][0].add_atom_info(target_atom,
self.observers[self.logger_ind][0].add_atom_info(target_atom,
self.stds[target_atom])

#self.is_std_in_bound(atom_list)
atom_count += 1

self.train()
self.observers[0][0].added_atoms_dat.write('\n')
self.observers[0][0].write_wall_time()
self.observers[self.logger_ind][0].added_atoms_dat.write('\n')
self.observers[self.logger_ind][0].write_wall_time()

def train(self, output=None, skip=False):
calc = self.atoms.calc
if (self.dft_count-1) < self.freeze_hyps:
#TODO: add other args to train()
calc.gp_model.train(output=output)
self.observers[0][0].write_hyps(calc.gp_model.hyp_labels,
self.observers[self.logger_ind][0].write_hyps(calc.gp_model.hyp_labels,
calc.gp_model.hyps, calc.gp_model.likelihood,
calc.gp_model.likelihood_gradient)
else:
Expand Down
8 changes: 4 additions & 4 deletions flare/ase/otf_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def __init__(self, atoms, timestep=None, trajectory=None, dt=None,
**kwargs):

VelocityVerlet.__init__(self, atoms, timestep, trajectory, dt=dt)

OTF.__init__(self, **kwargs)

self.md_engine = 'VelocityVerlet'
Expand Down Expand Up @@ -120,9 +119,10 @@ class OTF_Langevin(Langevin, OTF):
"""

def __init__(self, atoms, timestep=None, temperature=None, friction=None,
trajectory=None, **kwargs):
fixcm=True, trajectory=None, **kwargs):

Langevin.__init__(self, atoms, timestep, temperature, friction)
Langevin.__init__(self, atoms, timestep, temperature, friction,
fixcm, trajectory)

OTF.__init__(self, **kwargs)

Expand Down Expand Up @@ -190,7 +190,7 @@ def otf_md(md_engine: str, atoms, md_params: dict, otf_params: dict):

elif md_engine == 'Langevin':
return OTF_Langevin(atoms, timestep, md['temperature'],
md['friction'], trajectory, **otf_params)
md['friction'], md['fixcm'], trajectory, **otf_params)

else:
raise NotImplementedError(md_engine+' is not implemented')
Expand Down
10 changes: 6 additions & 4 deletions tests/test_ase_setup/test_otf.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def otf_md_test(md_engine):

# ----------- create otf object -----------
# set up OTF MD engine
md_params = {'timestep': 1 * units.fs, 'trajectory': None, 'dt': None,
md_params = {'timestep': 1 * units.fs, 'trajectory': 'otf_md.traj', 'dt': None,
'externalstress': 0, 'ttime': 25, 'pfactor': 3375,
'mask': None, 'temperature': 500, 'taut': 1, 'taup': 1,
'pressure': 0, 'compressibility': 0, 'fixcm': 1,
Expand All @@ -50,10 +50,12 @@ def otf_md_test(md_engine):

test_otf = otf_md(md_engine, super_cell, md_params, otf_params)

print(test_otf.observers)

# set up logger
test_otf.attach(OTFLogger(test_otf, super_cell,
logfile=md_engine+'.log', mode="w", data_in_logfile=True),
interval=1)
otf_logger = OTFLogger(test_otf, super_cell,
logfile=md_engine+'.log', mode="w", data_in_logfile=True)
test_otf.attach(otf_logger, interval=1)

# run otf
number_of_steps = 3
Expand Down

0 comments on commit 6f0dd93

Please sign in to comment.