Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support merge_traj for lammps #838

Merged
merged 12 commits into from
Aug 30, 2022
3 changes: 3 additions & 0 deletions dpgen/generator/arginfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def model_devi_lmp_args() -> List[Argument]:
doc_model_devi_perc_candi_v = 'See model_devi_adapt_trust_lo.'
doc_model_devi_f_avg_relative = 'Normalized the force model deviations by the RMS force magnitude along the trajectory. This key should not be used with use_relative.'
doc_model_devi_clean_traj = 'If type of model_devi_clean_traj is bool type then it denote whether to clean traj folders in MD since they are too large. If it is Int type, then the most recent n iterations of traj folders will be retained, others will be removed.'
doc_model_devi_merge_traj = 'If model_devi_merge_traj is set as True, only all.lammpstrj will be generated, instead of lots of small traj files.'
doc_model_devi_nopbc = 'Assume open boundary condition in MD simulations.'
doc_model_devi_activation_func = 'Set activation functions for models, length of the list should be the same as numb_models, and two elements in the list of string respectively assign activation functions to the embedding and fitting nets within each model. Backward compatibility: the orginal "list of String" format is still supported, where embedding and fitting nets of one model use the same activation function, and the length of the list should be the same as numb_models.'
doc_shuffle_poscar = 'Shuffle atoms of each frame before running simulations. The purpose is to sample the element occupation of alloys.'
Expand Down Expand Up @@ -194,6 +195,8 @@ def model_devi_lmp_args() -> List[Argument]:
doc=doc_model_devi_f_avg_relative),
Argument("model_devi_clean_traj", [
bool, int], optional=False, doc=doc_model_devi_clean_traj),
Argument("model_devi_merge_traj", [
bool], optional=False, doc=doc_model_devi_merge_traj),
njzjz marked this conversation as resolved.
Show resolved Hide resolved
Argument("model_devi_nopbc", bool, optional=True, default=False,
doc=doc_model_devi_nopbc),
Argument("model_devi_activation_func", list, optional=True,
Expand Down
61 changes: 60 additions & 1 deletion dpgen/generator/lib/lammps.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ def make_lammps_input(ensemble,
ret+= "\n"
ret+= "thermo_style custom step temp pe ke etotal press vol lx ly lz xy xz yz\n"
ret+= "thermo ${THERMO_FREQ}\n"
ret+= "dump 1 all custom ${DUMP_FREQ} traj/*.lammpstrj id type x y z fx fy fz\n"
model_devi_merge_traj = jdata.get('model_devi_merge_traj', False)
if(model_devi_merge_traj is True):
ret+= "dump 1 all custom ${DUMP_FREQ} all.lammpstrj id type x y z fx fy fz\n"
else:
ret+= "dump 1 all custom ${DUMP_FREQ} traj/*.lammpstrj id type x y z fx fy fz\n"
ret+= "restart 10000 dpgen.restart\n"
ret+= "\n"
if pka_e is None :
Expand Down Expand Up @@ -167,6 +171,61 @@ def get_dumped_forces(
ret = np.array(ret)
return ret

def get_all_dumped_forces(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A UT for this function should be provided.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Chengqian-Zhang You can directly pull request to HuangJiameng:merge_traj branch, and then this PR will be updated.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wanghan-iapcm @AnguseZhang @njzjz The UT for this function has been provided.Please check.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no problems on this PR.

file_name):
with open(file_name) as fp:
lines = fp.read().split('\n')

ret = []
exist_natoms = False
exist_atoms = False

for idx,ii in enumerate(lines):

if 'ITEM: NUMBER OF ATOMS' in ii:
natoms = int(lines[idx+1])
exist_natoms = True

if 'ITEM: ATOMS' in ii:
keys = ii
keys = keys.replace('ITEM: ATOMS', '')
keys = keys.split()
idfx = keys.index('fx')
idfy = keys.index('fy')
idfz = keys.index('fz')
exist_atoms = True

single_traj = []
for jj in range(idx+1, idx+natoms+1):
words = lines[jj].split()
single_traj.append([ float(words[jj]) for jj in [idfx, idfy, idfz] ])
single_traj = np.array(single_traj)
ret.append(single_traj)

if exist_natoms is False:
raise RuntimeError('wrong dump file format, cannot find number of atoms', file_name)
if exist_atoms is False:
raise RuntimeError('wrong dump file format, cannot find dump keys', file_name)
return ret

def generate_single_traj(all_traj, traj_ind, single_traj):
with open(all_traj) as all_traj_fp:
lines = all_traj_fp.read().split('\n')
HuangJiameng marked this conversation as resolved.
Show resolved Hide resolved
single_traj_fp = open(single_traj, "w")

time_step = None
get_traj = False
for idx,ii in enumerate(lines):
if 'ITEM: TIMESTEP' in ii:
if(get_traj is True):
HuangJiameng marked this conversation as resolved.
Show resolved Hide resolved
break
time_step = int(lines[idx+1])
if(time_step == traj_ind):
single_traj_fp.write('ITEM: TIMESTEP\n')
single_traj_fp.write(str(time_step) + '\n')
get_traj = True
elif(get_traj is True):
single_traj_fp.write(ii + '\n')

if __name__ == '__main__':
ret = get_dumped_forces('40.lammpstrj')
Expand Down
32 changes: 23 additions & 9 deletions dpgen/generator/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from dpgen.generator.lib.utils import record_iter
from dpgen.generator.lib.utils import log_task
from dpgen.generator.lib.utils import symlink_user_forward_files
from dpgen.generator.lib.lammps import make_lammps_input, get_dumped_forces
from dpgen.generator.lib.lammps import make_lammps_input, get_dumped_forces, get_all_dumped_forces, generate_single_traj
from dpgen.generator.lib.make_calypso import _make_model_devi_native_calypso,_make_model_devi_buffet
from dpgen.generator.lib.run_calypso import gen_structures,analysis,run_calypso_model_devi
from dpgen.generator.lib.parse_calypso import _parse_calypso_input,_parse_calypso_dis_mtx
Expand Down Expand Up @@ -1626,17 +1626,22 @@ def check_bad_box(conf_name,
raise RuntimeError('unknow key', key)
return is_bad


def _read_model_devi_file(
task_path : str,
model_devi_f_avg_relative : bool = False
model_devi_f_avg_relative : bool = False,
model_devi_merge_traj : bool = False
):
model_devi = np.loadtxt(os.path.join(task_path, 'model_devi.out'))
if model_devi_f_avg_relative :
trajs = glob.glob(os.path.join(task_path, 'traj', '*.lammpstrj'))
all_f = []
for ii in trajs:
all_f.append(get_dumped_forces(ii))
if(model_devi_merge_traj is True) :
all_traj = glob.glob(os.path.join(task_path, 'all.lammpstrj'))
all_f = get_all_dumped_forces(all_traj)
HuangJiameng marked this conversation as resolved.
Show resolved Hide resolved
else :
trajs = glob.glob(os.path.join(task_path, 'traj', '*.lammpstrj'))
all_f = []
for ii in trajs:
all_f.append(get_dumped_forces(ii))

all_f = np.array(all_f)
all_f = all_f.reshape([-1,3])
avg_f = np.sqrt(np.average(np.sum(np.square(all_f), axis = 1)))
Expand All @@ -1655,6 +1660,7 @@ def _select_by_model_devi_standard(
model_devi_engine : str,
model_devi_skip : int = 0,
model_devi_f_avg_relative : bool = False,
model_devi_merge_traj : bool = False,
detailed_report_make_fp : bool = True,
):
if model_devi_engine == 'calypso':
Expand All @@ -1675,7 +1681,7 @@ def _select_by_model_devi_standard(
for tt in modd_system_task :
with warnings.catch_warnings():
warnings.simplefilter("ignore")
all_conf = _read_model_devi_file(tt, model_devi_f_avg_relative)
all_conf = _read_model_devi_file(tt, model_devi_f_avg_relative, model_devi_merge_traj)

if all_conf.shape == (7,):
all_conf = all_conf.reshape(1,all_conf.shape[0])
Expand Down Expand Up @@ -1739,6 +1745,7 @@ def _select_by_model_devi_adaptive_trust_low(
perc_candi_v : float,
model_devi_skip : int = 0,
model_devi_f_avg_relative : bool = False,
model_devi_merge_traj : bool = False,
):
"""
modd_system_task model deviation tasks belonging to one system
Expand Down Expand Up @@ -1769,7 +1776,7 @@ def _select_by_model_devi_adaptive_trust_low(
with warnings.catch_warnings():
warnings.simplefilter("ignore")
model_devi = np.loadtxt(os.path.join(tt, 'model_devi.out'))
model_devi = _read_model_devi_file(tt, model_devi_f_avg_relative)
model_devi = _read_model_devi_file(tt, model_devi_f_avg_relative, model_devi_merge_traj)
for ii in range(model_devi.shape[0]) :
if model_devi[ii][0] < model_devi_skip :
continue
Expand Down Expand Up @@ -1892,6 +1899,7 @@ def _make_fp_vasp_inner (modd_path,
cluster_cutoff = jdata.get('cluster_cutoff', None)
model_devi_adapt_trust_lo = jdata.get('model_devi_adapt_trust_lo', False)
model_devi_f_avg_relative = jdata.get('model_devi_f_avg_relative', False)
model_devi_merge_traj = jdata.get('model_devi_merge_traj', False)
# skip save *.out if detailed_report_make_fp is False, default is True
detailed_report_make_fp = jdata.get("detailed_report_make_fp", True)
# skip bad box criteria
Expand Down Expand Up @@ -1930,6 +1938,7 @@ def _trust_limitation_check(sys_idx, lim):
model_devi_engine,
model_devi_skip,
model_devi_f_avg_relative = model_devi_f_avg_relative,
model_devi_merge_traj = model_devi_merge_traj,
detailed_report_make_fp = detailed_report_make_fp,
)
else:
Expand All @@ -1944,6 +1953,7 @@ def _trust_limitation_check(sys_idx, lim):
v_trust_hi_sys, numb_candi_v, perc_candi_v,
model_devi_skip = model_devi_skip,
model_devi_f_avg_relative = model_devi_f_avg_relative,
model_devi_merge_traj = model_devi_merge_traj,
)
dlog.info("system {0:s} {1:9s} : f_trust_lo {2:6.3f} v_trust_lo {3:6.3f}".format(ss, 'adapted', f_trust_lo_ad, v_trust_lo_ad))
elif model_devi_engine == "amber":
Expand Down Expand Up @@ -2052,6 +2062,10 @@ def _trust_limitation_check(sys_idx, lim):
ss = os.path.basename(tt).split('.')[1]
conf_name = os.path.join(tt, "traj")
if model_devi_engine == "lammps":
if(model_devi_merge_traj is True):
all_traj = os.path.join(tt, 'all.lammpstrj')
single_traj = os.path.join(conf_name, str(ii) + '.lammpstrj')
generate_single_traj(all_traj, int(str(ii)), single_traj)
conf_name = os.path.join(conf_name, str(ii) + '.lammpstrj')
ffmt = 'lammps/dump'
elif model_devi_engine == "gromacs":
Expand Down