88#include < mlir/Dialect/StandardOps/IR/Ops.h>
99#include < mlir/Dialect/Linalg/IR/LinalgOps.h>
1010#include < mlir/Dialect/Tensor/IR/Tensor.h>
11+ #include < mlir/Dialect/SCF/SCF.h>
1112#include < mlir/Parser.h>
1213#include < mlir/IR/BuiltinAttributes.h>
1314
@@ -206,8 +207,8 @@ struct PyLinalgResolver::Context
206207
207208namespace
208209{
209- py::list get_args (py::handle inspect, py::handle func, llvm::function_ref<py::object(mlir::Value)> create_var,
210- mlir::ValueRange args, llvm::ArrayRef<std::pair<llvm::StringRef, mlir::Value>> kwargs)
210+ py::object get_args (py::handle inspect, py::handle func, llvm::function_ref<py::object(mlir::Value)> create_var,
211+ mlir::ValueRange args, llvm::ArrayRef<std::pair<llvm::StringRef, mlir::Value>> kwargs)
211212{
212213 auto sig_func = inspect.attr (" signature" );
213214 auto sig = sig_func (func);
@@ -258,7 +259,11 @@ py::list get_args(py::handle inspect, py::handle func, llvm::function_ref<py::ob
258259 return py::none ();
259260 }
260261 }
261- return ret;
262+ if (!args.empty ())
263+ {
264+ return py::none ();
265+ }
266+ return std::move (ret);
262267}
263268
264269PyBuilderContext& get_py_context (py::capsule& ctx)
@@ -409,6 +414,97 @@ mlir::Type broadcast_type(mlir::Type type1, mlir::Type type2)
409414 llvm_unreachable (" Unable to broadcast type" );
410415}
411416
417+ mlir::Value broadcast_dim (mlir::OpBuilder& builder, mlir::Location loc, mlir::Value val1, mlir::Value val2)
418+ {
419+ assert (val1.getType ().isa <mlir::IndexType>());
420+ assert (val2.getType ().isa <mlir::IndexType>());
421+ auto one = builder.create <mlir::ConstantIndexOp>(loc, 1 );
422+ auto cond = builder.create <mlir::CmpIOp>(loc, mlir::CmpIPredicate::eq, val1, one);
423+ return builder.create <mlir::SelectOp>(loc, cond, val2, val1);
424+ }
425+
426+ mlir::Value expand_dim (mlir::OpBuilder& builder, mlir::Location loc, mlir::Value src, unsigned dim, mlir::ValueRange target_shape)
427+ {
428+ auto context = builder.getContext ();
429+ auto src_type = src.getType ().cast <mlir::ShapedType>();
430+ auto num_dims = static_cast <unsigned >(src_type.getRank ());
431+ auto shape = llvm::to_vector<8 >(src_type.getShape ());
432+ shape[dim] = -1 ;
433+ mlir::Type target_type = mlir::RankedTensorType::get (shape, src_type.getElementType ());
434+ auto dim_val = builder.create <mlir::DimOp>(loc, src, dim);
435+ auto one = builder.create <mlir::ConstantIndexOp>(loc, 1 );
436+ mlir::Value cond = builder.create <mlir::CmpIOp>(loc, mlir::CmpIPredicate::eq, one, dim_val);
437+ mlir::Value cond2 = builder.create <mlir::CmpIOp>(loc, mlir::CmpIPredicate::ne, target_shape[dim], dim_val);
438+ cond = builder.create <mlir::AndOp>(loc, cond, cond2);
439+ llvm::SmallVector<mlir::Value, 8 > new_shape (num_dims);
440+ for (unsigned i = 0 ; i < num_dims; ++i)
441+ {
442+ if (i == dim)
443+ {
444+ new_shape[i] = target_shape[i];
445+ }
446+ else
447+ {
448+ new_shape[i] = builder.create <mlir::DimOp>(loc, src, i);
449+ }
450+ }
451+ auto true_body = [&](mlir::OpBuilder &builder, mlir::Location loc)
452+ {
453+ assert (dim < shape.size ());
454+ shape[dim] = 1 ;
455+ mlir::Type casted_type = mlir::RankedTensorType::get (shape, src_type.getElementType ());
456+ auto casted = builder.create <mlir::tensor::CastOp>(loc, casted_type, src).getResult ();
457+ auto init = builder.create <mlir::linalg::InitTensorOp>(loc, new_shape, src_type.getElementType ()).getResult ();
458+ llvm::SmallVector<mlir::AffineExpr, 8 > exprs (num_dims);
459+ for (unsigned i = 0 ; i < num_dims; ++i)
460+ {
461+ if (i == dim)
462+ {
463+ exprs[i] = mlir::getAffineConstantExpr (0 , context);
464+ }
465+ else
466+ {
467+ exprs[i] = mlir::getAffineDimExpr (i, context);
468+ }
469+ }
470+ const mlir::AffineMap maps[] = {
471+ mlir::AffineMap::get (num_dims, 0 , exprs, context),
472+ mlir::AffineMap::getMultiDimIdentityMap (num_dims, context),
473+ };
474+ llvm::SmallVector<mlir::StringRef, 8 > iterators (num_dims, " parallel" );
475+
476+ auto body = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange values)
477+ {
478+ assert (values.size () == 2 );
479+ builder.create <mlir::linalg::YieldOp>(loc, values[0 ]);
480+ };
481+
482+ auto expanded = builder.create <mlir::linalg::GenericOp>(loc, target_type, casted, init, maps, iterators, body);
483+ auto res = builder.create <mlir::tensor::CastOp>(loc, target_type, expanded.getResult (0 ));
484+ builder.create <mlir::scf::YieldOp>(loc, res.getResult ());
485+ };
486+ auto false_body = [&](mlir::OpBuilder &builder, mlir::Location loc)
487+ {
488+ auto res = builder.create <mlir::tensor::CastOp>(loc, target_type, src);
489+ builder.create <mlir::scf::YieldOp>(loc, res.getResult ());
490+ };
491+ return builder.create <mlir::scf::IfOp>(loc, target_type, cond, true_body, false_body).getResult (0 );
492+ }
493+
494+ mlir::Value expand_dims (mlir::OpBuilder& builder, mlir::Location loc, mlir::Value val, unsigned num_dims, mlir::ValueRange target_shape)
495+ {
496+ assert (num_dims <= target_shape.size ());
497+ if (num_dims < target_shape.size ())
498+ {
499+ target_shape = target_shape.drop_front (target_shape.size () - num_dims);
500+ }
501+ for (unsigned i = 0 ; i < num_dims; ++i)
502+ {
503+ val = expand_dim (builder, loc, val, i, target_shape);
504+ }
505+ return val;
506+ }
507+
412508py::object broadcast_impl (py::capsule context, py::tuple args)
413509{
414510 if (1 == args.size ())
@@ -467,14 +563,22 @@ py::object broadcast_impl(py::capsule context, py::tuple args)
467563 py::none ();
468564 }
469565 res_type = broadcast_type (res_type, shape_and_type->second );
470- if (shape_and_type->first .size () > shape_vals.size ())
566+ auto new_shape_vals = shape_and_type->first ;
567+ for (auto it : llvm::zip (llvm::reverse (shape_vals), llvm::reverse (new_shape_vals)))
471568 {
472- shape_vals = shape_and_type->first ; // TODO
569+ auto & old_val = std::get<0 >(it);
570+ auto new_val = std::get<1 >(it);
571+ old_val = broadcast_dim (builder, loc, old_val, new_val);
572+ }
573+ if (new_shape_vals.size () > shape_vals.size ())
574+ {
575+ auto front = llvm::makeArrayRef (new_shape_vals).drop_back (shape_vals.size ());
576+ assert (!front.empty ());
577+ shape_vals.insert (shape_vals.begin (), front.begin (), front.end ());
473578 }
474579 }
475580
476- llvm::SmallVector<int64_t , 8 > shape (static_cast <size_t >(shape_vals.size ()), -1 );
477- py::tuple ret (args.size ());
581+ py::tuple ret (mlir_args.size ());
478582 if (shape_vals.empty ())
479583 {
480584 for (auto it : llvm::enumerate (mlir_args))
@@ -489,24 +593,31 @@ py::object broadcast_impl(py::capsule context, py::tuple args)
489593 return std::move (ret);
490594 }
491595
596+ llvm::SmallVector<int64_t , 8 > shape (static_cast <size_t >(shape_vals.size ()), -1 );
492597 auto tensor_type = mlir::RankedTensorType::get (shape, res_type);
493598 for (auto it : llvm::enumerate (mlir_args))
494599 {
495600 mlir::Value val = it.value ();
496- auto type = val.getType ();
497- if (type != tensor_type)
601+ if (auto src_type = val.getType ().dyn_cast <mlir::ShapedType>())
602+ {
603+ assert (src_type.hasRank ());
604+ val = expand_dims (builder, loc, val, static_cast <unsigned >(src_type.getRank ()), shape_vals);
605+ }
606+ if (val.getType () != tensor_type)
498607 {
608+ auto type = val.getType ();
499609 if (auto src_type = type.dyn_cast <mlir::ShapedType>())
500610 {
501611 assert (src_type.hasRank ());
502- auto num_dims = static_cast <unsigned >(src_type.getRank ());
612+ auto src_num_dims = static_cast <unsigned >(src_type.getRank ());
613+ auto num_dims = static_cast <unsigned >(tensor_type.getRank ());
503614 auto init = builder.create <mlir::linalg::InitTensorOp>(loc, shape_vals, tensor_type.getElementType ()).getResult ();
504- llvm::SmallVector<llvm::StringRef, 8 > iterators (num_dims, " parallel" );
505- auto map = mlir::AffineMap::getMultiDimIdentityMap (num_dims, builder.getContext ());
506615 mlir::AffineMap maps[] = {
507- map,
508- map,
616+ mlir::AffineMap::getMinorIdentityMap (num_dims, src_num_dims, builder.getContext ()),
617+ // mlir::AffineMap::getMultiDimIdentityMap(num_dims, builder.getContext()).getMajorSubMap(src_num_dims),
618+ mlir::AffineMap::getMultiDimIdentityMap (num_dims, builder.getContext ()),
509619 };
620+ llvm::SmallVector<llvm::StringRef, 8 > iterators (num_dims, " parallel" );
510621 auto body = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange values)
511622 {
512623 assert (values.size () == 2 );
@@ -559,9 +670,15 @@ py::object init_tensor_impl(py::capsule context, py::handle shape, py::handle dt
559670 {
560671 auto index_type = builder.getIndexType ();
561672 llvm::SmallVector<mlir::Value, 8 > shape_val (count);
673+ llvm::SmallVector<int64_t > static_shape (count, -1 );
562674 for (size_t i = 0 ; i < count; ++i)
563675 {
564- shape_val[i] = do_cast (loc, builder, ctx.context .unwrap_val (loc, builder, shape[py::int_ (i)]), index_type);
676+ auto elem = shape[py::int_ (i)];
677+ if (py::isinstance<py::int_>(elem))
678+ {
679+ static_shape[i] = elem.cast <int64_t >();
680+ }
681+ shape_val[i] = do_cast (loc, builder, ctx.context .unwrap_val (loc, builder, elem), index_type);
565682 }
566683
567684 if (init_val.is_none ())
@@ -579,6 +696,11 @@ py::object init_tensor_impl(py::capsule context, py::handle shape, py::handle dt
579696 auto type = mlir::RankedTensorType::get (shape, elem_type);
580697 init = builder.create <mlir::tensor::GenerateOp>(loc, type, shape_val, body);
581698 }
699+ if (llvm::any_of (static_shape, [](auto val){ return val >= 0 ;}))
700+ {
701+ auto new_type = mlir::RankedTensorType::get (static_shape, elem_type);
702+ init = builder.create <mlir::tensor::CastOp>(loc, new_type, init);
703+ }
582704 }
583705 return ctx.context .create_var (context, init);
584706}
0 commit comments