Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SCHEDULE] Improve bound inference, support reduce codegen. #30

Merged
merged 1 commit into from
Feb 2, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions include/tvm/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ using Halide::Internal::IRPrinter;
using Halide::Internal::Variable;

using Halide::Internal::make_const;
using Halide::Internal::make_zero;
using Halide::Internal::as_const_int;
using Halide::Internal::as_const_uint;


inline Type TVMType2Type(TVMType t) {
Expand Down Expand Up @@ -126,25 +129,25 @@ using Halide::abs;
using Halide::select;

/*!
* \brief sum of of source expression over rdom
* \brief sum of of source expression over axis
* \param source The source expression.
* \param rdom List of iteration variables that will be used for reduction.
* \param axis List of iteration variables that will be used for reduction.
*/
Expr sum(Expr source, Array<IterVar> rdom);
Expr sum(Expr source, Array<IterVar> axis);

/*!
* \brief max of of source expression over rdom
* \brief max of of source expression over axis
* \param source The source expression.
* \param rdom List of iteration variables that will be used for reduction.
* \param axis List of iteration variables that will be used for reduction.
*/
Expr max(Expr source, Array<IterVar> rdom);
Expr max(Expr source, Array<IterVar> axis);

/*!
* \brief max of of source expression over rdom
* \brief max of of source expression over axis
* \param source The source expression.
* \param rdom List of iteration variables that will be used for reduction.
* \param axis List of iteration variables that will be used for reduction.
*/
Expr min(Expr source, Array<IterVar> rdom);
Expr min(Expr source, Array<IterVar> axis);


// print functions for expr
Expand Down
6 changes: 3 additions & 3 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ struct Reduce : public ExprNode<Reduce> {
std::string op;
/*! \brief The source operand */
Expr source;
/*! \brief The reduction domains */
Array<IterVar> rdom;
/*! \brief The reduction axis */
Array<IterVar> axis;

/*! \brief construct expr from op and rdom */
static Expr make(std::string op, Expr src, Array<IterVar> rdom);
Expand All @@ -40,7 +40,7 @@ struct Reduce : public ExprNode<Reduce> {
v->Visit("dtype", &type);
v->Visit("op", &op);
v->Visit("source", &source);
v->Visit("rdom", &rdom);
v->Visit("axis", &axis);
}
static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
static constexpr const char* _type_key = "Reduce";
Expand Down
21 changes: 10 additions & 11 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
* \file ir_pass.h
* \brief Collection of IR pass functions
*
* All the pass functions in this file are for Stmt,
* We can use PassFunction(Evaluate(expr)) to apply it to Expr
* When the pass functions in this file are for Stmt,
* we can use PassFunction(Evaluate(expr)) to apply it to Expr
*/
#ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_
Expand Down Expand Up @@ -37,15 +37,6 @@ inline Stmt Simplify(Stmt a) {
return Halide::Internal::simplify(a);
}

/*!
* \brief Schedule s' dependent operations.
*
* \param s The schedule to be realized
* \param dom_map The domain of each iter vars.
* \return the result Stmt
*/
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map);

/*!
* \brief verifies whether the IR stmt or Expr is in SSA form.
* That is: each VarExpr is defined and assigned once(in Let/For)
Expand All @@ -69,6 +60,14 @@ bool HasSideEffect(const Expr& e);
*/
Stmt ConvertSSA(Stmt stmt);

/*!
* \brief Substitute the var specified in key->var to be value.
* \param stmt The source statement to be substituted
* \param value_map The map of new values.
* \return The converted form.
*/
Stmt Substitute(Stmt stmt, const Map<IterVar, Expr>& value_map);

/*!
* \brief inline all calls of f in stmt.
*
Expand Down
3 changes: 3 additions & 0 deletions include/tvm/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ class ComputeOpNode : public OperationNode {
public:
/*! \brief IterVar on each axis */
Array<IterVar> axis;
/*! \brief IterVar on each reduction axis, if the body is a Reduce */
Array<IterVar> reduce_axis;
/*! \brief the compute expression */
Expr body;
/*! \brief constructor */
Expand All @@ -64,6 +66,7 @@ class ComputeOpNode : public OperationNode {
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("axis", &axis);
v->Visit("reduce_axis", &reduce_axis);
v->Visit("body", &body);
}
static Operation make(std::string name,
Expand Down
37 changes: 37 additions & 0 deletions include/tvm/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ class Stage : public NodeRef {
IterVar* p_x_outer, IterVar* p_y_outer,
IterVar* p_x_inner, IterVar* p_y_inner,
Expr x_factor, Expr y_factor);
// declare container type
using ContainerType = StageNode;
};

/*!
Expand Down Expand Up @@ -152,11 +154,22 @@ class Schedule : public NodeRef {
Stage operator[](const Tensor& tensor) {
return this->operator[](tensor->op);
}
/*!
* \brief Normalize the schedule.
* This is needed before bound inference.
* Insert necessary RebaseNode to make sure all leaf_iter_vars
* are in form [0, extent)
*
* \return A normalized schedule, can be same as current one.
*/
void normalize();
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const ScheduleNode* operator->() const;
// declare container type
using ContainerType = ScheduleNode;
};

/*!
Expand Down Expand Up @@ -308,6 +321,30 @@ class FuseNode : public IterVarRelationNode {
TVM_DECLARE_NODE_TYPE_INFO(FuseNode);
};

/*!
* \brief Rebase the iteration to make min to be 0.
* This is useful to normalize the Schedule
* to make every leaf variable's min to be 0.
*/
class RebaseNode : public IterVarRelationNode {
public:
/*! \brief The parent domain */
IterVar parent;
/*! \brief The inner domain */
IterVar rebased;

void VisitAttrs(AttrVisitor* v) final {
v->Visit("parent", &parent);
v->Visit("rebased", &rebased);
}

static IterVarRelation make(IterVar parent, IterVar rebased);

static constexpr const char* _type_key = "Rebase";
TVM_DECLARE_NODE_TYPE_INFO(RebaseNode);
};


// implementations
inline const StageNode* Stage::operator->() const {
return static_cast<const StageNode*>(node_.get());
Expand Down
9 changes: 9 additions & 0 deletions include/tvm/schedule_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@ namespace schedule {
*/
Map<IterVar, Range> InferBound(Schedule sch);

/*!
* \brief Schedule s' dependent operations.
*
* \param s The schedule to be realized
* \param dom_map The domain of each iter vars.
* \return the result Stmt
*/
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map);

} // namespace schedule
} // namespace tvm
#endif // TVM_SCHEDULE_PASS_H_
36 changes: 18 additions & 18 deletions python/tvm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,51 +212,51 @@ def IterVar(dom=None, name=None, thread_tag=''):
return _api_internal._IterVar(dom, name, thread_tag)


def sum(expr, rdom):
"""Create a sum expression over rdom
def sum(expr, axis):
"""Create a sum expression over axis

Parameters
----------
expr : Expr
The source expression.

rdom : RDomain
The reduction domainx
axis : IterVar
The reduction IterVar axis
"""
rdom = rdom if isinstance(rdom, list) else [rdom]
x = _make.Reduce("Add", expr, rdom)
axis = axis if isinstance(axis, list) else [axis]
x = _make.Reduce("Add", expr, axis)
return x


def min(expr, rdom):
"""Create a min expression over rdom
def min(expr, axis):
"""Create a min expression over axis

Parameters
----------
expr : Expr
The source expression.

rdom : RDomain
The reduction domainx
axis : IterVar
The reduction IterVar axis
"""
rdom = rdom if isinstance(rdom, list) else [rdom]
x = _make.Reduce("Min", expr, rdom)
axis = axis if isinstance(axis, list) else [axis]
x = _make.Reduce("Min", expr, axis)
return x


def max(expr, rdom):
"""Create a min expression over rdom
def max(expr, axis):
"""Create a min expression over axis

Parameters
----------
expr : Expr
The source expression.

rdom : RDomain
The reduction domainx
axis : IterVar
The reduction IterVar axis
"""
rdom = rdom if isinstance(rdom, list) else [rdom]
x = _make.Reduce("Max", expr, rdom)
axis = axis if isinstance(axis, list) else [axis]
x = _make.Reduce("Max", expr, axis)
return x


Expand Down
6 changes: 4 additions & 2 deletions python/tvm/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,10 @@ def build(sch,

# lowering
bounds = schedule.InferBound(sch)
stmt = ir_pass.ScheduleOps(sch, bounds)
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.StorageFlatten(stmt, binds)
stmt = ir_pass.Simplify(stmt)
print(stmt)
fapi = codegen.MakeAPI(stmt, name, arg_list, len(arg_list))
fsplits = codegen.SplitHostDevice(fapi)

Expand All @@ -73,7 +74,8 @@ def build(sch,
for i, f in enumerate(fsplits):
t = target if i >= 1 else "c"
record_codes.append(codegen.CompileToC(f, output_ssa, t))

for c in record_codes:
print(c)
if target == "cuda":
ret = codegen.BuildNVRTC(fsplits, "stackvm")
elif target == "opencl":
Expand Down
8 changes: 8 additions & 0 deletions python/tvm/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ def __getitem__(self, k):
raise ValueError("Cannot find the operation %s in schedule" % (str(k)))
return self.stage_map[k]

def normalize(self):
"""Build a normalized schedule.

Insert necessary rebase to make certain iter var to start from 0.
This is needed before bound inference and followup step.
"""
_api_internal._ScheduleNormalize(self)

@register_node
class Stage(NodeBase):
"""A Stage represents schedule for one operation."""
Expand Down
6 changes: 6 additions & 0 deletions src/api/api_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,4 +253,10 @@ TVM_REGISTER_API(_StageTile)
*ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
});

TVM_REGISTER_API(_ScheduleNormalize)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Schedule()
.normalize();
});

} // namespace tvm
1 change: 0 additions & 1 deletion src/api/api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ TVM_REGISTER_API(_pass_Equal)
REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA);
REGISTER_PASS4(Inline);
REGISTER_PASS2(ScheduleOps);
REGISTER_PASS2(StorageFlatten);

} // namespace ir
Expand Down
1 change: 1 addition & 0 deletions src/api/api_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ namespace schedule {
REGISTER_SCHEDULE_PASS1(InferBound);
REGISTER_SCHEDULE_PASS1(CreateReadGraph);
REGISTER_SCHEDULE_PASS2(PostDFSOrder);
REGISTER_SCHEDULE_PASS2(ScheduleOps);

} // namespace schedule
} // namespace tvm
5 changes: 3 additions & 2 deletions src/codegen/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
* Copyright (c) 2017 by Contributors
* \file codegen_c.cc
*/
#include <iomanip>
#include "./codegen_c.h"

namespace tvm {
Expand Down Expand Up @@ -216,7 +217,7 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // N
switch (op->type.bits()) {
case 64: case 32: {
std::ostringstream temp;
temp << op->value;
temp << std::scientific << op->value;
if (op->type.bits() == 32) temp << 'f';
p->MarkConst(temp.str());
os << temp.str();
Expand All @@ -225,7 +226,7 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // N
case 16: {
os << '(';
p->PrintType(op->type, os);
os << ')' << op->value << 'f';
os << ')' << std::scientific <<op->value << 'f';
break;
}
default: LOG(FATAL) << "Bad bit-width for float: " << op->type << "\n";
Expand Down
10 changes: 5 additions & 5 deletions src/lang/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
<< op->op
<< ", ";
p->print(op->source);
p->stream << ", rdom=" << op->rdom << ")";
p->stream << ", axis=" << op->axis << ")";
});

} // namespace Internal
Expand All @@ -35,16 +35,16 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
namespace tvm {
namespace ir {

Expr Reduce::make(std::string op, Expr source, Array<IterVar> rdom) {
Expr Reduce::make(std::string op, Expr source, Array<IterVar> axis) {
auto n = std::make_shared<Reduce>();
CHECK(source.defined());
for (size_t i = 0; i < rdom.size(); ++i) {
CHECK(rdom[i].defined());
for (size_t i = 0; i < axis.size(); ++i) {
CHECK(axis[i].defined());
}
n->type = source.type();
n->source = source;
n->op = op;
n->rdom = rdom;
n->axis = axis;
return Expr(n);
}

Expand Down
Loading