Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(fix) Make bias statistics complete for all elements #4495

Closed
wants to merge 3 commits into from

Conversation

SumGuo-88
Copy link
Collaborator

@SumGuo-88 SumGuo-88 commented Dec 23, 2024

Summary by CodeRabbit

  • New Features

    • Introduced a method for mapping element types to frame indexes, enhancing data organization.
    • Enhanced statistics handling by adding logic to manage missing elements in datasets.
  • Bug Fixes

    • Improved the robustness of the statistics collection process, ensuring accurate data retrieval for missing elements.

Copy link
Contributor

coderabbitai bot commented Dec 23, 2024

📝 Walkthrough

Walkthrough

The pull request introduces a new private method _build_element_to_frames in the DeepmdDataSetForLoader class within the deepmd/pt/utils/dataset.py file. This method creates a mapping of element types to their corresponding frame indexes, with a constraint of storing up to 10 frame indexes per element type. Simultaneously, the make_stat_input function in deepmd/pt/utils/stat.py is enhanced to handle missing elements in datasets more robustly, improving data collection and statistics processing.

Changes

File Change Summary
deepmd/pt/utils/dataset.py Added private method _build_element_to_frames to create element type to frame index mapping.
deepmd/pt/utils/stat.py Enhanced make_stat_input function to handle missing elements and improve statistics collection.

Possibly related PRs

Suggested labels

Python

Suggested reviewers

  • njzjz
  • CaRoLZhangxy
  • wanghan-iapcm

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR. (Beta)
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
deepmd/pt/utils/dataset.py (1)

42-45: Consider parameterizing the frame limit.
Currently, the limit of 10 frames per element is hardcoded in the method. Exposing it as a configurable parameter (e.g., “max_frames_per_element”) can improve flexibility.

 def _build_element_to_frames(self, max_frames_per_element=10):
     """Mapping element types to frame indexes"""
     element_to_frames = {element: [] for element in range(self._ntypes)} 
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between cfe17a3 and 863ba9b.

📒 Files selected for processing (2)
  • deepmd/pt/utils/dataset.py (1 hunks)
  • deepmd/pt/utils/stat.py (1 hunks)
🔇 Additional comments (3)
deepmd/pt/utils/dataset.py (1)

46-53: Validate “atype” presence.
Before referencing frame_data["atype"], ensure it exists or handle any potential KeyError.

 for frame_idx in range(len(self)):
     frame_data = self._data_system.get_item_torch(frame_idx)
-    elements = frame_data["atype"]  
+    elements = frame_data.get("atype")
+    if elements is None:
+        # Possibly skip or handle error
+        continue
deepmd/pt/utils/stat.py (2)

86-88: Add a safeguard check for empty dictionaries.
If datasets[0] has no elements in its mapping, you risk all_elements being empty or incorrectly used. Consider handling the scenario where the set is empty.

 all_elements = set()
 if datasets and hasattr(datasets[0], 'element_to_frames'):
     if datasets[0].element_to_frames:
         all_elements.update(datasets[0].element_to_frames.keys())
     else:
         log.warning("No elements found in the first dataset’s mapping.")

101-102: Confirm nbatches usage.
The code checks if len(lst[i]['atype']) >= nbatches to stop appending frames. Make sure “atype” is the correct dimension to compare with nbatches and that partial batches or partial data appends won’t break statistics.

✅ Verification successful

Let me gather more information about the usage of nbatches and the structure of lst to verify this concern.


Let me gather more context about the data structures and their usage to better understand the batching logic.


From the gathered context, I can now generate a final response:

The nbatches comparison with atype length is correctly implemented

The code is working as intended because:

  • nbatches is a parameter that controls the batch count for collecting stats, as shown in the function signature make_stat_input(datasets, dataloaders, nbatches)
  • The lst structure is documented as "a list of dicts, each of which contains data from a system"
  • The comparison len(lst[i]['atype']) >= nbatches is used to ensure we don't collect more frames than the specified batch count
  • The code handles partial data correctly by:
    • Breaking only the inner loop when the batch limit is reached
    • Continuing to process other missing elements and systems
    • Maintaining the data structure consistency by appending frame data for all keys
🏁 Scripts executed

The following scripts were executed for the analysis:

Error: No shell scripts found


Script:

#!/bin/bash
# Find the context around the code in question
rg -B 10 -A 10 "if len\(lst\[i\]\['atype'\]\) >= nbatches:" deepmd/pt/utils/stat.py

# Search for nbatches definition and usage
rg "nbatches" deepmd/pt/utils/stat.py

# Look for atype usage in the same file
rg "atype" deepmd/pt/utils/stat.py

Length of output: 2028


Script:

#!/bin/bash
# Look for the dataset and dataloader structure
ast-grep --pattern 'class $_Dataset {
  $$$
}'

# Find where lst is initialized and how it's structured
rg -B 5 -A 5 "lst = \[" deepmd/pt/utils/stat.py

# Look for the complete make_stat_input function definition
ast-grep --pattern 'def make_stat_input($_) {
  $$$
}'

Length of output: 475

Comment on lines 90 to 113
collected_elements = set()
for sys_stat in lst:
if 'atype' in sys_stat:
collected_elements.update(np.unique(sys_stat['atype'].cpu().numpy()))
missing_elements = all_elements - collected_elements

for missing_element in missing_elements:
for i, dataset in enumerate(datasets):
if hasattr(dataset, 'element_to_frames'):
frame_indices = dataset.element_to_frames.get(missing_element, [])
for frame_idx in frame_indices:
if len(lst[i]['atype']) >= nbatches:
break
frame_data = dataset[frame_idx]
for key in frame_data:
if key not in lst[i]:
lst[i][key] = []
lst[i][key].append(frame_data[key])

collected_elements = set()
for sys_stat in lst:
if 'atype' in sys_stat:
collected_elements.update(np.unique(sys_stat['atype'].cpu().numpy()))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Avoid nested for loops for efficiency.
The nested iteration (missing_element → dataset → frame_indices) can become expensive for large datasets. If performance is critical, consider grouping missing elements or caching repeated data accesses (like dataset.get_item_torch calls).

Comment on lines 114 to 118
for sys_stat in lst:
for key in sys_stat:
if isinstance(sys_stat[key], list) and isinstance(sys_stat[key][0], torch.Tensor):
sys_stat[key] = torch.cat(sys_stat[key], dim=0)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Ensure tensor concatenation only occurs when the list is non-empty.
If sys_stat[key] is an empty list under certain conditions, this will fail. You might want to check for emptiness before concatenation.

 for sys_stat in lst:
     for key in sys_stat:
         if isinstance(sys_stat[key], list) and len(sys_stat[key]) > 0 \
                 and isinstance(sys_stat[key][0], torch.Tensor):
             sys_stat[key] = torch.cat(sys_stat[key], dim=0)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for sys_stat in lst:
for key in sys_stat:
if isinstance(sys_stat[key], list) and isinstance(sys_stat[key][0], torch.Tensor):
sys_stat[key] = torch.cat(sys_stat[key], dim=0)
for sys_stat in lst:
for key in sys_stat:
if isinstance(sys_stat[key], list) and len(sys_stat[key]) > 0 \
and isinstance(sys_stat[key][0], torch.Tensor):
sys_stat[key] = torch.cat(sys_stat[key], dim=0)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
deepmd/pt/utils/stat.py (1)

85-89: Consider collecting elements from all datasets

The current implementation only collects elements from the first dataset, which might miss elements present in other datasets. Consider updating the code to collect elements from all datasets.

-    if datasets and hasattr(datasets[0], "element_to_frames"):
-        all_elements.update(datasets[0].element_to_frames.keys())
+    if datasets:
+        for dataset in datasets:
+            if hasattr(dataset, "element_to_frames"):
+                all_elements.update(dataset.element_to_frames.keys())
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 863ba9b and 52e8c34.

📒 Files selected for processing (2)
  • deepmd/pt/utils/dataset.py (1 hunks)
  • deepmd/pt/utils/stat.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/pt/utils/dataset.py
🔇 Additional comments (2)
deepmd/pt/utils/stat.py (2)

118-124: 🛠️ Refactor suggestion

Add error handling for tensor concatenation

The current implementation might fail if sys_stat[key] is an empty list or if any tensor in the list has incompatible dimensions.

     for sys_stat in lst:
         for key in sys_stat:
-            if isinstance(sys_stat[key], list) and isinstance(
-                sys_stat[key][0], torch.Tensor
-            ):
-                sys_stat[key] = torch.cat(sys_stat[key], dim=0)
+            if isinstance(sys_stat[key], list):
+                if not sys_stat[key]:
+                    continue
+                if all(isinstance(t, torch.Tensor) for t in sys_stat[key]):
+                    try:
+                        sys_stat[key] = torch.cat(sys_stat[key], dim=0)
+                    except RuntimeError as e:
+                        log.error(f"Failed to concatenate tensors for key {key}: {e}")
+                        raise

Likely invalid or redundant comment.


90-117: 🛠️ Refactor suggestion

Improve performance and robustness of missing elements handling

The current implementation has several areas for improvement:

  1. The nested loops (missing_element → dataset → frame_indices) could be performance intensive
  2. The collected_elements set is unnecessarily cleared and recollected
  3. Missing error handling for frame data retrieval

Consider these improvements:

     collected_elements = set()
     for sys_stat in lst:
         if "atype" in sys_stat:
             collected_elements.update(np.unique(sys_stat["atype"].cpu().numpy()))
-            missing_elements = all_elements - collected_elements
-
-            for missing_element in missing_elements:
-                for i, dataset in enumerate(datasets):
-                    if hasattr(dataset, "element_to_frames"):
-                        frame_indices = dataset.element_to_frames.get(
-                            missing_element, []
-                        )
-                        for frame_idx in frame_indices:
-                            if len(lst[i]["atype"]) >= nbatches:
-                                break
-                            frame_data = dataset[frame_idx]
-                            for key in frame_data:
-                                if key not in lst[i]:
-                                    lst[i][key] = []
-                                lst[i][key].append(frame_data[key])
-
-            collected_elements = set()
-            for sys_stat in lst:
-                if "atype" in sys_stat:
-                    collected_elements.update(
-                        np.unique(sys_stat["atype"].cpu().numpy())
-                    )
+
+    missing_elements = all_elements - collected_elements
+    if missing_elements:
+        # Pre-process frame indices for missing elements
+        element_frames = {}
+        for missing_element in missing_elements:
+            element_frames[missing_element] = []
+            for i, dataset in enumerate(datasets):
+                if hasattr(dataset, "element_to_frames"):
+                    frames = dataset.element_to_frames.get(missing_element, [])
+                    element_frames[missing_element].extend((i, idx) for idx in frames)
+
+        # Process frame data in batches
+        for missing_element, frames in element_frames.items():
+            for dataset_idx, frame_idx in frames:
+                if len(lst[dataset_idx]["atype"]) >= nbatches:
+                    break
+                try:
+                    frame_data = datasets[dataset_idx][frame_idx]
+                    for key in frame_data:
+                        if key not in lst[dataset_idx]:
+                            lst[dataset_idx][key] = []
+                        lst[dataset_idx][key].append(frame_data[key])
+                except Exception as e:
+                    log.warning(f"Failed to retrieve frame {frame_idx} from dataset {dataset_idx}: {e}")

Likely invalid or redundant comment.

@SumGuo-88 SumGuo-88 closed this Dec 23, 2024
@SumGuo-88 SumGuo-88 deleted the devel branch December 23, 2024 12:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant