Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add nonzero in ops_infer_shape_in_runtime #69027

Merged
merged 8 commits into from
Oct 31, 2024

Conversation

fxfxfxfxfxfxfxfx
Copy link
Contributor

@fxfxfxfxfxfxfxfx fxfxfxfxfxfxfxfx commented Oct 29, 2024

PR Category

Auto Parallel

PR Types

Bug fixes

Description

In a distributed scenario, the output shape of the nonzero method is incorrect (the first dimension is -1), which is abnormal. This issue arises because the shape of nonzero is generated after the computation, but the framework does not place it into ops_infer_shape_in_runtime. The following code can reproduce the corresponding problem.

import paddle
import paddle.distributed as dist
x = paddle.to_tensor([1, 0, 1, 1])
mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
x = dist.shard_tensor(x,mesh,[dist.Replicate()])
out = paddle.nonzero(x)
print(out.shape)

Added a unit test for MoE layers with shared experts, and that multiple experts can be saved on a device
未命名绘图 drawio (8)

Copy link

paddle-bot bot commented Oct 29, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Oct 29, 2024
Copy link
Contributor

@pkuzyc pkuzyc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR 的介绍改一下,现在的介绍和pr内容不一致

num_or_sections=self.config.num_experts
// self.config.num_devices,
axis=0,
)[j]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

split 可以放循环外面,不用每次循环都 split

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感谢您的代码审查,已经按照意见修改

Copy link
Contributor

@pkuzyc pkuzyc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@luotao1 luotao1 merged commit a36943a into PaddlePaddle:develop Oct 31, 2024
28 checks passed
@luotao1
Copy link
Contributor

luotao1 commented Oct 31, 2024

hi, @fxfxfxfxfxfxfxfx

  • 非常感谢你对飞桨的贡献,我们正在运营一个PFCC组织,会通过定期分享技术知识与发布开发者主导任务的形式持续为飞桨做贡献,详情可见 https://github.com/luotao1 主页说明。
  • 如果你对PFCC有兴趣,请发送邮件至 ext_paddle_oss@baidu.com,我们会邀请你加入~

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants