Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【auto parallel】custom op spmd rule register #60509

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions cmake/inference_lib.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,13 @@ copy(
inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/visit_type.h
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core/)

copy(
inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/distributed/type_defs.h
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core/distributed/
)

copy(
inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/hostdevice.h
Expand Down
16 changes: 16 additions & 0 deletions paddle/phi/api/ext/op_meta_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/common/exception.h"
#include "paddle/phi/api/include/dll_decl.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/core/distributed/type_defs.h"
#include "paddle/utils/any.h"
#include "paddle/utils/none.h"
#include "paddle/utils/optional.h"
Expand Down Expand Up @@ -996,6 +997,15 @@ struct TrtGetOutputDimsFuncImpl<Return (*)(Args...), impl_fn> {

////////////////////// Op Meta Info //////////////////////

using CustomSpmdInferTensorArg =
paddle::variant<phi::distributed::DistMetaTensor,
std::vector<phi::distributed::DistMetaTensor>>;
using CustomSpmdInferAttrArg = paddle::any;

using InferSpmdFunc = phi::distributed::SpmdInfo (*)(
const std::vector<CustomSpmdInferTensorArg>& inputs,
const std::vector<CustomSpmdInferAttrArg>& attrs);

class PADDLE_API OpMetaInfo {
public:
explicit OpMetaInfo(const std::string& op_name) : name_(op_name) {}
Expand Down Expand Up @@ -1023,6 +1033,9 @@ class PADDLE_API OpMetaInfo {
// format: PD_INFER_DTYPE(...)
OpMetaInfo& SetInferDtypeFn(InferDtypeFunc&& func);

// format: PD_INFER_SPMD_RULE(...)
OpMetaInfo& SetInferSpmdFn(InferSpmdFunc&& func);

#ifdef PADDLE_WITH_TENSORRT
// format: PD_TRT_INFER_SHAPE(...)
OpMetaInfo& SetTrtInferShapeFn(TrtGetOutputDimsFunc&& func);
Expand All @@ -1045,6 +1058,7 @@ class PADDLE_API OpMetaInfo {
KernelFunc kernel_fn_{nullptr};
InferShapeFunc infer_shape_fn_{nullptr};
InferDtypeFunc infer_dtype_fn_{nullptr};
InferSpmdFunc infer_spmd_fn_{nullptr};
#ifdef PADDLE_WITH_TENSORRT
TrtGetOutputDimsFunc trt_infer_shape_fn_{nullptr};
std::vector<std::string> trt_supports_format_config_;
Expand All @@ -1068,6 +1082,7 @@ class OpMetaInfoHelper {
static const KernelFunc& GetKernelFn(const paddle::OpMetaInfo& info);
static const InferShapeFunc& GetInferShapeFn(const paddle::OpMetaInfo& info);
static const InferDtypeFunc& GetInferDtypeFn(const paddle::OpMetaInfo& info);
static const InferSpmdFunc& GetInferSpmdFn(const paddle::OpMetaInfo& info);

#ifdef PADDLE_WITH_TENSORRT
static const TrtGetOutputDimsFunc& GetTrtInferShapeFn(
Expand Down Expand Up @@ -1108,6 +1123,7 @@ class PADDLE_API OpMetaInfoBuilder {
OpMetaInfoBuilder& SetKernelFn(KernelFunc func);
OpMetaInfoBuilder& SetInferShapeFn(InferShapeFunc func);
OpMetaInfoBuilder& SetInferDtypeFn(InferDtypeFunc func);
OpMetaInfoBuilder& SetInferSpmdFn(InferSpmdFunc func);

#ifdef PADDLE_WITH_TENSORRT
OpMetaInfoBuilder& SetTrtInferShapeFn(TrtGetOutputDimsFunc func);
Expand Down
140 changes: 140 additions & 0 deletions paddle/phi/api/ext/spmd_infer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
#include "paddle/phi/core/distributed/type_defs.h"

namespace paddle {

using CustomSpmdInferTensorArg =
paddle::variant<phi::distributed::DistMetaTensor,
std::vector<phi::distributed::DistMetaTensor>>;

using CustomSpmdInferAttrArg = paddle::any;
template <typename T>
struct SpmdInferHelperTypeEnd {};

#define PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(attr_type) \
template <typename... Tail> \
struct SpmdInferHelper<attr_type, Tail...> { \
template <int in_idx, int attr_idx, typename... PreviousArgs> \
static phi::distributed::SpmdInfo InferSpmd( \
const std::vector<CustomSpmdInferTensorArg>& inputs, \
const std::vector<CustomSpmdInferAttrArg>& attrs, \
const PreviousArgs&... pargs) { \
try { \
attr_type arg = paddle::any_cast<attr_type>(attrs[attr_idx]); \
return SpmdInferHelper<Tail...>::template InferSpmd<in_idx, \
attr_idx + 1>( \
inputs, attrs, pargs..., arg); \
} catch (paddle::bad_any_cast&) { \
PD_THROW( \
"Attribute cast error in custom operator SpmdInferFunc " \
"function. " \
"Expected " #attr_type \
" value. SpmdInferFunc's attribute list must be exactly " \
"same " \
"as " \
"Forward " \
"KernelFn's attribute list except std::vector<int64_t> " \
"attribute."); \
} \
} \
}

template <typename F, F f>
struct SpmdInferImpl;

template <typename... Args, phi::distributed::SpmdInfo (*impl_fn)(Args...)>
struct SpmdInferImpl<phi::distributed::SpmdInfo (*)(Args...), impl_fn> {
static phi::distributed::SpmdInfo InferSpmd(
const std::vector<CustomSpmdInferTensorArg>& inputs,
const std::vector<CustomSpmdInferAttrArg>& attrs) {
return SpmdInferHelper<Args..., SpmdInferHelperTypeEnd<int>>::
template InferSpmd<0, 0>(inputs, attrs);
}

private:
template <typename... RemainingArgs>
struct SpmdInferHelper;

// Handle args for general tensor input case
template <typename... Tail>
struct SpmdInferHelper<const phi::distributed::DistMetaTensor&, Tail...> {
template <int in_idx, int attr_idx, typename... PreviousArgs>
static phi::distributed::SpmdInfo InferSpmd(
const std::vector<CustomSpmdInferTensorArg>& inputs,
const std::vector<CustomSpmdInferAttrArg>& attrs,
PreviousArgs&... pargs) {
auto& arg =
PADDLE_GET_CONST(phi::distributed::DistMetaTensor, inputs[in_idx]);
return SpmdInferHelper<Tail...>::template InferSpmd<in_idx + 1, attr_idx>(
inputs, attrs, pargs..., arg);
}
};

// Handle args for vector of Tensor input case
template <typename... Tail>
struct SpmdInferHelper<const std::vector<phi::distributed::DistMetaTensor>&,
Tail...> {
template <int in_idx, int attr_idx, typename... PreviousArgs>
static phi::distributed::SpmdInfo InferSpmd(
const std::vector<CustomSpmdInferTensorArg>& inputs,
const std::vector<CustomSpmdInferAttrArg>& attrs,
PreviousArgs&... pargs) {
auto& arg = PADDLE_GET_CONST(
std::vector<phi::distributed::DistMetaTensor>, inputs[in_idx]);
return SpmdInferHelper<Tail...>::template InferSpmd<in_idx + 1, attr_idx>(
inputs, attrs, pargs..., arg);
}
};

PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(bool);
PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(int);
PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(float);
PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(int64_t);
PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const std::string&);
PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const std::vector<int>&);
PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const std::vector<float>&);
PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const std::vector<std::string>&);
PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const std::vector<int64_t>&);
PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const bool&);
PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const int&);
PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const float&);
PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const int64_t&);

PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(std::string);
PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(std::vector<int>);
PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(std::vector<float>);
PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(std::vector<std::string>);
PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const std::vector<int64_t>);

// end: base template
template <typename T>
struct SpmdInferHelper<SpmdInferHelperTypeEnd<T>> {
template <int in_idx, int attr_idx, typename... PreviousArgs>
static phi::distributed::SpmdInfo InferSpmd(
const std::vector<CustomSpmdInferTensorArg>& inputs,
const std::vector<CustomSpmdInferAttrArg>& attrs,
PreviousArgs&... pargs) {
return impl_fn(pargs...);
}
};
};

#define PD_INFER_SPMD_RULE(...) \
::paddle::SpmdInferImpl<decltype(&__VA_ARGS__), &__VA_ARGS__>::InferSpmd

} // namespace paddle
15 changes: 15 additions & 0 deletions paddle/phi/api/lib/op_meta_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,11 @@ OpMetaInfo& OpMetaInfo::SetInferDtypeFn(InferDtypeFunc&& func) {
return *this;
}

OpMetaInfo& OpMetaInfo::SetInferSpmdFn(InferSpmdFunc&& func) {
infer_spmd_fn_ = std::forward<InferSpmdFunc>(func);
return *this;
}

#ifdef PADDLE_WITH_TENSORRT
OpMetaInfo& OpMetaInfo::SetTrtInferShapeFn(TrtGetOutputDimsFunc&& func) {
trt_infer_shape_fn_ = std::forward<TrtGetOutputDimsFunc>(func);
Expand Down Expand Up @@ -407,6 +412,11 @@ const InferDtypeFunc& OpMetaInfoHelper::GetInferDtypeFn(
return info.infer_dtype_fn_;
}

const InferSpmdFunc& OpMetaInfoHelper::GetInferSpmdFn(
const paddle::OpMetaInfo& info) {
return info.infer_spmd_fn_;
}

#ifdef PADDLE_WITH_TENSORRT
const TrtGetOutputDimsFunc& OpMetaInfoHelper::GetTrtInferShapeFn(
const paddle::OpMetaInfo& info) {
Expand Down Expand Up @@ -559,6 +569,11 @@ OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferDtypeFn(InferDtypeFunc func) {
return *this;
}

OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferSpmdFn(InferSpmdFunc func) {
info_ptr_->SetInferSpmdFn(std::forward<InferSpmdFunc>(func));
return *this;
}

#ifdef PADDLE_WITH_TENSORRT
OpMetaInfoBuilder& OpMetaInfoBuilder::SetTrtInferShapeFn(
TrtGetOutputDimsFunc func) {
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/core/distributed/type_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
namespace phi {
namespace distributed {
class TensorDistAttr;
class DistMetaTensor;

using ArgDistAttr =
paddle::variant<TensorDistAttr, std::vector<TensorDistAttr>>;
Expand Down
Loading