Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.50】 为 Paddle 新增 slice 的 spmd 切分推导规则 #57866

Merged
merged 11 commits into from
Oct 27, 2023
6 changes: 6 additions & 0 deletions paddle/phi/infermeta/spmd_rules/rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ limitations under the License. */
#include "paddle/phi/infermeta/spmd_rules/reduction.h"
#include "paddle/phi/infermeta/spmd_rules/replicated.h"
#include "paddle/phi/infermeta/spmd_rules/reshape.h"
#include "paddle/phi/infermeta/spmd_rules/slice.h"
#include "paddle/phi/infermeta/spmd_rules/softmax.h"
#include "paddle/phi/infermeta/spmd_rules/split.h"
#include "paddle/phi/infermeta/spmd_rules/transpose.h"
Expand Down Expand Up @@ -517,6 +518,11 @@ PD_REGISTER_SPMD_RULE(
PD_INFER_SPMD(phi::distributed::SplitWithNumInferSpmd),
PD_INFER_SPMD(phi::distributed::SplitWithNumInferSpmdReverse));

// slice rule
PD_REGISTER_SPMD_RULE(slice,
PD_INFER_SPMD(phi::distributed::SliceInferSpmd),
PD_INFER_SPMD(phi::distributed::SliceInferSpmdReverse));

// transpose rule
PD_REGISTER_SPMD_RULE(
transpose,
Expand Down
176 changes: 176 additions & 0 deletions paddle/phi/infermeta/spmd_rules/slice.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/infermeta/spmd_rules/slice.h"

#include "glog/logging.h"

#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
#include "paddle/phi/infermeta/spmd_rules/utils.h"

namespace phi {
namespace distributed {

using phi::distributed::auto_parallel::str_join;

SpmdInfo SliceInferSpmd(const DistMetaTensor& input,
const std::vector<int64_t>& axes,
const std::vector<int>& starts,
const std::vector<int>& ends,
const std::vector<int64_t>& infer_flags,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果一些attr 对切分推导无影响,只是为了和 phi ymal 中定义对齐, 用注释说明一下,下个pr 里可以更新一下

const std::vector<int64_t>& decrease_axis) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

代码里面加上注释

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pkuzyc 本地代码注释大体已加好,但是,涉及到下面两个点的地方,可能需要修改。麻烦老师先看一下~~

auto input_shape = phi::vectorize(input.dims());
int input_ndim = input_shape.size();
auto input_dist_attr_src = input.dist_attr();
std::vector<int64_t> input_dims_mapping = input_dist_attr_src.dims_mapping();
PADDLE_ENFORCE_EQ(
input_ndim,
input_dims_mapping.size(),
phi::errors::InvalidArgument("The Tensor Input's rank [%d] and Input's "
"dims_mapping size [%d] are not matched.",
input_ndim,
input_dims_mapping.size()));

std::string alphabet = "abcdefghijklmnopqrstuvwxyz";
std::string input_axes = alphabet.substr(0, input_ndim);
std::string special_axes = alphabet.substr(input_ndim);

for (int i = 0; i < static_cast<int>(axes.size()); i++) {
int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i];
input_axes[axis] = special_axes[i];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么需要 special_axes?out_axes 里面切分的维度是 '1',对应的维度不会传到 output 上。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

老师您好。这里借鉴的是split的切分推导规则split 只涉及一个
axis ,它用保留的 k 去做特殊标记。类比过来, slice 涉及多个 axis ,所以需要多个保留的(没用到的)字母去做特殊标记。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我觉得 slice 可以不用特殊标记,在最后的 log 里把 axes 打一下吧。其他规则里面也会用 'k' 做特殊标记,所以 split 里也用了 'k' 标记。slice 因为有多个 axis,不像 split 只有一个特殊维度,这么标的话也看不出哪些是特殊维度。

}

std::string out_axes(input_axes);

for (int i = 0; i < static_cast<int>(axes.size()); i++) {
int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i];
out_axes[axis] = '1';
}

std::unordered_map<std::string, int64_t> axis_to_dim_map =
ShardingMergeForTensors({{input_axes, input_dims_mapping}});

std::vector<int64_t> out_dims_mapping =
GetDimsMappingForAxes(out_axes, axis_to_dim_map);

TensorDistAttr out_dist_attr =
CopyTensorDistAttrForOutput(input_dist_attr_src);
out_dist_attr.set_dims_mapping(out_dims_mapping);

TensorDistAttr input_dist_attr_dst(input_dist_attr_src);
for (int i = 0; i < static_cast<int>(axes.size()); i++) {
int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i];
input_dims_mapping[axis] = -1;
}
input_dist_attr_dst.set_dims_mapping(input_dims_mapping);

VLOG(4) << "SliceInferSpmd:";
VLOG(4) << "Einsum Notation: " << input_axes << "-->" << out_axes;
VLOG(4) << "Input shape: [" << str_join(input_shape) << "] "
<< "src_dims_mapping: ["
<< str_join(input_dist_attr_src.dims_mapping()) << "] "
<< "dst_dims_mapping: [" << str_join(input_dims_mapping) << "]";
VLOG(4) << "Output"
<< " dims_mapping: [" << str_join(out_dims_mapping) << "]";
VLOG(4) << std::endl;

return {{input_dist_attr_dst}, {out_dist_attr}};
}

SpmdInfo SliceInferSpmdReverse(const DistMetaTensor& input,
const DistMetaTensor& output,
const std::vector<int64_t>& axes,
const std::vector<int>& starts,
const std::vector<int>& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis) {
auto output_shape = phi::vectorize(output.dims());
int out_ndim = output_shape.size();
auto out_dist_attr = output.dist_attr();
int out_dims_mapping_size = out_dist_attr.dims_mapping().size();
auto input_shape = phi::vectorize(input.dims());
int input_ndim = input_shape.size();
auto input_dist_attr = input.dist_attr();
std::vector<int64_t> input_dims_mapping = input_dist_attr.dims_mapping();

PADDLE_ENFORCE_EQ(
input_ndim,
out_ndim,
phi::errors::InvalidArgument("The Tensor Input's rank [%d] is not equal "
"to the Tensor Output's rank [%d]",
input_ndim,
out_ndim));

PADDLE_ENFORCE_EQ(
out_ndim,
out_dims_mapping_size,
phi::errors::InvalidArgument("The Tensor Output's rank [%d] and Its "
"dims_mapping size [%d] are not matched.",
out_ndim,
out_dims_mapping_size));

std::string alphabet = "abcdefghijklmnopqrstuvwxyz";
std::string input_axes = alphabet.substr(0, input_ndim);
std::string special_axes = alphabet.substr(input_ndim);

for (int i = 0; i < static_cast<int>(axes.size()); i++) {
int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i];
input_axes[axis] = special_axes[i];
}

std::string out_axes(input_axes);

for (int i = 0; i < static_cast<int>(axes.size()); i++) {
int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i];
out_axes[axis] = special_axes[i];
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

逆向推导的notation和正向保持一致,不一样的话会有点迷惑。

Copy link
Contributor Author

@WintersMontagne10335 WintersMontagne10335 Oct 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里也是参照的 split 。我看了下 split 的日志,有一个 notation 是这样的。

  • 正向:abck-->abc1
  • 逆向:abck-->abck

其中, k 是特殊字母。这种算不算保持一致呀?如果不算,按这个例子来说,应该都改成什么样子呢?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

slice 这里都用 abcd --> abc1 这种吧,主要还是因为有多个特殊维度,感觉这样可以直接看出来哪些是切分的。split 我之后看看需不需要改下,这样确实不大一致。


std::vector<std::pair<std::string, std::vector<int64_t>>> axes_sharding_info;
std::vector<int64_t> out_dims_mapping = output.dist_attr().dims_mapping();
axes_sharding_info.emplace_back(std::make_pair(out_axes, out_dims_mapping));

std::unordered_map<std::string, int64_t> axis_to_dim_map =
ShardingMergeForTensors(axes_sharding_info);

input_dims_mapping = GetDimsMappingForAxes(input_axes, axis_to_dim_map, true);
for (int i = 0; i < static_cast<int>(axes.size()); i++) {
int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i];
input_dims_mapping[axis] = -1;
}
input_dist_attr.set_dims_mapping(input_dims_mapping);
out_dims_mapping = GetDimsMappingForAxes(out_axes, axis_to_dim_map, true);
for (int i = 0; i < static_cast<int>(axes.size()); i++) {
int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i];
out_dims_mapping[axis] = -1;
}
out_dist_attr.set_dims_mapping(out_dims_mapping);

VLOG(4) << "SliceInferSpmdReverse:";
VLOG(4) << "Einsum Notation: " << input_axes << "-->" << out_axes;
VLOG(4) << "Output"
<< " shape: [" << str_join(phi::vectorize(output.dims())) << "] "
<< "src_dims_mapping: ["
<< str_join(output.dist_attr().dims_mapping()) << "] "
<< "dst_dims_mapping: [" << str_join(out_dist_attr.dims_mapping())
<< "]";
VLOG(4) << "Input shape: [" << str_join(input_shape) << "] "
<< "dims_mapping: [" << str_join(input_dims_mapping) << "]\n\n";

return {{input_dist_attr}, {out_dist_attr}};
}

} // namespace distributed
} // namespace phi
44 changes: 44 additions & 0 deletions paddle/phi/infermeta/spmd_rules/slice.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <iterator>
#include <map>
#include <string>
#include <vector>

#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
#include "paddle/phi/core/distributed/type_defs.h"

namespace phi {
namespace distributed {

SpmdInfo SliceInferSpmd(const DistMetaTensor& input,
const std::vector<int64_t>& axes,
const std::vector<int>& starts,
const std::vector<int>& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis);

SpmdInfo SliceInferSpmdReverse(const DistMetaTensor& input,
const DistMetaTensor& output,
const std::vector<int64_t>& axes,
const std::vector<int>& starts,
const std::vector<int>& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis);

} // namespace distributed
} // namespace phi
1 change: 1 addition & 0 deletions test/auto_parallel/spmd_rules/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ if(WITH_DISTRIBUTE)
py_test_modules(test_default_data_parallel_rule MODULES
test_default_data_parallel_rule)
py_test_modules(test_layer_norm_rule MODULES test_layer_norm_rule)
py_test_modules(test_slice_rule MODULES test_slice_rule)
py_test_modules(test_flatten_rule MODULES test_flatten_rule)
# End of unittests WITH single card WITHOUT timeout

Expand Down
Loading