-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 5th No.50】 为 Paddle 新增 slice 的 spmd 切分推导规则 #57866
Changes from all commits
0ac5deb
c2fe306
700a145
ccbbdb4
275937c
6ef2f15
6540b66
520ba03
b74747d
cd05691
0ed757e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/phi/infermeta/spmd_rules/slice.h" | ||
|
||
#include "glog/logging.h" | ||
|
||
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" | ||
#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" | ||
#include "paddle/phi/core/distributed/auto_parallel/utils.h" | ||
#include "paddle/phi/infermeta/spmd_rules/utils.h" | ||
|
||
namespace phi { | ||
namespace distributed { | ||
|
||
using phi::distributed::auto_parallel::str_join; | ||
|
||
SpmdInfo SliceInferSpmd(const DistMetaTensor& input, | ||
const std::vector<int64_t>& axes, | ||
const std::vector<int>& starts, | ||
const std::vector<int>& ends, | ||
const std::vector<int64_t>& infer_flags, | ||
const std::vector<int64_t>& decrease_axis) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 代码里面加上注释 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @pkuzyc 本地代码注释大体已加好,但是,涉及到下面两个点的地方,可能需要修改。麻烦老师先看一下~~ |
||
auto input_shape = phi::vectorize(input.dims()); | ||
int input_ndim = input_shape.size(); | ||
auto input_dist_attr_src = input.dist_attr(); | ||
std::vector<int64_t> input_dims_mapping = input_dist_attr_src.dims_mapping(); | ||
PADDLE_ENFORCE_EQ( | ||
input_ndim, | ||
input_dims_mapping.size(), | ||
phi::errors::InvalidArgument("The Tensor Input's rank [%d] and Input's " | ||
"dims_mapping size [%d] are not matched.", | ||
input_ndim, | ||
input_dims_mapping.size())); | ||
|
||
std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; | ||
std::string input_axes = alphabet.substr(0, input_ndim); | ||
std::string special_axes = alphabet.substr(input_ndim); | ||
|
||
for (int i = 0; i < static_cast<int>(axes.size()); i++) { | ||
int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i]; | ||
input_axes[axis] = special_axes[i]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为什么需要 special_axes?out_axes 里面切分的维度是 '1',对应的维度不会传到 output 上。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 老师您好。这里借鉴的是split的切分推导规则。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我觉得 slice 可以不用特殊标记,在最后的 log 里把 axes 打一下吧。其他规则里面也会用 'k' 做特殊标记,所以 split 里也用了 'k' 标记。slice 因为有多个 axis,不像 split 只有一个特殊维度,这么标的话也看不出哪些是特殊维度。 |
||
} | ||
|
||
std::string out_axes(input_axes); | ||
|
||
for (int i = 0; i < static_cast<int>(axes.size()); i++) { | ||
int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i]; | ||
out_axes[axis] = '1'; | ||
} | ||
|
||
std::unordered_map<std::string, int64_t> axis_to_dim_map = | ||
ShardingMergeForTensors({{input_axes, input_dims_mapping}}); | ||
|
||
std::vector<int64_t> out_dims_mapping = | ||
GetDimsMappingForAxes(out_axes, axis_to_dim_map); | ||
|
||
TensorDistAttr out_dist_attr = | ||
CopyTensorDistAttrForOutput(input_dist_attr_src); | ||
out_dist_attr.set_dims_mapping(out_dims_mapping); | ||
|
||
TensorDistAttr input_dist_attr_dst(input_dist_attr_src); | ||
for (int i = 0; i < static_cast<int>(axes.size()); i++) { | ||
int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i]; | ||
input_dims_mapping[axis] = -1; | ||
} | ||
input_dist_attr_dst.set_dims_mapping(input_dims_mapping); | ||
|
||
VLOG(4) << "SliceInferSpmd:"; | ||
VLOG(4) << "Einsum Notation: " << input_axes << "-->" << out_axes; | ||
VLOG(4) << "Input shape: [" << str_join(input_shape) << "] " | ||
<< "src_dims_mapping: [" | ||
<< str_join(input_dist_attr_src.dims_mapping()) << "] " | ||
<< "dst_dims_mapping: [" << str_join(input_dims_mapping) << "]"; | ||
VLOG(4) << "Output" | ||
<< " dims_mapping: [" << str_join(out_dims_mapping) << "]"; | ||
VLOG(4) << std::endl; | ||
|
||
return {{input_dist_attr_dst}, {out_dist_attr}}; | ||
} | ||
|
||
SpmdInfo SliceInferSpmdReverse(const DistMetaTensor& input, | ||
const DistMetaTensor& output, | ||
const std::vector<int64_t>& axes, | ||
const std::vector<int>& starts, | ||
const std::vector<int>& ends, | ||
const std::vector<int64_t>& infer_flags, | ||
const std::vector<int64_t>& decrease_axis) { | ||
auto output_shape = phi::vectorize(output.dims()); | ||
int out_ndim = output_shape.size(); | ||
auto out_dist_attr = output.dist_attr(); | ||
int out_dims_mapping_size = out_dist_attr.dims_mapping().size(); | ||
auto input_shape = phi::vectorize(input.dims()); | ||
int input_ndim = input_shape.size(); | ||
auto input_dist_attr = input.dist_attr(); | ||
std::vector<int64_t> input_dims_mapping = input_dist_attr.dims_mapping(); | ||
|
||
PADDLE_ENFORCE_EQ( | ||
input_ndim, | ||
out_ndim, | ||
phi::errors::InvalidArgument("The Tensor Input's rank [%d] is not equal " | ||
"to the Tensor Output's rank [%d]", | ||
input_ndim, | ||
out_ndim)); | ||
|
||
PADDLE_ENFORCE_EQ( | ||
out_ndim, | ||
out_dims_mapping_size, | ||
phi::errors::InvalidArgument("The Tensor Output's rank [%d] and Its " | ||
"dims_mapping size [%d] are not matched.", | ||
out_ndim, | ||
out_dims_mapping_size)); | ||
|
||
std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; | ||
std::string input_axes = alphabet.substr(0, input_ndim); | ||
std::string special_axes = alphabet.substr(input_ndim); | ||
|
||
for (int i = 0; i < static_cast<int>(axes.size()); i++) { | ||
int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i]; | ||
input_axes[axis] = special_axes[i]; | ||
} | ||
|
||
std::string out_axes(input_axes); | ||
|
||
for (int i = 0; i < static_cast<int>(axes.size()); i++) { | ||
int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i]; | ||
out_axes[axis] = special_axes[i]; | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 逆向推导的notation和正向保持一致,不一样的话会有点迷惑。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里也是参照的
其中, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. slice 这里都用 abcd --> abc1 这种吧,主要还是因为有多个特殊维度,感觉这样可以直接看出来哪些是切分的。split 我之后看看需不需要改下,这样确实不大一致。 |
||
|
||
std::vector<std::pair<std::string, std::vector<int64_t>>> axes_sharding_info; | ||
std::vector<int64_t> out_dims_mapping = output.dist_attr().dims_mapping(); | ||
axes_sharding_info.emplace_back(std::make_pair(out_axes, out_dims_mapping)); | ||
|
||
std::unordered_map<std::string, int64_t> axis_to_dim_map = | ||
ShardingMergeForTensors(axes_sharding_info); | ||
|
||
input_dims_mapping = GetDimsMappingForAxes(input_axes, axis_to_dim_map, true); | ||
for (int i = 0; i < static_cast<int>(axes.size()); i++) { | ||
int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i]; | ||
input_dims_mapping[axis] = -1; | ||
} | ||
input_dist_attr.set_dims_mapping(input_dims_mapping); | ||
out_dims_mapping = GetDimsMappingForAxes(out_axes, axis_to_dim_map, true); | ||
for (int i = 0; i < static_cast<int>(axes.size()); i++) { | ||
int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i]; | ||
out_dims_mapping[axis] = -1; | ||
} | ||
out_dist_attr.set_dims_mapping(out_dims_mapping); | ||
|
||
VLOG(4) << "SliceInferSpmdReverse:"; | ||
VLOG(4) << "Einsum Notation: " << input_axes << "-->" << out_axes; | ||
VLOG(4) << "Output" | ||
<< " shape: [" << str_join(phi::vectorize(output.dims())) << "] " | ||
<< "src_dims_mapping: [" | ||
<< str_join(output.dist_attr().dims_mapping()) << "] " | ||
<< "dst_dims_mapping: [" << str_join(out_dist_attr.dims_mapping()) | ||
<< "]"; | ||
VLOG(4) << "Input shape: [" << str_join(input_shape) << "] " | ||
<< "dims_mapping: [" << str_join(input_dims_mapping) << "]\n\n"; | ||
|
||
return {{input_dist_attr}, {out_dist_attr}}; | ||
} | ||
|
||
} // namespace distributed | ||
} // namespace phi |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include <iterator> | ||
#include <map> | ||
#include <string> | ||
#include <vector> | ||
|
||
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" | ||
#include "paddle/phi/core/distributed/type_defs.h" | ||
|
||
namespace phi { | ||
namespace distributed { | ||
|
||
SpmdInfo SliceInferSpmd(const DistMetaTensor& input, | ||
const std::vector<int64_t>& axes, | ||
const std::vector<int>& starts, | ||
const std::vector<int>& ends, | ||
const std::vector<int64_t>& infer_flags, | ||
const std::vector<int64_t>& decrease_axis); | ||
|
||
SpmdInfo SliceInferSpmdReverse(const DistMetaTensor& input, | ||
const DistMetaTensor& output, | ||
const std::vector<int64_t>& axes, | ||
const std::vector<int>& starts, | ||
const std::vector<int>& ends, | ||
const std::vector<int64_t>& infer_flags, | ||
const std::vector<int64_t>& decrease_axis); | ||
|
||
} // namespace distributed | ||
} // namespace phi |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果一些attr 对切分推导无影响,只是为了和 phi ymal 中定义对齐, 用注释说明一下,下个pr 里可以更新一下