Skip to content

Commit

Permalink
[NewIR] Add shapedtype interface (PaddlePaddle#56427)
Browse files Browse the repository at this point in the history
* add shapedtype interface

* add storage_manager_support

* delete DECLARE_TYPE_UTILITY_FUNCTOR

* split interfaceValue

* add dyn_cast_interface

* add unit_test

* change cast utils
  • Loading branch information
zhangbopd authored Sep 12, 2023
1 parent 36bec31 commit e6938c6
Show file tree
Hide file tree
Showing 24 changed files with 818 additions and 256 deletions.
20 changes: 10 additions & 10 deletions paddle/fluid/pir/dialect/kernel/ir/kernel_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@
namespace paddle {
namespace dialect {

class AllocatedDenseTensorType : public pir::Type {
class AllocatedDenseTensorType
: public pir::Type::TypeBase<AllocatedDenseTensorType,
pir::Type,
AllocatedDenseTensorTypeStorage> {
public:
using Type::Type;

DECLARE_TYPE_UTILITY_FUNCTOR(AllocatedDenseTensorType,
AllocatedDenseTensorTypeStorage);
using Base::Base;

static AllocatedDenseTensorType get(pir::IrContext *ctx,
const phi::Place &place,
Expand Down Expand Up @@ -62,12 +62,12 @@ class AllocatedDenseTensorType : public pir::Type {
const size_t &offset() const;
};

class AllocatedSelectedRowsType : public pir::Type {
class AllocatedSelectedRowsType
: public pir::Type::TypeBase<AllocatedSelectedRowsType,
pir::Type,
AllocatedSelectedRowsTypeStorage> {
public:
using Type::Type;

DECLARE_TYPE_UTILITY_FUNCTOR(AllocatedSelectedRowsType,
AllocatedSelectedRowsTypeStorage);
using Base::Base;

static AllocatedSelectedRowsType get(pir::IrContext *ctx,
const phi::Place &place,
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/pir/dialect/operator/interface/infermeta.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ namespace paddle {
namespace dialect {
class InferMetaInterface : public pir::OpInterfaceBase<InferMetaInterface> {
public:
/// Defined these methods with the interface.
struct Concept {
explicit Concept(void (*infer_meta)(phi::InferMetaContext *))
: infer_meta_(infer_meta) {}
Expand All @@ -28,13 +29,14 @@ class InferMetaInterface : public pir::OpInterfaceBase<InferMetaInterface> {

template <class ConcreteOp>
struct Model : public Concept {
static void InferMeta(phi::InferMetaContext *infer_meta) {
static inline void InferMeta(phi::InferMetaContext *infer_meta) {
return ConcreteOp::InferMeta(infer_meta);
}

Model() : Concept(InferMeta) {}
};

/// Constructor
InferMetaInterface(pir::Operation *op, Concept *impl)
: pir::OpInterfaceBase<InferMetaInterface>(op), impl_(impl) {}

Expand Down
11 changes: 7 additions & 4 deletions paddle/fluid/pir/dialect/operator/ir/op_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,19 @@

#include "paddle/fluid/pir/dialect/operator/ir/type_storage.h"
#include "paddle/pir/core/builtin_type.h"
#include "paddle/pir/core/builtin_type_interfaces.h"
#include "paddle/pir/core/type.h"

namespace paddle {
namespace dialect {

using DenseTensorType = pir::DenseTensorType;
class SelectedRowsType : public pir::Type {
class SelectedRowsType : public pir::Type::TypeBase<SelectedRowsType,
pir::Type,
SelectedRowsTypeStorage,
pir::ShapedTypeInterface> {
public:
using Type::Type;

DECLARE_TYPE_UTILITY_FUNCTOR(SelectedRowsType, SelectedRowsTypeStorage);
using Base::Base;

const pir::Type &dtype() const;

Expand Down
36 changes: 17 additions & 19 deletions paddle/pir/core/builtin_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@

#pragma once

#include "paddle/pir/core/builtin_type_interfaces.h"
#include "paddle/pir/core/builtin_type_storage.h"
#include "paddle/pir/core/type.h"

namespace pir {
///
/// \brief Define built-in parameterless types. Please add the necessary
/// interface functions for built-in types through the macro
/// DECLARE_TYPE_UTILITY_FUNCTOR.
/// \brief Define built-in parameterless types.
///
/// NOTE(zhangbo9674): If you need to directly
/// cache the object of this built-in type in IrContext, please overload the get
Expand All @@ -39,11 +38,10 @@ namespace pir {
// NOTE(dev): Currently Int8 are not considered as a cached member
// in IrContextImpl because it is not widely used.

class IR_API VectorType : public Type {
class IR_API VectorType
: public pir::Type::TypeBase<VectorType, pir::Type, VectorTypeStorage> {
public:
using Type::Type;

DECLARE_TYPE_UTILITY_FUNCTOR(VectorType, VectorTypeStorage);
using Base::Base;

std::vector<Type> data() const;

Expand All @@ -54,11 +52,12 @@ class IR_API VectorType : public Type {
Type operator[](size_t index) const { return data()[index]; }
};

class DenseTensorType : public pir::Type {
class DenseTensorType : public pir::Type::TypeBase<DenseTensorType,
pir::Type,
DenseTensorTypeStorage,
pir::ShapedTypeInterface> {
public:
using Type::Type;

DECLARE_TYPE_UTILITY_FUNCTOR(DenseTensorType, DenseTensorTypeStorage);
using Base::Base;

const pir::Type &dtype() const;

Expand All @@ -71,14 +70,13 @@ class DenseTensorType : public pir::Type {
const size_t &offset() const;
};

#define DECLARE_BUILTIN_TYPE(__name) \
class IR_API __name : public Type { \
public: \
using Type::Type; \
\
DECLARE_TYPE_UTILITY_FUNCTOR(__name, TypeStorage); \
\
static __name get(IrContext *context); \
#define DECLARE_BUILTIN_TYPE(__name) \
class IR_API __name : public ::pir::Type::TypeBase<__name, \
::pir::Type, \
::pir::TypeStorage> { \
public: \
using Base::Base; \
static __name get(IrContext *context); \
};

#define FOREACH_BUILTIN_TYPE(__macro) \
Expand Down
18 changes: 18 additions & 0 deletions paddle/pir/core/builtin_type_interfaces.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/pir/core/builtin_type_interfaces.h"
#include "paddle/pir/core/type_id.h"

IR_DEFINE_EXPLICIT_TYPE_ID(pir::ShapedTypeInterface)
153 changes: 153 additions & 0 deletions paddle/pir/core/builtin_type_interfaces.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <vector>
#include "paddle/phi/core/tensor_base.h"
#include "paddle/pir/core/cast_utils.h"
#include "paddle/pir/core/enforce.h"
#include "paddle/pir/core/type.h"

namespace details {

template <typename RangeT>
constexpr auto begin_impl(RangeT &&range)
-> decltype(std::begin(std::forward<RangeT>(range))) {
return std::begin(std::forward<RangeT>(range));
}

template <typename RangeT>
constexpr auto end_impl(RangeT &&range)
-> decltype(std::end(std::forward<RangeT>(range))) {
return std::end(std::forward<RangeT>(range));
}

/// Returns the begin iterator to \p range using `std::begin` and
/// function found through Argument-Dependent Lookup (ADL).
template <typename RangeT>
constexpr auto adl_begin(RangeT &&range)
-> decltype(begin_impl(std::forward<RangeT>(range))) {
return begin_impl(std::forward<RangeT>(range));
}

/// Returns the end iterator to \p range using `std::end` and
/// functions found through Argument-Dependent Lookup (ADL).
template <typename RangeT>
constexpr auto adl_end(RangeT &&range)
-> decltype(end_impl(std::forward<RangeT>(range))) {
return end_impl(std::forward<RangeT>(range));
}

/// Provide wrappers to std::any_of which take ranges instead of having to pass
/// begin/end explicitly.
template <typename R, typename UnaryPredicate>
bool any_of(R &&Range, UnaryPredicate P) {
return std::any_of(adl_begin(Range), adl_end(Range), P);
}

/// Wrapper function around std::count_if to count the number of times an
/// element satisfying a given predicate occurs in a range.
template <typename R, typename UnaryPredicate>
auto count_if(R &&Range, UnaryPredicate P) {
return std::count_if(adl_begin(Range), adl_end(Range), P);
}

} // namespace details
namespace pir {
class ShapedTypeInterface : public pir::TypeInterfaceBase<ShapedTypeInterface> {
public:
using DDim = phi::DDim;
using DataType = pir::Type;
struct Concept {
/// Defined these methods with the interface.
explicit Concept(DataType (*get_element_type)(pir::Type),
DDim (*get_shape)(pir::Type))
: get_element_type_(get_element_type), get_shape_(get_shape) {}

DataType (*get_element_type_)(pir::Type);
DDim (*get_shape_)(pir::Type);
};

template <class ConcreteType>
struct Model : public Concept {
static inline DataType getElementType(pir::Type type) {
return pir::cast<ConcreteType>(type).dtype();
}

static inline DDim getShape(pir::Type type) {
return pir::cast<ConcreteType>(type).dims();
}

Model() : Concept(getElementType, getShape) {}
};

/// Constructor
ShapedTypeInterface(pir::Type type, Concept *impl)
: pir::TypeInterfaceBase<ShapedTypeInterface>(type), impl_(impl) {}

/// Get the element type.
DataType getElementType() const { return impl_->get_element_type_(*this); }

/// Get the shape of this type.
DDim getShape() const { return impl_->get_shape_(*this); }

static constexpr int64_t kDynamic = std::numeric_limits<int64_t>::min();

/// Check whether this type is ranked, currently return true.
bool hasRank() const { return true; }

/// If this is a ranked type, return the rank. Otherwise, abort.
int64_t getRank() const {
IR_ENFORCE((*this).hasRank(), "Cannot query rank of unranked shaped type.");
return (*this).getShape().size();
}

/// Check whether the given dimension size is a dynamic dimension.
static constexpr bool isDynamic(int64_t dValue) { return dValue == kDynamic; }

/// Check whether the given shape has any size indicating a dynamic dimension.
static bool isDynamicShape(DDim dSizes) {
return ::details::any_of(vectorize(dSizes),
[](int64_t dSize) { return isDynamic(dSize); });
}

/// Check whether the given dimension has a dynamic size.
/// Aborts for unranked types.
bool isDynamicDim(unsigned idx) const {
IR_ENFORCE(idx < getRank(), "Invalid index for shaped type.");
return pir::ShapedTypeInterface::isDynamic((*this).getShape()[idx]);
}

/// Get the number of dimensions with dynamic size for a ranked type.
/// Aborts for unranked types.
int64_t getNumDynamicDims() const {
return ::details::count_if(vectorize((*this).getShape()),
pir::ShapedTypeInterface::isDynamic);
}

/// Get the size of the specified dimension for a ranked type.
/// Aborts for unranked types.
int64_t getDimSize(unsigned idx) const {
IR_ENFORCE(idx < getRank(), "Invalid index for shaped type.");
return (*this).getShape()[idx];
}

private:
Concept *impl_;
};

} // namespace pir

IR_DECLARE_EXPLICIT_TYPE_ID(pir::ShapedTypeInterface)
28 changes: 22 additions & 6 deletions paddle/pir/core/cast_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#pragma once

#include <memory>
#include <type_traits>

namespace pir {
Expand Down Expand Up @@ -114,7 +115,7 @@ struct ReturnTypeDuduction {
///
/// cast From to To
///
template <typename To, typename From>
template <typename To, typename From, typename Enable = void>
struct cast_impl {
// This _is_ a simple type, just cast it.
static typename ReturnTypeDuduction<To, From>::type call(const From &Val) {
Expand All @@ -125,32 +126,47 @@ struct cast_impl {
};

template <typename To, typename From>
inline typename ReturnTypeDuduction<To, From>::type cast(From &Val) { // NOLINT
inline decltype(auto) cast(const From &Val) {
if (!isa<To>(Val)) {
throw("cast<To>() argument of incompatible type!");
}
return cast_impl<To, const From>::call(Val);
}

template <typename To, typename From>
inline decltype(auto) cast(From &Val) { // NOLINT
if (!isa<To>(Val)) {
throw("cast<To>() argument of incompatible type!");
}
return cast_impl<To, From>::call(Val);
}

template <typename To, typename From>
inline typename ReturnTypeDuduction<To, From *>::type cast(From *Val) {
inline decltype(auto) cast(From *Val) {
if (!isa<To>(Val)) {
throw("cast<To>() argument of incompatible type!");
}
return cast_impl<To, From *>::call(Val);
}

template <typename To, typename From>
inline decltype(auto) cast(std::unique_ptr<From> &&Val) {
if (!isa<To>(Val)) {
throw("cast<To>() argument of incompatible type!");
}
return cast_impl<To, std::unique_ptr<From>>::call(std::move(Val));
}

///
/// \brief dyn_cast From to To.
///
template <typename To, typename From>
inline std::decay_t<typename ReturnTypeDuduction<To, From>::type> dyn_cast(
From &Val) { // NOLINT
inline decltype(auto) dyn_cast(From &Val) { // NOLINT
return isa<To>(Val) ? cast<To>(Val) : nullptr;
}

template <typename To, typename From>
inline typename ReturnTypeDuduction<To, From *>::type dyn_cast(From *Val) {
inline decltype(auto) dyn_cast(From *Val) {
return isa<To>(Val) ? cast<To>(Val) : nullptr;
}

Expand Down
Loading

0 comments on commit e6938c6

Please sign in to comment.