diff --git a/paddle/fluid/pir/dialect/kernel/ir/kernel_type.h b/paddle/fluid/pir/dialect/kernel/ir/kernel_type.h index e2d851020c1977..adb78639d65c00 100644 --- a/paddle/fluid/pir/dialect/kernel/ir/kernel_type.h +++ b/paddle/fluid/pir/dialect/kernel/ir/kernel_type.h @@ -21,12 +21,12 @@ namespace paddle { namespace dialect { -class AllocatedDenseTensorType : public pir::Type { +class AllocatedDenseTensorType + : public pir::Type::TypeBase { public: - using Type::Type; - - DECLARE_TYPE_UTILITY_FUNCTOR(AllocatedDenseTensorType, - AllocatedDenseTensorTypeStorage); + using Base::Base; static AllocatedDenseTensorType get(pir::IrContext *ctx, const phi::Place &place, @@ -62,12 +62,12 @@ class AllocatedDenseTensorType : public pir::Type { const size_t &offset() const; }; -class AllocatedSelectedRowsType : public pir::Type { +class AllocatedSelectedRowsType + : public pir::Type::TypeBase { public: - using Type::Type; - - DECLARE_TYPE_UTILITY_FUNCTOR(AllocatedSelectedRowsType, - AllocatedSelectedRowsTypeStorage); + using Base::Base; static AllocatedSelectedRowsType get(pir::IrContext *ctx, const phi::Place &place, diff --git a/paddle/fluid/pir/dialect/operator/interface/infermeta.h b/paddle/fluid/pir/dialect/operator/interface/infermeta.h index 2c01a006a0cdf8..958d2df369ed9b 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infermeta.h +++ b/paddle/fluid/pir/dialect/operator/interface/infermeta.h @@ -20,6 +20,7 @@ namespace paddle { namespace dialect { class InferMetaInterface : public pir::OpInterfaceBase { public: + /// Defined these methods with the interface. struct Concept { explicit Concept(void (*infer_meta)(phi::InferMetaContext *)) : infer_meta_(infer_meta) {} @@ -28,13 +29,14 @@ class InferMetaInterface : public pir::OpInterfaceBase { template struct Model : public Concept { - static void InferMeta(phi::InferMetaContext *infer_meta) { + static inline void InferMeta(phi::InferMetaContext *infer_meta) { return ConcreteOp::InferMeta(infer_meta); } Model() : Concept(InferMeta) {} }; + /// Constructor InferMetaInterface(pir::Operation *op, Concept *impl) : pir::OpInterfaceBase(op), impl_(impl) {} diff --git a/paddle/fluid/pir/dialect/operator/ir/op_type.h b/paddle/fluid/pir/dialect/operator/ir/op_type.h index a09a84c31d84ad..3ee0d642e2e478 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_type.h +++ b/paddle/fluid/pir/dialect/operator/ir/op_type.h @@ -16,16 +16,19 @@ #include "paddle/fluid/pir/dialect/operator/ir/type_storage.h" #include "paddle/pir/core/builtin_type.h" +#include "paddle/pir/core/builtin_type_interfaces.h" #include "paddle/pir/core/type.h" namespace paddle { namespace dialect { + using DenseTensorType = pir::DenseTensorType; -class SelectedRowsType : public pir::Type { +class SelectedRowsType : public pir::Type::TypeBase { public: - using Type::Type; - - DECLARE_TYPE_UTILITY_FUNCTOR(SelectedRowsType, SelectedRowsTypeStorage); + using Base::Base; const pir::Type &dtype() const; diff --git a/paddle/pir/core/builtin_type.h b/paddle/pir/core/builtin_type.h index 29c99f382ff52f..3f0e7a14717039 100644 --- a/paddle/pir/core/builtin_type.h +++ b/paddle/pir/core/builtin_type.h @@ -15,14 +15,13 @@ #pragma once +#include "paddle/pir/core/builtin_type_interfaces.h" #include "paddle/pir/core/builtin_type_storage.h" #include "paddle/pir/core/type.h" namespace pir { /// -/// \brief Define built-in parameterless types. Please add the necessary -/// interface functions for built-in types through the macro -/// DECLARE_TYPE_UTILITY_FUNCTOR. +/// \brief Define built-in parameterless types. /// /// NOTE(zhangbo9674): If you need to directly /// cache the object of this built-in type in IrContext, please overload the get @@ -39,11 +38,10 @@ namespace pir { // NOTE(dev): Currently Int8 are not considered as a cached member // in IrContextImpl because it is not widely used. -class IR_API VectorType : public Type { +class IR_API VectorType + : public pir::Type::TypeBase { public: - using Type::Type; - - DECLARE_TYPE_UTILITY_FUNCTOR(VectorType, VectorTypeStorage); + using Base::Base; std::vector data() const; @@ -54,11 +52,12 @@ class IR_API VectorType : public Type { Type operator[](size_t index) const { return data()[index]; } }; -class DenseTensorType : public pir::Type { +class DenseTensorType : public pir::Type::TypeBase { public: - using Type::Type; - - DECLARE_TYPE_UTILITY_FUNCTOR(DenseTensorType, DenseTensorTypeStorage); + using Base::Base; const pir::Type &dtype() const; @@ -71,14 +70,13 @@ class DenseTensorType : public pir::Type { const size_t &offset() const; }; -#define DECLARE_BUILTIN_TYPE(__name) \ - class IR_API __name : public Type { \ - public: \ - using Type::Type; \ - \ - DECLARE_TYPE_UTILITY_FUNCTOR(__name, TypeStorage); \ - \ - static __name get(IrContext *context); \ +#define DECLARE_BUILTIN_TYPE(__name) \ + class IR_API __name : public ::pir::Type::TypeBase<__name, \ + ::pir::Type, \ + ::pir::TypeStorage> { \ + public: \ + using Base::Base; \ + static __name get(IrContext *context); \ }; #define FOREACH_BUILTIN_TYPE(__macro) \ diff --git a/paddle/pir/core/builtin_type_interfaces.cc b/paddle/pir/core/builtin_type_interfaces.cc new file mode 100644 index 00000000000000..9084bffc7a1977 --- /dev/null +++ b/paddle/pir/core/builtin_type_interfaces.cc @@ -0,0 +1,18 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pir/core/builtin_type_interfaces.h" +#include "paddle/pir/core/type_id.h" + +IR_DEFINE_EXPLICIT_TYPE_ID(pir::ShapedTypeInterface) diff --git a/paddle/pir/core/builtin_type_interfaces.h b/paddle/pir/core/builtin_type_interfaces.h new file mode 100644 index 00000000000000..79db4d12a3d150 --- /dev/null +++ b/paddle/pir/core/builtin_type_interfaces.h @@ -0,0 +1,153 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/phi/core/tensor_base.h" +#include "paddle/pir/core/cast_utils.h" +#include "paddle/pir/core/enforce.h" +#include "paddle/pir/core/type.h" + +namespace details { + +template +constexpr auto begin_impl(RangeT &&range) + -> decltype(std::begin(std::forward(range))) { + return std::begin(std::forward(range)); +} + +template +constexpr auto end_impl(RangeT &&range) + -> decltype(std::end(std::forward(range))) { + return std::end(std::forward(range)); +} + +/// Returns the begin iterator to \p range using `std::begin` and +/// function found through Argument-Dependent Lookup (ADL). +template +constexpr auto adl_begin(RangeT &&range) + -> decltype(begin_impl(std::forward(range))) { + return begin_impl(std::forward(range)); +} + +/// Returns the end iterator to \p range using `std::end` and +/// functions found through Argument-Dependent Lookup (ADL). +template +constexpr auto adl_end(RangeT &&range) + -> decltype(end_impl(std::forward(range))) { + return end_impl(std::forward(range)); +} + +/// Provide wrappers to std::any_of which take ranges instead of having to pass +/// begin/end explicitly. +template +bool any_of(R &&Range, UnaryPredicate P) { + return std::any_of(adl_begin(Range), adl_end(Range), P); +} + +/// Wrapper function around std::count_if to count the number of times an +/// element satisfying a given predicate occurs in a range. +template +auto count_if(R &&Range, UnaryPredicate P) { + return std::count_if(adl_begin(Range), adl_end(Range), P); +} + +} // namespace details +namespace pir { +class ShapedTypeInterface : public pir::TypeInterfaceBase { + public: + using DDim = phi::DDim; + using DataType = pir::Type; + struct Concept { + /// Defined these methods with the interface. + explicit Concept(DataType (*get_element_type)(pir::Type), + DDim (*get_shape)(pir::Type)) + : get_element_type_(get_element_type), get_shape_(get_shape) {} + + DataType (*get_element_type_)(pir::Type); + DDim (*get_shape_)(pir::Type); + }; + + template + struct Model : public Concept { + static inline DataType getElementType(pir::Type type) { + return pir::cast(type).dtype(); + } + + static inline DDim getShape(pir::Type type) { + return pir::cast(type).dims(); + } + + Model() : Concept(getElementType, getShape) {} + }; + + /// Constructor + ShapedTypeInterface(pir::Type type, Concept *impl) + : pir::TypeInterfaceBase(type), impl_(impl) {} + + /// Get the element type. + DataType getElementType() const { return impl_->get_element_type_(*this); } + + /// Get the shape of this type. + DDim getShape() const { return impl_->get_shape_(*this); } + + static constexpr int64_t kDynamic = std::numeric_limits::min(); + + /// Check whether this type is ranked, currently return true. + bool hasRank() const { return true; } + + /// If this is a ranked type, return the rank. Otherwise, abort. + int64_t getRank() const { + IR_ENFORCE((*this).hasRank(), "Cannot query rank of unranked shaped type."); + return (*this).getShape().size(); + } + + /// Check whether the given dimension size is a dynamic dimension. + static constexpr bool isDynamic(int64_t dValue) { return dValue == kDynamic; } + + /// Check whether the given shape has any size indicating a dynamic dimension. + static bool isDynamicShape(DDim dSizes) { + return ::details::any_of(vectorize(dSizes), + [](int64_t dSize) { return isDynamic(dSize); }); + } + + /// Check whether the given dimension has a dynamic size. + /// Aborts for unranked types. + bool isDynamicDim(unsigned idx) const { + IR_ENFORCE(idx < getRank(), "Invalid index for shaped type."); + return pir::ShapedTypeInterface::isDynamic((*this).getShape()[idx]); + } + + /// Get the number of dimensions with dynamic size for a ranked type. + /// Aborts for unranked types. + int64_t getNumDynamicDims() const { + return ::details::count_if(vectorize((*this).getShape()), + pir::ShapedTypeInterface::isDynamic); + } + + /// Get the size of the specified dimension for a ranked type. + /// Aborts for unranked types. + int64_t getDimSize(unsigned idx) const { + IR_ENFORCE(idx < getRank(), "Invalid index for shaped type."); + return (*this).getShape()[idx]; + } + + private: + Concept *impl_; +}; + +} // namespace pir + +IR_DECLARE_EXPLICIT_TYPE_ID(pir::ShapedTypeInterface) diff --git a/paddle/pir/core/cast_utils.h b/paddle/pir/core/cast_utils.h index 3cc6e9abd09c43..db9f864aaabc3b 100644 --- a/paddle/pir/core/cast_utils.h +++ b/paddle/pir/core/cast_utils.h @@ -14,6 +14,7 @@ #pragma once +#include #include namespace pir { @@ -114,7 +115,7 @@ struct ReturnTypeDuduction { /// /// cast From to To /// -template +template struct cast_impl { // This _is_ a simple type, just cast it. static typename ReturnTypeDuduction::type call(const From &Val) { @@ -125,7 +126,15 @@ struct cast_impl { }; template -inline typename ReturnTypeDuduction::type cast(From &Val) { // NOLINT +inline decltype(auto) cast(const From &Val) { + if (!isa(Val)) { + throw("cast() argument of incompatible type!"); + } + return cast_impl::call(Val); +} + +template +inline decltype(auto) cast(From &Val) { // NOLINT if (!isa(Val)) { throw("cast() argument of incompatible type!"); } @@ -133,24 +142,31 @@ inline typename ReturnTypeDuduction::type cast(From &Val) { // NOLINT } template -inline typename ReturnTypeDuduction::type cast(From *Val) { +inline decltype(auto) cast(From *Val) { if (!isa(Val)) { throw("cast() argument of incompatible type!"); } return cast_impl::call(Val); } +template +inline decltype(auto) cast(std::unique_ptr &&Val) { + if (!isa(Val)) { + throw("cast() argument of incompatible type!"); + } + return cast_impl>::call(std::move(Val)); +} + /// /// \brief dyn_cast From to To. /// template -inline std::decay_t::type> dyn_cast( - From &Val) { // NOLINT +inline decltype(auto) dyn_cast(From &Val) { // NOLINT return isa(Val) ? cast(Val) : nullptr; } template -inline typename ReturnTypeDuduction::type dyn_cast(From *Val) { +inline decltype(auto) dyn_cast(From *Val) { return isa(Val) ? cast(Val) : nullptr; } diff --git a/paddle/pir/core/op_base.cc b/paddle/pir/core/interface_support.cc similarity index 75% rename from paddle/pir/core/op_base.cc rename to paddle/pir/core/interface_support.cc index a7ebd9febe973c..19cba9de0bd852 100644 --- a/paddle/pir/core/op_base.cc +++ b/paddle/pir/core/interface_support.cc @@ -12,19 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/pir/core/op_base.h" +#include "paddle/pir/core/interface_support.h" + namespace pir { -InterfaceValue::~InterfaceValue() { +details::InterfaceValue::~InterfaceValue() { if (model_) free(model_); } -InterfaceValue::InterfaceValue(InterfaceValue&& val) noexcept { +details::InterfaceValue::InterfaceValue(InterfaceValue&& val) noexcept { type_id_ = val.type_id_; model_ = val.model_; val.model_ = nullptr; } -InterfaceValue& InterfaceValue::operator=(InterfaceValue&& val) noexcept { +details::InterfaceValue& details::InterfaceValue::operator=( + InterfaceValue&& val) noexcept { swap(std::move(val)); return *this; } diff --git a/paddle/pir/core/interface_support.h b/paddle/pir/core/interface_support.h new file mode 100644 index 00000000000000..df8f776d7b87bf --- /dev/null +++ b/paddle/pir/core/interface_support.h @@ -0,0 +1,122 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/core/enforce.h" +#include "paddle/pir/core/interface_value.h" + +namespace pir { +namespace details { +template +class ConstructInterfacesOrTraits { + public: + /// Construct method for interfaces. + static details::InterfaceValue *interface( + details::InterfaceValue *p_interface) { + (void)std::initializer_list{ + 0, (PlacementConstrctInterface(p_interface), 0)...}; + return p_interface; + } + + /// Construct method for traits. + static TypeId *trait(TypeId *p_trait) { + (void)std::initializer_list{ + 0, (PlacementConstrctTrait(p_trait), 0)...}; + return p_trait; + } + + private: + /// Placement new interface. + template + static void PlacementConstrctInterface( + details::InterfaceValue *&p_interface) { // NOLINT + p_interface->swap(details::InterfaceValue::get()); + VLOG(6) << "New a interface: id[" + << (p_interface->type_id()).AsOpaquePointer() << "]."; + ++p_interface; + } + + /// Placement new trait. + template + static void PlacementConstrctTrait(pir::TypeId *&p_trait) { // NOLINT + *p_trait = TypeId::get(); + VLOG(6) << "New a trait: id[" << p_trait->AsOpaquePointer() << "]."; + ++p_trait; + } +}; + +/// Specialized for tuple type. +template +class ConstructInterfacesOrTraits> { + public: + /// Construct method for interfaces. + static details::InterfaceValue *interface( + details::InterfaceValue *p_interface) { + return ConstructInterfacesOrTraits::interface( + p_interface); + } + + /// Construct method for traits. + static TypeId *trait(TypeId *p_trait) { + return ConstructInterfacesOrTraits::trait(p_trait); + } +}; + +template +void *LookUp(const TypeId &interface_id, + const uint32_t num_interfaces, + const uint32_t num_traits, + const T *t) { + if (num_interfaces > 0) { + const details::InterfaceValue *p_first_interface = + reinterpret_cast( + reinterpret_cast(t) - sizeof(TypeId) * num_traits - + sizeof(details::InterfaceValue) * num_interfaces); + size_t left = 0, right = num_interfaces; + while (left < right) { + size_t mid = (left + right) / 2; + if ((p_first_interface + mid)->type_id() == interface_id) { + return (p_first_interface + mid)->model(); + } else if ((p_first_interface + mid)->type_id() < interface_id) { + left = mid + 1; + } else { + right = mid; + } + } + } + return nullptr; +} + +template +std::vector GetInterfaceMap() { + constexpr size_t interfaces_num = std::tuple_size::value; + std::vector interfaces_map(interfaces_num); + ConstructInterfacesOrTraits::interface( + interfaces_map.data()); + return interfaces_map; +} + +template +std::vector GetTraitSet() { + constexpr size_t traits_num = std::tuple_size::value; + std::vector trait_set(traits_num); + auto p_first_trait = trait_set.data(); + ConstructInterfacesOrTraits::trait(p_first_trait); + return trait_set; +} + +} // namespace details + +} // namespace pir diff --git a/paddle/pir/core/interface_value.h b/paddle/pir/core/interface_value.h new file mode 100644 index 00000000000000..fe7bc6d9ca2a82 --- /dev/null +++ b/paddle/pir/core/interface_value.h @@ -0,0 +1,67 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/pir/core/type_id.h" +#include "paddle/pir/core/utils.h" + +namespace pir { + +namespace details { +class IR_API InterfaceValue { + public: + template + static InterfaceValue get() { + InterfaceValue val; + val.type_id_ = TypeId::get(); + val.model_ = malloc(sizeof(typename T::template Model)); + if (val.model_ == nullptr) { + throw("Alloc memory for interface failed."); + } + static_assert(std::is_trivially_destructible< + typename T::template Model>::value, + "interface models must be trivially destructible"); + new (val.model_) typename T::template Model(); + return val; + } + TypeId type_id() const { return type_id_; } + void *model() const { return model_; } + + InterfaceValue() = default; + explicit InterfaceValue(TypeId type_id) : type_id_(type_id) {} + InterfaceValue(const InterfaceValue &) = delete; + InterfaceValue(InterfaceValue &&) noexcept; + InterfaceValue &operator=(const InterfaceValue &) = delete; + InterfaceValue &operator=(InterfaceValue &&) noexcept; + ~InterfaceValue(); + void swap(InterfaceValue &&val) { + using std::swap; + swap(type_id_, val.type_id_); + swap(model_, val.model_); + } + + /// + /// \brief Comparison operations. + /// + inline bool operator<(const InterfaceValue &other) const { + return type_id_ < other.type_id_; + } + + private: + TypeId type_id_; + void *model_{nullptr}; +}; + +} // namespace details +} // namespace pir diff --git a/paddle/pir/core/ir_context.cc b/paddle/pir/core/ir_context.cc index b7aca14e8f60b3..bfc05fabcf35b1 100644 --- a/paddle/pir/core/ir_context.cc +++ b/paddle/pir/core/ir_context.cc @@ -284,14 +284,15 @@ void IrContext::RegisterAbstractType(pir::TypeId type_id, } } -void IrContext::RegisterOpInfo(Dialect *dialect, - TypeId op_id, - const char *name, - std::vector &&interface_map, - const std::vector &trait_set, - size_t attributes_num, - const char **attributes_name, - VerifyPtr verify) { +void IrContext::RegisterOpInfo( + Dialect *dialect, + TypeId op_id, + const char *name, + std::vector &&interface_map, + const std::vector &trait_set, + size_t attributes_num, + const char **attributes_name, + VerifyPtr verify) { if (impl().IsOpInfoRegistered(name)) { LOG(WARNING) << name << " op already registered."; } else { diff --git a/paddle/pir/core/ir_context.h b/paddle/pir/core/ir_context.h index a9b45d2cb82927..a68c87f3bee0b1 100644 --- a/paddle/pir/core/ir_context.h +++ b/paddle/pir/core/ir_context.h @@ -28,12 +28,13 @@ class AbstractAttribute; class TypeId; class Dialect; class OpInfo; -class InterfaceValue; class Type; class OpResult; class Attribute; class Operation; - +namespace details { +class InterfaceValue; +} using OpInfoMap = std::unordered_map; /// @@ -109,7 +110,7 @@ class IR_API IrContext { void RegisterOpInfo(Dialect *dialect, TypeId op_id, const char *name, - std::vector &&interface_map, + std::vector &&interface_map, const std::vector &trait_set, size_t attributes_num, const char **attributes_name, diff --git a/paddle/pir/core/op_base.h b/paddle/pir/core/op_base.h index a19f906cbe8ca0..5eb637ea52b541 100644 --- a/paddle/pir/core/op_base.h +++ b/paddle/pir/core/op_base.h @@ -16,55 +16,12 @@ #include #include "paddle/pir/core/enforce.h" +#include "paddle/pir/core/interface_support.h" #include "paddle/pir/core/operation.h" #include "paddle/pir/core/utils.h" namespace pir { -class IR_API InterfaceValue { - public: - template - static InterfaceValue get() { - InterfaceValue val; - val.type_id_ = TypeId::get(); - val.model_ = malloc(sizeof(typename T::template Model)); - if (val.model_ == nullptr) { - throw("Alloc memory for interface failed."); - } - static_assert(std::is_trivially_destructible< - typename T::template Model>::value, - "interface models must be trivially destructible"); - new (val.model_) typename T::template Model(); - return val; - } - TypeId type_id() const { return type_id_; } - void *model() const { return model_; } - - InterfaceValue() = default; - explicit InterfaceValue(TypeId type_id) : type_id_(type_id) {} - InterfaceValue(const InterfaceValue &) = delete; - InterfaceValue(InterfaceValue &&) noexcept; - InterfaceValue &operator=(const InterfaceValue &) = delete; - InterfaceValue &operator=(InterfaceValue &&) noexcept; - ~InterfaceValue(); - void swap(InterfaceValue &&val) { - using std::swap; - swap(type_id_, val.type_id_); - swap(model_, val.model_); - } - - /// - /// \brief Comparison operations. - /// - inline bool operator<(const InterfaceValue &other) const { - return type_id_ < other.type_id_; - } - - private: - TypeId type_id_; - void *model_{nullptr}; -}; - class IR_API OpBase { public: explicit OpBase(Operation *operation = nullptr) : operation_(operation) {} @@ -133,6 +90,7 @@ class OpInterfaceBase : public OpBase { public: explicit OpInterfaceBase(Operation *op) : OpBase(op) {} + // Accessor for the ID of this interface. static TypeId GetInterfaceId() { return TypeId::get(); } static ConcreteInterface dyn_cast(Operation *op) { @@ -144,59 +102,6 @@ class OpInterfaceBase : public OpBase { } }; -template -class ConstructInterfacesOrTraits { - public: - /// Construct method for interfaces. - static InterfaceValue *interface(InterfaceValue *p_interface) { - (void)std::initializer_list{ - 0, (PlacementConstrctInterface(p_interface), 0)...}; - return p_interface; - } - - /// Construct method for traits. - static TypeId *trait(TypeId *p_trait) { - (void)std::initializer_list{ - 0, (PlacementConstrctTrait(p_trait), 0)...}; - return p_trait; - } - - private: - /// Placement new interface. - template - static void PlacementConstrctInterface( - InterfaceValue *&p_interface) { // NOLINT - p_interface->swap(InterfaceValue::get()); - VLOG(6) << "New a interface: id[" - << (p_interface->type_id()).AsOpaquePointer() << "]."; - ++p_interface; - } - - /// Placement new trait. - template - static void PlacementConstrctTrait(pir::TypeId *&p_trait) { // NOLINT - *p_trait = TypeId::get(); - VLOG(6) << "New a trait: id[" << p_trait->AsOpaquePointer() << "]."; - ++p_trait; - } -}; - -/// Specialized for tuple type. -template -class ConstructInterfacesOrTraits> { - public: - /// Construct method for interfaces. - static InterfaceValue *interface(InterfaceValue *p_interface) { - return ConstructInterfacesOrTraits::interface( - p_interface); - } - - /// Construct method for traits. - static TypeId *trait(TypeId *p_trait) { - return ConstructInterfacesOrTraits::trait(p_trait); - } -}; - template class Op : public OpBase { public: @@ -219,26 +124,22 @@ class Op : public OpBase { return op && op->info().id() == TypeId::get(); } - static std::vector GetInterfaceMap() { - constexpr size_t interfaces_num = std::tuple_size::value; - std::vector interfaces_map(interfaces_num); - ConstructInterfacesOrTraits::interface( - interfaces_map.data()); - return interfaces_map; + static std::vector GetInterfaceMap() { + return pir::details::GetInterfaceMap(); } static std::vector GetTraitSet() { - constexpr size_t traits_num = std::tuple_size::value; - std::vector trait_set(traits_num); - auto p_first_trait = trait_set.data(); - ConstructInterfacesOrTraits::trait(p_first_trait); - return trait_set; + return pir::details::GetTraitSet(); } + + // Checking that the derived class does not define any member by comparing + // its size to an ad-hoc EmptyOp. static constexpr bool HasNoDataMembers() { class EmptyOp : public Op {}; return sizeof(ConcreteOp) == sizeof(EmptyOp); } + // Implementation of `VerifyInvariantsFn` OperationName hook. static void VerifyInvariants(Operation *op) { static_assert(HasNoDataMembers(), "Op class shouldn't define new data members"); diff --git a/paddle/pir/core/op_info.h b/paddle/pir/core/op_info.h index 322229d027c344..130c05037d8ae0 100644 --- a/paddle/pir/core/op_info.h +++ b/paddle/pir/core/op_info.h @@ -61,15 +61,15 @@ class IR_API OpInfo { bool HasTrait(TypeId trait_id) const; - template + template bool HasInterface() const { - return HasInterface(TypeId::get()); + return HasInterface(TypeId::get()); } bool HasInterface(TypeId interface_id) const; - template - typename Interface::Concept *GetInterfaceImpl() const; + template + typename InterfaceT::Concept *GetInterfaceImpl() const; void *AsOpaquePointer() const { return impl_; } static OpInfo RecoverFromOpaquePointer(void *pointer) { @@ -84,13 +84,19 @@ class IR_API OpInfo { void *GetInterfaceImpl(TypeId interface_id) const; private: - OpInfoImpl *impl_{nullptr}; // not owned + /// The internal implementation of the operation name. + /// Not owned. + OpInfoImpl *impl_{nullptr}; }; -template -typename Interface::Concept *OpInfo::GetInterfaceImpl() const { - void *model = GetInterfaceImpl(TypeId::get()); - return reinterpret_cast(model); +/// +/// \brief Returns an instance of the concept object for the given interface if +/// it was registered to this operation, null otherwise. +/// +template +typename InterfaceT::Concept *OpInfo::GetInterfaceImpl() const { + void *model = GetInterfaceImpl(TypeId::get()); + return reinterpret_cast(model); } } // namespace pir diff --git a/paddle/pir/core/op_info_impl.cc b/paddle/pir/core/op_info_impl.cc index e77bf4342f5860..fa91d3173389a0 100644 --- a/paddle/pir/core/op_info_impl.cc +++ b/paddle/pir/core/op_info_impl.cc @@ -14,12 +14,13 @@ #include "paddle/pir/core/op_info_impl.h" #include "paddle/pir/core/dialect.h" +#include "paddle/pir/core/interface_support.h" namespace pir { OpInfo OpInfoImpl::Create(Dialect *dialect, TypeId op_id, const char *op_name, - std::vector &&interface_map, + std::vector &&interface_map, const std::vector &trait_set, size_t attributes_num, const char *attributes_name[], // NOLINT @@ -29,7 +30,7 @@ OpInfo OpInfoImpl::Create(Dialect *dialect, size_t traits_num = trait_set.size(); VLOG(6) << "Create OpInfoImpl with: " << interfaces_num << " interfaces, " << traits_num << " traits, " << attributes_num << " attributes."; - size_t base_size = sizeof(InterfaceValue) * interfaces_num + + size_t base_size = sizeof(details::InterfaceValue) * interfaces_num + sizeof(TypeId) * traits_num + sizeof(OpInfoImpl); char *base_ptr = static_cast(::operator new(base_size)); VLOG(6) << "Malloc " << base_size << " Bytes at " @@ -37,10 +38,10 @@ OpInfo OpInfoImpl::Create(Dialect *dialect, if (interfaces_num > 0) { std::sort(interface_map.begin(), interface_map.end()); for (size_t index = 0; index < interfaces_num; ++index) { - new (base_ptr + index * sizeof(InterfaceValue)) - InterfaceValue(std::move(interface_map[index])); + new (base_ptr + index * sizeof(details::InterfaceValue)) + details::InterfaceValue(std::move(interface_map[index])); } - base_ptr += interfaces_num * sizeof(InterfaceValue); + base_ptr += interfaces_num * sizeof(details::InterfaceValue); } if (traits_num > 0) { auto p_first_trait = reinterpret_cast(base_ptr); @@ -86,38 +87,21 @@ bool OpInfoImpl::HasTrait(TypeId trait_id) const { bool OpInfoImpl::HasInterface(TypeId interface_id) const { if (num_interfaces_ > 0) { - const InterfaceValue *p_first_interface = - reinterpret_cast( + const details::InterfaceValue *p_first_interface = + reinterpret_cast( reinterpret_cast(this) - sizeof(pir::TypeId) * num_traits_ - - sizeof(InterfaceValue) * num_interfaces_); + sizeof(details::InterfaceValue) * num_interfaces_); return std::binary_search(p_first_interface, p_first_interface + num_interfaces_, - InterfaceValue(interface_id)); + details::InterfaceValue(interface_id)); } return false; } void *OpInfoImpl::GetInterfaceImpl(TypeId interface_id) const { - if (num_interfaces_ > 0) { - const InterfaceValue *p_first_interface = - reinterpret_cast( - reinterpret_cast(this) - - sizeof(TypeId) * num_traits_ - - sizeof(InterfaceValue) * num_interfaces_); - size_t left = 0, right = num_interfaces_; - while (left < right) { - size_t mid = (left + right) / 2; - if ((p_first_interface + mid)->type_id() == interface_id) { - return (p_first_interface + mid)->model(); - } else if ((p_first_interface + mid)->type_id() < interface_id) { - left = mid + 1; - } else { - right = mid; - } - } - } - return nullptr; + return pir::details::LookUp( + interface_id, num_interfaces_, num_traits_, this); } void OpInfoImpl::Destroy() { @@ -125,10 +109,10 @@ void OpInfoImpl::Destroy() { // (1) free interfaces char *base_ptr = reinterpret_cast(this) - sizeof(pir::TypeId) * num_traits_ - - sizeof(InterfaceValue) * num_interfaces_; + sizeof(details::InterfaceValue) * num_interfaces_; if (num_interfaces_ > 0) { - InterfaceValue *p_interface_val = - reinterpret_cast(base_ptr); + details::InterfaceValue *p_interface_val = + reinterpret_cast(base_ptr); for (size_t i = 0; i < num_interfaces_; i++) { (p_interface_val + i)->~InterfaceValue(); } diff --git a/paddle/pir/core/op_info_impl.h b/paddle/pir/core/op_info_impl.h index cc63a52d40064a..410c9ef3719898 100644 --- a/paddle/pir/core/op_info_impl.h +++ b/paddle/pir/core/op_info_impl.h @@ -38,7 +38,7 @@ class OpInfoImpl { static OpInfo Create(Dialect *dialect, TypeId op_id, const char *op_name, - std::vector &&interface_map, + std::vector &&interface_map, const std::vector &trait_set, size_t attributes_num, const char *attributes_name[], diff --git a/paddle/pir/core/storage_manager_support.h b/paddle/pir/core/storage_manager_support.h new file mode 100644 index 00000000000000..a54e066a0e2a6a --- /dev/null +++ b/paddle/pir/core/storage_manager_support.h @@ -0,0 +1,106 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/core/interface_support.h" +#include "paddle/pir/core/ir_context.h" +#include "paddle/pir/core/type.h" +#include "paddle/pir/core/type_base.h" +#include "paddle/pir/core/type_id.h" + +namespace pir { +template +class TypeInterfaceBase; + +namespace detail { + +namespace storage_helper_base_impl { +/// Returns true if this given Trait ID matches the IDs of any of the provided +/// trait types `Traits`. +template