diff --git a/cinn/ir/ir.h b/cinn/ir/ir.h index ac8b36d67ba0a8..a9f1fa269b4fe2 100644 --- a/cinn/ir/ir.h +++ b/cinn/ir/ir.h @@ -19,10 +19,6 @@ namespace poly { class Element; } // namespace poly -namespace lang { -class Tensor; -} // namespace lang - namespace ir { using common::Object; @@ -541,4 +537,11 @@ struct Builder { }; } // namespace ir + +// Expose the following to cinn namespace for easier usage. +// @{ +using ir::Expr; +using ir::Var; +// @} + } // namespace cinn diff --git a/cinn/lang/compute.cc b/cinn/lang/compute.cc index 82085afb227f79..f5c93428277519 100644 --- a/cinn/lang/compute.cc +++ b/cinn/lang/compute.cc @@ -1,5 +1,6 @@ #include "cinn/lang/compute.h" +#include "cinn/common/common.h" #include "cinn/poly/dim.h" #include "cinn/poly/domain.h" #include "cinn/poly/element.h" @@ -7,53 +8,61 @@ namespace cinn { namespace lang { -using ir::Expr; - -template <> -ir::Tensor Compute(const std::vector &dims, compute_handle_1_t handle) { - CHECK_EQ(dims.size(), 1); - Var i(common::axis_name(0), Int(32)); - auto expr = handle(i); - - std::vector shape; - for (int v : dims) shape.emplace_back(v); +ir::Tensor Compute(const std::vector &dims, std::function fn) { + return Compute(dims, [fn](const std::vector &axis) -> Expr { + CHECK_EQ(axis.size(), 1); + return fn(axis[0]); + }); +} - ir::Tensor tensor(shape, {i}, expr.type(), expr); - return std::move(tensor); +ir::Tensor Compute(const std::vector &dims, std::function fn) { + return Compute(dims, [fn](const std::vector &axis) -> Expr { + CHECK_EQ(axis.size(), 2); + return fn(axis[0], axis[1]); + }); } -template <> -ir::Tensor Compute(const std::vector &dims, compute_handle_2_t handle) { - CHECK_EQ(dims.size(), 2); - poly::Dim dim("i", 0, dims[0] - 1); - Var i(common::axis_name(0), Int(32)); - Var j(common::axis_name(1), Int(32)); - auto expr = handle(i, j); +ir::Tensor Compute(const std::vector &dims, std::function fn) { + return Compute(dims, [fn](const std::vector &axis) -> Expr { + CHECK_EQ(axis.size(), 3); + return fn(axis[0], axis[1], axis[2]); + }); +} - std::vector shape; - for (int v : dims) shape.emplace_back(v); +ir::Tensor Compute(const std::vector &dims, std::function fn) { + return Compute(dims, [fn](const std::vector &axis) -> Expr { + CHECK_EQ(axis.size(), 4); + return fn(axis[0], axis[1], axis[2], axis[3]); + }); +} - ir::Tensor tensor(shape, {i, j}, expr.type(), expr); - CHECK(tensor.get()); - return std::move(tensor); +ir::Tensor Compute(const std::vector &dims, std::function fn) { + return Compute(dims, [fn](const std::vector &axis) -> Expr { + CHECK_EQ(axis.size(), 5); + return fn(axis[0], axis[1], axis[2], axis[3], axis[4]); + }); } -template <> -ir::Tensor Compute(const std::vector &dims, compute_handle_3_t handle) { - CHECK_EQ(dims.size(), 3); - Var i(common::axis_name(0), Int(32)); - Var j(common::axis_name(1), Int(32)); - Var k(common::axis_name(2), Int(32)); - auto expr = handle(i, j, k); +ir::Tensor Compute(const std::vector &dims, std::function &)> fn) { + auto axis = detail::GenDefaultAxis(dims.size()); + Expr expr = fn(axis); std::vector shape; for (int v : dims) shape.emplace_back(v); - ir::Tensor tensor(shape, {i, j}, expr.type(), expr); - return std::move(tensor); + ir::Tensor tensor(shape, axis, expr.type(), expr); + return tensor; } -} // namespace lang +namespace detail { +std::vector GenDefaultAxis(int naxis) { + std::vector axis; + for (int i = 0; i < naxis; i++) { + axis.emplace_back(common::axis_name(i)); + } + return axis; +} +} // namespace detail -namespace ir {} // namespace ir +} // namespace lang } // namespace cinn diff --git a/cinn/lang/compute.h b/cinn/lang/compute.h index 04f4b302a7a3a7..b8b2da90ca237d 100644 --- a/cinn/lang/compute.h +++ b/cinn/lang/compute.h @@ -10,14 +10,20 @@ namespace cinn { namespace lang { -using ir::Var; -using compute_handle_1_t = std::function; -using compute_handle_2_t = std::function; -using compute_handle_3_t = std::function; -using compute_handle_4_t = std::function; +//! Compute methods for one to five Vars as arguments. +// @{ +ir::Tensor Compute(const std::vector& dims, std::function fn); +ir::Tensor Compute(const std::vector& dims, std::function fn); +ir::Tensor Compute(const std::vector& dims, std::function fn); +ir::Tensor Compute(const std::vector& dims, std::function fn); +ir::Tensor Compute(const std::vector& dims, std::function fn); +ir::Tensor Compute(const std::vector& dims, std::function&)> fn); +// @} -template -ir::Tensor Compute(const std::vector& dims, Fn handle); +namespace detail { +//! Generate `naxis` axis using the global names (i,j,k...). +std::vector GenDefaultAxis(int naxis); +} // namespace detail } // namespace lang } // namespace cinn diff --git a/cinn/lang/compute_test.cc b/cinn/lang/compute_test.cc index 2ce0e4d057f0b8..ee72367738c49c 100644 --- a/cinn/lang/compute_test.cc +++ b/cinn/lang/compute_test.cc @@ -15,14 +15,14 @@ TEST(Compute, basic) { Placeholder x("x", {M, N}); - ir::Tensor y = Compute({10, 20}, [&](Var i, Var j) -> Expr { + ir::Tensor y = Compute({10, 20}, [&](Var i, Var j) -> Expr { LOG(INFO) << "type: " << x(i, j).type(); return x(i, j) + 1.f; }); LOG(INFO) << "compute: " << y->operaion->As()->body[0]; LOG(INFO) << "y.element: " << y->poly_element->domain(); - ir::Tensor z = Compute({10, 20}, [&](Var i, Var j) -> Expr { return y(i, j) * 2.f; }); + ir::Tensor z = Compute({10, 20}, [&](Var i, Var j) -> Expr { return y(i, j) * 2.f; }); LOG(INFO) << "z: " << z->operaion->As()->body[0]; LOG(INFO) << "z.element: " << z->poly_element->domain();