Skip to content

Commit 0402ca1

Browse files
committed
Create the DialectBuilder infrastructure (#41)
This PR introduces a `DialectBuilder` as a infrastructure to generate several MLIR dialects operations. Currently we have a `MemRefBuilder` and a `LLVMBuilder` both derived from an abstract `DialectBuilder` class. The 2 classes can be used to generate operations in the `MemRef` MLIR dialect and the `LLVM` MLIR dialect respectively. Note: not all possible operations for either builder have been implemented in this PR. Signed-off-by: Tiotto, Ettore <ettore.tiotto@intel.com>
1 parent f444338 commit 0402ca1

File tree

5 files changed

+247
-25
lines changed

5 files changed

+247
-25
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
//===- DialectBuilder.h - Dialect Builder ------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Declares builders to construct several dialect operations.
10+
// Builders are derived from an abstract 'DialectBuilder' base class.
11+
// Several builders can be added to this file.
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef MLIR_CONVERSION_SYCLTOLLVM_DIALECTBUILDER_H
15+
#define MLIR_CONVERSION_SYCLTOLLVM_DIALECTBUILDER_H
16+
17+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
18+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19+
#include "mlir/IR/Builders.h"
20+
#include "mlir/IR/Value.h"
21+
22+
namespace mlir {
23+
namespace sycl {
24+
25+
/// \class DialectBuilder
26+
/// Abstract base class for all dialect builders.
27+
class DialectBuilder {
28+
public:
29+
DialectBuilder(OpBuilder &b, Location loc) : builder(b), loc(loc) {}
30+
virtual ~DialectBuilder() = 0;
31+
32+
/// Inject a function declaration into the given module.
33+
FlatSymbolRefAttr getOrInsertFuncDecl(StringRef funcName, Type resultType,
34+
ArrayRef<Type> argsTypes,
35+
ModuleOp &module,
36+
bool isVarArg = false) const;
37+
38+
protected:
39+
/// Create a operation of type 'OP' given the argument list \p args, for the
40+
/// operation.
41+
template <typename OP, typename... Types> OP create(Types... args) const;
42+
43+
BoolAttr getBoolAttr(bool val) const;
44+
IntegerAttr getIntegerAttr(Type type, int64_t val) const;
45+
IntegerAttr getIntegerAttr(Type type, APInt val) const;
46+
FloatAttr getF16FloatAttr(float val) const;
47+
FloatAttr getF32FloatAttr(float val) const;
48+
FloatAttr getF64FloatAttr(double val) const;
49+
ArrayAttr getI64ArrayAttr(ArrayRef<int64_t>) const;
50+
51+
private:
52+
OpBuilder &builder;
53+
Location loc;
54+
};
55+
56+
/// \class MemRefBuilder
57+
/// Construct operations in the MemRef dialect.
58+
class MemRefBuilder : public DialectBuilder {
59+
public:
60+
MemRefBuilder(OpBuilder &b, Location loc) : DialectBuilder(b, loc) {}
61+
62+
memref::AllocOp genAlloc(MemRefType type) const;
63+
memref::AllocaOp genAlloca(MemRefType type) const;
64+
memref::CastOp genCast(Value input, MemRefType outputType) const;
65+
memref::DeallocOp genDealloc(Value val) const;
66+
};
67+
68+
/// \class LLVMBuilder
69+
/// Construct operations in the LLVM dialect.
70+
class LLVMBuilder : public DialectBuilder {
71+
public:
72+
LLVMBuilder(OpBuilder &b, Location loc) : DialectBuilder(b, loc) {}
73+
74+
LLVM::AllocaOp genAlloca(Type type, Value size, int64_t align) const;
75+
LLVM::BitcastOp genBitcast(Type type, Value val) const;
76+
LLVM::ExtractValueOp genExtractValue(Type type, Value container,
77+
ArrayRef<int64_t> pos) const;
78+
LLVM::CallOp genCall(FlatSymbolRefAttr funcSym, ArrayRef<Type> resTypes,
79+
ArrayRef<Value> operands) const;
80+
LLVM::ConstantOp genConstant(Type type, double val) const;
81+
LLVM::SExtOp genSignExtend(Type type, Value val) const;
82+
};
83+
84+
} // namespace sycl
85+
} // namespace mlir
86+
87+
#endif // MLIR_CONVERSION_SYCLTOLLVM_DIALECTBUILDER_H

mlir-sycl/include/mlir/Conversion/SYCLToLLVM/SYCLFuncRegistry.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class SYCLFuncRegistry;
2626
/// \class SYCLFuncDescriptor
2727
/// Represents a SYCL function (defined in a registry) that can be called by the
2828
/// compiler.
29-
/// Note: when a new enumerator is added the corresponding SYCLFuncDescriptor
29+
/// Note: when a new enumerator is added, the corresponding SYCLFuncDescriptor
3030
/// needs to be created in SYCLFuncRegistry constructor.
3131
class SYCLFuncDescriptor {
3232
friend class SYCLFuncRegistry;

mlir-sycl/lib/Conversion/SYCLToLLVM/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_mlir_conversion_library(MLIRSYCLToLLVM
2+
DialectBuilder.cpp
23
SYCLFuncRegistry.cpp
34
SYCLToLLVM.cpp
45
SYCLToLLVMPass.cpp
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
//===- DialectBuilder.cpp - Dialect Builder -------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements builders for several dialects operations.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Conversion/SYCLToLLVM/DialectBuilder.h"
14+
#include "mlir/IR/BuiltinTypes.h"
15+
#include "llvm/ADT/TypeSwitch.h"
16+
#include "llvm/Support/Debug.h"
17+
#include <cassert>
18+
19+
using namespace mlir;
20+
using namespace mlir::sycl;
21+
22+
#define DEBUG_TYPE "dialect-builder"
23+
24+
//===----------------------------------------------------------------------===//
25+
// DialectBuilder
26+
//===----------------------------------------------------------------------===//
27+
28+
DialectBuilder::~DialectBuilder() {}
29+
30+
template<typename OP, typename... Types>
31+
OP DialectBuilder::create(Types... args) const {
32+
return builder.create<OP>(loc, args...);
33+
}
34+
35+
FlatSymbolRefAttr DialectBuilder::getOrInsertFuncDecl(StringRef funcName,
36+
Type resultType,
37+
ArrayRef<Type> argsTypes,
38+
ModuleOp &module,
39+
bool isVarArg) const {
40+
if (module.lookupSymbol<LLVM::LLVMFuncOp>(funcName))
41+
return SymbolRefAttr::get(builder.getContext(), funcName);
42+
43+
OpBuilder::InsertionGuard guard(builder);
44+
builder.setInsertionPointToStart(module.getBody());
45+
auto funcType = LLVM::LLVMFunctionType::get(resultType, argsTypes, isVarArg);
46+
builder.create<LLVM::LLVMFuncOp>(module.getLoc(), funcName, funcType);
47+
return SymbolRefAttr::get(builder.getContext(), funcName);
48+
}
49+
50+
BoolAttr DialectBuilder::getBoolAttr(bool val) const {
51+
return builder.getBoolAttr(val);
52+
}
53+
54+
IntegerAttr DialectBuilder::getIntegerAttr(Type type, int64_t val) const {
55+
return builder.getIntegerAttr(type, val);
56+
}
57+
58+
IntegerAttr DialectBuilder::getIntegerAttr(Type type, APInt val) const {
59+
return builder.getIntegerAttr(type, val);
60+
}
61+
62+
FloatAttr DialectBuilder::getF16FloatAttr(float val) const {
63+
return builder.getF16FloatAttr(val);
64+
}
65+
66+
FloatAttr DialectBuilder::getF32FloatAttr(float val) const {
67+
return builder.getF32FloatAttr(val);
68+
}
69+
70+
FloatAttr DialectBuilder::getF64FloatAttr(double val) const {
71+
return builder.getF64FloatAttr(val);
72+
}
73+
74+
ArrayAttr DialectBuilder::getI64ArrayAttr(ArrayRef<int64_t> vals) const {
75+
return builder.getI64ArrayAttr(vals);
76+
}
77+
78+
//===----------------------------------------------------------------------===//
79+
// MemRefBuilder
80+
//===----------------------------------------------------------------------===//
81+
82+
memref::AllocOp MemRefBuilder::genAlloc(MemRefType type) const {
83+
return create<memref::AllocOp>(type);
84+
}
85+
86+
memref::AllocaOp MemRefBuilder::genAlloca(MemRefType type) const {
87+
return create<memref::AllocaOp>(type);
88+
}
89+
90+
memref::CastOp MemRefBuilder::genCast(Value input,
91+
MemRefType outputType) const {
92+
return create<memref::CastOp>(outputType, input);
93+
}
94+
95+
memref::DeallocOp MemRefBuilder::genDealloc(Value val) const {
96+
return create<memref::DeallocOp>(val);
97+
}
98+
99+
//===----------------------------------------------------------------------===//
100+
// LLVMBuilder
101+
//===----------------------------------------------------------------------===//
102+
103+
LLVM::AllocaOp LLVMBuilder::genAlloca(Type type, Value size,
104+
int64_t align) const {
105+
return create<LLVM::AllocaOp>(type, size, align);
106+
}
107+
108+
LLVM::BitcastOp LLVMBuilder::genBitcast(Type type, Value val) const {
109+
return create<LLVM::BitcastOp>(type, val);
110+
}
111+
112+
LLVM::ExtractValueOp LLVMBuilder::genExtractValue(Type type, Value container,
113+
ArrayRef<int64_t> position) const {
114+
return create<LLVM::ExtractValueOp>(type, container,
115+
getI64ArrayAttr(position));
116+
}
117+
118+
LLVM::CallOp LLVMBuilder::genCall(FlatSymbolRefAttr funcSym, ArrayRef<Type> resTypes,
119+
ArrayRef<Value> operands) const {
120+
return create<LLVM::CallOp>(resTypes, funcSym, operands);
121+
}
122+
123+
LLVM::ConstantOp LLVMBuilder::genConstant(Type type, double val) const {
124+
return llvm::TypeSwitch<Type, LLVM::ConstantOp>(type)
125+
.Case<IndexType>([&](IndexType type) {
126+
return create<LLVM::ConstantOp>(type,
127+
getIntegerAttr(type, (int64_t)val));
128+
})
129+
.Case<IntegerType>([&](IntegerType type) {
130+
bool isBool = (type.getWidth() == 1);
131+
return (isBool) ? create<LLVM::ConstantOp>(type, getBoolAttr(val != 0))
132+
: create<LLVM::ConstantOp>(
133+
type, getIntegerAttr(type, APInt(type.getWidth(),
134+
(int64_t)val)));
135+
})
136+
.Case<Float16Type>([&](Type) {
137+
return create<LLVM::ConstantOp>(type, getF16FloatAttr(val));
138+
})
139+
.Case<Float32Type>([&](Type) {
140+
return create<LLVM::ConstantOp>(type, getF32FloatAttr(val));
141+
})
142+
.Case<Float64Type>([&](Type) {
143+
return create<LLVM::ConstantOp>(type, getF64FloatAttr(val));
144+
})
145+
.Default([&](Type) {
146+
llvm_unreachable("Missing support for type");
147+
return LLVM::ConstantOp();
148+
});
149+
}
150+
151+
LLVM::SExtOp LLVMBuilder::genSignExtend(Type type, Value val) const {
152+
return create<LLVM::SExtOp>(type, val);
153+
}

mlir-sycl/lib/Conversion/SYCLToLLVM/SYCLFuncRegistry.cpp

+5-24
Original file line numberDiff line numberDiff line change
@@ -11,37 +11,21 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "mlir/Conversion/SYCLToLLVM/SYCLFuncRegistry.h"
14-
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14+
#include "mlir/Conversion/SYCLToLLVM/DialectBuilder.h"
1515
#include "llvm/Support/Debug.h"
1616

1717
#define DEBUG_TYPE "sycl-func-registry"
1818

1919
using namespace mlir;
2020
using namespace mlir::sycl;
2121

22-
// TODO: move in LLVMBuilder class when available.
23-
static FlatSymbolRefAttr getOrInsertFuncDecl(ModuleOp module, OpBuilder &b,
24-
StringRef funcName,
25-
Type resultType,
26-
ArrayRef<Type> argsTypes,
27-
bool isVarArg = false) {
28-
if (!module.lookupSymbol<LLVM::LLVMFuncOp>(funcName)) {
29-
OpBuilder::InsertionGuard guard(b);
30-
b.setInsertionPointToStart(module.getBody());
31-
LLVM::LLVMFunctionType funcType =
32-
LLVM::LLVMFunctionType::get(resultType, argsTypes, isVarArg);
33-
b.create<LLVM::LLVMFuncOp>(module.getLoc(), funcName, funcType);
34-
}
35-
return SymbolRefAttr::get(b.getContext(), funcName);
36-
}
37-
3822
//===----------------------------------------------------------------------===//
3923
// SYCLFuncDescriptor
4024
//===----------------------------------------------------------------------===//
4125

4226
void SYCLFuncDescriptor::declareFunction(ModuleOp &module, OpBuilder &b) {
43-
// TODO: use LLVMBuilder once available.
44-
funcRef = getOrInsertFuncDecl(module, b, name, outputTy, argTys);
27+
LLVMBuilder builder(b, module.getLoc());
28+
builder.getOrInsertFuncDecl(name, outputTy, argTys, module);
4529
}
4630

4731
Value SYCLFuncDescriptor::call(FuncId id, ArrayRef<Value> args,
@@ -52,13 +36,10 @@ Value SYCLFuncDescriptor::call(FuncId id, ArrayRef<Value> args,
5236
if (!funcDesc.outputTy.isa<LLVM::LLVMVoidType>())
5337
funcOutputTys.emplace_back(funcDesc.outputTy);
5438

55-
// TODO: generate the call via LLVMBuilder here
56-
// LLVMBuilder builder(b, loc);
57-
// return builder.call(funcDesc.funcRef, ArrayRef<Type>(funcOutputsTys), args);
39+
LLVMBuilder builder(b, loc);
40+
LLVM::CallOp callOp = builder.genCall(funcDesc.funcRef, funcOutputTys, args);
5841
// TODO: we could check here the arguments against the function signature and
5942
// assert if there is a mismatch.
60-
auto callOp = b.create<LLVM::CallOp>(loc, ArrayRef<Type>(funcOutputTys),
61-
funcDesc.funcRef, args);
6243
assert(callOp.getNumResults() == 1 && "expecting a single result");
6344

6445
return callOp.getResult(0);

0 commit comments

Comments
 (0)