Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#18 from Superjomn/refine-tile
Browse files Browse the repository at this point in the history
add tile transform
  • Loading branch information
Superjomn authored Feb 6, 2020
2 parents 7f64f9c + 3ec8a18 commit bedf2e4
Show file tree
Hide file tree
Showing 11 changed files with 82 additions and 70 deletions.
2 changes: 1 addition & 1 deletion cinn/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ cc_library(common
graph_utils.cc
DEPS boost utils)

cc_test(test_pod_value SRCS pod_value_test.cc DEPS common)
cc_test(test_pod_value SRCS pod_value_test.cc DEPS common ir)
cc_test(test_shared SRCS shared_test.cc DEPS common)
cc_test(test_graph_utils SRCS graph_utils_test.cc DEPS common)
1 change: 1 addition & 0 deletions cinn/common/object.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#pragma once
#include "cinn/common/shared.h"

namespace cinn {
Expand Down
31 changes: 8 additions & 23 deletions cinn/common/pod_value.cc
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
#include "cinn/common/pod_value.h"
#include "cinn/ir/ir.h"
#include "cinn/ir/node.h"

namespace cinn {

namespace ir {

class Expr;
class Var;

} // namespace ir

namespace common {

//! Implement the type_code for all the supported types.
Expand Down Expand Up @@ -91,16 +98,6 @@ void PODValue::Set<char const *>(char const *v) {
type_code_ = TypeCode<char *>();
value_.v_str = const_cast<char *>(v);
}
template <>
void PODValue::Set<ir::Var>(ir::Var v) {
type_code_ = TypeCode<ir::Var>();
value_.v_handle = v.ptr();
}
template <>
void PODValue::Set<ir::Expr>(ir::Expr v) {
type_code_ = TypeCode<ir::Expr>();
value_.v_handle = v.ptr();
}
// @}

//! Implement ToValue.
Expand Down Expand Up @@ -141,18 +138,6 @@ Value ToValue<char const *>(char const *v) {
val.v_str = const_cast<char *>(v);
return val;
}
template <>
Value ToValue<ir::Expr>(ir::Expr v) {
Value val;
val.v_handle = v.ptr();
return val;
}
template <>
Value ToValue<ir::Var>(ir::Var v) {
Value val;
val.v_handle = v.ptr();
return val;
}
// @}

} // namespace common
Expand Down
28 changes: 28 additions & 0 deletions cinn/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,4 +220,32 @@ Expr Call::Make(Type type,
return Expr(node);
}
} // namespace ir

namespace common {

template <>
void PODValue::Set<ir::Var>(ir::Var v) {
type_code_ = TypeCode<ir::Var>();
value_.v_handle = v.ptr();
}
template <>
void PODValue::Set<ir::Expr>(ir::Expr v) {
type_code_ = TypeCode<ir::Expr>();
value_.v_handle = v.ptr();
}
template <>
Value ToValue<ir::Expr>(ir::Expr v) {
Value val;
val.v_handle = v.ptr();
return val;
}
template <>
Value ToValue<ir::Var>(ir::Var v) {
Value val;
val.v_handle = v.ptr();
return val;
}

} // namespace common

} // namespace cinn
24 changes: 1 addition & 23 deletions cinn/ir/operation.cc
Original file line number Diff line number Diff line change
@@ -1,27 +1,5 @@
#include "cinn/ir/operation.h"

namespace cinn {
namespace ir {

Operation ExternOp::Make(std::string name,
std::string tag,
std::map<std::string, IrNodeRef> attrs,
std::vector<Tensor> inputs,
std::vector<Buffer> input_placeholders,
std::vector<Buffer> output_placeholders,
Stmt body) {
auto n = common::make_shared<ExternOp>();
n->name = std::move(name);
n->tag = std::move(tag);
n->attrs = std::move(attrs);
CHECK_EQ(inputs.size(), input_placeholders.size());

n->inputs = std::move(inputs);
n->input_placeholders = std::move(input_placeholders);
n->output_placeholders = std::move(output_placeholders);
n->body = std::move(body);
return Operation(n);
}

} // namespace ir
namespace ir {} // namespace ir
} // namespace cinn
17 changes: 2 additions & 15 deletions cinn/ir/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,7 @@ struct PlaceholderOp : public _Operation_ {
//! The data type of the input.
Type dtype;

static Operation Make(std::string name, std::vector<Expr> shape, Type dtype) {
auto n = common::make_shared<PlaceholderOp>();
n->name = name;
n->shape = shape;
n->dtype = dtype;
return Operation(n);
}
static Operation Make(std::string name, std::vector<Expr> shape, Type dtype);
};

/**
Expand All @@ -68,14 +62,7 @@ struct ComputeOp : public _Operation_ {
std::string tag,
std::map<std::string, IrNodeRef> attrs,
std::vector<Var> axis,
std::vector<Expr> body) {
auto n = common::make_shared<ComputeOp>();
n->name = std::move(name);
n->tag = std::move(tag);
n->attrs = std::move(attrs);
n->body = std::move(body);
return Operation(n);
}
std::vector<Expr> body);
};

} // namespace ir
Expand Down
4 changes: 2 additions & 2 deletions cinn/ir/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ Tensor::Tensor(const std::vector<Expr> &shape, Type type) : IrNodeRef(common::ma
}

const _Tensor_ *Tensor::operator->() const {
auto *p = As<_Tensor_>();
auto *p = Object::As<_Tensor_>();
CHECK(p) << "type not match";
return p;
}
_Tensor_ *Tensor::operator->() {
auto *p = As<_Tensor_>();
auto *p = Object::As<_Tensor_>();
CHECK(p) << "type not match";
return p;
}
Expand Down
8 changes: 4 additions & 4 deletions cinn/poly/element.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,11 @@ std::tuple<Iterator, Iterator, Iterator, Iterator> Element::Tile(const Iterator
const Iterator &level1,
int factor0,
int factor1) {
Iterator level0_inner(InnerName(level0));
Iterator level0_outer(OuterName(level0));
Iterator level1_inner(InnerName(level1));
Iterator level1_outer(OuterName(level1));
Iterator level0_inner, level0_outer;
Iterator level1_inner, level1_outer;

std::tie(level0_outer, level0_inner) = Split(level0, factor0);
std::tie(level1_outer, level1_inner) = Split(level1, factor1);
return std::make_tuple(level0_outer, level0_inner, level1_outer, level1_inner);
}

Expand Down
30 changes: 28 additions & 2 deletions cinn/poly/element_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,38 @@
namespace cinn {
namespace poly {

TEST(Element, basic) {
TEST(Element, split) {
isl::ctx ctx(isl_ctx_alloc());
isl::set domain(ctx, "{ S[i,j]: 0<=i,j<=100 }");

Element ele(domain);
ele.Split(Iterator("i"), 4);
Iterator outer, inner;
std::tie(outer, inner) = ele.Split(Iterator("i"), 4);
LOG(INFO) << ele.schedule();
EXPECT_EQ(utils::GetStreamCnt(ele.schedule()),
"{ S[i, j] -> S[i_outer, i_inner, j' = j] : (-i + i_inner) mod 4 = 0 and -3 + i <= 4i_outer <= i and 0 <= "
"i_inner <= 3 }");

EXPECT_EQ(outer.id, "i_outer");
EXPECT_EQ(inner.id, "i_inner");
}

TEST(Element, tile) {
isl::ctx ctx(isl_ctx_alloc());
isl::set domain(ctx, "{ S[i,j,k]: 0<=i,j,k<=100 }");
Element ele(domain);

Iterator outer0, inner0, outer1, inner1;
std::tie(outer0, inner0, outer1, inner1) = ele.Tile(Iterator("i"), Iterator("j"), 4, 6);
LOG(INFO) << ele.schedule();
EXPECT_EQ(outer0.id, "i_outer");
EXPECT_EQ(outer1.id, "j_outer");
EXPECT_EQ(inner0.id, "i_inner");
EXPECT_EQ(outer1.id, "j_outer");
EXPECT_EQ(
utils::GetStreamCnt(ele.schedule()),
"{ S[i, j, k] -> S[i_outer, i_inner, j_outer, j_inner, k' = k] : (-i + i_inner) mod 4 = 0 and (-j + j_inner) mod "
"6 = 0 and -3 + i <= 4i_outer <= i and 0 <= i_inner <= 3 and -5 + j <= 6j_outer <= j and 0 <= j_inner <= 5 }");
}

} // namespace poly
Expand Down
3 changes: 3 additions & 0 deletions cinn/poly/map.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@ namespace poly {
struct Iterator {
std::string id;

Iterator() = default;
explicit Iterator(const std::string& id) : id(id) {}
explicit Iterator(const Iterator& x) : id(x.id) {}
explicit Iterator(Iterator&& x) : id(std::move(x.id)) {}

Iterator& operator=(const Iterator& other) { id = other.id; }

friend std::ostream& operator<<(std::ostream& os, const Iterator& x);
};

Expand Down
4 changes: 4 additions & 0 deletions cinn/utils/functional.cc
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
#include "cinn/utils/functional.h"

namespace cinn {
namespace utils {} // namespace utils
} // namespace cinn

0 comments on commit bedf2e4

Please sign in to comment.