Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
add

me find type checker problem

save

save

lint

do

lint

reset ti

add some doc

add failed test case

add recursion for cps

add recursion for cps

fix pytest
  • Loading branch information
MarisaKirisame committed Jul 1, 2019
1 parent 6c81d78 commit a7ec531
Show file tree
Hide file tree
Showing 13 changed files with 777 additions and 107 deletions.
48 changes: 44 additions & 4 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,15 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const Module& mod);
*/
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const Module& mod);

/*!
* \brief Deduplicate the bound variables and type variables in the expression.
*
* \param e the expression.
*
* \return the deduplicated expression.
*/
TVM_DLL Expr DeDup(const Expr& e);

/*! \brief Remove expressions which does not effect the program result.
*
* It will remove let bindings which are not referenced,
Expand Down Expand Up @@ -437,24 +446,55 @@ TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const Module& mod);
* It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
* As a side effect, code size will explode.
*
* \param e the expression
* \param mod the module
* \param e the expression.
* \param mod the module.
*
* \return the optimized expression.
*/
TVM_DLL Expr PartialEval(const Expr& e, const Module& mod);

/*
/*!
* \brief Bind function parameters or free variables.
*
* Parameter binding can only happen if expr is a Function.
* binds cannot change internal arguments of internal functions.
*
* \param expr The function to be binded.
* \param binds The map of arguments to
* \param bind_map The map of arguments to Expr.
*/
TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& bind_map);

/*!
* \brief Turn an expression into continuation passing style(CPS).
*
* CPS mean that every function will, instead of returning the result directly,
* be passed down an extra function (called the continuation) as argument,
* and pass the result to the continuation instead.
*
* Thus, every function call has to be passed an extra argument
* that represent the rest of the computation (Hence the name of continuation).
*
* Similarly, all other compute will be wrapped and call the continuation as well.
*
* \param f the function.
* \param mod the module.
*
* \return the converted Function.
*/
TVM_DLL Function ToCPS(const Function& f, const Module& mod);

/*!
* \brief Remove the continuation argument of a CPS function.
*
* Note that this only transform the type back into un-CPS form
* when there is no higher order input/output.
*
* \param f the function.
*
* \return the converted Function.
*/
TVM_DLL Function ToCPS(const Function& f, const Module& mod);

/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
/*! \brief Hash a Relay type.
Expand Down
16 changes: 16 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,22 @@ TVM_DLL Pass RewriteAnnotatedOps(int fallback_device);
*/
TVM_DLL Pass ToANormalForm();

/*!
* \brief Turn an expression into continuation passing style(CPS).
*
* CPS mean that every function will, instead of returning the result directly,
* be passed down an extra function (called the continuation) as argument,
* and pass the result to the continuation instead.
*
* Thus, every function call has to be passed an extra argument
* that represent the rest of the computation (Hence the name of continuation).
*
* Similarly, all other compute will be wrapped and call the continuation as well.
*
* \return the pass.
*/
TVM_DLL Pass ToCPS();

/*!
* \brief Remove let binding and directly share via pointer instead.
*
Expand Down
56 changes: 50 additions & 6 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,30 +546,74 @@ def to_a_normal_form(expr, mod=None):
Parameters
----------
expr : tvm.relay.Expr
expr: tvm.relay.Expr
The input expression.
mod : Optional[tvm.relay.Module]
mod: Optional[tvm.relay.Module]
The global module.
Returns
-------
result : tvm.relay.Expr
result: tvm.relay.Expr
The output expression.
"""
return _ir_pass.to_a_normal_form(expr, mod)


def to_cps(func, mod=None):
"""
Turn expression into CPS expression.
Every intermediate compute will be passed to a continuation.
Parameters
----------
func: tvm.relay.Function
The input function.
mod: Optional[tvm.relay.Module]
The global module.
Returns
-------
result: tvm.relay.Function
The output function.
"""
return _ir_pass.to_cps(func, mod)


def un_cps(func):
"""
Turn an cps function into a Function without the continuation argument.
Note that this will not give the exact same interface as before cps:
If the input/output is higher order, they will still be in cps form.
Parameters
----------
func: tvm.relay.Function
The input function
Returns
-------
result: tvm.relay.Function
The output function
"""
x = _ir_pass.un_cps(func)
return x


def to_graph_normal_form(expr):
"""Turn A Normal Form expression into Graph Normal Form expression
Parameters
----------
expr : tvm.relay.Expr
expr: tvm.relay.Expr
The input expression
Returns
-------
result : tvm.relay.Expr
The output expression
result: tvm.relay.Expr
The output expression
"""
return _ir_pass.to_graph_normal_form(expr)

Expand Down
18 changes: 18 additions & 0 deletions python/tvm/relay/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,21 @@ def ToANormalForm():
"""
return _transform.ToANormalForm()


def ToCPS(expr, mod=None):
"""
Turn expression into continuation passing style(CPS).
Every intermediate compute will be passed to a continuation.
Returns
-------
result: tvm.relay.Pass
The registered pass that transforms an expression into CPS.
"""
return _ir_pass.to_cps(expr, mod)


def EtaExpand():
"""Add abstraction over a function
Expand All @@ -416,6 +431,7 @@ def EtaExpand():
"""
return _transform.EtaExpand()


def ToGraphNormalForm():
"""Turn A Normal Form expression into Graph Normal Form expression
Expand Down Expand Up @@ -454,6 +470,7 @@ def PartialEvaluate():
"""
return _transform.PartialEvaluate()


def CanonicalizeCast():
"""
Canonicalize cast expressions to make operator fusion more efficient.
Expand All @@ -465,6 +482,7 @@ def CanonicalizeCast():
"""
return _transform.CanonicalizeCast()


def _wrap_class_module_pass(pass_cls, pass_info):
"""Wrap a python class as function pass"""
class PyModulePass(ModulePass):
Expand Down
3 changes: 2 additions & 1 deletion src/relay/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,9 @@ GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const {
}

void ModuleNode::Add(const GlobalVar& var,
const Function& func,
const Function& f,
bool update) {
Function func = Downcast<Function>(DeDup(f));
// Type check the item before we add it to the module.
auto mod = GetRef<Module>(this);
Function checked_func = InferType(func, mod, var);
Expand Down
12 changes: 11 additions & 1 deletion src/relay/ir/pretty_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -645,11 +645,21 @@ class PrettyPrinter :

Doc VisitType_(const FuncTypeNode* node) final {
Doc doc;
doc << "fn ";
if (node->type_params.size() != 0) {
doc << "<";
std::vector<Doc> type_params;
for (Type type_param : node->type_params) {
type_params.push_back(Print(type_param));
}
doc << PrintVec(type_params);
doc << ">";
}
std::vector<Doc> arg_types;
for (Type arg_type : node->arg_types) {
arg_types.push_back(Print(arg_type));
}
return doc << "fn (" << PrintVec(arg_types) << ") -> " << Print(node->ret_type);
return doc << "(" << PrintVec(arg_types) << ") -> " << Print(node->ret_type);
}

Doc VisitType_(const RefTypeNode* node) final {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/ir/type_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ class TypeBinder : public TypeMutator {
};

Type Bind(const Type& type, const tvm::Map<TypeVar, Type>& args_map) {
return TypeBinder(args_map).VisitType(type);
return type.defined() ? TypeBinder(args_map).VisitType(type) : type;
}

} // namespace relay
Expand Down
122 changes: 122 additions & 0 deletions src/relay/pass/de_duplicate.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
*
* \file de_duplicate.cc
* \brief Use a fresh Id for every Var to make the result well-formed.
*/

#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include "../ir/type_functor.h"

namespace tvm {
namespace relay {

Expr DeDup(const Expr& e) {
class DeDupMutator : public TypeMutator,
public ExprMutator,
public PatternMutator {
public:
TypeVar Fresh(const TypeVar& tv) {
TypeVar ret = TypeVarNode::make(tv->var->name_hint, tv->kind);
type_rename_[tv] = ret;
return ret;
}

Var Fresh(const Var& v) {
Var ret = VarNode::make(v->name_hint(), VisitType(v->type_annotation));
rename_[v] = ret;
return ret;
}

Expr VisitExpr(const Expr& e) final {
return ExprMutator::VisitExpr(e);
}

Expr VisitExpr_(const VarNode* op) final {
Var v = GetRef<Var>(op);
return rename_.count(v) != 0 ? rename_.at(v) : v;
}

Expr VisitExpr_(const LetNode* op) final {
Var v = Fresh(op->var);
return LetNode::make(v, VisitExpr(op->value), VisitExpr(op->body));
}

Type VisitType(const Type& t) final {
return t.defined() ? TypeMutator::VisitType(t) : t;
}

Expr VisitExpr_(const FunctionNode* op) final {
tvm::Array<TypeVar> type_params;
for (const TypeVar& type_param : op->type_params) {
type_params.push_back(Fresh(type_param));
}
tvm::Array<Var> params;
for (const Var& param : op->params) {
params.push_back(Fresh(param));
}
return FunctionNode::make(params,
VisitExpr(op->body),
VisitType(op->ret_type),
type_params,
op->attrs);
}

Pattern VisitPattern(const Pattern& p) final {
return PatternMutator::VisitPattern(p);
}

Pattern VisitPattern_(const PatternVarNode* op) final {
return PatternVarNode::make(Fresh(op->var));
}

Clause VisitClause(const Clause& c) final {
Pattern pat = VisitPattern(c->lhs);
return ClauseNode::make(pat, VisitExpr(c->rhs));
}

Type VisitType_(const TypeVarNode* op) final {
TypeVar v = GetRef<TypeVar>(op);
return type_rename_.count(v) != 0 ? type_rename_.at(v) : v;
}

Var VisitVar(const Var& v) final {
return Fresh(v);
}

private:
std::unordered_map<Var, Var, NodeHash, NodeEqual> rename_;
std::unordered_map<TypeVar, TypeVar, NodeHash, NodeEqual> type_rename_;
};

Expr ret = DeDupMutator().VisitExpr(e);
CHECK_EQ(FreeVars(ret).size(), FreeVars(e).size());
return ret;
}

TVM_REGISTER_API("relay._ir_pass.dedup")
.set_body_typed(FreeVars);

} // namespace relay
} // namespace tvm
Loading

0 comments on commit a7ec531

Please sign in to comment.