Skip to content

Commit

Permalink
Add normalize part
Browse files Browse the repository at this point in the history
  • Loading branch information
Chengqian-Zhang committed Sep 25, 2023
1 parent 36a90e1 commit 1c72709
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
13 changes: 8 additions & 5 deletions deepmd_pt/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,11 @@ def single_preprocess(self, batch, sid):
return batch
else:
batch['clean_type'] = clean_type
batch['clean_coord'] = clean_coord
if self.pbc:
_clean_coord = normalize_coord(clean_coord, region, nloc)
else:
_clean_coord = clean_coord.clone()
batch['clean_coord'] = _clean_coord
# add noise
for i in range(self.max_fail_num):
mask_num = 0
Expand Down Expand Up @@ -587,13 +591,13 @@ def single_preprocess(self, batch, sid):
)
else:
NotImplementedError(f"Unknown noise type {self.noise_type}!")
noised_coord = clean_coord.clone().detach()
noised_coord = _clean_coord.clone().detach()
noised_coord[coord_mask] += noise_on_coord
batch['coord_mask'] = torch.tensor(coord_mask,
dtype=torch.bool,
device=env.PREPROCESS_DEVICE)
else:
noised_coord = clean_coord
noised_coord = _clean_coord
batch['coord_mask'] = torch.tensor(np.zeros_like(coord_mask, dtype=np.bool),
dtype=torch.bool,
device=env.PREPROCESS_DEVICE)
Expand All @@ -611,7 +615,7 @@ def single_preprocess(self, batch, sid):
dtype=torch.bool,
device=env.PREPROCESS_DEVICE)
if self.pbc:
_coord = normalize_coord(noised_coord, region, nloc)
_coord = region.move_noised_coord_all_in_box(noised_coord, _clean_coord)
else:
_coord = noised_coord.clone()
batch['coord'] = _coord
Expand All @@ -626,7 +630,6 @@ def single_preprocess(self, batch, sid):
RuntimeError(f"Add noise times beyond max tries {self.max_fail_num}!")
continue
batch['atype'] = masked_type
batch['coord'] = noised_coord
batch['nlist'] = nlist
batch['nlist_loc'] = nlist_loc
batch['nlist_type'] = nlist_type
Expand Down
11 changes: 10 additions & 1 deletion deepmd_pt/utils/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,16 @@ def get_face_distance(self):
"""Return face distinces to each surface of YZ, ZX, XY."""
return torch.stack([self._h2yz, self._h2zx, self._h2xy])


def move_noised_coord_all_in_box(self, noised_coord, clean_coord):
"""Ensure all noised atoms are inside region"""
tmp_coord_noised = noised_coord.clone()
tmp_coord_clean = clean_coord.clone()
inter_coord_noised = self.phys2inter(tmp_coord_noised)
inter_coord_clean = self.phys2inter(tmp_coord_clean)
inter_coord_noised_update = torch.where(abs(inter_coord_noised-0.50)<0.50, inter_coord_noised, 2*inter_coord_clean-inter_coord_noised)
coord_noised_update = self.inter2phys(inter_coord_noised_update)
return coord_noised_update

def normalize_coord(coord, region: Region3D, nloc: int):
"""Move outer atoms into region by mirror.
Expand Down

0 comments on commit 1c72709

Please sign in to comment.