Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(fix) Make bias statistics complete for all elements #4495

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions deepmd/pt/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,18 @@ def __getitem__(self, index):
b_data["natoms"] = self._natoms_vec
return b_data

def _build_element_to_frames(self):
"""Mapping element types to frame indexes"""
element_to_frames = {element: [] for element in range(self._ntypes)}
for frame_idx in range(len(self)):
frame_data = self._data_system.get_item_torch(frame_idx)

elements = frame_data["atype"]
for element in set(elements):
if len(element_to_frames[element]) < 10:
element_to_frames[element].append(frame_idx)
return element_to_frames

def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> None:
"""Add data requirement for this data system."""
for data_item in data_requirement:
Expand Down
40 changes: 40 additions & 0 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,46 @@ def make_stat_input(datasets, dataloaders, nbatches):
sys_stat[key] = torch.cat(sys_stat[key], dim=0)
dict_to_device(sys_stat)
lst.append(sys_stat)

all_elements = set()
if datasets and hasattr(datasets[0], "element_to_frames"):
all_elements.update(datasets[0].element_to_frames.keys())

collected_elements = set()
for sys_stat in lst:
if "atype" in sys_stat:
collected_elements.update(np.unique(sys_stat["atype"].cpu().numpy()))
missing_elements = all_elements - collected_elements

for missing_element in missing_elements:
for i, dataset in enumerate(datasets):
if hasattr(dataset, "element_to_frames"):
frame_indices = dataset.element_to_frames.get(
missing_element, []
)
for frame_idx in frame_indices:
if len(lst[i]["atype"]) >= nbatches:
break
frame_data = dataset[frame_idx]
for key in frame_data:
if key not in lst[i]:
lst[i][key] = []
lst[i][key].append(frame_data[key])

collected_elements = set()
for sys_stat in lst:
if "atype" in sys_stat:
collected_elements.update(
np.unique(sys_stat["atype"].cpu().numpy())
)

for sys_stat in lst:
for key in sys_stat:
if isinstance(sys_stat[key], list) and isinstance(
sys_stat[key][0], torch.Tensor
):
sys_stat[key] = torch.cat(sys_stat[key], dim=0)

return lst


Expand Down
Loading