Skip to content

Commit

Permalink
replace any by variant in infermeta (#42181)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenwhql authored Apr 25, 2022
1 parent a3a6f0c commit c2a05a9
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 96 deletions.
34 changes: 33 additions & 1 deletion paddle/phi/core/infermeta_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ void InferMetaContext::EmplaceBackOutput(MetaTensor output) {
outputs_.emplace_back(std::move(output));
output_range_.emplace_back(std::pair<int, int>(index, index + 1));
}
void InferMetaContext::EmplaceBackAttr(paddle::any attr) {
void InferMetaContext::EmplaceBackAttr(Attribute attr) {
attrs_.emplace_back(std::move(attr));
}

Expand Down Expand Up @@ -120,6 +120,38 @@ std::vector<MetaTensor*> InferMetaContext::MutableOutputBetween(size_t start,
return result;
}

template <typename AttrType>
const AttrType& InferMetaContext::AttrAt(size_t idx) const {
try {
return paddle::get<AttrType>(attrs_.at(idx));
} catch (paddle::bad_variant_access const& e) {
PADDLE_THROW(phi::errors::InvalidArgument(
"Attribute cast error in InferMeta Context, the expected attribute "
"type is `%s`.",
std::type_index(typeid(AttrType)).name()));
}
}

template const bool& InferMetaContext::AttrAt(size_t idx) const;
template const int& InferMetaContext::AttrAt(size_t idx) const;
template const int64_t& InferMetaContext::AttrAt(size_t idx) const;
template const float& InferMetaContext::AttrAt(size_t idx) const;
template const double& InferMetaContext::AttrAt(size_t idx) const;
template const std::string& InferMetaContext::AttrAt(size_t idx) const;
template const std::vector<bool>& InferMetaContext::AttrAt(size_t idx) const;
template const std::vector<int>& InferMetaContext::AttrAt(size_t idx) const;
template const std::vector<int64_t>& InferMetaContext::AttrAt(size_t idx) const;
template const std::vector<float>& InferMetaContext::AttrAt(size_t idx) const;
template const std::vector<double>& InferMetaContext::AttrAt(size_t idx) const;
template const std::vector<std::string>& InferMetaContext::AttrAt(
size_t idx) const;
template const Scalar& InferMetaContext::AttrAt(size_t idx) const;
template const std::vector<Scalar>& InferMetaContext::AttrAt(size_t idx) const;
template const IntArray& InferMetaContext::AttrAt(size_t idx) const;
template const DataType& InferMetaContext::AttrAt(size_t idx) const;
template const DataLayout& InferMetaContext::AttrAt(size_t idx) const;
template const Place& InferMetaContext::AttrAt(size_t idx) const;

MetaFnFactory& MetaFnFactory::Instance() {
static MetaFnFactory g_meta_fn_map;
return g_meta_fn_map;
Expand Down
60 changes: 33 additions & 27 deletions paddle/phi/core/infermeta_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License. */

#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/attribute.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/macros.h"
#include "paddle/phi/core/meta_tensor.h"
Expand All @@ -41,7 +42,7 @@ class InferMetaContext {

void EmplaceBackInput(MetaTensor input);
void EmplaceBackOutput(MetaTensor output);
void EmplaceBackAttr(paddle::any attr);
void EmplaceBackAttr(Attribute attr);

void EmplaceBackInputs(
paddle::SmallVector<MetaTensor, phi::kInputSmallVectorSize> inputs);
Expand All @@ -61,17 +62,7 @@ class InferMetaContext {
size_t end);

template <typename AttrType>
AttrType AttrAt(size_t idx) {
try {
return paddle::any_cast<AttrType>(attrs_.at(idx));
} catch (paddle::bad_any_cast& e) {
PADDLE_THROW(phi::errors::InvalidArgument(
"Attribute cast error in InferMeta Context, the expected attribute "
"type is `%s`, but actual attribute type is `%s`.",
std::type_index(typeid(AttrType)).name(),
std::type_index(attrs_.at(idx).type()).name()));
}
}
const AttrType& AttrAt(size_t idx) const;

const std::pair<int, int>& InputRangeAt(size_t idx) const;
const std::pair<int, int>& OutputRangeAt(size_t idx) const;
Expand All @@ -81,7 +72,7 @@ class InferMetaContext {
protected:
MetaConfig config_;

paddle::SmallVector<paddle::any, kAttrSmallVectorSize> attrs_;
paddle::SmallVector<Attribute, kAttrSmallVectorSize> attrs_;

paddle::SmallVector<std::pair<int, int>, phi::kInputSmallVectorSize>
input_range_;
Expand Down Expand Up @@ -111,6 +102,21 @@ class InferMetaContext {
} \
}

#define PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(attr_type) \
template <typename... Tail> \
struct InferMetaFnCallHelper<const attr_type&, Tail...> { \
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs> \
static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) { \
static_assert(out_idx == 0, \
"InferMeta's Attributes should appear before Outputs."); \
const attr_type& arg = ctx->AttrAt<attr_type>(attr_idx); \
InferMetaFnCallHelper< \
Tail...>::template Call<in_idx, attr_idx + 1, out_idx>(ctx, \
pargs..., \
arg); \
} \
}

template <typename T>
struct InferMetaTypeTag {};

Expand Down Expand Up @@ -201,27 +207,27 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
}
};

// TODO(chenweihang): support other attr type later
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(bool);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int64_t);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(float);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::string&);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector<bool>&);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector<int>&);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(
const std::vector<int64_t>&);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector<float>&);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector<double>&);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(
const std::vector<std::string>&);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataType);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(Backend);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataLayout);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const Scalar&);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const IntArray&);

// TODO(chenweihang): support vector<MetaTensor> input later
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(std::string);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(Scalar);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(IntArray);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(
std::vector<bool>);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector<int>);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(
std::vector<int64_t>);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(
std::vector<float>);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(
std::vector<double>);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(
std::vector<std::string>);

template <typename... Tail>
struct InferMetaFnCallHelper<MetaTensor*, Tail...> {
Expand Down
29 changes: 0 additions & 29 deletions paddle/phi/core/type_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,37 +18,8 @@
#include <string>
#include <vector>

#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/scalar.h"

#include "paddle/utils/variant.h"

namespace phi {

class Place;

// NOTE: Add needed type in the future
using Attribute = paddle::variant<bool,
int,
int64_t,
float,
double,
std::string,
std::vector<bool>,
std::vector<int>,
std::vector<int64_t>,
std::vector<float>,
std::vector<double>,
std::vector<std::string>,
Scalar,
std::vector<Scalar>,
IntArray,
DataType,
DataLayout,
Place>;

class Kernel;
class KernelKey;
class KernelArgsDef;
Expand Down
8 changes: 0 additions & 8 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,6 @@ void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out) {
out->set_dtype(x.dtype());
}

void CopyToInferMeta(const MetaTensor& x,
Backend backend,
bool blocking,
MetaTensor* out) {
UnchangedInferMeta(x, out);
}

void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out) {
out->set_dims(x.dims());
out->set_dtype(dtype == DataType::UNDEFINED ? x.dtype() : dtype);
Expand Down Expand Up @@ -3008,6 +3001,5 @@ void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out) {

} // namespace phi

PD_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta);
PD_REGISTER_INFER_META_FN(flatten, phi::FlattenInferMeta);
PD_REGISTER_INFER_META_FN(split, phi::SplitInferMeta);
5 changes: 0 additions & 5 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,6 @@ void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out);

void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out);

void CopyToInferMeta(const MetaTensor& x,
Backend backend,
bool blocking,
MetaTensor* out);

void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out);

void CumsumInferMeta(const MetaTensor& x,
Expand Down
26 changes: 0 additions & 26 deletions paddle/phi/tests/core/test_meta_fn_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,32 +60,6 @@ TEST(MetaFnFactory, InferMetaFnExists) {
EXPECT_EQ(dense_out1.dims()[1], dense_out2.dims()[1]);
}

TEST(MetaFnFactory, CopyInferMetaFn) {
phi::DenseTensor dense_x;
dense_x.Resize({3, 4});

phi::MetaTensor meta_x(&dense_x);
phi::DenseTensor dense_out1;
phi::MetaTensor meta_out(&dense_out1);
phi::UnchangedInferMeta(meta_x, &meta_out);

auto shared_meat_x = phi::MetaTensor(&dense_x);
phi::DenseTensor dense_out2;
auto shared_meta_out = phi::MetaTensor(&dense_out2);

phi::InferMetaContext ctx;
ctx.EmplaceBackInput(shared_meat_x);
ctx.EmplaceBackAttr(Backend::CPU);
ctx.EmplaceBackAttr(false);
ctx.EmplaceBackOutput(shared_meta_out);
ctx.SetMetaConfig({/*is_runtime =*/true, /*is_run_mkldnn_kernel=*/false});
phi::MetaFnFactory::Instance().Get("copy_to")(&ctx);

EXPECT_EQ(dense_out1.dims().size(), dense_out2.dims().size());
EXPECT_EQ(dense_out1.dims()[0], dense_out2.dims()[0]);
EXPECT_EQ(dense_out1.dims()[1], dense_out2.dims()[1]);
}

TEST(MetaFnFactory, SplitInferMetaFn) {
phi::DenseTensor dense_x;
dense_x.Resize({4, 10});
Expand Down

0 comments on commit c2a05a9

Please sign in to comment.