Skip to content

Conversation

@HydrogenSulfate
Copy link
Collaborator

@HydrogenSulfate HydrogenSulfate commented Jun 30, 2025

support running input_torch_dynamic.json with paddle backend(including CINN)

TODO list:

Summary by CodeRabbit

Summary by CodeRabbit

  • Bug Fixes

    • Resolved issues with tensor shape and indexing consistency, preventing assertion errors during model execution.
    • Improved handling of default tensor initialization to avoid JIT assertion issues.
  • Refactor

    • Standardized tensor dimension handling and broadcasting for improved clarity and maintainability.
    • Enhanced code readability with clearer indexing conventions and formatting.
    • Updated aggregation logic for safer and more efficient tensor operations.
  • New Features

    • Added an option to control graph index mapping behavior for greater flexibility in advanced use cases.
  • Tests

    • Introduced comprehensive tests validating descriptor model consistency with dynamic selection enabled.

@HydrogenSulfate HydrogenSulfate changed the base branch from master to devel June 30, 2025 07:13
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jun 30, 2025

📝 Walkthrough
## Walkthrough

This change standardizes tensor dimension handling and indexing conventions across several modules, primarily affecting how edge and angle indices are structured and used. It introduces explicit reshaping, modifies the shape and construction of index tensors, and clarifies broadcasting logic, without altering core algorithms or public interfaces.

## Changes

| File(s)                                    | Change Summary                                                                                                     |
|--------------------------------------------|-------------------------------------------------------------------------------------------------------------------|
| deepmd/pd/model/descriptor/repflow_layer.py | Standardized tensor dimension handling and reshaping; changed edge/angle index slicing from column to row-based; clarified broadcasting; minor formatting improvements. |
| deepmd/pd/model/descriptor/repflows.py      | Changed initialization of edge/angle index tensors to new shapes; updated owner argument to match new indexing; conditional node embedding extraction based on local mapping. |
| deepmd/pd/model/network/utils.py             | Modified `aggregate` to compute `bin_count` only when needed; changed output tensor initialization; added assertion; updated `get_graph_index` to stack indices along axis 0 and added `use_loc_mapping` parameter. |
| source/tests/pd/model/test_dynamic_sel.py   | Added test class validating descriptor consistency between dynamic selection enabled and disabled modes with multiple parameter combinations and precision settings. |

## Sequence Diagram(s)

```mermaid
sequenceDiagram
    participant User
    participant RepFlowLayer
    participant Utils

    User->>RepFlowLayer: forward(nlist, nlist_mask, ...)
    RepFlowLayer->>Utils: get_graph_index(nlist, nlist_mask, ..., use_loc_mapping)
    Utils-->>RepFlowLayer: edge_index [2, n_edge], angle_index [3, n_angle]
    RepFlowLayer->>RepFlowLayer: _cal_hg_dynamic(..., owner=edge_index[0], ...)
    RepFlowLayer->>Utils: aggregate(data, owners, average, num_owner)
    Utils-->>RepFlowLayer: aggregated tensor
    RepFlowLayer-->>User: output

Possibly related PRs

Suggested reviewers

  • caic99

</details>

<!-- walkthrough_end -->
<!-- internal state start -->


<!-- DwQgtGAEAqAWCWBnSTIEMB26CuAXA9mAOYCmGJATmriQCaQDG+Ats2bgFyQAOFk+AIwBWJBrngA3EsgEBPRvlqU0AgfFwA6NPEgQAfACgjoCEYDEZyAAUASpETZWaCrKNwSPbABsvkCiQBHbGlcSHFcLzpIACIAbQB1AEkrAF0eWi4Hbm58ClDabjQAZkhaWQw0ZngGe1g0bg8AM1z0yAE0BgBrMlpolGQHASrcGno5MNgPbERKSAAJWVoKfFIMAGVvRuoPTHpUeCUMcUb4KLRkDEcBWYAWAA4AJjuNGEnrO21mZAJ7bGzc0IUbAYDDwDBECZNeCRSAAAzB3DwAH0CBQGLAkWUKlUGBohIh8BhYZAAO7qWCQ6xoWi0GHtLo9AA0KAwDC82FoYIhAGFEgA5PmQAAUaAUzG40NmYMaVEQuCBYmw/gAlC93KVpAwKPBuOJCfYGgx4CdpBNqBNUGQ6qySGwjpBmgxpqb9WQVF4uaVypVqrV6h4sjk8g6Wrg3lZqbSPPTuhh6GQJPBlhg7ZpXh5bCy2RzTaL0aJOh65fxGn4SF5tvQIzTItWo54fGWgiFvnVQnUpG0SGQxdxIqMwvgHR1oepthbkI0SNQlSQ1W9Gt5fEpcNovMhdqUkE7EIh4PqGJguyHgfRzWGPLx8ImlPQAKo2AAyXFgI24iA4AHpP0RydgBBoTDMJ+SgkNwzCKOWXIgd24G0GAnTqJ+iI+J+9xPBoRj6MY4BQD0JY4AQxBkMoA5AamXC8PwwiiOIUgyPITBKFQqjqFoOjYSYUBwPsG5YGgeCEKspFROR7BcFQJK/E4LhtIxkEsWomjaLoYCGDhpgGKBcHIbQn4QUoXgwYgWo6qin7+NwjRePgJJIhWsiUBo3CyBwBjRB5BgWJAACCiTEeQVADg4MnyPgpbopgpCIG4byReCpr+Cc5BhGQBJ8Jydp7vq1q0p6m5gkoAAenpMBgUhHPuGDIGCkDMN44h9h4bBhoofH0BedXaFgsLNBQJLOLQxKLqyepYOFlKwjYYEAGI2SSj5oI5FDEmy5yIC8ADSJDyAZxrVNQVU1ay7JKG5BhQIkPVIoeXhIrARCYt6OKwsynX1V4jUeoeY0Eb1Fa4PdDzEpu/3UEidCkBDAhDal1W5DVyDxaQ9AyiwkDTJ65Yeu++xxiQJXgqS5JwrEGjk8yfKEiQKQg3GpMcJT1PMhwtODpABN9tU6hwhowKIM2JAAF4kEKYAAIzKnT9Cwnz1WCyLYsPFLjBoD4MgkH1Hj+IgdQSuCzJrdqjSyJ6AjLNSh5yp6MzMJg4gMBtWGQFdcL4Lq8DMEiUWREify0NsT3YtU0tux7XuQyQfvcAHNBBz6DCvZSOt656E1gjQFBsJy440PDFAa1rcKFDWc6FQTSIzJEYjEqg/uVuzhS7ugkBFqEzhEI47Ds7CKf+sSQpzkQGjMr30h66LsQYKWABUkAYDZDDMuTGgpFLyrMpunWwiXUYaOXRWV+WtGrWr67oP4BoeqEGdDh9X0eB65CTi0ntXlI9D+NSKijrgsiYRdF2WBt59QGhQWGLVYCKDem8Wg+AnSpjOPTDouBsBqz9A0ZAE1YSRyRAfUOsIfZR3wXDdKG5L7IyiGjZgcIhQYAhrQUgkAiqQGVgQuh3twQwhYUUFWPxYRCgeMw+eDDSAq1BkKEoLD6FEKlsyHWho6Lln/pAbkuQFGEk5Fw2QMCpgYEKF0NOpYLwzBZJyBgppEBklwPmScywaGIG+qVfA7JmDVWFDgxhxD8ZFViIzSAAAGWmyp2aOO5kTZYJJkBCk8VDA+sQgmwlVM7dUEhnDwHdB4WE9DI61wuEOGyCU+BAURAOc4lIQp/QXkgQGdtECdGJAIPA/R7AkFCPwqm5BiQkkmD1Z08ccRHy8Hk4c64SCb3pvw2ADwNC639AktmPSezymCAA7ibxgQGMQkTbBuCSAw2JHMhoLTLIVgsWMeQmyOjbIhNktusyJ7Eh+KQUIAiMACGZAveBnzyDwClmsmAaUWj3x1N9Q6hJjoSBcYmImVj1DogdMCMQR0PFoEriSJJF8PCUNRvYjmPgdR7mOsVT01iKSwjJhTSAnTxnUupmzH4nMnGvLlgLYIwtRYS1VPzBWnLJan3ViGYpFZtR/3QPTMqRK84MH/s7AAsmCFofU7YjE9G/ZYUhUzHWzEodANJnFOHalfdQ4giaFIhE/U0fUuwjFmF/AOahr6yoMFTRg1porsyYJfGyv4l4KCOMsXw1lbLMhaJQZYfBcpPwhM4DwGdli0GwOc+czVFD7R+iix00x+BYElTU7ued0pbiylVSAUauTMgPs48q7AjoTPoExbFIr1DOvMJYHyn1SIop+J1JQRtwXuImpzAEUQWiIgEN9DmlVxDSGdq6yh3whydT3EQCoqCdYEWHXkUdfBx2TtrTO5AJJKDNWpHOXyjY9onAzRCrFLJM4VF8D8SBigUDikiKmAdR7yT4GaWrTOnoX30GxNIZkhQqAtUoIgUNfB/DruAbITBmEPLRCwppbSzA9Ix30pBIySgTLal1LkCyYFg1ROcq5dynlvJ+QCiJegIU7ayQmgu2K2L3W5loEIaY7S3hgnUBkj0QsB1/VwSQ0GRC8E+KeUCguLJJqgMGsSIDpJelwn6ViBOQyRkzTVjMF4Vh/CJl/Y4nRbR8BhlJCelkAm1bwBFmeDc9guQwkLS0CaRysmxHFsyIotMXhUxJKPMT0mWn8fEHZhz6AnMi2WKQ9zpZPOkweMycWbMJNcO8cVEGMXKBLtkwRJLFKiipdpsyAQFmKRkoqZUDwtA/7HPKXcmpstXioDQFCg4TmABSiRoDRZmHkUtSAHDRnkO/A4NswLOFzrJ79lnPPIBVeiLGRVFFRCrUTTKaUjoANdsu2rdU2lQNoLojmRV5QoNLdgjAkFoa0Ahhd4aeLsl3f2bDTGRMd6RkiBoVc3Rvbmu9iVRAIzbtSTKpyMa6D9Swiro0PmMx7LwKRHbbIXJWuJFLLCfpi9Uf1H1kQEZ0AgS0te0oe7j3cAjPWvAVdUROT+DEF4eQ1C4S3Yp+92EABufgF5+pIFpTzZ0yBvulz+2gAHatCSPTQCD7pJNYRo8J61niSMOOtmxYSeULiHTzSFfPSCHNmDXH1UTAmhbS1KAaHGNOwC3iL3Qcrz01k0DDyMDNMEZ8zO1U6rdJ9Q4pqzXmotZaGgbpq3uo9TTL0zuwlsoFEGFBO6phaTih0L2QvFT8cyRJwoq60U9J1E4BdQhMFcRgEJ/Cs8EwWQPAvyKibF6TMWSJqo2tLbTSaTXZiCZTbtpVR2BvOoJ9mOB2rmdmQem6Cz9AHpV1qtCNVzq5ApJJd2V4qT2WAGusq7MH1voWhlR10G/XC673xsUEmugAC20Xszl+9mvbRAisf0OtbI76Bjv/Pu6dpwYourUzIaeQXToawSYa6Q4aGSfjkC4Aki5CdCfh4DQgbQuRuQoZeTtr+TCRBRRCMbOBhQRQa5GD7ZvCEJEBED+BEDbDDRIpjRnYlJ4AibYJqD0JMDAjU4Yx7hfa7y/asHsFHBg62T+pQ5VRqxcA8yoANAUDKqjoLzyDGjs6OBIij4rQtKwg0rEivylgdbKC/hEyoD+DNhyh0C8574C4zCjysE3S/qCEtIzDtKB6aEpr8B4ClLxZ8DhaCb2Yibg5z4DSyDIAi5KFeyqE5aUgl7FjbbVTDbVQ0DUh/TWECGaCeZ14uEIjNKxyiioBWyOGUi/q4DuFuYUAADkyA9WiG0YmsuQHg28B83sNISIWh0hA6LwPk/Eu4lAv0A0G4NIUQjK1Us4cISRthnBqAt2rymhXYRcnIN4pKamuhVA+h7uBgpBWSLySIlB9QGIJCI0yKhIm888JAUkFWLi04WA4+kGfAOOSOeOzu4IA8SgWwDUcIJOwQmKPReqt484UhM2Vx/qJ+Pet0ToAM122OMotWlcCAjQ1OEh2OuOKO9xRO6hbxJAScvUEGUcusxoYxSMLApSZwouvBc4M2CUQoASnyjQISc8XyicZh/OZIlhKAoQwR2SZ8qubwcGSoWAaS7IposSWWBMBCkmJCx6FCHGuK6MZUP0boY0G45qlIFYxYcuqAPwco1y+UCpooq+6AIOgShs7qRekwSYGCVqL2U8oitKDw6W9MFKMimWvmDKgesQKWIikcNpMssQJWIiRCbMCiheUgLOt+mB9+Xat6PasCL+M2cpm6H+26X+u6P+voB6/+zsM0dBpaK6a6Qx9cowXAsITx6AFBVB2wQoWRlEP2c40AsmoaJIgUH4PAlZGg1Z+cm8UgSxJAXApxvgAAvDAKTp8soaoVwAAPIeyEhqyxAZxpB9k0ohJqSNni4tnpS1w9QYZYZ6QGTlgwFtLwEUCIHIHrgUYDwZyUCPqtwrC+g4qfJDhZkziXyXFtJQZfE7plgQQfzKiALpmjSZl07ZmXy5l0D5mFkbFbHcA7E+J0JtwVlLk1nzxtz471IwV7zLm5CbxIjVJyiIWdDIW/aoUUCfJnwSFHDzl6AwB/CRCxDEnNlwXUX4VsyABJhHCCBW0psVQOBVvgTFBTUrhVWXBZhbUucDhYuShXBWigJdhbxTRa2fPERfesyAiQwPjujuCF2fgLrn2aiaReRU1FRU2fhWBvpbJmzLVAWeARuVAdubAXuQeeIEeS5APJGFEI+ZnOprcYiQThjsTJZk8QJJ9K8aTkklRqhqAWAEYASEqBYp+DQHKIgJAVuUZDFYDDHtUEMhRugdRlgXRrgQxo4ExoQW6lFLOgYD5EcVJElXVFfjCLCElQMqlVXMeaSOUk5R1EOLyQcLnGQQACKagUC6hdVWA+RFDEj4amREZ8AJXeUUgpU1AN6/RugTo37pipTFhrTNywjVlyg9UEb9WDVFBdXPTVBrDlirmTCip2Low46gi2pyg0VyjcjnBonipniMA8bowVXtCmIbUhAPUzBrAuYkAzSYnxDkh8htyY6hAOFBHcAtwVUJhJiEip7Va/JECwAVZ8BtxGpMC5BaLbAbTLV2y+4hBHatQyy1V5omGsiyDEg4F41+C7DoxsBHibjqlxiDQaiJgibFEvwZS9VmRKr64dxdxHBGrqAiRYLtl1QNSgqXh/FPnFIsCsGP5gg6qeiAXipEAwiu5EDQYYwxzjhyiyCRC606wHBoK+BeGRYiZAa60lI6zZRYC4G60gpgCRxHa7hu5a4cG628CiBICloNagbPX8Dhw+G/Ta1OwGAzRhodAUguWzBARK30FhDwFwjbVai7VDWrms0WLkLYpfx5k5oeDVYzWtLVy/SciICZJnj0zUxTVejByzXHz7FYALWRC0AvAABClWLIOduYl8ltQmUQy+bwVdbA9NcY6MMwSCn8pGgtyewtuA+N6oFVZU7ZS988jgLkF8VAgR7MdYrmc2OalIZUJwnc/g9ASgiYFiwdQI7iz+BG/NfAim4CjZzc+oBwtat0LI7hWRwo2N4CXuMVm8BANCgdutKNaNLQmN9aXUKlOt7e6oBRpSWCxibwcBQ4o1hGqIvdq4NoedvYsa9An2ty+if2IQXIWgnReQgOXgbI+AMwCuC2ho+0UQ/goJnYm4KgBI7INAg4kQVA+D7MaUQxlwbA2o39FN0q8g1wcB3Y9udWh1Td5dpam4t2GAYApdc1w276to7AX6ACK9xNTA7ZyAvtRoDtYAiALD16/DygNoeq3GcoWqwdwRooJwRUbDDNDi3YqMLQlkCaSa8AjqLaLwiQ7c64Q4Ktp0lilA3hwmv0Jh74ckZYtOC+TebwTN/E9MrNAcr9V9GSv03N0xNRFSogmiGoj941HMvJaCY0LhcNgxG6YY5obdnoWjzdv0vukZVTODcCpoExc+rlfu+JeADOfN1TtTpo1wsgFTHQFiuomSm94jB0vg4awKHcYIkdd+HaD+MZEZdWUZQUKK7+QYA43+E6yZf+xVUA3ISpvRt4+Zm1uA6dfVLze1B1jdx1wy8mcOv6aIJA0VLY8VuGQLWFM1aVDlBgkAugkAcqx2r6LV+ZDhd43AQo8O2lmh0LsL8LpNL5GQcI5NEK+aVN6L5YVJugZFWLGBoB6kXEU6X+Ohgk2VDcYkRwEkaAa+eVBBKTjaikbEKknEuEYowweCtAiASIRmpwx6D26pwY2EBg9LNIAA7AEiQAAJzqsBIBIPAPAACsdw9w4sASjQyrU4aAerNwar9wFiOr4s4sNwWwAAbI0AEqpIYPS3qw8I0EUAIEUMq8q7QHcMG2gE6yQAEsqwwLQAIOLI0Ka6qzcGgGgHcEUGgDq+q0UAEgwAIE6+6xpFAOROoGKxK1K8cXQBDPTPoEAA=== -->

<!-- internal state end -->
<!-- finishing_touch_checkbox_start -->

<details open="true">
<summary>✨ Finishing Touches</summary>

- [ ] <!-- {"checkboxId": "7962f53c-55bc-4827-bfbf-6a18da830691"} --> 📝 Generate Docstrings

</details>

<!-- finishing_touch_checkbox_end -->
<!-- tips_start -->

---

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

<details>
<summary>❤️ Share</summary>

- [X](https://twitter.com/intent/tweet?text=I%20just%20used%20%40coderabbitai%20for%20my%20code%20review%2C%20and%20it%27s%20fantastic%21%20It%27s%20free%20for%20OSS%20and%20offers%20a%20free%20trial%20for%20the%20proprietary%20code.%20Check%20it%20out%3A&url=https%3A//coderabbit.ai)
- [Mastodon](https://mastodon.social/share?text=I%20just%20used%20%40coderabbitai%20for%20my%20code%20review%2C%20and%20it%27s%20fantastic%21%20It%27s%20free%20for%20OSS%20and%20offers%20a%20free%20trial%20for%20the%20proprietary%20code.%20Check%20it%20out%3A%20https%3A%2F%2Fcoderabbit.ai)
- [Reddit](https://www.reddit.com/submit?title=Great%20tool%20for%20code%20review%20-%20CodeRabbit&text=I%20just%20used%20CodeRabbit%20for%20my%20code%20review%2C%20and%20it%27s%20fantastic%21%20It%27s%20free%20for%20OSS%20and%20offers%20a%20free%20trial%20for%20proprietary%20code.%20Check%20it%20out%3A%20https%3A//coderabbit.ai)
- [LinkedIn](https://www.linkedin.com/sharing/share-offsite/?url=https%3A%2F%2Fcoderabbit.ai&mini=true&title=Great%20tool%20for%20code%20review%20-%20CodeRabbit&summary=I%20just%20used%20CodeRabbit%20for%20my%20code%20review%2C%20and%20it%27s%20fantastic%21%20It%27s%20free%20for%20OSS%20and%20offers%20a%20free%20trial%20for%20proprietary%20code)

</details>

<details>
<summary>🪧 Tips</summary>

### Chat

There are 3 ways to chat with [CodeRabbit](https://coderabbit.ai?utm_source=oss&utm_medium=github&utm_campaign=deepmodeling/deepmd-kit&utm_content=4828):

- Review comments: Directly reply to a review comment made by CodeRabbit. Example:
  - `I pushed a fix in commit <commit_id>, please review it.`
  - `Explain this complex logic.`
  - `Open a follow-up GitHub issue for this discussion.`
- Files and specific lines of code (under the "Files changed" tab): Tag `@coderabbitai` in a new review comment at the desired location with your query. Examples:
  - `@coderabbitai explain this code block.`
  -	`@coderabbitai modularize this function.`
- PR comments: Tag `@coderabbitai` in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
  - `@coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.`
  - `@coderabbitai read src/utils.ts and explain its main purpose.`
  - `@coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.`
  - `@coderabbitai help me debug CodeRabbit configuration file.`

### Support

Need help? Create a ticket on our [support page](https://www.coderabbit.ai/contact-us/support) for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

### CodeRabbit Commands (Invoked using PR comments)

- `@coderabbitai pause` to pause the reviews on a PR.
- `@coderabbitai resume` to resume the paused reviews.
- `@coderabbitai review` to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
- `@coderabbitai full review` to do a full review from scratch and review all the files again.
- `@coderabbitai summary` to regenerate the summary of the PR.
- `@coderabbitai generate docstrings` to [generate docstrings](https://docs.coderabbit.ai/finishing-touches/docstrings) for this PR.
- `@coderabbitai generate sequence diagram` to generate a sequence diagram of the changes in this PR.
- `@coderabbitai resolve` resolve all the CodeRabbit review comments.
- `@coderabbitai configuration` to show the current CodeRabbit configuration for the repository.
- `@coderabbitai help` to get help.

### Other keywords and placeholders

- Add `@coderabbitai ignore` anywhere in the PR description to prevent this PR from being reviewed.
- Add `@coderabbitai summary` to generate the high-level summary at a specific location in the PR description.
- Add `@coderabbitai` anywhere in the PR title to generate the title automatically.

### CodeRabbit Configuration File (`.coderabbit.yaml`)

- You can programmatically configure CodeRabbit by adding a `.coderabbit.yaml` file to the root of your repository.
- Please see the [configuration documentation](https://docs.coderabbit.ai/guides/configure-coderabbit) for more information.
- If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: `# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json`

### Documentation and Community

- Visit our [Documentation](https://docs.coderabbit.ai) for detailed information on how to use CodeRabbit.
- Join our [Discord Community](http://discord.gg/coderabbit) to get help, request features, and share feedback.
- Follow us on [X/Twitter](https://twitter.com/coderabbitai) for updates and announcements.

</details>

<!-- tips_end -->

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
deepmd/pd/model/descriptor/repflow_layer.py (1)

752-752: Remove unused variable assignment.

The variable nall is assigned but never used in this scope. Consider removing this assignment to clean up the code.

-        nall = node_ebd_ext.shape[1]
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f8f01cb and dd70e99.

📒 Files selected for processing (3)
  • deepmd/pd/model/descriptor/repflow_layer.py (9 hunks)
  • deepmd/pd/model/descriptor/repflows.py (2 hunks)
  • deepmd/pd/model/network/utils.py (5 hunks)
🧰 Additional context used
🧠 Learnings (2)
deepmd/pd/model/descriptor/repflows.py (2)
Learnt from: njzjz
PR: deepmodeling/deepmd-kit#4160
File: deepmd/dpmodel/utils/env_mat.py:52-64
Timestamp: 2024-10-08T15:32:11.479Z
Learning: Negative indices in `nlist` are properly handled by masking later in the computation, so they do not cause issues in indexing operations.
Learnt from: njzjz
PR: deepmodeling/deepmd-kit#4160
File: deepmd/dpmodel/utils/env_mat.py:52-64
Timestamp: 2024-09-24T01:59:37.973Z
Learning: Negative indices in `nlist` are properly handled by masking later in the computation, so they do not cause issues in indexing operations.
deepmd/pd/model/network/utils.py (2)
Learnt from: njzjz
PR: deepmodeling/deepmd-kit#4160
File: deepmd/dpmodel/utils/env_mat.py:52-64
Timestamp: 2024-09-24T01:59:37.973Z
Learning: Negative indices in `nlist` are properly handled by masking later in the computation, so they do not cause issues in indexing operations.
Learnt from: njzjz
PR: deepmodeling/deepmd-kit#4160
File: deepmd/dpmodel/utils/env_mat.py:52-64
Timestamp: 2024-10-08T15:32:11.479Z
Learning: Negative indices in `nlist` are properly handled by masking later in the computation, so they do not cause issues in indexing operations.
🧬 Code Graph Analysis (1)
deepmd/pd/model/network/utils.py (1)
source/tests/consistent/descriptor/test_dpa3.py (1)
  • data (78-134)
🪛 Ruff (0.11.9)
deepmd/pd/model/descriptor/repflow_layer.py

752-752: Local variable nall is assigned to but never used

Remove assignment to unused variable nall

(F841)

🪛 Flake8 (7.2.0)
deepmd/pd/model/descriptor/repflow_layer.py

[error] 752-752: local variable 'nall' is assigned to but never used

(F841)

⏰ Context from checks skipped due to timeout of 90000ms (21)
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Test C++ (false)
  • GitHub Check: Test Python (6, 3.9)
  • GitHub Check: Test Python (6, 3.12)
  • GitHub Check: Test Python (5, 3.12)
  • GitHub Check: Test Python (3, 3.12)
  • GitHub Check: Test Python (5, 3.9)
  • GitHub Check: Test Python (4, 3.9)
  • GitHub Check: Test Python (2, 3.9)
  • GitHub Check: Test Python (4, 3.12)
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Test Python (3, 3.9)
  • GitHub Check: Test Python (1, 3.9)
  • GitHub Check: Test Python (2, 3.12)
  • GitHub Check: Test C++ (true)
  • GitHub Check: Analyze (c-cpp)
🔇 Additional comments (13)
deepmd/pd/model/descriptor/repflows.py (2)

518-519: LGTM: Tensor shape initialization updated to match new indexing conventions.

The initialization of edge_index and angle_index has been correctly updated from shapes [1, 3] to [2, 1] and [3, 1] respectively, which aligns with the new tensor dimension conventions where these tensors are structured as [2, n_edge] and [3, n_angle].


570-570: LGTM: Indexing updated to match new tensor shape convention.

The change from edge_index[:, 0] (column-wise indexing) to edge_index[0] (row-wise indexing) is consistent with the new tensor shape where edge_index has dimensions [2, n_edge] instead of [n_edge, 2].

deepmd/pd/model/network/utils.py (5)

32-43: LGTM: Optimized bin_count computation for better performance.

This optimization computes bin_count only when needed (num_owner is None or averaging is requested), which can improve performance in cases where bincount computation is expensive and unnecessary.


46-50: LGTM: Improved tensor initialization and assertion for safety.

The output tensor initialization now consistently uses num_owner for the first dimension, and the assertion ensures bin_count is not None before division when averaging is requested, preventing potential runtime errors.


59-59: LGTM: Added use_loc_mapping parameter for flexible index calculation.

The new parameter provides control over how frame_shift is computed, allowing the function to work with different indexing schemes based on whether local mapping is used.


109-111: LGTM: Frame shift calculation adapted for different mapping modes.

The conditional logic correctly adjusts the frame shift calculation based on use_loc_mapping, using either nall or nloc as the multiplier, which maintains proper indexing behavior across different execution modes.


140-143: LGTM: Tensor stacking changes align with new indexing conventions.

The change from concatenation to stacking transforms the output tensor shapes from [n_edge, 2] and [n_angle, 3] to [2, n_edge] and [3, n_angle] respectively, which standardizes the indexing convention across the codebase.

deepmd/pd/model/descriptor/repflow_layer.py (6)

375-377: LGTM: Improved code clarity with explicit tensor operations.

The change from ellipsis-based indexing to explicit .unsqueeze() calls makes the tensor operations more explicit and easier to understand, improving code readability.


588-592: LGTM: Standardized tensor reshaping with list arguments.

Using list arguments for reshape() instead of tuple unpacking is more explicit and consistent with modern tensor operation practices.


703-704: LGTM: Documentation updated to reflect new tensor shapes.

The function signature documentation correctly reflects the new tensor shapes 2 x n_edge and 3 x n_angle, maintaining consistency with the implementation changes.


764-769: LGTM: Updated tensor indexing to match new shape conventions.

The change from column-wise indexing (e.g., edge_index[:, 0]) to row-wise indexing (e.g., edge_index[0]) correctly adapts to the new tensor shapes where edge_index has dimensions [2, n_edge] and angle_index has dimensions [3, n_angle].


900-900: LGTM: Consistent use of explicit reshaping.

The change to use list arguments in reshape() maintains consistency with the other tensor operations in the codebase.


1046-1049: LGTM: Explicit tensor operations improve readability.

The use of explicit .unsqueeze() calls instead of implicit broadcasting makes the tensor operations more transparent and easier to debug.

@codecov
Copy link

codecov bot commented Jun 30, 2025

Codecov Report

❌ Patch coverage is 72.72727% with 6 lines in your changes missing coverage. Please review.
✅ Project coverage is 84.71%. Comparing base (c151e04) to head (52f3b37).
⚠️ Report is 81 commits behind head on devel.

Files with missing lines Patch % Lines
deepmd/pd/model/network/utils.py 57.14% 6 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##            devel    #4828      +/-   ##
==========================================
+ Coverage   84.57%   84.71%   +0.14%     
==========================================
  Files         699      699              
  Lines       68070    68074       +4     
  Branches     3541     3541              
==========================================
+ Hits        57567    57666      +99     
+ Misses       9369     9273      -96     
- Partials     1134     1135       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@HydrogenSulfate HydrogenSulfate changed the title [WIP] pd: support pd dynamic shape [WIP] pd: support dpa3 dynamic shape for pd backend Jun 30, 2025
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
deepmd/pd/model/descriptor/repflow_layer.py (1)

768-768: Remove unused variable assignment.

The variable nall is assigned but never used in the function. This creates unnecessary overhead and can be confusing.

Apply this diff to remove the unused variable:

-        nall = node_ebd_ext.shape[1]
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between dd70e99 and 52f3b37.

📒 Files selected for processing (4)
  • deepmd/pd/model/descriptor/repflow_layer.py (9 hunks)
  • deepmd/pd/model/descriptor/repflows.py (3 hunks)
  • deepmd/pd/model/network/utils.py (5 hunks)
  • source/tests/pd/model/test_dynamic_sel.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • deepmd/pd/model/descriptor/repflows.py
  • deepmd/pd/model/network/utils.py
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: njzjz
PR: deepmodeling/deepmd-kit#4144
File: source/api_cc/tests/test_deeppot_dpa_pt.cc:166-246
Timestamp: 2024-09-19T04:25:12.408Z
Learning: Refactoring between test classes `TestInferDeepPotDpaPt` and `TestInferDeepPotDpaPtNopbc` is addressed in PR #3905.
Learnt from: njzjz
PR: deepmodeling/deepmd-kit#4144
File: source/api_cc/tests/test_deeppot_dpa_pt.cc:166-246
Timestamp: 2024-10-08T15:32:11.479Z
Learning: Refactoring between test classes `TestInferDeepPotDpaPt` and `TestInferDeepPotDpaPtNopbc` is addressed in PR #3905.
source/tests/pd/model/test_dynamic_sel.py (2)
Learnt from: njzjz
PR: deepmodeling/deepmd-kit#4144
File: source/api_cc/tests/test_deeppot_dpa_pt.cc:166-246
Timestamp: 2024-10-08T15:32:11.479Z
Learning: Refactoring between test classes `TestInferDeepPotDpaPt` and `TestInferDeepPotDpaPtNopbc` is addressed in PR #3905.
Learnt from: njzjz
PR: deepmodeling/deepmd-kit#4144
File: source/api_cc/tests/test_deeppot_dpa_pt.cc:166-246
Timestamp: 2024-09-19T04:25:12.408Z
Learning: Refactoring between test classes `TestInferDeepPotDpaPt` and `TestInferDeepPotDpaPtNopbc` is addressed in PR #3905.
🪛 Ruff (0.11.9)
deepmd/pd/model/descriptor/repflow_layer.py

768-768: Local variable nall is assigned to but never used

Remove assignment to unused variable nall

(F841)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (30)
  • GitHub Check: Test Python (6, 3.12)
  • GitHub Check: Test Python (6, 3.9)
  • GitHub Check: Test Python (4, 3.9)
  • GitHub Check: Test Python (3, 3.12)
  • GitHub Check: Test Python (4, 3.12)
  • GitHub Check: Test Python (5, 3.12)
  • GitHub Check: Test Python (5, 3.9)
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Test Python (3, 3.9)
  • GitHub Check: Test Python (2, 3.12)
  • GitHub Check: Test Python (2, 3.9)
  • GitHub Check: Test Python (1, 3.9)
  • GitHub Check: Test C++ (true)
  • GitHub Check: Test C++ (false)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Analyze (python)
  • GitHub Check: Analyze (javascript-typescript)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
🔇 Additional comments (12)
source/tests/pd/model/test_dynamic_sel.py (1)

1-156: Well-structured test for dynamic selection consistency validation.

The test comprehensively validates that enabling dynamic selection (use_dynamic_sel=True) produces numerically equivalent results to the standard implementation. The parameter combinations cover key configuration options, and the tolerance settings are appropriately configured for different precision types.

The test methodology is sound:

  1. Creates two identical descriptor instances with different dynamic selection settings
  2. Uses the same random seed for reproducibility
  3. Compares outputs using appropriate numerical tolerances
  4. Tests multiple precision types and parameter combinations
deepmd/pd/model/descriptor/repflow_layer.py (11)

391-391: Improved tensor broadcasting clarity.

The change from ellipsis-based indexing to explicit .unsqueeze() calls improves code readability and makes the broadcasting operation more explicit and clear.


605-608: Improved tensor reshaping with explicit list argument.

The change to use a list argument [nf * nloc, sub_node_update.shape[-1]] instead of implicit tuple unpacking improves code clarity and follows best practices for tensor reshaping operations.


687-690: Consistent tensor reshaping improvement.

Good consistency with the previous change, using explicit list arguments for reshape operations.


696-699: Consistent tensor reshaping pattern.

Maintains the same improved pattern of using explicit list arguments for tensor reshaping operations.


719-720: Updated parameter documentation to reflect new tensor shapes.

The documentation correctly reflects the change from column-wise indexing (n_edge x 2) to row-wise indexing (2 x n_edge) for edge_index and (3 x n_angle) for angle_index.


745-756: Comprehensive documentation update for new indexing convention.

The parameter documentation has been thoroughly updated to reflect the new tensor shapes and indexing patterns. This maintains consistency with the implementation changes.


767-767: Extract shape information for dynamic selection logic.

The unpacking of nb, nloc, nnei from nlist.shape provides the necessary dimensions for the dynamic selection logic that follows.


775-777: Proper n_edge computation for dynamic selection.

The logic correctly sets n_edge = None when dynamic selection is disabled and n_edge = h2.shape[0] when enabled. This aligns with the different data structures used in each mode.


780-785: Updated indexing to match new tensor shapes.

The changes from edge_index[:, 0] to edge_index[0] and similar updates for angle_index correctly reflect the new tensor shapes where the first dimension represents the index type rather than the second dimension.


916-916: Consistent reshape operation improvement.

The change to use a list argument in the reshape operation maintains consistency with the earlier improvements in the file.


1062-1065: Explicit tensor operations for switch function application.

The replacement of implicit broadcasting with explicit .unsqueeze() calls makes the tensor operations more readable and explicit, improving code maintainability.

@HydrogenSulfate HydrogenSulfate changed the title [WIP] pd: support dpa3 dynamic shape for pd backend pd: support dpa3 dynamic shape for pd backend Jul 9, 2025
@njzjz njzjz requested a review from caic99 July 10, 2025 13:40
@njzjz njzjz enabled auto-merge July 10, 2025 14:29
@njzjz njzjz added this pull request to the merge queue Jul 10, 2025
Merged via the queue into deepmodeling:devel with commit 1eefc8e Jul 10, 2025
60 checks passed
@HydrogenSulfate HydrogenSulfate deleted the support_pd_dynamic_shape branch July 23, 2025 14:07
This was referenced Aug 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants