Skip to content

Commit

Permalink
Add paddle::variant and replace paddle::any (#42139)
Browse files Browse the repository at this point in the history
* add variant and replace any

* split attribute
  • Loading branch information
chenwhql authored Apr 24, 2022
1 parent b1c6378 commit 79f717d
Show file tree
Hide file tree
Showing 14 changed files with 2,993 additions and 28 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/framework/custom_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ limitations under the License. */
#include "paddle/phi/api/all.h"
#include "paddle/phi/api/lib/utils/tensor_utils.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/utils/any.h"

namespace paddle {
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/new_executor/interpretercore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/os_info.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/phi/core/kernel_context.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/kernel_factory.h"

#ifdef PADDLE_WITH_MKLDNN
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ limitations under the License. */
#include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/kernel_factory.h"
#include "paddle/phi/ops/compat/signatures.h"

Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ limitations under the License. */
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/core/compat/arg_map_context.h"
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/kernel_factory.h"

namespace paddle {
Expand All @@ -55,6 +54,10 @@ class Variable;
} // namespace framework
} // namespace paddle

namespace phi {
class KernelContext;
}

DECLARE_int32(inner_op_parallelism);

namespace paddle {
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/imperative/prepared_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/selected_rows.h"

DECLARE_bool(use_mkldnn);
Expand Down
50 changes: 50 additions & 0 deletions paddle/phi/core/attribute.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include <vector>

#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/utils/variant.h"

namespace phi {

class Place;

// NOTE: Add needed type in the future
using Attribute = paddle::variant<bool,
int,
int64_t,
float,
double,
std::string,
std::vector<bool>,
std::vector<int>,
std::vector<int64_t>,
std::vector<float>,
std::vector<double>,
std::vector<std::string>,
Scalar,
std::vector<Scalar>,
IntArray,
DataType,
DataLayout,
Place>;

} // namespace phi
32 changes: 31 additions & 1 deletion paddle/phi/core/kernel_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ void KernelContext::EmplaceBackOutputsWithoutSetRange(
std::make_move_iterator(outputs.end()));
}

void KernelContext::EmplaceBackAttr(paddle::any attr) {
void KernelContext::EmplaceBackAttr(Attribute attr) {
attrs_.emplace_back(std::move(attr));
}

Expand Down Expand Up @@ -113,4 +113,34 @@ const std::pair<int, int>& KernelContext::OutputRangeAt(size_t idx) const {
return output_range_.at(idx);
}

template <typename AttrType>
const AttrType& KernelContext::AttrAt(size_t idx) const {
try {
return paddle::get<AttrType>(attrs_.at(idx));
} catch (paddle::bad_variant_access const& ex) {
PADDLE_THROW(phi::errors::InvalidArgument(
"Attribute cast error in Op Kernel Context."));
}
}

template const bool& KernelContext::AttrAt(size_t idx) const;
template const int& KernelContext::AttrAt(size_t idx) const;
template const int64_t& KernelContext::AttrAt(size_t idx) const;
template const float& KernelContext::AttrAt(size_t idx) const;
template const double& KernelContext::AttrAt(size_t idx) const;
template const std::string& KernelContext::AttrAt(size_t idx) const;
template const std::vector<bool>& KernelContext::AttrAt(size_t idx) const;
template const std::vector<int>& KernelContext::AttrAt(size_t idx) const;
template const std::vector<int64_t>& KernelContext::AttrAt(size_t idx) const;
template const std::vector<float>& KernelContext::AttrAt(size_t idx) const;
template const std::vector<double>& KernelContext::AttrAt(size_t idx) const;
template const std::vector<std::string>& KernelContext::AttrAt(
size_t idx) const;
template const Scalar& KernelContext::AttrAt(size_t idx) const;
template const std::vector<Scalar>& KernelContext::AttrAt(size_t idx) const;
template const IntArray& KernelContext::AttrAt(size_t idx) const;
template const DataType& KernelContext::AttrAt(size_t idx) const;
template const DataLayout& KernelContext::AttrAt(size_t idx) const;
template const Place& KernelContext::AttrAt(size_t idx) const;

} // namespace phi
15 changes: 4 additions & 11 deletions paddle/phi/core/kernel_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
#include <iterator>
#include <utility>

#include "paddle/phi/core/attribute.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/tensor_base.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/utils/any.h"
#include "paddle/utils/optional.h"
#include "paddle/utils/small_vector.h"

Expand Down Expand Up @@ -64,7 +64,7 @@ class KernelContext {
void EmplaceBackOutputsWithoutSetRange(
paddle::SmallVector<TensorBase*> outputs);

void EmplaceBackAttr(paddle::any attr);
void EmplaceBackAttr(Attribute attr);

const std::pair<int, int>& InputRangeAt(size_t idx) const;

Expand Down Expand Up @@ -128,14 +128,7 @@ class KernelContext {
}

template <typename AttrType>
AttrType AttrAt(size_t idx) const {
try {
return paddle::any_cast<AttrType>(attrs_.at(idx));
} catch (paddle::bad_any_cast&) {
PADDLE_THROW(phi::errors::InvalidArgument(
"Attribute cast error in Op Kernel Context."));
}
}
const AttrType& AttrAt(size_t idx) const;

size_t InputsSize() const { return inputs_.size(); }
size_t OutputsSize() const { return outputs_.size(); }
Expand All @@ -146,7 +139,7 @@ class KernelContext {

paddle::SmallVector<const TensorBase*> inputs_;
paddle::SmallVector<TensorBase*> outputs_;
paddle::SmallVector<paddle::any> attrs_;
paddle::SmallVector<Attribute> attrs_;

paddle::SmallVector<std::pair<int, int>> input_range_;
paddle::SmallVector<std::pair<int, int>> output_range_;
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/core/kernel_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(const StringTensor&))) {
args_def->AppendInput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(const SparseCooTensor&))) {
args_def->AppendInput(default_key.backend(),
default_tensor_layout,
Expand Down Expand Up @@ -153,6 +158,11 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(StringTensor*))) {
args_def->AppendOutput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else {
// Attribute deal with
// TODO(chenweihang): now here allow any types of attribute, maybe
Expand Down
39 changes: 29 additions & 10 deletions paddle/phi/core/kernel_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,24 @@ namespace phi {
} \
}

#define PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(attr_type) \
template <typename... Tail> \
struct KernelCallHelper<const attr_type&, Tail...> { \
template <int dev_ctx_idx, \
int in_idx, \
int attr_idx, \
int out_idx, \
typename... PreviousArgs> \
static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \
static_assert(out_idx == 0, \
"Kernel's Attributes should appear before Outputs."); \
const attr_type& arg = ctx->AttrAt<attr_type>(attr_idx); \
KernelCallHelper<Tail...>:: \
template Compute<dev_ctx_idx, in_idx, attr_idx + 1, out_idx>( \
ctx, pargs..., arg); \
} \
}

#define PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(tensor_type) \
template <typename... Tail> \
struct KernelCallHelper<tensor_type*, Tail...> { \
Expand Down Expand Up @@ -270,19 +288,20 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(int);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(int64_t);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(phi::dtype::float16);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const Scalar&);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(DataType);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(DataLayout);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(Place);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<int64_t>&);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const IntArray&);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<int>&);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::string&);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<bool>&);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<float>&);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<double>&);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<std::string>&);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<Scalar>&);
PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::string);
PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(Scalar);
PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(IntArray);
PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector<bool>);
PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector<int>);
PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector<int64_t>);
PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector<float>);
PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector<double>);
PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(
std::vector<std::string>);
PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector<Scalar>);

/* Output Helpers */

Expand Down
31 changes: 31 additions & 0 deletions paddle/phi/core/type_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,40 @@
#pragma once

#include <functional>
#include <string>
#include <vector>

#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/scalar.h"

#include "paddle/utils/variant.h"

namespace phi {

class Place;

// NOTE: Add needed type in the future
using Attribute = paddle::variant<bool,
int,
int64_t,
float,
double,
std::string,
std::vector<bool>,
std::vector<int>,
std::vector<int64_t>,
std::vector<float>,
std::vector<double>,
std::vector<std::string>,
Scalar,
std::vector<Scalar>,
IntArray,
DataType,
DataLayout,
Place>;

class Kernel;
class KernelKey;
class KernelArgsDef;
Expand Down
5 changes: 0 additions & 5 deletions paddle/phi/tests/core/test_custom_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ void FakeDot(const Context& dev_ctx,
float fake_attr_float,
double fake_attr_double,
int64_t fake_attr_int64,
phi::dtype::float16 fake_attr_f16,
phi::DataType fake_attr_dtype,
const phi::Scalar& fake_attr_scalar,
const phi::IntArray& fake_attr_int_array,
Expand All @@ -64,7 +63,6 @@ void FakeDot(const Context& dev_ctx,
std::cout << "fake_attr_float: " << fake_attr_float << std::endl;
std::cout << "fake_attr_double: " << fake_attr_double << std::endl;
std::cout << "fake_attr_int64: " << fake_attr_int64 << std::endl;
std::cout << "fake_attr_f16: " << fake_attr_f16 << std::endl;
std::cout << "fake_attr_dtype: " << fake_attr_dtype << std::endl;
std::cout << "fake_attr_int64_vec: " << fake_attr_int64_vec.size()
<< std::endl;
Expand All @@ -78,7 +76,6 @@ void FakeDot(const Context& dev_ctx,
assert(fake_attr_float == 2);
assert(fake_attr_double == 3);
assert(fake_attr_int64 == 4);
assert(fake_attr_f16 == phi::dtype::float16(5));
assert(fake_attr_dtype == phi::DataType::UINT32);
assert(fake_attr_int64_vec.size() == 0);
assert(fake_attr_int_vec.size() == 0);
Expand Down Expand Up @@ -248,7 +245,6 @@ TEST(CustomKernel, custom_kernel_dot) {
float fake_attr_float = 2.0;
double fake_attr_double = 3.0;
int64_t fake_attr_int64 = 4;
phi::dtype::float16 fake_attr_f16 = phi::dtype::float16(5);
phi::DataType fake_attr_dtype = phi::DataType::UINT32;
paddle::framework::LoDTensor tmp_tensor;
tmp_tensor.mutable_data<uint8_t>({1}, phi::TransToPhiPlace(backend));
Expand All @@ -262,7 +258,6 @@ TEST(CustomKernel, custom_kernel_dot) {
kernel_context.EmplaceBackAttr(fake_attr_float);
kernel_context.EmplaceBackAttr(fake_attr_double);
kernel_context.EmplaceBackAttr(fake_attr_int64);
kernel_context.EmplaceBackAttr(fake_attr_f16);
kernel_context.EmplaceBackAttr(fake_attr_dtype);
kernel_context.EmplaceBackAttr(fake_attr_scalar);
kernel_context.EmplaceBackAttr(fake_attr_int_array);
Expand Down
Loading

0 comments on commit 79f717d

Please sign in to comment.