Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#27 from Superjomn/fea/make-compute-mo…
Browse files Browse the repository at this point in the history
…re-friendly

make compute more use-friendly
  • Loading branch information
Superjomn authored Feb 12, 2020
2 parents d32672d + bc8f63c commit 60b917a
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 48 deletions.
11 changes: 7 additions & 4 deletions cinn/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,6 @@ namespace poly {
class Element;
} // namespace poly

namespace lang {
class Tensor;
} // namespace lang

namespace ir {

using common::Object;
Expand Down Expand Up @@ -541,4 +537,11 @@ struct Builder {
};

} // namespace ir

// Expose the following to cinn namespace for easier usage.
// @{
using ir::Expr;
using ir::Var;
// @}

} // namespace cinn
79 changes: 44 additions & 35 deletions cinn/lang/compute.cc
Original file line number Diff line number Diff line change
@@ -1,59 +1,68 @@
#include "cinn/lang/compute.h"

#include "cinn/common/common.h"
#include "cinn/poly/dim.h"
#include "cinn/poly/domain.h"
#include "cinn/poly/element.h"

namespace cinn {
namespace lang {

using ir::Expr;

template <>
ir::Tensor Compute<compute_handle_1_t>(const std::vector<int> &dims, compute_handle_1_t handle) {
CHECK_EQ(dims.size(), 1);
Var i(common::axis_name(0), Int(32));
auto expr = handle(i);

std::vector<Expr> shape;
for (int v : dims) shape.emplace_back(v);
ir::Tensor Compute(const std::vector<int> &dims, std::function<Expr(Var)> fn) {
return Compute(dims, [fn](const std::vector<Var> &axis) -> Expr {
CHECK_EQ(axis.size(), 1);
return fn(axis[0]);
});
}

ir::Tensor tensor(shape, {i}, expr.type(), expr);
return std::move(tensor);
ir::Tensor Compute(const std::vector<int> &dims, std::function<Expr(Var, Var)> fn) {
return Compute(dims, [fn](const std::vector<Var> &axis) -> Expr {
CHECK_EQ(axis.size(), 2);
return fn(axis[0], axis[1]);
});
}

template <>
ir::Tensor Compute<compute_handle_2_t>(const std::vector<int> &dims, compute_handle_2_t handle) {
CHECK_EQ(dims.size(), 2);
poly::Dim dim("i", 0, dims[0] - 1);
Var i(common::axis_name(0), Int(32));
Var j(common::axis_name(1), Int(32));
auto expr = handle(i, j);
ir::Tensor Compute(const std::vector<int> &dims, std::function<Expr(Var, Var, Var)> fn) {
return Compute(dims, [fn](const std::vector<Var> &axis) -> Expr {
CHECK_EQ(axis.size(), 3);
return fn(axis[0], axis[1], axis[2]);
});
}

std::vector<Expr> shape;
for (int v : dims) shape.emplace_back(v);
ir::Tensor Compute(const std::vector<int> &dims, std::function<Expr(Var, Var, Var, Var)> fn) {
return Compute(dims, [fn](const std::vector<Var> &axis) -> Expr {
CHECK_EQ(axis.size(), 4);
return fn(axis[0], axis[1], axis[2], axis[3]);
});
}

ir::Tensor tensor(shape, {i, j}, expr.type(), expr);
CHECK(tensor.get());
return std::move(tensor);
ir::Tensor Compute(const std::vector<int> &dims, std::function<Expr(Var, Var, Var, Var, Var)> fn) {
return Compute(dims, [fn](const std::vector<Var> &axis) -> Expr {
CHECK_EQ(axis.size(), 5);
return fn(axis[0], axis[1], axis[2], axis[3], axis[4]);
});
}

template <>
ir::Tensor Compute<compute_handle_3_t>(const std::vector<int> &dims, compute_handle_3_t handle) {
CHECK_EQ(dims.size(), 3);
Var i(common::axis_name(0), Int(32));
Var j(common::axis_name(1), Int(32));
Var k(common::axis_name(2), Int(32));
auto expr = handle(i, j, k);
ir::Tensor Compute(const std::vector<int> &dims, std::function<Expr(const std::vector<Var> &)> fn) {
auto axis = detail::GenDefaultAxis(dims.size());
Expr expr = fn(axis);

std::vector<Expr> shape;
for (int v : dims) shape.emplace_back(v);

ir::Tensor tensor(shape, {i, j}, expr.type(), expr);
return std::move(tensor);
ir::Tensor tensor(shape, axis, expr.type(), expr);
return tensor;
}

} // namespace lang
namespace detail {
std::vector<Var> GenDefaultAxis(int naxis) {
std::vector<Var> axis;
for (int i = 0; i < naxis; i++) {
axis.emplace_back(common::axis_name(i));
}
return axis;
}
} // namespace detail

namespace ir {} // namespace ir
} // namespace lang
} // namespace cinn
20 changes: 13 additions & 7 deletions cinn/lang/compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,20 @@
namespace cinn {
namespace lang {

using ir::Var;
using compute_handle_1_t = std::function<ir::Expr(Var i)>;
using compute_handle_2_t = std::function<ir::Expr(Var i0, Var i1)>;
using compute_handle_3_t = std::function<ir::Expr(Var i0, Var i1, Var i2)>;
using compute_handle_4_t = std::function<ir::Expr(Var i0, Var i1, Var i2, Var i3)>;
//! Compute methods for one to five Vars as arguments.
// @{
ir::Tensor Compute(const std::vector<int>& dims, std::function<Expr(Var)> fn);
ir::Tensor Compute(const std::vector<int>& dims, std::function<Expr(Var, Var)> fn);
ir::Tensor Compute(const std::vector<int>& dims, std::function<Expr(Var, Var, Var)> fn);
ir::Tensor Compute(const std::vector<int>& dims, std::function<Expr(Var, Var, Var, Var)> fn);
ir::Tensor Compute(const std::vector<int>& dims, std::function<Expr(Var, Var, Var, Var, Var)> fn);
ir::Tensor Compute(const std::vector<int>& dims, std::function<Expr(const std::vector<Var>&)> fn);
// @}

template <typename Fn>
ir::Tensor Compute(const std::vector<int>& dims, Fn handle);
namespace detail {
//! Generate `naxis` axis using the global names (i,j,k...).
std::vector<Var> GenDefaultAxis(int naxis);
} // namespace detail

} // namespace lang
} // namespace cinn
4 changes: 2 additions & 2 deletions cinn/lang/compute_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ TEST(Compute, basic) {

Placeholder<float> x("x", {M, N});

ir::Tensor y = Compute<compute_handle_2_t>({10, 20}, [&](Var i, Var j) -> Expr {
ir::Tensor y = Compute({10, 20}, [&](Var i, Var j) -> Expr {
LOG(INFO) << "type: " << x(i, j).type();
return x(i, j) + 1.f;
});
LOG(INFO) << "compute: " << y->operaion->As<ir::ComputeOp>()->body[0];
LOG(INFO) << "y.element: " << y->poly_element->domain();

ir::Tensor z = Compute<compute_handle_2_t>({10, 20}, [&](Var i, Var j) -> Expr { return y(i, j) * 2.f; });
ir::Tensor z = Compute({10, 20}, [&](Var i, Var j) -> Expr { return y(i, j) * 2.f; });

LOG(INFO) << "z: " << z->operaion->As<ir::ComputeOp>()->body[0];
LOG(INFO) << "z.element: " << z->poly_element->domain();
Expand Down

0 comments on commit 60b917a

Please sign in to comment.