Skip to content

Conversation

@HydrogenSulfate
Copy link
Contributor

@HydrogenSulfate HydrogenSulfate commented Jun 25, 2025

PR Category

Operator Mechanism

PR Types

New features

Description

Pcard-75624

masked_select_grad静态图组合算子支持动态shape

与torch对比精度验证代码

import paddle
from paddle.static import InputSpec
import torch
import numpy as np

paddle.framework.core._set_prim_all_enabled(True)


for i in range(50):
    # paddle
    x_pd = paddle.randn([1024, 13, 14])
    x_pd.stop_gradient = False
    mask_pd = paddle.randint(low=0, high=2, shape=[7, 1, 13, 1]).astype("bool")
    with paddle.no_grad():
        t = paddle.masked_select(x_pd, mask_pd)
    dy_pd = paddle.randn(t.shape)
    dy_pd.stop_gradient = False
    ddx_pd = paddle.randn(x_pd.shape)
    ddx_pd.stop_gradient = False

    def g(x_pd_, mask_pd_, dy_pd, ddx_pd):
        y_pd = paddle.masked_select(x_pd_, mask_pd_)
        dx_pd, = paddle.grad(y_pd, x_pd_, dy_pd, create_graph=True)
        ddy_pd, = paddle.grad(dx_pd, dy_pd, ddx_pd, create_graph=True)
        return y_pd, dx_pd, ddy_pd

    g = paddle.jit.to_static(
        g,
        full_graph=True,
        input_spec=[
            InputSpec(shape=[-1, -1, -1], name="x"),
            InputSpec(shape=[-1, -1, -1, -1], name="mask"),
            InputSpec(shape=[-1], name="dy_pd"),
            InputSpec(shape=[-1, -1, -1], name="ddx_pd"),
        ],
    )
    y_pd, dx_pd, ddy_pd = g(x_pd, mask_pd, dy_pd, ddx_pd)
    # y_pd, dx_pd = g(x_pd, mask_pd, dy_pd)

    # torch
    x_pt = torch.from_dlpack(x_pd.detach()).requires_grad_(True)
    mask_pt = torch.from_dlpack(mask_pd)
    y_pt = torch.masked_select(x_pt, mask_pt)

    dy_pt = torch.from_dlpack(dy_pd.detach()).requires_grad_(True)
    dx_pt, = torch.autograd.grad(y_pt, x_pt, dy_pt, create_graph=True)

    ddx_pt = torch.from_dlpack(ddx_pd.detach()).requires_grad_(True)
    ddy_pt, = torch.autograd.grad(dx_pt, dy_pt, ddx_pt, create_graph=True)

    np.testing.assert_allclose(y_pd.numpy(), y_pt.detach().cpu().numpy(), 1e-6, 1e-6)
    np.testing.assert_allclose(dx_pd.numpy(), dx_pt.detach().cpu().numpy(), 1e-6, 1e-6)
    np.testing.assert_allclose(ddy_pd.numpy(), ddy_pt.detach().cpu().numpy(), 1e-6, 1e-6)

@HydrogenSulfate HydrogenSulfate merged commit 012a9f0 into PaddlePaddle:develop Jun 30, 2025
119 of 125 checks passed
@HydrogenSulfate HydrogenSulfate deleted the add_masked_select_dyshape_comp branch June 30, 2025 11:49
github-merge-queue bot pushed a commit to deepmodeling/deepmd-kit that referenced this pull request Jul 10, 2025
support running `input_torch_dynamic.json` with paddle backend(including
CINN)


TODO list:

- [x] PaddlePaddle/Paddle#73601
- [x] PaddlePaddle/Paddle#73622
- [x] PaddlePaddle/Paddle#73737
- [x] PaddlePaddle/Paddle#73747
- [x] PaddlePaddle/Paddle#73809
- [x] PaddlePaddle/Paddle#73761


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Summary by CodeRabbit

* **Bug Fixes**
* Resolved issues with tensor shape and indexing consistency, preventing
assertion errors during model execution.
* Improved handling of default tensor initialization to avoid JIT
assertion issues.

* **Refactor**
* Standardized tensor dimension handling and broadcasting for improved
clarity and maintainability.
* Enhanced code readability with clearer indexing conventions and
formatting.
* Updated aggregation logic for safer and more efficient tensor
operations.

* **New Features**
* Added an option to control graph index mapping behavior for greater
flexibility in advanced use cases.

* **Tests**
* Introduced comprehensive tests validating descriptor model consistency
with dynamic selection enabled.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants