-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Semi-Auto] Add reduction spmd rule #54991
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
@@ -24,6 +25,7 @@ namespace auto_parallel { | |||
|
|||
// matmul rule | |||
REGISTER_SPMD_RULE(matmul, MatmulSPMDRule); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
register with op_name specifically
reduce_sum, sum, reduce_max, reduce_min
mean
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
// step1: Build Einsum Notation | ||
bool keep_dim = ExtractAttr<bool>("keep_dim", attrs); | ||
// bool keep_dim = false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
bool keep_dim = ExtractAttr<bool>("keep_dim", attrs); | ||
// bool keep_dim = false; | ||
std::vector<int64_t> reduce_dims = | ||
ExtractAttr<std::vector<int64_t>>("dim", attrs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where the "dim" come from ? should it be "axis" ?
Phi API:
- op : max
args : (Tensor x, IntArray axis={}, bool keepdim=false) - op : mean
args : (Tensor x, IntArray axis={}, bool keepdim=false) - op : sum
args : (Tensor x, IntArray axis={}, DataType dtype=DataType::UNDEFINED, bool keepdim=false)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
modified to "axis" now. "dim" is the attribute name in static mode.
std::vector<TensorDistAttr> new_input_dist_attrs; | ||
std::vector<TensorDistAttr> output_dist_attrs; | ||
output_dist_attrs.emplace_back(output_dist_attr); | ||
// step2.3: update the input dist_attr if reshard is needed. When |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the replicate logic for reduce axis is wrong.
if the op is a "linearity"(e.g. sum, all, mean), and the reduce dim is shared, there is no need to reshard the reduce axis as replicated, and we just need mark this axis of output tensor as "Partial" by the sharding mesh dim.
if the op is a "non-linearity"(e.g. variance, ), replicate logic is need.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
bool keep_dim = ExtractAttr<bool>("keep_dim", attrs); | ||
// bool keep_dim = false; | ||
std::vector<int64_t> reduce_dims = | ||
ExtractAttr<std::vector<int64_t>>("dim", attrs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there should be another attribute for reduce,
"reduce_type": sum, max, min, mean
"linearity": true/false
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added a "linearity" attribute.
|
||
// step2.4: handle partial | ||
// Step2.4.1 Output Partial | ||
std::vector<int64_t> partial_on_dims = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should use this logic to infer the partial dim.
if a axis is missing in output tensor, and this axis is sharded in input tensor,
this axis would be Partial on the dim in output tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
4bd765c
to
b04a262
Compare
|
||
// step2.4: handle partial | ||
// Step2.4.1 Output Partial | ||
// If the op is a linear op, i.e. `linearity` is true, the output's |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
non-linear op requires its input to be non-partial. but could generate partial output.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* add reduction spmd rule for auto parallel * fix the logic of handling partial * fix code style * fix the partial handling
PR types
New features
PR changes
Others
Description
Pcard-70448