diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt index 1c4569ecfa5848..9788e24e4a1d91 100644 --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -21,6 +21,7 @@ add_subdirectory(Math) add_subdirectory(MemRef) add_subdirectory(Mesh) add_subdirectory(MLProgram) +add_subdirectory(MPI) add_subdirectory(NVGPU) add_subdirectory(OpenACC) add_subdirectory(OpenACCMPCommon) diff --git a/mlir/include/mlir/Dialect/MPI/CMakeLists.txt b/mlir/include/mlir/Dialect/MPI/CMakeLists.txt new file mode 100644 index 00000000000000..f33061b2d87cff --- /dev/null +++ b/mlir/include/mlir/Dialect/MPI/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/mlir/include/mlir/Dialect/MPI/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/MPI/IR/CMakeLists.txt new file mode 100644 index 00000000000000..dfec2ea486cb29 --- /dev/null +++ b/mlir/include/mlir/Dialect/MPI/IR/CMakeLists.txt @@ -0,0 +1,22 @@ +add_mlir_dialect(MPI mpi) +add_mlir_doc(MPIOps MPI Dialects/ -gen-dialect-doc) + +# Add MPI operations +set(LLVM_TARGET_DEFINITIONS MPIOps.td) +mlir_tablegen(MPIOps.h.inc -gen-op-decls) +mlir_tablegen(MPIOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRMPIOpsIncGen) + +# Add MPI types +set(LLVM_TARGET_DEFINITIONS MPITypes.td) +mlir_tablegen(MPITypesGen.h.inc -gen-typedef-decls) +mlir_tablegen(MPITypesGen.cpp.inc -gen-typedef-defs) +add_public_tablegen_target(MLIRMPITypesIncGen) + +# Add MPI attributes +set(LLVM_TARGET_DEFINITIONS MPI.td) +mlir_tablegen(MPIEnums.h.inc -gen-enum-decls) +mlir_tablegen(MPIEnums.cpp.inc -gen-enum-defs) +mlir_tablegen(MPIAttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(MPIAttrDefs.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(MLIRMPIAttrsIncGen) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPI.h b/mlir/include/mlir/Dialect/MPI/IR/MPI.h new file mode 100644 index 00000000000000..f06b911ce3fe31 --- /dev/null +++ b/mlir/include/mlir/Dialect/MPI/IR/MPI.h @@ -0,0 +1,33 @@ +//===- MPI.h - MPI dialect ----------------------------------------*- C++-*-==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_DIALECT_MPI_IR_MPI_H_ +#define MLIR_DIALECT_MPI_IR_MPI_H_ + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" + +//===----------------------------------------------------------------------===// +// MPIDialect +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MPI/IR/MPIDialect.h.inc" + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/MPI/IR/MPITypesGen.h.inc" + +#include "mlir/Dialect/MPI/IR/MPIEnums.h.inc" + +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/MPI/IR/MPIAttrDefs.h.inc" + +#define GET_OP_CLASSES +#include "mlir/Dialect/MPI/IR/MPIOps.h.inc" + +#endif // MLIR_DIALECT_MPI_IR_MPI_H_ diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPI.td b/mlir/include/mlir/Dialect/MPI/IR/MPI.td new file mode 100644 index 00000000000000..643612e1e2ee89 --- /dev/null +++ b/mlir/include/mlir/Dialect/MPI/IR/MPI.td @@ -0,0 +1,218 @@ +//===- MPI.td - Base defs for mpi dialect ------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_MPI_IR_MPI_TD +#define MLIR_DIALECT_MPI_IR_MPI_TD + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/EnumAttr.td" + +def MPI_Dialect : Dialect { + let name = "mpi"; + let cppNamespace = "::mlir::mpi"; + let description = [{ + This dialect models the Message Passing Interface (MPI), version + 4.0. It is meant to serve as an interfacing dialect that is targeted + by higher-level dialects. The MPI dialect itself can be lowered to + multiple MPI implementations and hide differences in ABI. The dialect + models the functions of the MPI specification as close to 1:1 as possible + while preserving SSA value semantics where it makes sense, and uses + `memref` types instead of bare pointers. + + This dialect is under active development, and while stability is an + eventual goal, it is not guaranteed at this juncture. Given the early + state, it is recommended to inquire further prior to using this dialect. + + For an in-depth documentation of the MPI library interface, please refer + to official documentation such as the + [OpenMPI online documentation](https://www.open-mpi.org/doc/current/). + }]; + + let useDefaultAttributePrinterParser = 1; + let useDefaultTypePrinterParser = 1; +} + +//===----------------------------------------------------------------------===// +// Error classes enum: +//===----------------------------------------------------------------------===// + +def MPI_CodeSuccess : I32EnumAttrCase<"MPI_SUCCESS", 0, "MPI_SUCCESS">; +def MPI_CodeErrAccess : I32EnumAttrCase<"MPI_ERR_ACCESS", 1, "MPI_ERR_ACCESS">; +def MPI_CodeErrAmode : I32EnumAttrCase<"MPI_ERR_AMODE", 2, "MPI_ERR_AMODE">; +def MPI_CodeErrArg : I32EnumAttrCase<"MPI_ERR_ARG", 3, "MPI_ERR_ARG">; +def MPI_CodeErrAssert : I32EnumAttrCase<"MPI_ERR_ASSERT", 4, "MPI_ERR_ASSERT">; +def MPI_CodeErrBadFile + : I32EnumAttrCase<"MPI_ERR_BAD_FILE", 5, "MPI_ERR_BAD_FILE">; +def MPI_CodeErrBase : I32EnumAttrCase<"MPI_ERR_BASE", 6, "MPI_ERR_BASE">; +def MPI_CodeErrBuffer : I32EnumAttrCase<"MPI_ERR_BUFFER", 7, "MPI_ERR_BUFFER">; +def MPI_CodeErrComm : I32EnumAttrCase<"MPI_ERR_COMM", 8, "MPI_ERR_COMM">; +def MPI_CodeErrConversion + : I32EnumAttrCase<"MPI_ERR_CONVERSION", 9, "MPI_ERR_CONVERSION">; +def MPI_CodeErrCount : I32EnumAttrCase<"MPI_ERR_COUNT", 10, "MPI_ERR_COUNT">; +def MPI_CodeErrDims : I32EnumAttrCase<"MPI_ERR_DIMS", 11, "MPI_ERR_DIMS">; +def MPI_CodeErrDisp : I32EnumAttrCase<"MPI_ERR_DISP", 12, "MPI_ERR_DISP">; +def MPI_CodeErrDupDatarep + : I32EnumAttrCase<"MPI_ERR_DUP_DATAREP", 13, "MPI_ERR_DUP_DATAREP">; +def MPI_CodeErrErrhandler + : I32EnumAttrCase<"MPI_ERR_ERRHANDLER", 14, "MPI_ERR_ERRHANDLER">; +def MPI_CodeErrFile : I32EnumAttrCase<"MPI_ERR_FILE", 15, "MPI_ERR_FILE">; +def MPI_CodeErrFileExists + : I32EnumAttrCase<"MPI_ERR_FILE_EXISTS", 16, "MPI_ERR_FILE_EXISTS">; +def MPI_CodeErrFileInUse + : I32EnumAttrCase<"MPI_ERR_FILE_IN_USE", 17, "MPI_ERR_FILE_IN_USE">; +def MPI_CodeErrGroup : I32EnumAttrCase<"MPI_ERR_GROUP", 18, "MPI_ERR_GROUP">; +def MPI_CodeErrInfo : I32EnumAttrCase<"MPI_ERR_INFO", 19, "MPI_ERR_INFO">; +def MPI_CodeErrInfoKey + : I32EnumAttrCase<"MPI_ERR_INFO_KEY", 20, "MPI_ERR_INFO_KEY">; +def MPI_CodeErrInfoNokey + : I32EnumAttrCase<"MPI_ERR_INFO_NOKEY", 21, "MPI_ERR_INFO_NOKEY">; +def MPI_CodeErrInfoValue + : I32EnumAttrCase<"MPI_ERR_INFO_VALUE", 22, "MPI_ERR_INFO_VALUE">; +def MPI_CodeErrInStatus + : I32EnumAttrCase<"MPI_ERR_IN_STATUS", 23, "MPI_ERR_IN_STATUS">; +def MPI_CodeErrIntern : I32EnumAttrCase<"MPI_ERR_INTERN", 24, "MPI_ERR_INTERN">; +def MPI_CodeErrIo : I32EnumAttrCase<"MPI_ERR_IO", 25, "MPI_ERR_IO">; +def MPI_CodeErrKeyval : I32EnumAttrCase<"MPI_ERR_KEYVAL", 26, "MPI_ERR_KEYVAL">; +def MPI_CodeErrLocktype + : I32EnumAttrCase<"MPI_ERR_LOCKTYPE", 27, "MPI_ERR_LOCKTYPE">; +def MPI_CodeErrName : I32EnumAttrCase<"MPI_ERR_NAME", 28, "MPI_ERR_NAME">; +def MPI_CodeErrNoMem : I32EnumAttrCase<"MPI_ERR_NO_MEM", 29, "MPI_ERR_NO_MEM">; +def MPI_CodeErrNoSpace + : I32EnumAttrCase<"MPI_ERR_NO_SPACE", 30, "MPI_ERR_NO_SPACE">; +def MPI_CodeErrNoSuchFile + : I32EnumAttrCase<"MPI_ERR_NO_SUCH_FILE", 31, "MPI_ERR_NO_SUCH_FILE">; +def MPI_CodeErrNotSame + : I32EnumAttrCase<"MPI_ERR_NOT_SAME", 32, "MPI_ERR_NOT_SAME">; +def MPI_CodeErrOp : I32EnumAttrCase<"MPI_ERR_OP", 33, "MPI_ERR_OP">; +def MPI_CodeErrOther : I32EnumAttrCase<"MPI_ERR_OTHER", 34, "MPI_ERR_OTHER">; +def MPI_CodeErrPending + : I32EnumAttrCase<"MPI_ERR_PENDING", 35, "MPI_ERR_PENDING">; +def MPI_CodeErrPort : I32EnumAttrCase<"MPI_ERR_PORT", 36, "MPI_ERR_PORT">; +def MPI_CodeErrProcAborted + : I32EnumAttrCase<"MPI_ERR_PROC_ABORTED", 37, "MPI_ERR_PROC_ABORTED">; +def MPI_CodeErrQuota : I32EnumAttrCase<"MPI_ERR_QUOTA", 38, "MPI_ERR_QUOTA">; +def MPI_CodeErrRank : I32EnumAttrCase<"MPI_ERR_RANK", 39, "MPI_ERR_RANK">; +def MPI_CodeErrReadOnly + : I32EnumAttrCase<"MPI_ERR_READ_ONLY", 40, "MPI_ERR_READ_ONLY">; +def MPI_CodeErrRequest + : I32EnumAttrCase<"MPI_ERR_REQUEST", 41, "MPI_ERR_REQUEST">; +def MPI_CodeErrRmaAttach + : I32EnumAttrCase<"MPI_ERR_RMA_ATTACH", 42, "MPI_ERR_RMA_ATTACH">; +def MPI_CodeErrRmaConflict + : I32EnumAttrCase<"MPI_ERR_RMA_CONFLICT", 43, "MPI_ERR_RMA_CONFLICT">; +def MPI_CodeErrRmaFlavor + : I32EnumAttrCase<"MPI_ERR_RMA_FLAVOR", 44, "MPI_ERR_RMA_FLAVOR">; +def MPI_CodeErrRmaRange + : I32EnumAttrCase<"MPI_ERR_RMA_RANGE", 45, "MPI_ERR_RMA_RANGE">; +def MPI_CodeErrRmaShared + : I32EnumAttrCase<"MPI_ERR_RMA_SHARED", 46, "MPI_ERR_RMA_SHARED">; +def MPI_CodeErrRmaSync + : I32EnumAttrCase<"MPI_ERR_RMA_SYNC", 47, "MPI_ERR_RMA_SYNC">; +def MPI_CodeErrRoot : I32EnumAttrCase<"MPI_ERR_ROOT", 48, "MPI_ERR_ROOT">; +def MPI_CodeErrService + : I32EnumAttrCase<"MPI_ERR_SERVICE", 49, "MPI_ERR_SERVICE">; +def MPI_CodeErrSession + : I32EnumAttrCase<"MPI_ERR_SESSION", 50, "MPI_ERR_SESSION">; +def MPI_CodeErrSize : I32EnumAttrCase<"MPI_ERR_SIZE", 51, "MPI_ERR_SIZE">; +def MPI_CodeErrSpawn : I32EnumAttrCase<"MPI_ERR_SPAWN", 52, "MPI_ERR_SPAWN">; +def MPI_CodeErrTag : I32EnumAttrCase<"MPI_ERR_TAG", 53, "MPI_ERR_TAG">; +def MPI_CodeErrTopology + : I32EnumAttrCase<"MPI_ERR_TOPOLOGY", 54, "MPI_ERR_TOPOLOGY">; +def MPI_CodeErrTruncate + : I32EnumAttrCase<"MPI_ERR_TRUNCATE", 55, "MPI_ERR_TRUNCATE">; +def MPI_CodeErrType : I32EnumAttrCase<"MPI_ERR_TYPE", 56, "MPI_ERR_TYPE">; +def MPI_CodeErrUnknown + : I32EnumAttrCase<"MPI_ERR_UNKNOWN", 57, "MPI_ERR_UNKNOWN">; +def MPI_CodeErrUnsupportedDatarep + : I32EnumAttrCase<"MPI_ERR_UNSUPPORTED_DATAREP", 58, + "MPI_ERR_UNSUPPORTED_DATAREP">; +def MPI_CodeErrUnsupportedOperation + : I32EnumAttrCase<"MPI_ERR_UNSUPPORTED_OPERATION", 59, + "MPI_ERR_UNSUPPORTED_OPERATION">; +def MPI_CodeErrValueTooLarge + : I32EnumAttrCase<"MPI_ERR_VALUE_TOO_LARGE", 60, "MPI_ERR_VALUE_TOO_LARGE">; +def MPI_CodeErrWin : I32EnumAttrCase<"MPI_ERR_WIN", 61, "MPI_ERR_WIN">; +def MPI_CodeErrLastcode + : I32EnumAttrCase<"MPI_ERR_LASTCODE", 62, "MPI_ERR_LASTCODE">; + +def MPI_ErrorClassEnum + : I32EnumAttr<"MPI_ErrorClassEnum", "MPI error class name", [ + MPI_CodeSuccess, + MPI_CodeErrAccess, + MPI_CodeErrAmode, + MPI_CodeErrArg, + MPI_CodeErrAssert, + MPI_CodeErrBadFile, + MPI_CodeErrBase, + MPI_CodeErrBuffer, + MPI_CodeErrComm, + MPI_CodeErrConversion, + MPI_CodeErrCount, + MPI_CodeErrDims, + MPI_CodeErrDisp, + MPI_CodeErrDupDatarep, + MPI_CodeErrErrhandler, + MPI_CodeErrFile, + MPI_CodeErrFileExists, + MPI_CodeErrFileInUse, + MPI_CodeErrGroup, + MPI_CodeErrInfo, + MPI_CodeErrInfoKey, + MPI_CodeErrInfoNokey, + MPI_CodeErrInfoValue, + MPI_CodeErrInStatus, + MPI_CodeErrIntern, + MPI_CodeErrIo, + MPI_CodeErrKeyval, + MPI_CodeErrLocktype, + MPI_CodeErrName, + MPI_CodeErrNoMem, + MPI_CodeErrNoSpace, + MPI_CodeErrNoSuchFile, + MPI_CodeErrNotSame, + MPI_CodeErrOp, + MPI_CodeErrOther, + MPI_CodeErrPending, + MPI_CodeErrPort, + MPI_CodeErrProcAborted, + MPI_CodeErrQuota, + MPI_CodeErrRank, + MPI_CodeErrReadOnly, + MPI_CodeErrRequest, + MPI_CodeErrRmaAttach, + MPI_CodeErrRmaConflict, + MPI_CodeErrRmaFlavor, + MPI_CodeErrRmaRange, + MPI_CodeErrRmaShared, + MPI_CodeErrRmaSync, + MPI_CodeErrRoot, + MPI_CodeErrService, + MPI_CodeErrSession, + MPI_CodeErrSize, + MPI_CodeErrSpawn, + MPI_CodeErrTag, + MPI_CodeErrTopology, + MPI_CodeErrTruncate, + MPI_CodeErrType, + MPI_CodeErrUnknown, + MPI_CodeErrUnsupportedDatarep, + MPI_CodeErrUnsupportedOperation, + MPI_CodeErrValueTooLarge, + MPI_CodeErrWin, + MPI_CodeErrLastcode + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::mpi"; +} + +def MPI_ErrorClassAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +#endif // MLIR_DIALECT_MPI_IR_MPI_TD diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td new file mode 100644 index 00000000000000..768f376e24da4c --- /dev/null +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -0,0 +1,189 @@ +//===- MPIops.td - Message Passing Interface Ops -----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MPI_MLIR_IR_MPIOPS_TD +#define MPI_MLIR_IR_MPIOPS_TD + +include "mlir/Dialect/MPI/IR/MPI.td" +include "mlir/Dialect/MPI/IR/MPITypes.td" + +class MPI_Op traits = []> + : Op; + +//===----------------------------------------------------------------------===// +// InitOp +//===----------------------------------------------------------------------===// + +def MPI_InitOp : MPI_Op<"init", []> { + let summary = + "Initialize the MPI library, equivalent to `MPI_Init(NULL, NULL)`"; + let description = [{ + This operation must preceed most MPI calls (except for very few exceptions, + please consult with the MPI specification on these). + + Passing &argc, &argv is not supported currently. + + This operation can optionally return an `!mpi.retval` value that can be used + to check for errors. + }]; + + let results = (outs Optional:$retval); + + let assemblyFormat = "attr-dict (`:` type($retval)^)?"; +} + +//===----------------------------------------------------------------------===// +// CommRankOp +//===----------------------------------------------------------------------===// + +def MPI_CommRankOp : MPI_Op<"comm_rank", []> { + let summary = "Get the current rank, equivalent to " + "`MPI_Comm_rank(MPI_COMM_WORLD, &rank)`"; + let description = [{ + Communicators other than `MPI_COMM_WORLD` are not supported for now. + + This operation can optionally return an `!mpi.retval` value that can be used + to check for errors. + }]; + + let results = ( + outs Optional : $retval, + I32 : $rank + ); + + let assemblyFormat = "attr-dict `:` type(results)"; +} + +//===----------------------------------------------------------------------===// +// SendOp +//===----------------------------------------------------------------------===// + +def MPI_SendOp : MPI_Op<"send", []> { + let summary = + "Equivalent to `MPI_Send(ptr, size, dtype, dest, tag, MPI_COMM_WORLD)`"; + let description = [{ + MPI_Send performs a blocking send of `size` elements of type `dtype` to rank + `dest`. The `tag` value and communicator enables the library to determine + the matching of multiple sends and receives between the same ranks. + + Communicators other than `MPI_COMM_WORLD` are not supprted for now. + + This operation can optionally return an `!mpi.retval` value that can be used + to check for errors. + }]; + + let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank); + + let results = (outs Optional:$retval); + + let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` " + "type($ref) `,` type($tag) `,` type($rank)" + "(`->` type($retval)^)?"; +} + +//===----------------------------------------------------------------------===// +// RecvOp +//===----------------------------------------------------------------------===// + +def MPI_RecvOp : MPI_Op<"recv", []> { + let summary = "Equivalent to `MPI_Recv(ptr, size, dtype, dest, tag, " + "MPI_COMM_WORLD, MPI_STATUS_IGNORE)`"; + let description = [{ + MPI_Recv performs a blocking receive of `size` elements of type `dtype` + from rank `dest`. The `tag` value and communicator enables the library to + determine the matching of multiple sends and receives between the same + ranks. + + Communicators other than `MPI_COMM_WORLD` are not supprted for now. + The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object + is not yet ported to MLIR. + + This operation can optionally return an `!mpi.retval` value that can be used + to check for errors. + }]; + + let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank); + + let results = (outs Optional:$retval); + + let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` " + "type($ref) `,` type($tag) `,` type($rank)" + "(`->` type($retval)^)?"; +} + + +//===----------------------------------------------------------------------===// +// FinalizeOp +//===----------------------------------------------------------------------===// + +def MPI_FinalizeOp : MPI_Op<"finalize", []> { + let summary = "Finalize the MPI library, equivalent to `MPI_Finalize()`"; + let description = [{ + This function cleans up the MPI state. Afterwards, no MPI methods may + be invoked (excpet for MPI_Get_version, MPI_Initialized, and MPI_Finalized). + Notably, MPI_Init cannot be called again in the same program. + + This operation can optionally return an `!mpi.retval` value that can be used + to check for errors. + }]; + + let results = (outs Optional:$retval); + + let assemblyFormat = "attr-dict (`:` type($retval)^)?"; +} + + +//===----------------------------------------------------------------------===// +// RetvalCheckOp +//===----------------------------------------------------------------------===// + +def MPI_RetvalCheckOp : MPI_Op<"retval_check", []> { + let summary = "Check an MPI return value against an error class"; + let description = [{ + This operation compares MPI status codes to known error class + constants such as `MPI_SUCCESS`, or `MPI_ERR_COMM`. + }]; + + let arguments = ( + ins MPI_Retval:$val, + MPI_ErrorClassAttr:$errclass + ); + + let results = ( + outs I1:$res + ); + + let assemblyFormat = "$val `=` $errclass attr-dict `:` type($res)"; +} + + + +//===----------------------------------------------------------------------===// +// RetvalCheckOp +//===----------------------------------------------------------------------===// + +def MPI_ErrorClassOp : MPI_Op<"error_class", []> { + let summary = "Get the error class from an error code, equivalent to " + "the `MPI_Error_class` function"; + let description = [{ + `MPI_Error_class` maps return values from MPI calls to a set of well-known + MPI error classes. + }]; + + let arguments = ( + ins MPI_Retval:$val + ); + + let results = ( + outs MPI_Retval:$errclass + ); + + let assemblyFormat = "$val attr-dict `:` type($val)"; +} + +#endif // MPI_MLIR_IR_MPIOPS_TD diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td new file mode 100644 index 00000000000000..87eefa719d45c0 --- /dev/null +++ b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td @@ -0,0 +1,43 @@ +//===- MPITypes.td - Message Passing Interface types -------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the Message Passing Interface dialect types. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_MPI_IR_MPITYPES_TD +#define MLIR_DIALECT_MPI_IR_MPITYPES_TD + +include "mlir/IR/AttrTypeBase.td" +include "mlir/Dialect/MPI/IR/MPI.td" + +//===----------------------------------------------------------------------===// +// MPI Types +//===----------------------------------------------------------------------===// + +class MPI_Type traits = []> + : TypeDef { + let mnemonic = typeMnemonic; +} + +//===----------------------------------------------------------------------===// +// mpi::RetvalType +//===----------------------------------------------------------------------===// + +def MPI_Retval : MPI_Type<"Retval", "retval"> { + let summary = "MPI function call return value"; + let description = [{ + This type represents a return value from an MPI function call. + This value can be MPI_SUCCESS, MPI_ERR_IN_STATUS, or any error code. + + This return value can be compared agains the known MPI error classes + represented by `#mpi.errclass` using the `mpi.retval_check` operation. + }]; +} + +#endif // MLIR_DIALECT_MPI_IR_MPITYPES_TD diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h index 3c145540356bd7..e508d51205f347 100644 --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -49,6 +49,7 @@ #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/MPI/IR/MPI.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h" @@ -123,6 +124,7 @@ inline void registerAllDialects(DialectRegistry ®istry) { memref::MemRefDialect, mesh::MeshDialect, ml_program::MLProgramDialect, + mpi::MPIDialect, nvgpu::NVGPUDialect, NVVM::NVVMDialect, omp::OpenMPDialect, diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt index 68776a695cac4d..c72107939cf42b 100644 --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -21,6 +21,7 @@ add_subdirectory(Math) add_subdirectory(MemRef) add_subdirectory(Mesh) add_subdirectory(MLProgram) +add_subdirectory(MPI) add_subdirectory(NVGPU) add_subdirectory(OpenACC) add_subdirectory(OpenACCMPCommon) diff --git a/mlir/lib/Dialect/MPI/CMakeLists.txt b/mlir/lib/Dialect/MPI/CMakeLists.txt new file mode 100644 index 00000000000000..f33061b2d87cff --- /dev/null +++ b/mlir/lib/Dialect/MPI/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/mlir/lib/Dialect/MPI/IR/CMakeLists.txt b/mlir/lib/Dialect/MPI/IR/CMakeLists.txt new file mode 100644 index 00000000000000..29d0b2379da747 --- /dev/null +++ b/mlir/lib/Dialect/MPI/IR/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_dialect_library(MLIRMPIDialect + MPIOps.cpp + MPI.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MPI + + DEPENDS + MLIRMPIIncGen + MLIRMPIOpsIncGen + MLIRMPITypesIncGen + MLIRMPIAttrsIncGen + + LINK_LIBS PUBLIC + MLIRDialect + MLIRIR + MLIRInferTypeOpInterface + MLIRSideEffectInterfaces + ) diff --git a/mlir/lib/Dialect/MPI/IR/MPI.cpp b/mlir/lib/Dialect/MPI/IR/MPI.cpp new file mode 100644 index 00000000000000..d64850d5b706c6 --- /dev/null +++ b/mlir/lib/Dialect/MPI/IR/MPI.cpp @@ -0,0 +1,53 @@ +//===- MPI.cpp - MPI dialect implementation -------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MPI/IR/MPI.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/DialectImplementation.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::mpi; + +//===----------------------------------------------------------------------===// +/// Tablegen Definitions +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MPI/IR/MPI.cpp.inc" + +#include "mlir/Dialect/MPI/IR/MPIDialect.cpp.inc" + +void MPIDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/MPI/IR/MPIOps.cpp.inc" + >(); + + addTypes< +#define GET_TYPEDEF_LIST +#include "mlir/Dialect/MPI/IR/MPITypesGen.cpp.inc" + >(); + + addAttributes< +#define GET_ATTRDEF_LIST +#include "mlir/Dialect/MPI/IR/MPIAttrDefs.cpp.inc" + >(); +} + +//===----------------------------------------------------------------------===// +// TableGen'd dialect, type, and op definitions +//===----------------------------------------------------------------------===// + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/MPI/IR/MPITypesGen.cpp.inc" + +#include "mlir/Dialect/MPI/IR/MPIEnums.cpp.inc" + +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/MPI/IR/MPIAttrDefs.cpp.inc" diff --git a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp new file mode 100644 index 00000000000000..ddd77b8f586ee0 --- /dev/null +++ b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp @@ -0,0 +1,21 @@ +//===- MPIOps.cpp - MPI dialect ops implementation ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MPI/IR/MPI.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" + +using namespace mlir; +using namespace mlir::mpi; + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/MPI/IR/MPIOps.cpp.inc" diff --git a/mlir/test/Dialect/MPI/ops.mlir b/mlir/test/Dialect/MPI/ops.mlir new file mode 100644 index 00000000000000..8f2421a73396c2 --- /dev/null +++ b/mlir/test/Dialect/MPI/ops.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-opt %s | mlir-opt | FileCheck %s + +func.func @mpi_test(%ref : memref<100xf32>) -> () { + // Note: the !mpi.retval result is optional on all operations except mpi.error_class + + // CHECK: %0 = mpi.init : !mpi.retval + %err = mpi.init : !mpi.retval + + // CHECK-NEXT: %retval, %rank = mpi.comm_rank : !mpi.retval, i32 + %retval, %rank = mpi.comm_rank : !mpi.retval, i32 + + // CHECK-NEXT: mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 + mpi.send(%ref, %rank, %rank) : memref<100xf32>, i32, i32 + + // CHECK-NEXT: %1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval + %err2 = mpi.send(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval + + // CHECK-NEXT: mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 + mpi.recv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 + + // CHECK-NEXT: %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval + %err3 = mpi.recv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval + + // CHECK-NEXT: %3 = mpi.finalize : !mpi.retval + %rval = mpi.finalize : !mpi.retval + + // CHECK-NEXT: %4 = mpi.retval_check %retval = : i1 + %res = mpi.retval_check %retval = : i1 + + // CHECK-NEXT: %5 = mpi.error_class %0 : !mpi.retval + %errclass = mpi.error_class %err : !mpi.retval + + // CHECK-NEXT: return + func.return +}