-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoParallel] Recompute Pass #38920
Conversation
Thanks for your contribution! |
def init(self): | ||
if paddle.is_compiled_with_cuda(): | ||
paddle.set_flags({'FLAGS_cudnn_deterministic': 1}) | ||
self.rtol = 1e-5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
重计算能做到精度逐位对齐吧?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
基本可以做到逐位对齐,但会偶尔有几个step第6位开始对不齐,误差在1e-6
optimizer = paddle.fluid.optimizer.AdamOptimizer( | ||
learning_rate=0.00001, | ||
beta1=0.9, | ||
beta2=0.999, | ||
epsilon=1e-08, | ||
grad_clip=clip) | ||
grad_clip=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么不支持 clip
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已支持。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -190,7 +193,7 @@ def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False): | |||
# serial forward pass | |||
self._apply_pre_optimization_passed(completed_main_program, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename _apply_pre_optimization_passed
to _apply_pre_optimization_passes
and _apply_post_optimization_passed
to _apply_post_optimization_passes
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
@@ -26,6 +26,9 @@ | |||
from .dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute | |||
from .process_group import new_process_group, ProcessGroup, _g_process_group_map | |||
|
|||
# NOTE: If op in SPECIAL_OPS, it will not be resharded. | |||
SPECIAL_OPS = ['check_finite_and_unscale', 'update_loss_scaling'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The global variable should use _g_xxxx.
Please rename SPECIAL_OPS
to _g_special_ops
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
self._ops = ops | ||
self.var_op_deps = {} | ||
|
||
def build_stats(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
build_stats
is build_state
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function inherits from ProgramStats
of backward.py
.
|
||
return segments | ||
|
||
def modify_forward_desc_for_recompute(self, dist_context): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add more comments. For example, what is the purpose of modify_forward_desc_for_recompute
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
PR types
New features
PR changes
Others
Describe