Skip to content
This repository was archived by the owner on Jan 25, 2023. It is now read-only.

Commit 4c147f6

Browse files
author
Ivan Butygin
authored
[MLIR] Some linalg fixes and proper increfs/decrefs for array arguments (#199)
* some broadcoasting opt * rework broadcast * add pass * move force inline to opt pass * Proper increfs/decrefs for input arrays
1 parent bf50f4f commit 4c147f6

File tree

6 files changed

+127
-37
lines changed

6 files changed

+127
-37
lines changed

mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp

Lines changed: 46 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -610,12 +610,6 @@ struct ApplyFastmathFlags : public mlir::OpRewritePattern<Op>
610610
};
611611

612612
// Copypaste from StandardToLLVM
613-
mlir::Value createIndexAttrConstant(mlir::OpBuilder &builder, mlir::Location loc,
614-
mlir::Type resultType, int64_t value) {
615-
return builder.create<mlir::LLVM::ConstantOp>(
616-
loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value));
617-
}
618-
619613
struct AllocLikeOpLowering : public mlir::ConvertToLLVMPattern {
620614
using ConvertToLLVMPattern::createIndexConstant;
621615
using ConvertToLLVMPattern::getIndexType;
@@ -625,19 +619,6 @@ struct AllocLikeOpLowering : public mlir::ConvertToLLVMPattern {
625619
: ConvertToLLVMPattern(opName, &converter.getContext(), converter, /*benefit*/99) {}
626620

627621
protected:
628-
// Returns 'input' aligned up to 'alignment'. Computes
629-
// bumped = input + alignement - 1
630-
// aligned = bumped - bumped % alignment
631-
// static mlir::Value createAligned(mlir::ConversionPatternRewriter &rewriter, mlir::Location loc,
632-
// mlir::Value input, mlir::Value alignment) {
633-
// using namespace mlir;
634-
// Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1);
635-
// Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
636-
// Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
637-
// Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
638-
// return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
639-
// }
640-
641622
// Creates a call to an allocation function with params and casts the
642623
// resulting void pointer to ptrType.
643624
mlir::Value createAllocCall(mlir::Location loc, mlir::StringRef name, mlir::Type ptrType,
@@ -1227,6 +1208,51 @@ struct PostLLVMLowering :
12271208
}
12281209
};
12291210

1211+
struct LowerRetain : public mlir::OpConversionPattern<plier::RetainOp>
1212+
{
1213+
using mlir::OpConversionPattern<plier::RetainOp>::OpConversionPattern;
1214+
1215+
mlir::LogicalResult
1216+
matchAndRewrite(plier::RetainOp op, llvm::ArrayRef<mlir::Value> operands,
1217+
mlir::ConversionPatternRewriter &rewriter) const override {
1218+
assert(operands.size() == 1);
1219+
auto arg = operands[0];
1220+
if (!arg.getType().isa<mlir::LLVM::LLVMStructType>())
1221+
{
1222+
return mlir::failure();
1223+
}
1224+
1225+
auto llvmVoidPointerType =
1226+
mlir::LLVM::LLVMPointerType::get(rewriter.getIntegerType(8));
1227+
auto incref_func = [&]()
1228+
{
1229+
auto mod = op->getParentOfType<mlir::ModuleOp>();
1230+
assert(mod);
1231+
auto func = mod.lookupSymbol<mlir::LLVM::LLVMFuncOp>("NRT_incref");
1232+
if (!func)
1233+
{
1234+
mlir::OpBuilder::InsertionGuard guard(rewriter);
1235+
rewriter.setInsertionPointToStart(mod.getBody());
1236+
auto llvmVoidType = mlir::LLVM::LLVMVoidType::get(rewriter.getContext());
1237+
func = rewriter.create<mlir::LLVM::LLVMFuncOp>(
1238+
rewriter.getUnknownLoc(), "NRT_incref",
1239+
mlir::LLVM::LLVMFunctionType::get(llvmVoidType, llvmVoidPointerType));
1240+
}
1241+
return func;
1242+
}();
1243+
1244+
auto loc = op.getLoc();
1245+
auto index = rewriter.getI64ArrayAttr(0);
1246+
auto elemType = arg.getType().cast<mlir::LLVM::LLVMStructType>().getBody()[0];
1247+
mlir::Value ptr = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, elemType, arg, index);
1248+
ptr = rewriter.create<mlir::LLVM::BitcastOp>(loc, llvmVoidPointerType, ptr);
1249+
rewriter.create<mlir::LLVM::CallOp>(loc, incref_func, ptr);
1250+
rewriter.replaceOp(op, arg);
1251+
1252+
return mlir::success();
1253+
}
1254+
};
1255+
12301256
struct LowerCasts : public mlir::OpConversionPattern<plier::CastOp>
12311257
{
12321258
using mlir::OpConversionPattern<plier::CastOp>::OpConversionPattern;
@@ -1277,7 +1303,7 @@ struct LLVMLoweringPass : public mlir::PassWrapper<LLVMLoweringPass, mlir::Opera
12771303

12781304
OwningRewritePatternList patterns;
12791305
populateStdToLLVMConversionPatterns(typeConverter, patterns);
1280-
patterns.insert<LowerCasts>(typeConverter, &getContext());
1306+
patterns.insert<LowerCasts, LowerRetain>(typeConverter, &getContext());
12811307
patterns.insert<AllocOpLowering, DeallocOpLowering>(typeConverter);
12821308

12831309
LLVMConversionTarget target(getContext());

mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -759,8 +759,7 @@ void PlierToLinalgPass::runOnOperation()
759759
patterns.insert<
760760
GetitemOpLowering<plier::GetItemOp>,
761761
GetitemOpLowering<plier::StaticGetItemOp>,
762-
SetitemOpLowering<plier::SetItemOp>,
763-
plier::ForceInline
762+
SetitemOpLowering<plier::SetItemOp>
764763
>(&getContext());
765764

766765
// range/prange lowering need dead branch pruning to properly
@@ -802,8 +801,8 @@ void LowerLinalgPass::runOnOperation()
802801
(void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
803802
}
804803

805-
struct PostFusionOptPass :
806-
public mlir::PassWrapper<PostFusionOptPass, mlir::OperationPass<mlir::ModuleOp>>
804+
struct CommonOptPass :
805+
public mlir::PassWrapper<CommonOptPass, mlir::OperationPass<mlir::ModuleOp>>
807806
{
808807
virtual void getDependentDialects(
809808
mlir::DialectRegistry &registry) const override
@@ -817,7 +816,7 @@ struct PostFusionOptPass :
817816
void runOnOperation() override;
818817
};
819818

820-
void PostFusionOptPass::runOnOperation()
819+
void CommonOptPass::runOnOperation()
821820
{
822821
mlir::OwningRewritePatternList patterns;
823822

@@ -829,6 +828,7 @@ void PostFusionOptPass::runOnOperation()
829828

830829
patterns.insert<
831830
// LoopInvariantCodeMotion, TODO
831+
plier::ForceInline,
832832
plier::CSERewrite<mlir::FuncOp>
833833
>(&context);
834834

@@ -859,6 +859,41 @@ struct LoopInvariantCodeMotion : public mlir::OpRewritePattern<mlir::scf::ForOp>
859859
}
860860
};
861861

862+
struct RetainArgsPass :
863+
public mlir::PassWrapper<RetainArgsPass, mlir::FunctionPass>
864+
{
865+
virtual void getDependentDialects(
866+
mlir::DialectRegistry &registry) const override
867+
{
868+
registry.insert<plier::PlierDialect>();
869+
}
870+
871+
void runOnFunction() override;
872+
};
873+
874+
void RetainArgsPass::runOnFunction()
875+
{
876+
auto func = getFunction();
877+
if (func.isPrivate() || func.isDeclaration() || func.body().empty())
878+
{
879+
return;
880+
}
881+
882+
mlir::OpBuilder builder(&getContext());
883+
auto loc = builder.getUnknownLoc();
884+
auto block = &func.body().front();
885+
builder.setInsertionPointToStart(block);
886+
for (auto arg : block->getArguments())
887+
{
888+
if (arg.getType().isa<mlir::MemRefType>())
889+
{
890+
auto retained = builder.create<plier::RetainOp>(loc, arg);
891+
llvm::SmallPtrSet<mlir::Operation*, 1> except({retained});
892+
arg.replaceAllUsesExcept(retained, except);
893+
}
894+
}
895+
}
896+
862897
struct PostLinalgOptPass :
863898
public mlir::PassWrapper<PostLinalgOptPass, mlir::OperationPass<mlir::ModuleOp>>
864899
{
@@ -900,13 +935,14 @@ void PostLinalgOptPass::runOnOperation()
900935
void populate_plier_to_linalg_gen_pipeline(mlir::OpPassManager& pm)
901936
{
902937
pm.addPass(std::make_unique<PlierToLinalgPass>());
938+
pm.addPass(std::make_unique<CommonOptPass>());
903939
pm.addPass(mlir::createSymbolDCEPass());
904940
}
905941

906942
void populate_plier_to_linalg_opt_pipeline(mlir::OpPassManager& pm)
907943
{
908944
pm.addPass(mlir::createLinalgFusionOfTensorOpsPass());
909-
pm.addPass(std::make_unique<PostFusionOptPass>());
945+
pm.addPass(std::make_unique<CommonOptPass>());
910946

911947
pm.addPass(mlir::createTensorConstantBufferizePass());
912948
pm.addNestedPass<mlir::FuncOp>(mlir::createSCFBufferizePass());
@@ -920,6 +956,7 @@ void populate_plier_to_linalg_opt_pipeline(mlir::OpPassManager& pm)
920956
pm.addNestedPass<mlir::FuncOp>(mlir::createBufferLoopHoistingPass());
921957
pm.addNestedPass<mlir::FuncOp>(mlir::createPromoteBuffersToStackPass());
922958

959+
pm.addNestedPass<mlir::FuncOp>(std::make_unique<RetainArgsPass>());
923960
pm.addNestedPass<mlir::FuncOp>(mlir::createBufferDeallocationPass());
924961

925962
pm.addPass(std::make_unique<LowerLinalgPass>());

mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ void container_iterate(py::handle obj, F&& func)
104104

105105
llvm::Optional<py::object> make_py_literal(mlir::Value val)
106106
{
107+
assert(val);
107108
if (auto int_val = plier::getConstVal<mlir::IntegerAttr>(val))
108109
{
109110
return py::int_(int_val.getInt());
@@ -144,6 +145,7 @@ struct PyLinalgResolver::Context
144145

145146
py::object create_var(py::capsule context, mlir::Value value)
146147
{
148+
assert(value);
147149
if (auto literal = make_py_literal(value))
148150
{
149151
return *literal;
@@ -423,19 +425,17 @@ mlir::Value broadcast_dim(mlir::OpBuilder& builder, mlir::Location loc, mlir::Va
423425
return builder.create<mlir::SelectOp>(loc, cond, val2, val1);
424426
}
425427

426-
mlir::Value expand_dim(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value src, unsigned dim, mlir::ValueRange target_shape)
428+
mlir::Value expand_dim(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value initial, mlir::Value src, unsigned dim, mlir::ValueRange target_shape)
427429
{
428430
auto context = builder.getContext();
429431
auto src_type = src.getType().cast<mlir::ShapedType>();
430432
auto num_dims = static_cast<unsigned>(src_type.getRank());
431433
auto shape = llvm::to_vector<8>(src_type.getShape());
432434
shape[dim] = -1;
433435
mlir::Type target_type = mlir::RankedTensorType::get(shape, src_type.getElementType());
434-
auto dim_val = builder.create<mlir::DimOp>(loc, src, dim);
436+
auto dim_val = builder.create<mlir::DimOp>(loc, initial, dim);
435437
auto one = builder.create<mlir::ConstantIndexOp>(loc, 1);
436438
mlir::Value cond = builder.create<mlir::CmpIOp>(loc, mlir::CmpIPredicate::eq, one, dim_val);
437-
mlir::Value cond2 = builder.create<mlir::CmpIOp>(loc, mlir::CmpIPredicate::ne, target_shape[dim], dim_val);
438-
cond = builder.create<mlir::AndOp>(loc, cond, cond2);
439439
llvm::SmallVector<mlir::Value> new_shape(num_dims);
440440
for (unsigned i = 0 ; i < num_dims; ++i)
441441
{
@@ -498,11 +498,12 @@ mlir::Value expand_dims(mlir::OpBuilder& builder, mlir::Location loc, mlir::Valu
498498
{
499499
target_shape = target_shape.drop_front(target_shape.size() - num_dims);
500500
}
501+
mlir::Value current = val;
501502
for (unsigned i = 0; i < num_dims; ++i)
502503
{
503-
val = expand_dim(builder, loc, val, i, target_shape);
504+
current = expand_dim(builder, loc, val, current, i, target_shape);
504505
}
505-
return val;
506+
return current;
506507
}
507508

508509
py::object broadcast_impl(py::capsule context, py::tuple args)
@@ -632,11 +633,20 @@ py::object broadcast_impl(py::capsule context, py::tuple args)
632633
{
633634
val = builder.create<plier::CastOp>(loc, tensor_type.getElementType(), val);
634635
}
635-
auto body = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange /*indices*/)
636+
val = builder.create<mlir::tensor::FromElementsOp>(loc, val);
637+
auto num_dims = static_cast<unsigned>(tensor_type.getRank());
638+
auto init = builder.create<mlir::linalg::InitTensorOp>(loc, shape_vals, tensor_type.getElementType()).getResult();
639+
mlir::AffineMap maps[] = {
640+
mlir::AffineMap::get(num_dims, 0, mlir::getAffineConstantExpr(0, builder.getContext())),
641+
mlir::AffineMap::getMultiDimIdentityMap(num_dims, builder.getContext()),
642+
};
643+
llvm::SmallVector<llvm::StringRef> iterators(num_dims, "parallel");
644+
auto body = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange values)
636645
{
637-
builder.create<mlir::tensor::YieldOp>(loc, val);
646+
assert(values.size() == 2);
647+
builder.create<mlir::linalg::YieldOp>(loc, values[0]);
638648
};
639-
val = builder.create<mlir::tensor::GenerateOp>(loc, tensor_type, shape_vals, body);
649+
val = builder.create<mlir::linalg::GenericOp>(loc, tensor_type, val, init, maps, iterators, body).getResult(0);
640650
}
641651
}
642652
ret[it.index()] = ctx.context.create_var(context, val);
@@ -688,12 +698,12 @@ py::object init_tensor_impl(py::capsule context, py::handle shape, py::handle dt
688698
else
689699
{
690700
auto val = do_cast(loc, builder, ctx.context.unwrap_val(loc, builder, init_val), elem_type);
701+
llvm::SmallVector<int64_t> shape(count, -1);
702+
auto type = mlir::RankedTensorType::get(shape, elem_type);
691703
auto body = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange /*indices*/)
692704
{
693705
builder.create<mlir::tensor::YieldOp>(loc, val);
694706
};
695-
llvm::SmallVector<int64_t> shape(count, -1);
696-
auto type = mlir::RankedTensorType::get(shape, elem_type);
697707
init = builder.create<mlir::tensor::GenerateOp>(loc, type, shape_val, body);
698708
}
699709
if (llvm::any_of(static_shape, [](auto val){ return val >= 0;}))

mlir-compiler/plier/include/plier/PlierOps.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,16 @@ def GetattrOp : Plier_Op<"getattr", [NoSideEffect]> {
214214
];
215215
}
216216

217+
def RetainOp : Plier_Op<"retain"> {
218+
let arguments = (ins AnyMemRef:$value);
219+
220+
let results = (outs Res<AnyMemRef, "", [MemAlloc<DefaultResource>]>:$memref);
221+
222+
let builders = [
223+
OpBuilderDAG<(ins "::mlir::Value":$value)>
224+
];
225+
}
226+
217227
def ParallelOp : Plier_Op<"parallel",
218228
[AttrSizedOperandSegments,
219229
DeclareOpInterfaceMethods<LoopLikeOpInterface>,

mlir-compiler/plier/include/plier/dialect.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include <mlir/IR/BuiltinTypes.h>
34
#include <mlir/IR/Dialect.h>
45
#include <mlir/IR/Types.h>
56
#include <mlir/IR/OpDefinition.h>
@@ -14,6 +15,7 @@ using Value = ::mlir::Value;
1415
using Region = ::mlir::Region;
1516
using LogicalResult = ::mlir::LogicalResult;
1617
using Operation = ::mlir::Operation;
18+
namespace MemoryEffects = ::mlir::MemoryEffects;
1719

1820
template<typename T>
1921
using ArrayRef = ::mlir::ArrayRef<T>;

mlir-compiler/plier/src/dialect.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,11 @@ void GetattrOp::getCanonicalizationPatterns(
336336
results.insert<GetattrGlobalRewrite>(context);
337337
}
338338

339+
void RetainOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
340+
mlir::Value value) {
341+
RetainOp::build(builder, state, value.getType(), value);
342+
}
343+
339344
mlir::LogicalResult ParallelOp::moveOutOfLoop(mlir::ArrayRef<mlir::Operation *> ops)
340345
{
341346
for (mlir::Operation *op : ops)

0 commit comments

Comments
 (0)