Skip to content

Conversation

@Difers
Copy link
Contributor

@Difers Difers commented Aug 11, 2025

PR Category

Operator Mechanism

PR Types

New features

Description

Pcard-73145

添加paddle.narrow

  • torch下的参考实现:
    torch是直接调用其slice算子,slice是调用了as_strided返回视图,保证返回的tensor与输入tesnor共享内存 https://github.com/pytorch/pytorch/blob/ee1b0412b919dfb358d5a697b3be49621497fbc2/aten/src/ATen/native/TensorShape.cpp#L1665
  • paddle的slice现在基本会调用stride版的,也会共享内存,但尚未清晰slice在什么情况下会退化到原版返回新tensor,因此采用了as_strided进行实现。
  • as_strided与torch有部分差异,torch offset是元素数量,paddle是byte数,此外,其反向kernel的CheckStride会有除0错误,进行了修复
  • 在测试中加入了输入输出是否共享的检测,通过修改narrow后的某个元素值,查看输入是否同步变化进行检测
  • 此外,as_strided还剩有0-size问题,暂无人力修复,相关参考资料:maximum minimum support 0-size #71829
  • xpu中as_strided似乎还有问题

@paddle-bot
Copy link

paddle-bot bot commented Aug 11, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@yuanlehome yuanlehome self-assigned this Aug 11, 2025
@codecov-commenter
Copy link

codecov-commenter commented Aug 11, 2025

Codecov Report

❌ Patch coverage is 96.00000% with 1 line in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@8b5c009). Learn more about missing BASE report.

Files with missing lines Patch % Lines
python/paddle/tensor/manipulation.py 96.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop   #74546   +/-   ##
==========================================
  Coverage           ?   96.00%           
==========================================
  Files              ?        1           
  Lines              ?       25           
  Branches           ?        0           
==========================================
  Hits               ?       24           
  Misses             ?        1           
  Partials           ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@Difers Difers force-pushed the add_narrow branch 3 times, most recently from 3f9c7b8 to f16d8d8 Compare August 15, 2025 06:30
def narrow(
input: Tensor,
dim: int,
start: Sequence[int | Tensor],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

类型注解这里,是 int | Tensor

[8, 6]
>>> # the stride is [6, 1].
"""
offset *= paddle.core.size_of_dtype(x.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里改动应该会不兼容,可以把这行改动放到 paddle.narrow 里面去

zhwesky2010
zhwesky2010 previously approved these changes Aug 19, 2025
Copy link
Contributor

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhwesky2010 zhwesky2010 requested a review from SigureMo August 19, 2025 11:45
SigureMo
SigureMo previously approved these changes Aug 19, 2025
XiaoguangHu01
XiaoguangHu01 previously approved these changes Aug 20, 2025
Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

zhwesky2010
zhwesky2010 previously approved these changes Aug 21, 2025
SigureMo
SigureMo previously approved these changes Aug 21, 2025
@Difers Difers dismissed stale reviews from SigureMo and zhwesky2010 via 51e0a30 August 21, 2025 05:46
@Difers Difers force-pushed the add_narrow branch 2 times, most recently from 51e0a30 to 31f0973 Compare August 21, 2025 06:35
@Difers
Copy link
Contributor Author

Difers commented Aug 21, 2025

/re-run all-failed

1 similar comment
@Difers
Copy link
Contributor Author

Difers commented Aug 21, 2025

/re-run all-failed

@Difers
Copy link
Contributor Author

Difers commented Aug 22, 2025

/re-run all-failed

1 similar comment
@Difers
Copy link
Contributor Author

Difers commented Aug 22, 2025

/re-run all-failed

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@luotao1 luotao1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for skipif

@zhwesky2010 zhwesky2010 merged commit 2ae2249 into PaddlePaddle:develop Aug 25, 2025
87 of 97 checks passed
@ooooo-create
Copy link
Contributor

@Difers as_strided 0-size 这个 pr 在修了 #74860

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants