Skip to content

Commit

Permalink
Initialize Paddle MLIR Dialect (PaddlePaddle#308)
Browse files Browse the repository at this point in the history
  • Loading branch information
Superjomn authored Dec 25, 2020
1 parent 0aa1e74 commit ce70b28
Show file tree
Hide file tree
Showing 13 changed files with 318 additions and 34 deletions.
34 changes: 19 additions & 15 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ set(DOWNLOAD_MODEL_DIR "${THIRD_PARTY_PATH}/model")
option(WITH_TESTING "Compile with Unittests" ON)
option(WITH_CUDA "Compile with CUDA support" OFF)
option(WITH_DEBUG "Compile with debug information" OFF)
option(PUBLISH_LIBS "Whether to publish compiled libraries" OFF)

if (WITH_TESTING)
add_definitions(-DCINN_WITH_TEST)
Expand Down Expand Up @@ -89,6 +90,7 @@ add_dependencies(cinncore GEN_LLVM_RUNTIME_IR_HEADER
cinn_base_inc
tensor_shape_inc
dense_tensor_inc
pd_ops_inc
${core_deps}
)

Expand All @@ -97,21 +99,23 @@ if (WITH_CUDA)
endif()

# --------distribute cinncore lib and include begin--------
foreach(header ${core_includes})
get_filename_component(prefix ${header} DIRECTORY)
file(COPY ${header} DESTINATION ${CMAKE_BINARY_DIR}/dist/cinn/include/${prefix})
endforeach()

add_custom_command(TARGET cinncore POST_BUILD
COMMAND cmake -E copy ${CMAKE_SOURCE_DIR}/tutorials/demo.cc ${CMAKE_BINARY_DIR}/dist/demo.cc
COMMAND cmake -E copy ${CMAKE_SOURCE_DIR}/tutorials/build_demo.sh ${CMAKE_BINARY_DIR}/dist/build_demo.sh
COMMAND cmake -E copy ${CMAKE_BINARY_DIR}/libcinncore.a ${CMAKE_BINARY_DIR}/dist/cinn/lib/libcinncore.a
COMMAND cmake -E copy ${CMAKE_BINARY_DIR}/cinn/frontend/paddle/libframework_proto.a ${CMAKE_BINARY_DIR}/dist/cinn/lib/libframework_proto.a
COMMAND cmake -E copy_directory ${LLVM11_DIR} ${CMAKE_BINARY_DIR}/dist/third_party/llvm11
COMMAND cmake -E copy_directory ${CMAKE_BINARY_DIR}/thirds/install ${CMAKE_BINARY_DIR}/dist/third_party
COMMENT "distribute libcinncore.a and related header files."
DEPENDS cinncore
)
if (PUBLISH_LIBS)
foreach(header ${core_includes})
get_filename_component(prefix ${header} DIRECTORY)
file(COPY ${header} DESTINATION ${CMAKE_BINARY_DIR}/dist/cinn/include/${prefix})
endforeach()

add_custom_command(TARGET cinncore POST_BUILD
COMMAND cmake -E copy ${CMAKE_SOURCE_DIR}/tutorials/demo.cc ${CMAKE_BINARY_DIR}/dist/demo.cc
COMMAND cmake -E copy ${CMAKE_SOURCE_DIR}/tutorials/build_demo.sh ${CMAKE_BINARY_DIR}/dist/build_demo.sh
COMMAND cmake -E copy ${CMAKE_BINARY_DIR}/libcinncore.a ${CMAKE_BINARY_DIR}/dist/cinn/lib/libcinncore.a
COMMAND cmake -E copy ${CMAKE_BINARY_DIR}/cinn/frontend/paddle/libframework_proto.a ${CMAKE_BINARY_DIR}/dist/cinn/lib/libframework_proto.a
COMMAND cmake -E copy_directory ${LLVM11_DIR} ${CMAKE_BINARY_DIR}/dist/third_party/llvm11
COMMAND cmake -E copy_directory ${CMAKE_BINARY_DIR}/thirds/install ${CMAKE_BINARY_DIR}/dist/third_party
COMMENT "distribute libcinncore.a and related header files."
DEPENDS cinncore
)
endif()
# --------distribute cinncore lib and include end--------

add_subdirectory(python)
2 changes: 1 addition & 1 deletion build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ function cmake_ {
echo "set(WITH_CUDA $cuda_config)" >> $build_dir/config.cmake
echo "set(WITH_MKL_CBLAS ON)" >> $build_dir/config.cmake
cd $build_dir
cmake .. -DLLVM11_DIR=${LLVM11_DIR} -DLLVM_DIR=${LLVM11_DIR}/lib/cmake/llvm -DMLIR_DIR=${LLVM11_DIR}/lib/cmake/mlir
cmake .. -DLLVM11_DIR=${LLVM11_DIR} -DLLVM_DIR=${LLVM11_DIR}/lib/cmake/llvm -DMLIR_DIR=${LLVM11_DIR}/lib/cmake/mlir -DPUBLISH_LIBS=ON

make GEN_LLVM_RUNTIME_IR_HEADER
# make the code generated compilable
Expand Down
11 changes: 10 additions & 1 deletion cinnrt/dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,35 @@ set(srcs
dense_tensor.cc
mlir_loader.cc
diagnostic_utils.cc
pd_types.cc
pd_ops.cc
)

mlir_tablegen_on(ops)
mlir_tablegen_on(basic_kernels)
mlir_tablegen_on(cinn_base DIALECT cinn)
mlir_tablegen_on(tensor_shape DIALECT ts)
mlir_tablegen_on(dense_tensor DIALECT dt)
mlir_tablegen_on(pd_op_base DIALECT pd)
mlir_tablegen_on(pd_ops)

# TODO(Superjomn) add a cmake function cc_executable to ecapsulate the following code
add_executable(cinn-opt opt.cc)
target_link_libraries(cinn-opt cinncore ${mlir_libs})
add_dependencies(cinn-opt pd_ops_inc)


# MLIR opt tests
# %{
add_test(test_mlir_opt_on_basic ${CMAKE_BINARY_DIR}/cinnrt/dialect/cinn-opt
${CMAKE_SOURCE_DIR}/cinnrt/dialect/mlir_tests/basic.mlir)

add_test(test_mlir_opt_on_tensor_shape ${CMAKE_BINARY_DIR}/cinnrt/dialect/cinn-opt
set(cinn_opt_path ${CMAKE_BINARY_DIR}/cinnrt/dialect/cinn-opt)
add_test(test_mlir_opt_on_tensor_shape ${cinn_opt_path}
${CMAKE_SOURCE_DIR}/cinnrt/dialect/mlir_tests/tensor_shape.mlir)
add_test(test_mlir_opt_on_paddle_ops
${cinn_opt_path}
${CMAKE_SOURCE_DIR}/cinnrt/dialect/mlir_tests/paddle_ops.mlir)
# %}

cc_test(test_mlir_loader SRCS mlir_loader_test.cc DEPS cinncore ${MLIR_IR_LIBS})
Expand Down
3 changes: 2 additions & 1 deletion cinnrt/dialect/init_cinn_dialects.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
#include "cinnrt/dialect/basic_kernels.h"
#include "cinnrt/dialect/cinn_base.h"
#include "cinnrt/dialect/dense_tensor.h"
#include "cinnrt/dialect/pd_ops.h"
#include "cinnrt/dialect/tensor_shape.h"

namespace cinnrt {

void RegisterCinnDialects(mlir::DialectRegistry& registry) {
registry.insert<ts::TensorShapeDialect, dialect::CINNDialect, dt::DTDialect>();
registry.insert<ts::TensorShapeDialect, dialect::CINNDialect, dt::DTDialect, mlir::PD::PaddleDialect>();
}

} // namespace cinnrt
8 changes: 8 additions & 0 deletions cinnrt/dialect/mlir_tests/paddle_ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
func @ops() {
%a = pd.Feed() : tensor<?xf32>
%b = pd.Feed() : tensor<?xf32>

%c = "pd.Matmul"(%a, %b) {transpose_x=true, transpose_y=false} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>

cinn.return
}
77 changes: 77 additions & 0 deletions cinnrt/dialect/pd_op_base.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// This file defines some basic elements of Paddle(alias pd) dialect.
// We learned much from TensorFlow mlir dialect https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td

#ifndef PD_OP_BASE
#define PD_OP_BASE

include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

def PD_Dialect : Dialect {
let name = "pd";

let description = [{
The PaddlePaddle dialect.

This dialect contains the PaddlePaddle operators.
}];

let cppNamespace = "::mlir::PD";
}

class PD_Op<string mnemonic, list<OpTrait> traits = []> :
Op<PD_Dialect, mnemonic, traits>;


class PD_PaddleAttr <string name, string description> :
Attr<CPred<"$_self.isa<mlir::PD::" # name # "Attr>()">,
"PaddlePaddle " # description # " attribute">;


//===----------------------------------------------------------------------===//
// PaddlePaddle type definitions
//===----------------------------------------------------------------------===//

def PD_PDDialectType : Type<CPred<"$_self.isa<mlir::PD::PDType>()">, "PaddlePaddle type">;

class PD_PaddleType <string name, string description> :
Type<CPred<"$_self.isa<mlir::PD::" # name #"Type>()">,
"Paddle " # description # " type">,
BuildableType<"getType<mlir::PD::" # name # "Type>()">;

//===----------------------------------------------------------------------===//
// Integer types
def PD_Bool : AnyTypeOf<[I<1>], "bool">;
def PD_Int8 : AnyTypeOf<[I8], "8-bit integer">;
def PD_Int16 : AnyTypeOf<[I16], "16-bit integer">;
def PD_Int32 : AnyTypeOf<[I32], "32-bit integer">;
def PD_Int64 : AnyTypeOf<[I64], "64-bit integer">;

def PD_UInt8 : AnyTypeOf<[UI<8>], "8-bit unsigned integer">;
def PD_UInt16 : AnyTypeOf<[UI<16>], "16-bit unsigned integer">;
def PD_UInt32 : AnyTypeOf<[UI<32>], "32-bit unsigned integer">;
def PD_UInt64 : AnyTypeOf<[UI<64>], "64-bit unsigned integer">;

def PD_SInt : AnyTypeOf<[PD_Int8, PD_Int16, PD_Int32, PD_Int64], "signed integer">;
def PD_UInt : AnyTypeOf<[PD_UInt8, PD_UInt16, PD_UInt32, PD_UInt64], "unsigned integer">;
def PD_Int : AnyTypeOf<[PD_SInt, PD_UInt], "integer">;

// Float types
def PD_Float16 : AnyTypeOf<[F16], "16-bit float">;
def PD_Float32 : AnyTypeOf<[F32], "32-bit float">;
def PD_Float64 : AnyTypeOf<[F64], "64-bit float">;

def PD_Float : AnyTypeOf<[PD_Float16, PD_Float32, PD_Float64], "floating-point">;


// Tensor types

def PD_ElementType : Type<Or<[PD_Float.predicate,
PD_Bool.predicate,
PD_Int.predicate]>,
"pd.dtype">;

def PD_Tensor : TensorOf<[PD_ElementType]>;


#endif // PD_OP_BASE
26 changes: 26 additions & 0 deletions cinnrt/dialect/pd_ops.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#include "cinnrt/dialect/pd_ops.h"

namespace mlir {
namespace PD {

#define GET_OP_CLASSES
#include "cinnrt/dialect/pd_ops.hpp.inc"
#undef GET_OP_CLASSES

PaddleDialect::PaddleDialect(MLIRContext *context) : Dialect("pd", context, TypeID::get<PaddleDialect>()) {
addOperations<
#define GET_OP_LIST
#include "cinnrt/dialect/pd_ops.cpp.inc"
>();
#undef GET_OP_LIST

// Support unknown operations because not all Paddle operations are registered.
allowUnknownOperations();
}

#define GET_OP_CLASSES
#include "cinnrt/dialect/pd_ops.cpp.inc"
#undef GET_OP_CLASSES

} // namespace PD
} // namespace mlir
33 changes: 33 additions & 0 deletions cinnrt/dialect/pd_ops.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#pragma once

#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/DerivedAttributeOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

namespace mlir {
namespace PD {

class PaddleDialect : public Dialect {
public:
explicit PaddleDialect(MLIRContext* context);

static StringRef getDialectNamespace() { return "PD"; }

Type parseType(DialectAsmParser& parser) const override { return Dialect::parseType(parser); }
void printType(Type type, DialectAsmPrinter& printer) const override { Dialect::printType(type, printer); }
};

} // namespace PD
} // namespace mlir
86 changes: 86 additions & 0 deletions cinnrt/dialect/pd_ops.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#ifndef PD_OPS
#define PD_OPS

include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/IR/OpBase.td"
include "cinnrt/dialect/pd_op_base.td"

def PD_FeedOp : PD_Op<"Feed", [NoSideEffect]> {
let summary = "Feed Op";

let description = [{
Feed a tensor into the model.
}];

let arguments = (ins);
let results = (outs PD_Tensor:$out);

let assemblyFormat = [{
`(` `)` attr-dict `:` type($out)
}];
}

def PD_AbsOp : PD_Op<"Abs", [NoSideEffect]> {
let summary = "Computes the absolute value of a tensor";

let description = [{
}];

let arguments = (ins PD_Tensor:$x);
let results = (outs PD_Tensor:$y);
}

def PD_ReluOp : PD_Op<"Relu", [NoSideEffect]> {
let summary = "Computes the Relu of a tensor";

let description = [{
}];

let arguments = (ins PD_Tensor:$x);
let results = (outs PD_Tensor:$y);
}

def PD_Relu6Op : PD_Op<"Relu6", [NoSideEffect]> {
let summary = "Computes the Relu6 of a tensor";

let description = [{
}];

let arguments = (ins PD_Tensor:$x);
let results = (outs PD_Tensor:$y);
}

def PD_ElementwiseAdd : PD_Op<"ElementwiseAdd", [NoSideEffect]> {
let summary = "ElementwiseAdd Op";
let description = [{
}];

let arguments = (ins PD_Tensor:$x, PD_Tensor:$y, I32Attr:$axis);
let results = (outs PD_Tensor:$out);
}

def PD_ElementwiseMul : PD_Op<"ElementwiseMul", [NoSideEffect]> {
let summary = "ElementwiseMul Op";
let description = [{
}];

let arguments = (ins PD_Tensor:$x, PD_Tensor:$y, I32Attr:$axis);
let results = (outs PD_Tensor:$out);
}

def PD_MatmulOp : PD_Op<"Matmul", [NoSideEffect]> {
let summary = "Computes the matrix mulplication result of two tensors";
let description = [{
}];

let arguments = (ins PD_Tensor:$x, PD_Tensor:$y,
DefaultValuedAttr<BoolAttr, "false">:$transpose_x,
DefaultValuedAttr<BoolAttr, "false">:$transpose_y,
DefaultValuedAttr<F32Attr, "1.0">:$alpha);
let results = (outs PD_Tensor:$out);

//let hasCanonicalizer = 1;
}

#endif // PD_OPS
Empty file added cinnrt/dialect/pd_types.cc
Empty file.
42 changes: 42 additions & 0 deletions cinnrt/dialect/pd_types.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// This file defines the types used in PaddlePaddle MLIR dialect.
// We borrowed much ideas from tensorflow mlir dialect (tf_types.h in tensorflow).

#pragma once

#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"

namespace mlir {
namespace PD {

class PaddleType : public Type {
public:
using Type::Type;

static bool classof(Type type);
};

namespace detail {

template <typename Derived>
class PaddleTypeImpl : public Type::TypeBase<Derived, PaddleType, TypeStorage> {
public:
using Base = typename Type::TypeBase<Derived, PaddleType, TypeStorage>;
using PDBase = PaddleTypeImpl<Derived>;
using Base::Base;
};

} // namespace detail

#define HANDLE_PD_TYPE(pdtype, enumerant, name) \
class pdtype##Type : public detail::PaddleTypeImpl<pdtype##Type> { \
public: \
using PDBase::PDBase; \
};

} // namespace PD
} // namespace mlir
Loading

0 comments on commit ce70b28

Please sign in to comment.