Skip to content
This repository was archived by the owner on Jan 25, 2023. It is now read-only.

Commit 6d295aa

Browse files
author
Ivan Butygin
authored
[MLIR] Numpy broadcasting (#193)
* fixes to broadcasting * fix * work on broadcast * some work on broadcast * work on broadcast * broadcasting * broadcast fix * PostFusionOptPass
1 parent 22f3d43 commit 6d295aa

File tree

3 files changed

+194
-17
lines changed

3 files changed

+194
-17
lines changed

mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -802,8 +802,8 @@ void LowerLinalgPass::runOnOperation()
802802
(void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
803803
}
804804

805-
struct PostLinalgOptPass :
806-
public mlir::PassWrapper<PostLinalgOptPass, mlir::OperationPass<mlir::ModuleOp>>
805+
struct PostFusionOptPass :
806+
public mlir::PassWrapper<PostFusionOptPass, mlir::OperationPass<mlir::ModuleOp>>
807807
{
808808
virtual void getDependentDialects(
809809
mlir::DialectRegistry &registry) const override
@@ -817,6 +817,26 @@ struct PostLinalgOptPass :
817817
void runOnOperation() override;
818818
};
819819

820+
void PostFusionOptPass::runOnOperation()
821+
{
822+
mlir::OwningRewritePatternList patterns;
823+
824+
auto& context = getContext();
825+
for (auto *op : context.getRegisteredOperations())
826+
{
827+
op->getCanonicalizationPatterns(patterns, &context);
828+
}
829+
830+
patterns.insert<
831+
// LoopInvariantCodeMotion, TODO
832+
plier::CSERewrite<mlir::FuncOp>
833+
>(&context);
834+
835+
plier::populate_index_propagate_patterns(context, patterns);
836+
837+
mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
838+
}
839+
820840
struct LoopInvariantCodeMotion : public mlir::OpRewritePattern<mlir::scf::ForOp>
821841
{
822842
using mlir::OpRewritePattern<mlir::scf::ForOp>::OpRewritePattern;
@@ -839,6 +859,21 @@ struct LoopInvariantCodeMotion : public mlir::OpRewritePattern<mlir::scf::ForOp>
839859
}
840860
};
841861

862+
struct PostLinalgOptPass :
863+
public mlir::PassWrapper<PostLinalgOptPass, mlir::OperationPass<mlir::ModuleOp>>
864+
{
865+
virtual void getDependentDialects(
866+
mlir::DialectRegistry &registry) const override
867+
{
868+
registry.insert<mlir::StandardOpsDialect>();
869+
registry.insert<mlir::linalg::LinalgDialect>();
870+
registry.insert<mlir::scf::SCFDialect>();
871+
registry.insert<mlir::AffineDialect>();
872+
}
873+
874+
void runOnOperation() override;
875+
};
876+
842877
void PostLinalgOptPass::runOnOperation()
843878
{
844879
mlir::OwningRewritePatternList patterns;
@@ -871,6 +906,7 @@ void populate_plier_to_linalg_gen_pipeline(mlir::OpPassManager& pm)
871906
void populate_plier_to_linalg_opt_pipeline(mlir::OpPassManager& pm)
872907
{
873908
pm.addPass(mlir::createLinalgFusionOfTensorOpsPass());
909+
pm.addPass(std::make_unique<PostFusionOptPass>());
874910

875911
pm.addPass(mlir::createTensorConstantBufferizePass());
876912
pm.addNestedPass<mlir::FuncOp>(mlir::createSCFBufferizePass());

mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp

Lines changed: 137 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <mlir/Dialect/StandardOps/IR/Ops.h>
99
#include <mlir/Dialect/Linalg/IR/LinalgOps.h>
1010
#include <mlir/Dialect/Tensor/IR/Tensor.h>
11+
#include <mlir/Dialect/SCF/SCF.h>
1112
#include <mlir/Parser.h>
1213
#include <mlir/IR/BuiltinAttributes.h>
1314

@@ -206,8 +207,8 @@ struct PyLinalgResolver::Context
206207

207208
namespace
208209
{
209-
py::list get_args(py::handle inspect, py::handle func, llvm::function_ref<py::object(mlir::Value)> create_var,
210-
mlir::ValueRange args, llvm::ArrayRef<std::pair<llvm::StringRef, mlir::Value>> kwargs)
210+
py::object get_args(py::handle inspect, py::handle func, llvm::function_ref<py::object(mlir::Value)> create_var,
211+
mlir::ValueRange args, llvm::ArrayRef<std::pair<llvm::StringRef, mlir::Value>> kwargs)
211212
{
212213
auto sig_func = inspect.attr("signature");
213214
auto sig = sig_func(func);
@@ -258,7 +259,11 @@ py::list get_args(py::handle inspect, py::handle func, llvm::function_ref<py::ob
258259
return py::none();
259260
}
260261
}
261-
return ret;
262+
if (!args.empty())
263+
{
264+
return py::none();
265+
}
266+
return std::move(ret);
262267
}
263268

264269
PyBuilderContext& get_py_context(py::capsule& ctx)
@@ -409,6 +414,97 @@ mlir::Type broadcast_type(mlir::Type type1, mlir::Type type2)
409414
llvm_unreachable("Unable to broadcast type");
410415
}
411416

417+
mlir::Value broadcast_dim(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value val1, mlir::Value val2)
418+
{
419+
assert(val1.getType().isa<mlir::IndexType>());
420+
assert(val2.getType().isa<mlir::IndexType>());
421+
auto one = builder.create<mlir::ConstantIndexOp>(loc, 1);
422+
auto cond = builder.create<mlir::CmpIOp>(loc, mlir::CmpIPredicate::eq, val1, one);
423+
return builder.create<mlir::SelectOp>(loc, cond, val2, val1);
424+
}
425+
426+
mlir::Value expand_dim(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value src, unsigned dim, mlir::ValueRange target_shape)
427+
{
428+
auto context = builder.getContext();
429+
auto src_type = src.getType().cast<mlir::ShapedType>();
430+
auto num_dims = static_cast<unsigned>(src_type.getRank());
431+
auto shape = llvm::to_vector<8>(src_type.getShape());
432+
shape[dim] = -1;
433+
mlir::Type target_type = mlir::RankedTensorType::get(shape, src_type.getElementType());
434+
auto dim_val = builder.create<mlir::DimOp>(loc, src, dim);
435+
auto one = builder.create<mlir::ConstantIndexOp>(loc, 1);
436+
mlir::Value cond = builder.create<mlir::CmpIOp>(loc, mlir::CmpIPredicate::eq, one, dim_val);
437+
mlir::Value cond2 = builder.create<mlir::CmpIOp>(loc, mlir::CmpIPredicate::ne, target_shape[dim], dim_val);
438+
cond = builder.create<mlir::AndOp>(loc, cond, cond2);
439+
llvm::SmallVector<mlir::Value, 8> new_shape(num_dims);
440+
for (unsigned i = 0 ; i < num_dims; ++i)
441+
{
442+
if (i == dim)
443+
{
444+
new_shape[i] = target_shape[i];
445+
}
446+
else
447+
{
448+
new_shape[i] = builder.create<mlir::DimOp>(loc, src, i);
449+
}
450+
}
451+
auto true_body = [&](mlir::OpBuilder &builder, mlir::Location loc)
452+
{
453+
assert(dim < shape.size());
454+
shape[dim] = 1;
455+
mlir::Type casted_type = mlir::RankedTensorType::get(shape, src_type.getElementType());
456+
auto casted = builder.create<mlir::tensor::CastOp>(loc, casted_type, src).getResult();
457+
auto init = builder.create<mlir::linalg::InitTensorOp>(loc, new_shape, src_type.getElementType()).getResult();
458+
llvm::SmallVector<mlir::AffineExpr, 8> exprs(num_dims);
459+
for (unsigned i = 0; i < num_dims; ++i)
460+
{
461+
if (i == dim)
462+
{
463+
exprs[i] = mlir::getAffineConstantExpr(0, context);
464+
}
465+
else
466+
{
467+
exprs[i] = mlir::getAffineDimExpr(i, context);
468+
}
469+
}
470+
const mlir::AffineMap maps[] = {
471+
mlir::AffineMap::get(num_dims, 0, exprs, context),
472+
mlir::AffineMap::getMultiDimIdentityMap(num_dims, context),
473+
};
474+
llvm::SmallVector<mlir::StringRef, 8> iterators(num_dims, "parallel");
475+
476+
auto body = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange values)
477+
{
478+
assert(values.size() == 2);
479+
builder.create<mlir::linalg::YieldOp>(loc, values[0]);
480+
};
481+
482+
auto expanded = builder.create<mlir::linalg::GenericOp>(loc, target_type, casted, init, maps, iterators, body);
483+
auto res = builder.create<mlir::tensor::CastOp>(loc, target_type, expanded.getResult(0));
484+
builder.create<mlir::scf::YieldOp>(loc, res.getResult());
485+
};
486+
auto false_body = [&](mlir::OpBuilder &builder, mlir::Location loc)
487+
{
488+
auto res = builder.create<mlir::tensor::CastOp>(loc, target_type, src);
489+
builder.create<mlir::scf::YieldOp>(loc, res.getResult());
490+
};
491+
return builder.create<mlir::scf::IfOp>(loc, target_type, cond, true_body, false_body).getResult(0);
492+
}
493+
494+
mlir::Value expand_dims(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value val, unsigned num_dims, mlir::ValueRange target_shape)
495+
{
496+
assert(num_dims <= target_shape.size());
497+
if (num_dims < target_shape.size())
498+
{
499+
target_shape = target_shape.drop_front(target_shape.size() - num_dims);
500+
}
501+
for (unsigned i = 0; i < num_dims; ++i)
502+
{
503+
val = expand_dim(builder, loc, val, i, target_shape);
504+
}
505+
return val;
506+
}
507+
412508
py::object broadcast_impl(py::capsule context, py::tuple args)
413509
{
414510
if (1 == args.size())
@@ -467,14 +563,22 @@ py::object broadcast_impl(py::capsule context, py::tuple args)
467563
py::none();
468564
}
469565
res_type = broadcast_type(res_type, shape_and_type->second);
470-
if (shape_and_type->first.size() > shape_vals.size())
566+
auto new_shape_vals = shape_and_type->first;
567+
for (auto it : llvm::zip(llvm::reverse(shape_vals), llvm::reverse(new_shape_vals)))
471568
{
472-
shape_vals = shape_and_type->first; // TODO
569+
auto& old_val = std::get<0>(it);
570+
auto new_val = std::get<1>(it);
571+
old_val = broadcast_dim(builder, loc, old_val, new_val);
572+
}
573+
if (new_shape_vals.size() > shape_vals.size())
574+
{
575+
auto front = llvm::makeArrayRef(new_shape_vals).drop_back(shape_vals.size());
576+
assert(!front.empty());
577+
shape_vals.insert(shape_vals.begin(), front.begin(), front.end());
473578
}
474579
}
475580

476-
llvm::SmallVector<int64_t, 8> shape(static_cast<size_t>(shape_vals.size()), -1);
477-
py::tuple ret(args.size());
581+
py::tuple ret(mlir_args.size());
478582
if (shape_vals.empty())
479583
{
480584
for (auto it : llvm::enumerate(mlir_args))
@@ -489,24 +593,31 @@ py::object broadcast_impl(py::capsule context, py::tuple args)
489593
return std::move(ret);
490594
}
491595

596+
llvm::SmallVector<int64_t, 8> shape(static_cast<size_t>(shape_vals.size()), -1);
492597
auto tensor_type = mlir::RankedTensorType::get(shape, res_type);
493598
for (auto it : llvm::enumerate(mlir_args))
494599
{
495600
mlir::Value val = it.value();
496-
auto type = val.getType();
497-
if (type != tensor_type)
601+
if (auto src_type = val.getType().dyn_cast<mlir::ShapedType>())
602+
{
603+
assert(src_type.hasRank());
604+
val = expand_dims(builder, loc, val, static_cast<unsigned>(src_type.getRank()), shape_vals);
605+
}
606+
if (val.getType() != tensor_type)
498607
{
608+
auto type = val.getType();
499609
if (auto src_type = type.dyn_cast<mlir::ShapedType>())
500610
{
501611
assert(src_type.hasRank());
502-
auto num_dims = static_cast<unsigned>(src_type.getRank());
612+
auto src_num_dims = static_cast<unsigned>(src_type.getRank());
613+
auto num_dims = static_cast<unsigned>(tensor_type.getRank());
503614
auto init = builder.create<mlir::linalg::InitTensorOp>(loc, shape_vals, tensor_type.getElementType()).getResult();
504-
llvm::SmallVector<llvm::StringRef, 8> iterators(num_dims, "parallel");
505-
auto map = mlir::AffineMap::getMultiDimIdentityMap(num_dims, builder.getContext());
506615
mlir::AffineMap maps[] = {
507-
map,
508-
map,
616+
mlir::AffineMap::getMinorIdentityMap(num_dims, src_num_dims, builder.getContext()),
617+
// mlir::AffineMap::getMultiDimIdentityMap(num_dims, builder.getContext()).getMajorSubMap(src_num_dims),
618+
mlir::AffineMap::getMultiDimIdentityMap(num_dims, builder.getContext()),
509619
};
620+
llvm::SmallVector<llvm::StringRef, 8> iterators(num_dims, "parallel");
510621
auto body = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange values)
511622
{
512623
assert(values.size() == 2);
@@ -559,9 +670,15 @@ py::object init_tensor_impl(py::capsule context, py::handle shape, py::handle dt
559670
{
560671
auto index_type = builder.getIndexType();
561672
llvm::SmallVector<mlir::Value, 8> shape_val(count);
673+
llvm::SmallVector<int64_t> static_shape(count, -1);
562674
for (size_t i = 0; i < count; ++i)
563675
{
564-
shape_val[i] = do_cast(loc, builder, ctx.context.unwrap_val(loc, builder, shape[py::int_(i)]), index_type);
676+
auto elem = shape[py::int_(i)];
677+
if (py::isinstance<py::int_>(elem))
678+
{
679+
static_shape[i] = elem.cast<int64_t>();
680+
}
681+
shape_val[i] = do_cast(loc, builder, ctx.context.unwrap_val(loc, builder, elem), index_type);
565682
}
566683

567684
if (init_val.is_none())
@@ -579,6 +696,11 @@ py::object init_tensor_impl(py::capsule context, py::handle shape, py::handle dt
579696
auto type = mlir::RankedTensorType::get(shape, elem_type);
580697
init = builder.create<mlir::tensor::GenerateOp>(loc, type, shape_val, body);
581698
}
699+
if (llvm::any_of(static_shape, [](auto val){ return val >= 0;}))
700+
{
701+
auto new_type = mlir::RankedTensorType::get(static_shape, elem_type);
702+
init = builder.create<mlir::tensor::CastOp>(loc, new_type, init);
703+
}
582704
}
583705
return ctx.context.create_var(context, init);
584706
}

numba/mlir/tests/test_numpy.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,5 +276,24 @@ def test_reshape(self):
276276
for a in [arr1]:
277277
assert_equal(py_func(a), jit_func(a))
278278

279+
def test_broadcast(self):
280+
def py_func(a, b):
281+
return np.add(a, b)
282+
283+
jit_func = njit(py_func)
284+
285+
test_data = [
286+
1,
287+
np.array([1]),
288+
np.array([[1]]),
289+
np.array([[1,2],[3,4]]),
290+
np.array([5,6]),
291+
np.array([[5],[6]]),
292+
np.array([[5,6]]),
293+
]
294+
295+
for a, b in itertools.product(test_data, test_data):
296+
assert_equal(py_func(a,b), jit_func(a,b))
297+
279298
if __name__ == '__main__':
280299
unittest.main()

0 commit comments

Comments
 (0)