Skip to content

Commit

Permalink
assign type_map for all systems (#1033)
Browse files Browse the repository at this point in the history
Fix deepmodeling#1029.

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored Nov 16, 2022
1 parent 8c24997 commit c7e69cc
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 12 deletions.
2 changes: 1 addition & 1 deletion dpgen/generator/lib/run_calypso.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def gen_main(iter_index, jdata, mdata, caly_run_opt_list, gen_idx):
def analysis(iter_index, jdata, calypso_model_devi_path):
# Analysis

ms = dpdata.MultiSystems()
ms = dpdata.MultiSystems(type_map=jdata['type_map'])

cwd = os.getcwd()
iter_name = make_iter_name(iter_index)
Expand Down
4 changes: 2 additions & 2 deletions dpgen/generator/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3660,9 +3660,9 @@ def post_fp_amber_diff(iter_index, jdata):
for ss in system_index :
sys_output = glob.glob(os.path.join(work_path, "task.%s.*"%ss))
sys_output.sort()
all_sys=dpdata.MultiSystems()
all_sys=dpdata.MultiSystems(type_map=jdata['type_map'])
for oo in sys_output :
sys=dpdata.MultiSystems().from_deepmd_npy(os.path.join(oo, 'dataset'))
sys=dpdata.MultiSystems(type_map=jdata['type_map']).from_deepmd_npy(os.path.join(oo, 'dataset'))
all_sys.append(sys)
sys_data_path = os.path.join(work_path, 'data.%s'%ss)
all_sys.to_deepmd_raw(sys_data_path)
Expand Down
18 changes: 10 additions & 8 deletions dpgen/simplify/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def get_multi_system(path, jdata):
system = get_system_cls(jdata)
system_paths = expand_sys_str(path)
systems = dpdata.MultiSystems(
*[system(s, fmt=('deepmd/npy' if "#" not in s else 'deepmd/hdf5')) for s in system_paths])
*[system(s, fmt=('deepmd/npy' if "#" not in s else 'deepmd/hdf5')) for s in system_paths],
type_map=jdata['type_map'],
)
return systems


Expand Down Expand Up @@ -115,7 +117,7 @@ def init_pick(iter_index, jdata, mdata):


def _init_dump_selected_frames(systems, labels, selc_idx, sys_data_path, jdata):
selc_systems = dpdata.MultiSystems()
selc_systems = dpdata.MultiSystems(type_map=jdata['type_map'])
for j in selc_idx:
sys_name, sys_id = labels[j]
selc_systems.append(systems[sys_name][sys_id])
Expand Down Expand Up @@ -214,12 +216,12 @@ def post_model_devi(iter_index, jdata, mdata):
f_trust_lo = jdata['model_devi_f_trust_lo']
f_trust_hi = jdata['model_devi_f_trust_hi']

sys_accurate = dpdata.MultiSystems()
sys_candinate = dpdata.MultiSystems()
sys_failed = dpdata.MultiSystems()
type_map = jdata.get("type_map", [])
sys_accurate = dpdata.MultiSystems(type_map=type_map)
sys_candinate = dpdata.MultiSystems(type_map=type_map)
sys_failed = dpdata.MultiSystems(type_map=type_map)

labeled = jdata.get("labeled", False)
type_map = jdata.get("type_map", [])
sys_entire = dpdata.MultiSystems(type_map = type_map).from_deepmd_npy(os.path.join(work_path, rest_data_name + ".old"), labeled=labeled)

detail_file_name = detail_file_name_prefix
Expand Down Expand Up @@ -270,7 +272,7 @@ def post_model_devi(iter_index, jdata, mdata):
(counter['candidate'], len(pick_idx), float(len(pick_idx))/counter['candidate']*100., len(rest_idx), float(len(rest_idx))/counter['candidate']*100.))

# dump the picked candinate data
picked_systems = dpdata.MultiSystems()
picked_systems = dpdata.MultiSystems(type_map = type_map)
for j in pick_idx:
sys_name, sys_id = labels[j]
picked_systems.append(sys_candinate[sys_name][sys_id])
Expand All @@ -280,7 +282,7 @@ def post_model_devi(iter_index, jdata, mdata):


# dump the rest data (not picked candinate data and failed data)
rest_systems = dpdata.MultiSystems()
rest_systems = dpdata.MultiSystems(type_map = type_map)
for j in rest_idx:
sys_name, sys_id = labels[j]
rest_systems.append(sys_candinate[sys_name][sys_id])
Expand Down
3 changes: 2 additions & 1 deletion tests/generator/test_post_fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,9 @@ def setUp(self):
self.system_1 = list(ms.systems.values())[0]
with open (param_amber_file, 'r') as fp :
jdata = json.load (fp)
jdata['type_map'] = self.system_1.get_atom_names()
post_fp(0, jdata)
self.system_2 = list(dpdata.MultiSystems().from_deepmd_raw('iter.000000/02.fp/data.000').systems.values())[0]
self.system_2 = list(dpdata.MultiSystems(type_map = jdata['type_map']).from_deepmd_raw('iter.000000/02.fp/data.000').systems.values())[0]


if __name__ == '__main__':
Expand Down

0 comments on commit c7e69cc

Please sign in to comment.