Skip to content

[PIR-Auto-Parallel]refactor recompute pass in PIR mode #69681

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Dec 5, 2024

Conversation

waliwali777
Copy link
Contributor

@waliwali777 waliwali777 commented Nov 25, 2024

PR Category

Auto Parallel

PR Types

Performance

Description

基于PIR 对重计算 pass 进行重构
(参考旧 IR 下重计算实现:#38920
PCard-88114

Copy link

paddle-bot bot commented Nov 25, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@waliwali777 waliwali777 force-pushed the refactor_recomput_pass branch from d40ff3c to 2f7cdb3 Compare November 29, 2024 14:23
return segment_num + 1, rc_op_num

def run_test_cases(self):
self.strategy._recompute.enable = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why strategy.recompute used in pass but strategy._recompute used here ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strategy.recompute (in pass) is the recompute attribute of class paddle.distributed.Strategy()
strategy._recompute(in engine) is the _recompute attribute of class paddle.distributed.fleet.auto.Strategy()
The two classes all describe the configuration information of recompute and can be converted in to_static funxtion

assert (
base_segment_num < segment_num_1
and segment_num_1 < segment_num_2
and segment_num_2 < segment_num_3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check results more accurate.
eg.
assert base_segment_num == XX
assert op_num == XX

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

rc_end_id = len(block.ops)
for idx in range(rc_begin_id, rc_end_id):
rc_op = block.ops[idx]
rc_op.set_int_attr("fwd_recompute_id", _g_recompute_idx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use the prefix "fwd" ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fwd_compute_id indicates the checkpoint information in the forward that needs to be recomputed.
bwd_recompute_id corresponds to fwd_recompute_id and is newly added in the backward for recompute, which facilitates debugging.
There are comments in the code to provide explanations.

Copy link
Contributor

@liym27 liym27 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants