Skip to content

[MLIR][Wasm] Introduce the WasmSSA MLIR dialect #149233

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,6 @@ add_subdirectory(Transform)
add_subdirectory(UB)
add_subdirectory(Utils)
add_subdirectory(Vector)
add_subdirectory(WasmSSA)
add_subdirectory(X86Vector)
add_subdirectory(XeGPU)
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/WasmSSA/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(IR)
13 changes: 13 additions & 0 deletions mlir/include/mlir/Dialect/WasmSSA/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
set(LLVM_TARGET_DEFINITIONS WasmSSATypes.td)
mlir_tablegen(WasmSSATypeConstraints.h.inc -gen-type-constraint-decls)
mlir_tablegen(WasmSSATypeConstraints.cpp.inc -gen-type-constraint-defs)

set (LLVM_TARGET_DEFINITIONS WasmSSAInterfaces.td)
mlir_tablegen(WasmSSAInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(WasmSSAInterfaces.cpp.inc -gen-op-interface-defs)
add_public_tablegen_target(MLIRWasmSSAInterfacesIncGen)

set(LLVM_TARGET_DEFINITIONS WasmSSAOps.td)

add_mlir_dialect(WasmSSAOps wasmssa)
add_mlir_doc(WasmSSAOps WasmSSAOps Dialects/ -gen-dialect-doc)
55 changes: 55 additions & 0 deletions mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSA.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
//===- WasmSSA.h - WasmSSA dialect ------------------*- C++-*-==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_WasmSSA_IR_WasmSSA_H_
#define MLIR_DIALECT_WasmSSA_IR_WasmSSA_H_

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/IR/Dialect.h"

//===----------------------------------------------------------------------===//
// WebAssemblyDialect
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/WasmSSA/IR/WasmSSAOpsDialect.h.inc"

//===----------------------------------------------------------------------===//
// WebAssembly Dialect Types
//===----------------------------------------------------------------------===//

#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/WasmSSA/IR/WasmSSAOpsTypes.h.inc"

//===----------------------------------------------------------------------===//
// WebAssembly Interfaces
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h"

//===----------------------------------------------------------------------===//
// WebAssembly Dialect Operations
//===----------------------------------------------------------------------===//
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"

//===----------------------------------------------------------------------===//
// WebAssembly Constraints
//===----------------------------------------------------------------------===//

namespace mlir {
namespace wasmssa {
#include "mlir/Dialect/WasmSSA/IR/WasmSSATypeConstraints.h.inc"
}
} // namespace mlir

#define GET_OP_CLASSES
#include "mlir/Dialect/WasmSSA/IR/WasmSSAOps.h.inc"

#endif // MLIR_DIALECT_WasmSSA_IR_WasmSSA_H_
25 changes: 25 additions & 0 deletions mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSABase.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
//===- WasmSSABase.td - Base defs for wasmssa dialect -*- tablegen -*-==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef WasmSSA_BASE
#define WasmSSA_BASE

include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpBase.td"

def WasmSSA_Dialect : Dialect {
let name = "wasmssa";
let cppNamespace = "::mlir::wasmssa";
let description = [{
The `wasmssa` dialect is intended to represent WebAssembly
modules in SSA form for easier manipulation.
}];
let useDefaultTypePrinterParser = true;
}

#endif //WasmSSA_BASE
60 changes: 60 additions & 0 deletions mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
//===- WasmSSAInterfaces.h - WasmSSA Interfaces ---*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines op interfaces for the WasmSSA dialect in MLIR.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_WasmSSA_IR_WasmSSAINTERFACES_H_
#define MLIR_DIALECT_WasmSSA_IR_WasmSSAINTERFACES_H_

#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/OpDefinition.h"

namespace mlir::wasmssa {
namespace detail {
/// Verify that `op` conforms to the ConstantExpressionInterface.
/// `op` must be initialized with valid constant expressions.
LogicalResult verifyConstantExpressionInterface(Operation *op);

/// Verify that `op` conforms to the LabelBranchingOpInterface
/// Checks that the branching is targetting something within its scope.
LogicalResult verifyLabelBranchingOpInterface(Operation *op);

/// Verify that `op` conforms to LabelLevelInterfaceIsTerminator
template <typename OpType>
LogicalResult verifyLabelLevelInterfaceIsTerminator() {
static_assert(OpType::template hasTrait<::mlir::OpTrait::IsTerminator>(),
"LabelLevelOp should be terminator ops");
return success();
}

/// Verify that `op` conforms to the LabelLevelInterface
/// `op`'s target should defined at the same scope level.
LogicalResult verifyLabelLevelInterface(Operation *op);
} // namespace detail

/// Operations implementing this trait are considered as valid
/// constant expressions in any context (In contrast of ConstantExprCheckOpInterface
/// which are sometimes considered valid constant expressions.
template <class OperationType>
struct AlwaysValidConstantExprOpTrait
: public OpTrait::TraitBase<OperationType, AlwaysValidConstantExprOpTrait> {};

/// Trait used to verify operations that need a constant expression initializer.
template<typename OpType>
struct ConstantExpressionInitializerOpTrait : public OpTrait::TraitBase<OpType, ConstantExpressionInitializerOpTrait>{
static LogicalResult verifyTrait(Operation* op) {
return detail::verifyConstantExpressionInterface(op);
}
};

} // namespace mlir::wasmssa
#include "mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.h.inc"

#endif // MLIR_DIALECT_WasmSSA_IR_WasmSSAINTERFACES_H_
162 changes: 162 additions & 0 deletions mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAInterfaces.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
//===-- WasmSSAInterfaces.td - WasmSSA Interfaces -*- tablegen -*--===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines interfaces for the WasmSSA dialect in MLIR.
//
//===----------------------------------------------------------------------===//

#ifndef WasmSSA_INTERFACES
#define WasmSSA_INTERFACES

include "mlir/IR/OpBase.td"
include "mlir/IR/BuiltinAttributes.td"

def LabelLevelOpInterface : OpInterface<"LabelLevelOpInterface"> {
let cppNamespace = "::mlir::wasmssa";
let description = [{
Operation that defines one level of nesting for wasm branching.

These ops defines Wasm control flow nesting levels (Wasm Labels) that Wasm
branching operations can target.
The branching operations specify a number of nesting level they want to exit,
and are redirected to the target of the corresponding nesting LabelLevelOp.

As multiple level can be escaped at once, the level defining ops need themselves
to be `Terminator` ops.
}];
let methods = [
InterfaceMethod<
/*desc=*/ "Returns the target block address",
/*returnType=*/ "::mlir::Block*",
/*methodName=*/ "getLabelTarget",
/*args=*/ (ins)
>
];

let verify = [{
return success(
succeeded(verifyLabelLevelInterfaceIsTerminator<ConcreteOp>()) &&
succeeded(verifyLabelLevelInterface($_op)));
}];
}

def LabelBranchingOpInterface : OpInterface<"LabelBranchingOpInterface"> {
let cppNamespace = "::mlir::wasmssa";
let description = [{
Wasm operation that targets a label for a jump.
}];
let methods = [
InterfaceMethod<
/*desc=*/ "Returns the number of context to break from",
/*returnType=*/ "size_t",
/*methodName=*/ "getExitLevel",
/*args=*/ (ins)
>,
InterfaceMethod<
/*desc=*/ "Returns the destination of this operation",
/*returnType=*/ "LabelLevelOpInterface",
/*methodName=*/ "getTargetOp",
/*args=*/ (ins),
/*methodBody=*/ [{
return *LabelBranchingOpInterface::getTargetOpFromBlock($_op.getOperation()->getBlock(), $_op.getExitLevel());
}]
>,
InterfaceMethod<
/*desc=*/ "Return the target control flow ops that defined the label of this operation",
/*returnType=*/ "::mlir::Block*",
/*methodName=*/ "getTarget",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = mlir::cast<LabelBranchingOpInterface>(this->getOperation());
return op.getTargetOp().getLabelTarget();
}]
>
];

let extraClassDeclaration = [{
static ::llvm::FailureOr<LabelLevelOpInterface> getTargetOpFromBlock(::mlir::Block *block, uint32_t level);
}];
let verify = [{return verifyLabelBranchingOpInterface($_op);}];
}

def ImportOpInterface : OpInterface<"ImportOpInterface"> {
let cppNamespace = "::mlir::wasmssa";
let description = [{
Operation that imports a symbol from an external wasm module;
}];

let methods = [
InterfaceMethod<
/*desc=*/ "Returns the module name for the import",
/*returnType=*/ "::llvm::StringRef",
/*methodName=*/ "getModuleName",
/*args=*/ (ins)
>,
InterfaceMethod<
/*desc=*/ "Returns the import name for the import",
/*returnType=*/ "::llvm::StringRef",
/*methodName=*/ "getImportName",
/*args=*/ (ins)
>,
InterfaceMethod<
/*desc=*/ "Returns the wasm index based symbol of the op",
/*returnType=*/ "::mlir::StringAttr",
/*methodName=*/ "getSymbolName",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = mlir::cast<ConcreteOp>(this->getOperation());
return op.getSymNameAttr();
}]
>,
InterfaceMethod<
/*desc=*/ "Returns the qualified name of the import",
/*returnType=*/ "std::string",
/*methodName=*/ "getQualifiedImportName",
/*args=*/ (ins),
/*methodBody=*/ [{
return ($_op.getModuleName() + ::llvm::Twine{"::"} + $_op.getImportName()).str();
}]
>,
];
}

def ConstantExpressionInitializerOpTrait : NativeOpTrait<"ConstantExpressionInitializerOpTrait"> {
let cppNamespace = "::mlir::wasmssa";
}

def ConstantExprCheckOpInterface :
OpInterface<"ConstantExprCheckOpInterface"> {
let cppNamespace = "::mlir::wasmssa";
let description = [{
Interface for allowing to verify that operations can be used in a Wasm Constant Expression.
}];

let methods = [
InterfaceMethod<
/*desc=*/ [{
Returns success if the current operation is valid in a constant expression context.
A diagnostic is emitted on error.
}],
/*returnType=*/ "::mlir::LogicalResult",
/*methodName=*/ "CheckValidInConstantExpr",
/*args=*/ (ins)
>
];
}

def AlwaysValidInConstantExprOpTrait : NativeOpTrait<"AlwaysValidConstantExprOpTrait", [], [{
::mlir::LogicalResult CheckValidInConstantExpr() {
return success();
}
}]> {
let cppNamespace = "::mlir::wasmssa";
}

#endif // WEBASSEMBLY_INTERFACES
Loading
Loading