-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix ernie ci auto trainer error #9758
Fix ernie ci auto trainer error #9758
Conversation
Thanks for your contribution! |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #9758 +/- ##
===========================================
+ Coverage 51.12% 52.37% +1.24%
===========================================
Files 732 730 -2
Lines 118947 115249 -3698
===========================================
- Hits 60814 60363 -451
+ Misses 58133 54886 -3247 ☔ View full report in Codecov by Sentry. |
paddlenlp/trainer/auto_trainer.py
Outdated
|
||
auto_dist_degree = { | ||
"tensor_parallel": training_args.tensor_parallel_degree > 1, | ||
"sequence_parallel": sequence_parallel, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
already move sequence_parallel
in training_args
, so use "sequence_parallel": training_args.sequence_parallel,
directly, no need line 110 above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
) | ||
# NOTE(zhangwl):in pipeline mode , param my be initialized before while delte init_func ,but param is still not is_initialized | ||
if not param._is_initialized() and param._init_func is not None: | ||
param.initialize() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if param._init_func is not None, should use param._init_func()
or model. _init_weights(Layer)
?
paddlenlp/trainer/auto_trainer.py
Outdated
if ( | ||
kwargs.get("args", None) is not None | ||
and kwargs["args"].use_intermediate_api | ||
and not parallelize.has_parallelized_model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
put judgment not parallelize.has_parallelized_model
into this branch, the judgment here is to determine whether to use the basic API or the intermediate API
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
paddlenlp/trainer/auto_trainer.py
Outdated
and kwargs["args"].use_intermediate_api | ||
and not parallelize.has_parallelized_model | ||
): | ||
if self.auto_dist_config is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can determine parallelize.has_parallelized_model
here, if yes, must have auto_dist_config
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -495,7 +494,7 @@ def main(): | |||
config.recompute_granularity = model_args.recompute_granularity | |||
config.virtual_pp_degree = model_args.virtual_pp_degree | |||
config.sequence_parallel = training_args.sequence_parallel | |||
config.fuse_sequence_parallel_allreduce = training_args.fuse_sequence_parallel_allreduce |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qwen模型没有这个问题是吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Bug fixes
PR changes
Others
Description
修复Ernie ci error bug
允许用户自己控制param_init和使用中层api的时机。