Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(pt): fix type annotations for dummy compress op; improve docs #4342

Merged
merged 2 commits into from
Nov 12, 2024

Conversation

njzjz
Copy link
Member

@njzjz njzjz commented Nov 12, 2024

Summary by CodeRabbit

  • New Features

    • Added type annotations to several functions for improved clarity and type safety.
    • Updated documentation to include installation requirements for the PyTorch backend when compressing models.
  • Documentation

    • New section on installation requirements added to the compress.md document.
  • Bug Fixes

    • No bug fixes were introduced in this release.
  • Refactor

    • Minor refactoring for better code readability without changing existing functionalities.

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Copy link
Contributor

coderabbitai bot commented Nov 12, 2024

📝 Walkthrough
📝 Walkthrough

Walkthrough

This pull request primarily introduces type annotations to several functions within the deepmd/pt/model/descriptor module, enhancing type safety and clarity regarding the expected input and output types. The functions tabulate_fusion_se_a, tabulate_fusion_se_atten, tabulate_fusion_se_r, and tabulate_fusion_se_t have been updated with specific types for their parameters and return values. Additionally, the DescrptBlockSeAtten class's constructor has been modified to remove an unnecessary parameter. The overall functionality and logic of the related classes remain unchanged.

Changes

File Change Summary
deepmd/pt/model/descriptor/se_a.py - Added type annotations to tabulate_fusion_se_a: parameters as torch.Tensor and int, return type as list[torch.Tensor].
deepmd/pt/model/descriptor/se_atten.py - Added type annotations to tabulate_fusion_se_atten: parameters as torch.Tensor, int, and bool, return type as list[torch.Tensor].
- Removed type parameter from DescrptBlockSeAtten constructor.
deepmd/pt/model/descriptor/se_r.py - Added type annotations to tabulate_fusion_se_r: parameters as torch.Tensor and int, return type as list[torch.Tensor].
deepmd/pt/model/descriptor/se_t.py - Added type annotations to tabulate_fusion_se_t: parameters as torch.Tensor and int, return type as list[torch.Tensor].
doc/freeze/compress.md - Added a new section "Requirements of installation" for PyTorch backend, detailing the installation of a customized OP library during model freezing.

Possibly related PRs

  • Add 4 pt descriptor compression #4227: This PR adds compression functionality to the DescrptSeA class, which is directly related to the changes made in the main PR that introduces type annotations to the tabulate_fusion_se_a function in the same class.
  • docs: add documentation for installation requirements of DPA-2 #4178: This PR includes updates to the DescrptBlockSeAtten class, which also involves changes to the tabulate_fusion_se_atten function, similar to the main PR's focus on type annotations in tabulate_fusion_se_a.
  • fix(tf): fix compress suffix in DescrptDPA1Compat #4243: This PR modifies the enable_compression method in the DescrptSeAtten class, which is relevant as it relates to the compression functionality introduced in the main PR's context.
  • feat(jax/array-api): se_e3 #4286: This PR introduces a new descriptor DescrptSeT, which is relevant as it extends the functionality of descriptors in a similar manner to the changes made in the main PR.
  • feat(jax/array-api): se_t_tebd #4288: This PR adds new classes that extend existing descriptor functionality, which aligns with the main PR's focus on enhancing descriptor capabilities through type annotations.

Suggested reviewers

  • wanghan-iapcm
  • iProzd
  • njzjz

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Outside diff range and nitpick comments (5)
doc/freeze/compress.md (1)

132-132: Consider adding a complete example command.

To make the instructions more actionable, consider adding a complete example command showing how to set the environment variable:

 The customized OP library for the Python interface can be installed by setting environment variable {envvar}`DP_ENABLE_PYTORCH` to `1` during [installation](../install/install-from-source.md).
+
+```bash
+export DP_ENABLE_PYTORCH=1
+pip install .
+```
deepmd/pt/model/descriptor/se_a.py (1)

Line range hint 82-85: Consider enhancing the error message and documentation.

The error message could be more specific about which operation is missing and provide a direct link to the model compression documentation. Additionally, the LAMMPS compatibility note could be more prominent.

Consider updating the error message like this:

- "tabulate_fusion_se_a is not available since customized PyTorch OP library is not built when freezing the model. "
+ "The tabulate_fusion_se_a operation is not available. This operation requires the custom PyTorch operations library, which is not built. "
+ "Please refer to https://docs.deepmodeling.org/projects/deepmd/en/master/model-compression.html for details on model compression. "
deepmd/pt/model/descriptor/se_t.py (1)

76-80: Improve type annotations and documentation

While adding type annotations is good for type safety, consider these improvements:

  1. Use more descriptive parameter names instead of generic argumentN
  2. Add dimension information to tensor type hints using # shape: ... comments
  3. Add a docstring explaining the parameters and return value
-    argument0: torch.Tensor,
-    argument1: torch.Tensor,
-    argument2: torch.Tensor,
-    argument3: torch.Tensor,
-    argument4: int,
+    table_data: torch.Tensor,  # shape: [table_size, out_dim]
+    table_info: torch.Tensor,  # shape: [6]
+    env_deriv: torch.Tensor,  # shape: [batch_size, 1]
+    env: torch.Tensor,  # shape: [batch_size, n_types_i, n_types_j]
+    output_dim: int,

Also consider adding a docstring:

"""Fuse and tabulate the environment matrix for the SE(3) descriptor.

Parameters
----------
table_data : torch.Tensor
    The tabulated data for interpolation
table_info : torch.Tensor
    Configuration info containing [lower, upper, upper_ext, stride1, stride2, check_freq]
env_deriv : torch.Tensor
    Environment derivatives
env : torch.Tensor
    Environment matrix
output_dim : int
    Output dimension of the network

Returns
-------
list[torch.Tensor]
    List containing the fused descriptor values
"""
deepmd/pt/model/descriptor/se_atten.py (2)

55-61: Improve parameter names and add docstring.

While the type annotations are correct, the function would benefit from:

  1. More descriptive parameter names that indicate their purpose
  2. A docstring explaining the parameters and return value

Consider renaming parameters and adding documentation:

 def tabulate_fusion_se_atten(
-    argument0: torch.Tensor,
-    argument1: torch.Tensor,
-    argument2: torch.Tensor,
-    argument3: torch.Tensor,
-    argument4: torch.Tensor,
-    argument5: int,
-    argument6: bool,
+    compress_data: torch.Tensor,
+    compress_info: torch.Tensor,
+    input_data: torch.Tensor,
+    radial_data: torch.Tensor,
+    gate_data: torch.Tensor,
+    filter_neuron: int,
+    is_sorted: bool,
 ) -> list[torch.Tensor]:
+    """Fallback implementation for the custom tabulate_fusion_se_atten operation.
+    
+    Parameters
+    ----------
+    compress_data : torch.Tensor
+        Compressed network data
+    compress_info : torch.Tensor
+        Compression configuration information
+    input_data : torch.Tensor
+        Input tensor for the network
+    radial_data : torch.Tensor
+        Radial information tensor
+    gate_data : torch.Tensor
+        Gating tensor
+    filter_neuron : int
+        Number of filter neurons
+    is_sorted : bool
+        Whether the input data is sorted
+        
+    Returns
+    -------
+    list[torch.Tensor]
+        List of output tensors from the fusion operation
+    """

Line range hint 87-156: Remove deprecated type parameter.

The type parameter is immediately deleted after being passed to the constructor, indicating it's deprecated. For better code clarity:

  1. Remove the parameter from the constructor signature
  2. Remove its documentation from the docstring

Apply this change:

     def __init__(
         self,
         rcut: float,
         rcut_smth: float,
         sel: Union[list[int], int],
         ntypes: int,
         neuron: list = [25, 50, 100],
         axis_neuron: int = 16,
         tebd_dim: int = 8,
         tebd_input_mode: str = "concat",
         set_davg_zero: bool = True,
         attn: int = 128,
         attn_layer: int = 2,
         attn_dotr: bool = True,
         attn_mask: bool = False,
         activation_function="tanh",
         precision: str = "float64",
         resnet_dt: bool = False,
         scaling_factor=1.0,
         normalize=True,
         temperature=None,
         smooth: bool = True,
         type_one_side: bool = False,
         exclude_types: list[tuple[int, int]] = [],
         env_protection: float = 0.0,
         trainable_ln: bool = True,
         ln_eps: Optional[float] = 1e-5,
-        seed: Optional[Union[int, list[int]]] = None,
-        type: Optional[str] = None,
+        seed: Optional[Union[int, list[int]]] = None
     ):

Also remove the parameter from the docstring.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between c4a973a and 55de77c.

📒 Files selected for processing (5)
  • deepmd/pt/model/descriptor/se_a.py (1 hunks)
  • deepmd/pt/model/descriptor/se_atten.py (1 hunks)
  • deepmd/pt/model/descriptor/se_r.py (1 hunks)
  • deepmd/pt/model/descriptor/se_t.py (1 hunks)
  • doc/freeze/compress.md (1 hunks)
🔇 Additional comments (1)
deepmd/pt/model/descriptor/se_a.py (1)

76-80: LGTM! Type annotations are accurate and helpful.

The type annotations correctly specify torch.Tensor for tensor arguments and int for the scalar argument, improving code clarity and type safety.

doc/freeze/compress.md Show resolved Hide resolved
deepmd/pt/model/descriptor/se_r.py Show resolved Hide resolved
Copy link

codecov bot commented Nov 12, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 84.35%. Comparing base (4793125) to head (48670a6).
Report is 1 commits behind head on devel.

Additional details and impacted files
@@            Coverage Diff             @@
##            devel    #4342      +/-   ##
==========================================
- Coverage   84.35%   84.35%   -0.01%     
==========================================
  Files         593      593              
  Lines       55899    55899              
  Branches     3388     3388              
==========================================
- Hits        47154    47153       -1     
  Misses       7635     7635              
- Partials     1110     1111       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@wanghan-iapcm wanghan-iapcm added this pull request to the merge queue Nov 12, 2024
Merged via the queue into deepmodeling:devel with commit 4a9ed88 Nov 12, 2024
60 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants