-
Notifications
You must be signed in to change notification settings - Fork 527
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
(fix) Make bias statistics complete for all elements #4495
Conversation
for more information, see https://pre-commit.ci
📝 WalkthroughWalkthroughThe pull request introduces a new private method Changes
Possibly related PRs
Suggested labels
Suggested reviewers
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (1)
deepmd/pt/utils/dataset.py (1)
42-45
: Consider parameterizing the frame limit.
Currently, the limit of 10 frames per element is hardcoded in the method. Exposing it as a configurable parameter (e.g., “max_frames_per_element”) can improve flexibility.def _build_element_to_frames(self, max_frames_per_element=10): """Mapping element types to frame indexes""" element_to_frames = {element: [] for element in range(self._ntypes)}
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pt/utils/dataset.py
(1 hunks)deepmd/pt/utils/stat.py
(1 hunks)
🔇 Additional comments (3)
deepmd/pt/utils/dataset.py (1)
46-53
: Validate “atype” presence.
Before referencing frame_data["atype"], ensure it exists or handle any potential KeyError.
for frame_idx in range(len(self)):
frame_data = self._data_system.get_item_torch(frame_idx)
- elements = frame_data["atype"]
+ elements = frame_data.get("atype")
+ if elements is None:
+ # Possibly skip or handle error
+ continue
deepmd/pt/utils/stat.py (2)
86-88
: Add a safeguard check for empty dictionaries.
If datasets[0] has no elements in its mapping, you risk all_elements being empty or incorrectly used. Consider handling the scenario where the set is empty.
all_elements = set()
if datasets and hasattr(datasets[0], 'element_to_frames'):
if datasets[0].element_to_frames:
all_elements.update(datasets[0].element_to_frames.keys())
else:
log.warning("No elements found in the first dataset’s mapping.")
101-102
: Confirm nbatches usage.
The code checks if len(lst[i]['atype']) >= nbatches to stop appending frames. Make sure “atype” is the correct dimension to compare with nbatches and that partial batches or partial data appends won’t break statistics.
✅ Verification successful
Let me gather more information about the usage of nbatches
and the structure of lst
to verify this concern.
Let me gather more context about the data structures and their usage to better understand the batching logic.
From the gathered context, I can now generate a final response:
The nbatches comparison with atype length is correctly implemented
The code is working as intended because:
nbatches
is a parameter that controls the batch count for collecting stats, as shown in the function signaturemake_stat_input(datasets, dataloaders, nbatches)
- The
lst
structure is documented as "a list of dicts, each of which contains data from a system" - The comparison
len(lst[i]['atype']) >= nbatches
is used to ensure we don't collect more frames than the specified batch count - The code handles partial data correctly by:
- Breaking only the inner loop when the batch limit is reached
- Continuing to process other missing elements and systems
- Maintaining the data structure consistency by appending frame data for all keys
🏁 Scripts executed
The following scripts were executed for the analysis:
Error: No shell scripts found
Script:
#!/bin/bash
# Find the context around the code in question
rg -B 10 -A 10 "if len\(lst\[i\]\['atype'\]\) >= nbatches:" deepmd/pt/utils/stat.py
# Search for nbatches definition and usage
rg "nbatches" deepmd/pt/utils/stat.py
# Look for atype usage in the same file
rg "atype" deepmd/pt/utils/stat.py
Length of output: 2028
Script:
#!/bin/bash
# Look for the dataset and dataloader structure
ast-grep --pattern 'class $_Dataset {
$$$
}'
# Find where lst is initialized and how it's structured
rg -B 5 -A 5 "lst = \[" deepmd/pt/utils/stat.py
# Look for the complete make_stat_input function definition
ast-grep --pattern 'def make_stat_input($_) {
$$$
}'
Length of output: 475
deepmd/pt/utils/stat.py
Outdated
collected_elements = set() | ||
for sys_stat in lst: | ||
if 'atype' in sys_stat: | ||
collected_elements.update(np.unique(sys_stat['atype'].cpu().numpy())) | ||
missing_elements = all_elements - collected_elements | ||
|
||
for missing_element in missing_elements: | ||
for i, dataset in enumerate(datasets): | ||
if hasattr(dataset, 'element_to_frames'): | ||
frame_indices = dataset.element_to_frames.get(missing_element, []) | ||
for frame_idx in frame_indices: | ||
if len(lst[i]['atype']) >= nbatches: | ||
break | ||
frame_data = dataset[frame_idx] | ||
for key in frame_data: | ||
if key not in lst[i]: | ||
lst[i][key] = [] | ||
lst[i][key].append(frame_data[key]) | ||
|
||
collected_elements = set() | ||
for sys_stat in lst: | ||
if 'atype' in sys_stat: | ||
collected_elements.update(np.unique(sys_stat['atype'].cpu().numpy())) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Avoid nested for loops for efficiency.
The nested iteration (missing_element → dataset → frame_indices) can become expensive for large datasets. If performance is critical, consider grouping missing elements or caching repeated data accesses (like dataset.get_item_torch calls).
deepmd/pt/utils/stat.py
Outdated
for sys_stat in lst: | ||
for key in sys_stat: | ||
if isinstance(sys_stat[key], list) and isinstance(sys_stat[key][0], torch.Tensor): | ||
sys_stat[key] = torch.cat(sys_stat[key], dim=0) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensure tensor concatenation only occurs when the list is non-empty.
If sys_stat[key] is an empty list under certain conditions, this will fail. You might want to check for emptiness before concatenation.
for sys_stat in lst:
for key in sys_stat:
if isinstance(sys_stat[key], list) and len(sys_stat[key]) > 0 \
and isinstance(sys_stat[key][0], torch.Tensor):
sys_stat[key] = torch.cat(sys_stat[key], dim=0)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
for sys_stat in lst: | |
for key in sys_stat: | |
if isinstance(sys_stat[key], list) and isinstance(sys_stat[key][0], torch.Tensor): | |
sys_stat[key] = torch.cat(sys_stat[key], dim=0) | |
for sys_stat in lst: | |
for key in sys_stat: | |
if isinstance(sys_stat[key], list) and len(sys_stat[key]) > 0 \ | |
and isinstance(sys_stat[key][0], torch.Tensor): | |
sys_stat[key] = torch.cat(sys_stat[key], dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
deepmd/pt/utils/stat.py (1)
85-89
: Consider collecting elements from all datasetsThe current implementation only collects elements from the first dataset, which might miss elements present in other datasets. Consider updating the code to collect elements from all datasets.
- if datasets and hasattr(datasets[0], "element_to_frames"): - all_elements.update(datasets[0].element_to_frames.keys()) + if datasets: + for dataset in datasets: + if hasattr(dataset, "element_to_frames"): + all_elements.update(dataset.element_to_frames.keys())
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pt/utils/dataset.py
(1 hunks)deepmd/pt/utils/stat.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- deepmd/pt/utils/dataset.py
🔇 Additional comments (2)
deepmd/pt/utils/stat.py (2)
118-124
: 🛠️ Refactor suggestion
Add error handling for tensor concatenation
The current implementation might fail if sys_stat[key]
is an empty list or if any tensor in the list has incompatible dimensions.
for sys_stat in lst:
for key in sys_stat:
- if isinstance(sys_stat[key], list) and isinstance(
- sys_stat[key][0], torch.Tensor
- ):
- sys_stat[key] = torch.cat(sys_stat[key], dim=0)
+ if isinstance(sys_stat[key], list):
+ if not sys_stat[key]:
+ continue
+ if all(isinstance(t, torch.Tensor) for t in sys_stat[key]):
+ try:
+ sys_stat[key] = torch.cat(sys_stat[key], dim=0)
+ except RuntimeError as e:
+ log.error(f"Failed to concatenate tensors for key {key}: {e}")
+ raise
Likely invalid or redundant comment.
90-117
: 🛠️ Refactor suggestion
Improve performance and robustness of missing elements handling
The current implementation has several areas for improvement:
- The nested loops (missing_element → dataset → frame_indices) could be performance intensive
- The
collected_elements
set is unnecessarily cleared and recollected - Missing error handling for frame data retrieval
Consider these improvements:
collected_elements = set()
for sys_stat in lst:
if "atype" in sys_stat:
collected_elements.update(np.unique(sys_stat["atype"].cpu().numpy()))
- missing_elements = all_elements - collected_elements
-
- for missing_element in missing_elements:
- for i, dataset in enumerate(datasets):
- if hasattr(dataset, "element_to_frames"):
- frame_indices = dataset.element_to_frames.get(
- missing_element, []
- )
- for frame_idx in frame_indices:
- if len(lst[i]["atype"]) >= nbatches:
- break
- frame_data = dataset[frame_idx]
- for key in frame_data:
- if key not in lst[i]:
- lst[i][key] = []
- lst[i][key].append(frame_data[key])
-
- collected_elements = set()
- for sys_stat in lst:
- if "atype" in sys_stat:
- collected_elements.update(
- np.unique(sys_stat["atype"].cpu().numpy())
- )
+
+ missing_elements = all_elements - collected_elements
+ if missing_elements:
+ # Pre-process frame indices for missing elements
+ element_frames = {}
+ for missing_element in missing_elements:
+ element_frames[missing_element] = []
+ for i, dataset in enumerate(datasets):
+ if hasattr(dataset, "element_to_frames"):
+ frames = dataset.element_to_frames.get(missing_element, [])
+ element_frames[missing_element].extend((i, idx) for idx in frames)
+
+ # Process frame data in batches
+ for missing_element, frames in element_frames.items():
+ for dataset_idx, frame_idx in frames:
+ if len(lst[dataset_idx]["atype"]) >= nbatches:
+ break
+ try:
+ frame_data = datasets[dataset_idx][frame_idx]
+ for key in frame_data:
+ if key not in lst[dataset_idx]:
+ lst[dataset_idx][key] = []
+ lst[dataset_idx][key].append(frame_data[key])
+ except Exception as e:
+ log.warning(f"Failed to retrieve frame {frame_idx} from dataset {dataset_idx}: {e}")
Likely invalid or redundant comment.
Summary by CodeRabbit
New Features
Bug Fixes