Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added & unit tested store_dft_output #123

Merged
merged 2 commits into from
Dec 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 43 additions & 8 deletions flare/otf.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
import sys
import numpy as np
import datetime
import time
from typing import List
import copy
import multiprocessing as mp
import subprocess
from shutil import copyfile
from typing import List, Tuple, Union
from datetime import datetime

import flare.predict as predict
from flare import struc, gp, env, md
from flare.dft_interface import dft_software
from flare.output import Output
import flare.predict as predict
from flare.util import is_std_in_bound


class OTF:
"""Trains a Gaussian process force field on the fly.
"""Trains a Gaussian process force field on the fly during
molecular dynamics.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slightly more succinct: "...on the fly during molecular dynamics."

Args:
dft_input (str): Input file.
Expand Down Expand Up @@ -59,8 +62,18 @@ class OTF:
calculations. Defaults to None.
mpi (str, optional): Determines how mpi is called. Defaults to
"srun".
dft_kwargs ([type], optional): Additional DFT arguments. Defaults
to None.
dft_kwargs ([type], optional): Additional arguments which are
passed when DFT is called; keyword arguments vary based on the
program (e.g. ESPRESSO vs. VASP). Defaults to None.
store_dft_output (Tuple[Union[str,List[str]],str], optional):
After DFT calculations are called, copy the file or files
specified in the first element of the tuple to a directory
specified as the second element of the tuple.
Useful when DFT calculations are expensive and want to be kept
for later use. The first element of the tuple can either be a
single file name, or a list of several. Copied files will be
prepended with the date and time with the format
'Year.Month.Day:Hour:Minute:Second:'.
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the first element of the tuple a list of strings? Or just a string? It looks like you've set it up to work with both, but it might be helpful to explicitly state that in the docstring. (You make this very clear in the constructor by giving the type as "Tuple[Union[str,List[str]],str]". Consider doing the same here, so that it shows up on RTD.)

def __init__(self, dft_input: str, dt: float, number_of_steps: int,
gp: gp.GaussianProcess, dft_loc: str,
Expand All @@ -72,7 +85,8 @@ def __init__(self, dft_input: str, dt: float, number_of_steps: int,
rescale_steps: List[int] = [], rescale_temps: List[int] = [],
dft_softwarename: str = "qe",
no_cpus: int = 1, npool: int = None, mpi: str = "srun",
dft_kwargs=None):
dft_kwargs=None,
store_dft_output: Tuple[Union[str,List[str]],str] = None):

self.dft_input = dft_input
self.dt = dt
Expand Down Expand Up @@ -140,9 +154,16 @@ def __init__(self, dft_input: str, dt: float, number_of_steps: int,
self.mpi = mpi

self.dft_kwargs = dft_kwargs
self.store_dft_output = store_dft_output

def run(self):
"""Performs an on-the-fly training run."""
"""
Performs an on-the-fly training run.

If OTF has store_dft_output set, then the specified DFT files will
be copied with the current date and time prepended in the format
'Year.Month.Day:Hour:Minute:Second:'.
"""

self.output.write_header(self.gp.cutoffs, self.gp.kernel_name,
self.gp.hyps, self.gp.algo,
Expand Down Expand Up @@ -212,6 +233,20 @@ def run(self):
if (self.dft_count-1) < self.freeze_hyps:
self.train_gp()

# Store DFT outputs in another folder if desired
# specified in self.store_dft_output
if self.store_dft_output is not None:
dest = self.store_dft_output[1]
target_files = self.store_dft_output[0]
now = datetime.now()
dt_string = now.strftime("%Y.%m.%d:%H:%M:%S:")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a good way of labeling the output files. We should write the labeling convention in the docstring for the "run" method.

if isinstance(target_files, str):
to_copy = [target_files]
else:
to_copy = target_files
for file in to_copy:
copyfile(file, dest+'/'+dt_string+file)

# write gp forces
if counter >= self.skip and not self.dft_step:
self.update_temperature(new_pos)
Expand Down
5 changes: 3 additions & 2 deletions tests/test_OTF_qe.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,13 @@ def test_otf_h2():
otf = OTF(qe_input, dt, number_of_steps, gp, dft_loc,
std_tolerance_factor, init_atoms=[0],
calculate_energy=True, max_atoms_added=1,
output_name='h2_otf_qe')
output_name='h2_otf_qe',
store_dft_output=(['pwscf.out', 'pwscf.in'], '.'))

otf.run()
os.system('mkdir test_outputs')
os.system('mv h2_otf_qe* test_outputs')
cleanup_espresso_run()
cleanup_espresso_run("{*pwscf.out,*pwscf.in}")

@pytest.mark.skipif(not os.environ.get('PWSCF_COMMAND',
False), reason='PWSCF_COMMAND not found '
Expand Down
5 changes: 3 additions & 2 deletions tests/test_OTF_qe_par.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,13 @@ def test_otf_h2():
calculate_energy=True, max_atoms_added=1,
no_cpus=2, par=True,
mpi="mpi",
output_name='h2_otf_qe_par')
output_name='h2_otf_qe_par',
store_dft_output=('pwscf.out', '.'))

otf.run()
os.system('mkdir test_outputs')
os.system('mv h2_otf_qe_par* test_outputs')
cleanup_espresso_run()
cleanup_espresso_run("*pwscf.out")

@pytest.mark.skipif(not os.environ.get('PWSCF_COMMAND',
False), reason='PWSCF_COMMAND not found '
Expand Down