Skip to content

Commit

Permalink
Merge branch 'master' into pr
Browse files Browse the repository at this point in the history
  • Loading branch information
antinucleon authored Apr 10, 2019
2 parents 678e765 + 57f47a1 commit 9d4a9d7
Show file tree
Hide file tree
Showing 30 changed files with 1,466 additions and 131 deletions.
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ else(MSVC)
check_cxx_compiler_flag("-std=c++11" SUPPORT_CXX11)
if ("${CMAKE_BUILD_TYPE}" STREQUAL "Debug")
message("Build in Debug mode")
add_compile_options(-O0 -g -Wall -fPIC -fvisibility=hidden -std=c++11)
set(CMAKE_C_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_C_FLAGS} -rdynamic")
set(CMAKE_CXX_FLAGS "-O0 -g -Wall -fPIC -std=c++11 ${CMAKE_CXX_FLAGS} -rdynamic")
else()
set(CMAKE_C_FLAGS "-O2 -Wall -fPIC -fvisibility=hidden ${CMAKE_C_FLAGS}")
set(CMAKE_CXX_FLAGS "-O2 -Wall -fPIC -fvisibility=hidden -std=c++11 ${CMAKE_CXX_FLAGS}")
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -570,8 +570,8 @@ inline const TTypeNode* ExprNode::type_as() const {
* \return The text representation.
*/
std::string AsText(const NodeRef& node,
bool show_meta_data = true,
runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr);
bool show_meta_data = true,
runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_EXPR_H_
1 change: 1 addition & 0 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
* \return The result of the call
*/
virtual R VisitExpr(const Expr& n, Args... args) {
CHECK(n.defined());
static FType vtable = InitVTable();
return vtable(n, this, std::forward<Args>(args)...);
}
Expand Down
30 changes: 24 additions & 6 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
#include <tvm/relay/module.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/type.h>

#include <tvm/relay/adt.h>
#include <string>
#include <vector>

Expand Down Expand Up @@ -344,6 +344,17 @@ TVM_DLL bool WellFormed(const Expr& expr);
*/
TVM_DLL tvm::Array<Var> BoundVars(const Expr& expr);

/*! \brief Get all bound variables from pattern pat.
*
* Bound variables are all variables that got bound by the pat.
* They only have meaning inside that expr, and can only be used in it.
*
* \param pat the Pattern.
*
* \return List of bound vars, in the PostDFS order in the expression.
*/
TVM_DLL tvm::Array<Var> BoundVars(const Pattern& pat);

/*! \brief Get free type parameters from expression expr.
*
* Free variables are variables that are not bound by a
Expand Down Expand Up @@ -431,12 +442,13 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const Module& mod);

/*! \brief Remove expressions which does not effect the program result.
*
* It will remove let bindings which are not referenced, and branches that will
* not be entered.
* It will remove let bindings which are not referenced,
* and inline let bindings that are only used once.
*
* For example, this pass should turn `let a = 1 in 2` into `2`, as the value of
* the expression does not depend on a. Another example is `if (true) then 1
* else 2` will be optimized into 1.
* For example, this pass should turn `let a = 1 in 2` into `2`,
* as the value of the expression does not depend on a.
*
* As another example, `let a = 1 in a` will be optimized into 1.
*
* \param e the expression to optimize.
*
Expand Down Expand Up @@ -558,6 +570,12 @@ TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod);
*/
TVM_DLL Expr ToGraphNormalForm(const Expr& e);

/*! \brief Aggressive constant propagation/constant folding/inlining.
* It will do as much computation in compile time as possible.
* It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
* As a side effect, code size will explode.
*/
Expr PartialEval(const Expr& e);
} // namespace relay
} // namespace tvm

Expand Down
1 change: 1 addition & 0 deletions include/tvm/relay/pattern_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class PatternFunctor<R(const Pattern& n, Args...)> {
* \return The result of the call
*/
virtual R VisitPattern(const Pattern& n, Args... args) {
CHECK(n.defined());
static FType vtable = InitVTable();
return vtable(n, this, std::forward<Args>(args)...);
}
Expand Down
2 changes: 1 addition & 1 deletion jvm/core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ under the License.
<dependency>
<groupId>com.puppycrawl.tools</groupId>
<artifactId>checkstyle</artifactId>
<version>[8.18,)</version>
<version>8.18</version>
</dependency>
</dependencies>
<executions>
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/contrib/ndk.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,6 @@ def create_shared(output,
msg = "Compilation error:\n"
msg += py_str(out)
raise RuntimeError(msg)

# assign output format
create_shared.output_format = "so"
1 change: 1 addition & 0 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@

# ExprFunctor
ExprFunctor = expr_functor.ExprFunctor
ExprVisitor = expr_functor.ExprVisitor
ExprMutator = expr_functor.ExprMutator

# Parser
Expand Down
67 changes: 64 additions & 3 deletions python/tvm/relay/expr_functor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,8 @@ def __init__(self):
# pylint: disable=no-else-return
def visit(self, expr):
"""Apply the visitor to an expression."""
found = self.memo_map.get(expr)
if found:
return found
if expr in self.memo_map:
return self.memo_map[expr]

if isinstance(expr, Function):
res = self.visit_function(expr)
Expand Down Expand Up @@ -126,6 +125,68 @@ def visit_match(self, _):
raise NotImplementedError()


class ExprVisitor(ExprFunctor):
"""
A visitor over Expr.
The default behavior recursively traverses the AST.
"""
def visit_tuple(self, t):
for x in t.fields:
self.visit(x)

def visit_call(self, c):
self.visit(c.op)
for a in c.args:
self.visit(a)

def visit_var(self, v):
pass

def visit_let(self, l):
self.visit(l.var)
self.visit(l.value)
self.visit(l.body)

def visit_function(self, f):
self.visit(f.body)

def visit_if(self, i):
self.visit(i.cond)
self.visit(i.true_branch)
self.visit(i.false_branch)

def visit_global_var(self, gv):
pass

def visit_constructor(self, c):
pass

def visit_op(self, op):
pass

def visit_constant(self, const):
pass

def visit_ref_create(self, r):
self.visit(r.value)

def visit_ref_read(self, r):
self.visit(r.ref)

def visit_ref_write(self, r):
self.visit(r.ref)
self.visit(r.value)

def visit_tuple_getitem(self, t):
self.visit(t.tuple_value)

def visit_match(self, m):
self.visit(m.data)
for c in m.clause:
self.visit(c.rhs)


class ExprMutator(ExprFunctor):
"""
A functional visitor over Expr.
Expand Down
26 changes: 23 additions & 3 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,20 +722,23 @@ def fuse_ops(expr, opt_level=1):
return _ir_pass.FuseOps(expr, opt_level)


def combine_parallel_conv2d(expr):
"""Fold multiple conv2d into one.
def combine_parallel_conv2d(expr, min_num_branches=3):
"""Combine multiple conv2d into one.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
min_num_branches : int
The minimum number of parallel branches when the transformation should be applied.
Returns
-------
transformed_expr : tvm.relay.Expr
Transformed expression
"""
return _ir_pass.CombineParallelConv2D(expr)
return _ir_pass.CombineParallelConv2D(expr, min_num_branches)


def alter_op_layout(expr):
Expand Down Expand Up @@ -953,3 +956,20 @@ def pass_debug_print(ast, show_meta_data=True, annotate=None, gnf=True):
A text representation of `ast`.
"""
return _ir_pass.pass_debug_print(ast, show_meta_data, annotate, gnf)


def partial_evaluate(expr):
"""
Evaluate the static fragment of the code.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
Returns
-------
expr : tvm.relay.Expr
The output expression.
"""
return _ir_pass.partial_evaluate(expr)
2 changes: 1 addition & 1 deletion src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ class Interpreter :
CHECK_NE(cvn->constructor->tag, -1);
if (op->constructor->tag == cvn->constructor->tag) {
// todo(M.K.): should use ptr equality but it is broken
CHECK(op->patterns.size() == cvn->fields.size());
CHECK_EQ(op->patterns.size(), cvn->fields.size());
for (size_t i = 0; i < op->patterns.size(); ++i) {
if (!VisitPattern(op->patterns[i], cvn->fields[i])) {
return false;
Expand Down
3 changes: 0 additions & 3 deletions src/relay/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,6 @@ Expr ExprMutator::VisitExpr(const Expr& expr) {
}

Expr ExprMutator::VisitExpr_(const VarNode* op) {
// NOTE: var will only be mutated once
// Thanks to the memo and reused during rewriting if necessary.
// It is safe to assume that the
if (op->type_annotation.defined()) {
auto type = this->VisitType(op->type_annotation);
if (!op->type_annotation.same_as(type)) {
Expand Down
70 changes: 60 additions & 10 deletions src/relay/ir/pretty_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,15 +245,55 @@ class PrettyPrinter :
return Doc(unique_prefix);
}

Doc Print(Kind k) {
switch (k) {
case kType:
return Doc("Type");
case kShapeVar:
return Doc("Shape");
case kBaseType:
return Doc("BaseType");
case kConstraint:
return Doc("Constraint");
case kAdtHandle:
return Doc("AdtHandle");
case kTypeData:
return Doc("TypeData");
default:
LOG(ERROR) << "Unknown Kind";
throw;
}
}
/*!
* \brief Allocate name to a variable.
* \param var The input variable.
* \return The corresponding name.
*/
* \brief Allocate name to a type variable.
* \param var The input type variable.
* \return The corresponding name.
*/
Doc AllocTypeVar(const TypeVar& var) {
std::string name = var->var->name_hint;
if (name.length() == 0 || !std::isalpha(name[0])) {
name = "t" + name;
}
Doc val = GetUniqueName("%" + name);
if (memo_type_.count(var)) {
val << "-malformed-ir";
}
memo_type_[var] = val;
if (var->kind != kType) {
val << ": " << Print(var->kind);
}
return val;
}

/*!
* \brief Allocate name to a variable.
* \param var The input variable.
* \return The corresponding name.
*/
Doc AllocVar(const Var& var) {
std::string name = var->name_hint();
// always make sure first name is alpha
if (name.length() != 0 && !std::isalpha(name[0])) {
if (name.length() == 0 || !std::isalpha(name[0])) {
name = "v" + name;
}
Doc val = GetUniqueName("%" + name);
Expand Down Expand Up @@ -387,12 +427,18 @@ class PrettyPrinter :
}

Doc PrintFunc(const Doc& prefix, const Function& fn) {
// TODO(tqchen, M.K.) support generic function
// Possibly through meta data
CHECK_EQ(fn->type_params.size(), 0U)
<< "generic fn not yet supported";
Doc doc;
doc << prefix << "(";
doc << prefix;
if (fn->type_params.size() > 0) {
doc << "<";
std::vector<Doc> type_params;
for (const TypeVar& tv : fn->type_params) {
type_params.push_back(AllocTypeVar(tv));
}
doc << PrintVec(type_params);
doc << ">";
}
doc << "(";
std::vector<Doc> params;
for (Var param : fn->params) {
params.push_back(AllocVar(param));
Expand Down Expand Up @@ -516,6 +562,10 @@ class PrettyPrinter :
return Print(GetRef<NodeRef>(node), true);
}

Doc VisitType_(const TypeVarNode* node) final {
return AllocTypeVar(GetRef<TypeVar>(node));
}

Doc VisitType_(const TensorTypeNode* node) final {
// scalar type
if (node->shape.size() == 0) {
Expand Down
1 change: 1 addition & 0 deletions src/relay/ir/type_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class TypeFunctor<R(const Type& n, Args...)> {
* \return The result of the call
*/
virtual R VisitType(const Type& n, Args... args) {
CHECK(n.defined());
static FType vtable = InitVTable();
return vtable(n, this, std::forward<Args>(args)...);
}
Expand Down
Loading

0 comments on commit 9d4a9d7

Please sign in to comment.