From ceca458eabe5475bc90c1bc8d81636ca9add5ad8 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Thu, 27 Jun 2019 20:26:47 -0700 Subject: [PATCH] save add me find type checker problem save save lint do lint reset ti add some doc add failed test case add recursion for cps add recursion for cps fix pytest lint save fix test error lint save fix error --- include/tvm/relay/pass.h | 40 ++ include/tvm/relay/transform.h | 16 + python/tvm/relay/ir_pass.py | 42 ++ python/tvm/relay/transform.py | 15 + src/relay/ir/adt.cc | 6 +- src/relay/ir/module.cc | 3 +- src/relay/ir/pretty_printer.cc | 12 +- src/relay/ir/type_functor.cc | 2 +- src/relay/pass/de_duplicate.cc | 122 ++++++ src/relay/pass/dependency_graph.h | 2 +- src/relay/pass/let_list.h | 25 +- src/relay/pass/partial_eval.cc | 84 +--- src/relay/pass/to_a_normal_form.cc | 37 +- src/relay/pass/to_cps.cc | 397 ++++++++++++++++++ .../relay/test_pass_to_a_normal_form.py | 16 + tests/python/relay/test_pass_to_cps.py | 101 +++++ 16 files changed, 807 insertions(+), 113 deletions(-) create mode 100644 src/relay/pass/de_duplicate.cc create mode 100644 src/relay/pass/to_cps.cc create mode 100644 tests/python/relay/test_pass_to_cps.py diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 79172c3743167..47e029f063ae4 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -271,6 +271,15 @@ TVM_DLL tvm::Array AllTypeVars(const Expr& expr, const Module& mod); */ TVM_DLL tvm::Array AllTypeVars(const Type& t, const Module& mod); +/*! + * \brief Deduplicate the bound variables and type variables in the expression. + * + * \param e the expression. + * + * \return the deduplicated expression. + */ +TVM_DLL Expr DeDup(const Expr& e); + /*! * \brief Fold constant expressions. * @@ -377,6 +386,37 @@ TVM_DLL Array UnmatchedCases(const Match& match, const Module& mod); */ TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& binds); +/*! + * \brief Turn an expression into continuation passing style(CPS). + * + * CPS mean that every function will, instead of returning the result directly, + * be passed down an extra function (called the continuation) as argument, + * and pass the result to the continuation instead. + * + * Thus, every function call has to be passed an extra argument + * that represent the rest of the computation (Hence the name of continuation). + * + * Similarly, all other compute will be wrapped and call the continuation as well. + * + * \param f the function. + * \param mod the module. + * + * \return the converted Function. + */ +TVM_DLL Function ToCPS(const Function& f, const Module& mod); + +/*! + * \brief Remove the continuation argument of a CPS function. + * + * Note that this only transform the type back into un-CPS form + * when there is no higher order input/output. + * + * \param f the function. + * + * \return the converted Function. + */ +TVM_DLL Function UnCPS(const Function& f); + /*! \brief A hashing structure in the style of std::hash. */ struct StructuralHash { /*! \brief Hash a Relay type. diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 9ae71d824f94e..0e7a6c71df8c8 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -434,6 +434,22 @@ TVM_DLL Pass RewriteAnnotatedOps(int fallback_device); */ TVM_DLL Pass ToANormalForm(); +/*! + * \brief Turn an expression into continuation passing style(CPS). + * + * CPS mean that every function will, instead of returning the result directly, + * be passed down an extra function (called the continuation) as argument, + * and pass the result to the continuation instead. + * + * Thus, every function call has to be passed an extra argument + * that represent the rest of the computation (Hence the name of continuation). + * + * Similarly, all other compute will be wrapped and call the continuation as well. + * + * \return the pass. + */ +TVM_DLL Pass ToCPS(); + /*! * \brief Remove let binding and directly share via pointer instead. * diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 52dc34d7aac9d..bc8ef2c8707a7 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -498,6 +498,48 @@ def collect_device_annotation_ops(expr): return _ir_pass.CollectDeviceAnnotationOps(expr) +def to_cps(func, mod=None): + """ + Turn expression into CPS expression. + + Every intermediate compute will be passed to a continuation. + + Parameters + ---------- + func: tvm.relay.Function + The input function. + + mod: Optional[tvm.relay.Module] + The global module. + + Returns + ------- + result: tvm.relay.Function + The output function. + """ + return _ir_pass.to_cps(func, mod) + + +def un_cps(func): + """ + Turn an cps function into a Function without the continuation argument. + + Note that this will not give the exact same interface as before cps: + If the input/output is higher order, they will still be in cps form. + + Parameters + ---------- + func: tvm.relay.Function + The input function + + Returns + ------- + result: tvm.relay.Function + The output function + """ + return _ir_pass.un_cps(func) + + def gradient(expr, mod=None, mode='higher_order'): """ Transform the input function, diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index ba4857dc4d36e..53c8e864e4bf2 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -412,6 +412,20 @@ def ToANormalForm(): return _transform.ToANormalForm() +def ToCPS(expr, mod=None): + """ + Turn expression into continuation passing style(CPS). + + Every intermediate compute will be passed to a continuation. + + Returns + ------- + result: tvm.relay.Pass + The registered pass that transforms an expression into CPS. + """ + return _ir_pass.to_cps(expr, mod) + + def EtaExpand(): """Add abstraction over a function @@ -461,6 +475,7 @@ def PartialEvaluate(): """ return _transform.PartialEvaluate() + def CanonicalizeCast(): """ Canonicalize cast expressions to make operator fusion more efficient. diff --git a/src/relay/ir/adt.cc b/src/relay/ir/adt.cc index b59281a4f1fd9..3eb1d99f5a889 100644 --- a/src/relay/ir/adt.cc +++ b/src/relay/ir/adt.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,7 +18,7 @@ */ /*! - * Copyright (c) 2018 by Contributors + * Copyright (c) 2019 by Contributors * \file src/tvm/ir/adt.cc * \brief AST nodes for Relay algebraic data types (ADTs). */ diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index 58f614a3cc77c..e2da2c3bbbb6d 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -88,8 +88,9 @@ GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const { } void ModuleNode::Add(const GlobalVar& var, - const Function& func, + const Function& f, bool update) { + Function func = Downcast(DeDup(f)); // Type check the item before we add it to the module. auto mod = GetRef(this); Function checked_func = InferType(func, mod, var); diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 7a61079204edc..39fc36fba4baf 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -645,11 +645,21 @@ class PrettyPrinter : Doc VisitType_(const FuncTypeNode* node) final { Doc doc; + doc << "fn "; + if (node->type_params.size() != 0) { + doc << "<"; + std::vector type_params; + for (Type type_param : node->type_params) { + type_params.push_back(Print(type_param)); + } + doc << PrintVec(type_params); + doc << ">"; + } std::vector arg_types; for (Type arg_type : node->arg_types) { arg_types.push_back(Print(arg_type)); } - return doc << "fn (" << PrintVec(arg_types) << ") -> " << Print(node->ret_type); + return doc << "(" << PrintVec(arg_types) << ") -> " << Print(node->ret_type); } Doc VisitType_(const RefTypeNode* node) final { diff --git a/src/relay/ir/type_functor.cc b/src/relay/ir/type_functor.cc index 9fca2e0326859..516f4c875b20c 100644 --- a/src/relay/ir/type_functor.cc +++ b/src/relay/ir/type_functor.cc @@ -221,7 +221,7 @@ class TypeBinder : public TypeMutator { }; Type Bind(const Type& type, const tvm::Map& args_map) { - return TypeBinder(args_map).VisitType(type); + return type.defined() ? TypeBinder(args_map).VisitType(type) : type; } } // namespace relay diff --git a/src/relay/pass/de_duplicate.cc b/src/relay/pass/de_duplicate.cc new file mode 100644 index 0000000000000..866ac9e37c544 --- /dev/null +++ b/src/relay/pass/de_duplicate.cc @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * + * \file de_duplicate.cc + * \brief Use a fresh Id for every Var to make the result well-formed. + */ + +#include +#include +#include +#include "../ir/type_functor.h" + +namespace tvm { +namespace relay { + +Expr DeDup(const Expr& e) { + class DeDupMutator : public TypeMutator, + public ExprMutator, + public PatternMutator { + public: + TypeVar Fresh(const TypeVar& tv) { + TypeVar ret = TypeVarNode::make(tv->var->name_hint, tv->kind); + type_rename_[tv] = ret; + return ret; + } + + Var Fresh(const Var& v) { + Var ret = VarNode::make(v->name_hint(), VisitType(v->type_annotation)); + rename_[v] = ret; + return ret; + } + + Expr VisitExpr(const Expr& e) final { + return ExprMutator::VisitExpr(e); + } + + Expr VisitExpr_(const VarNode* op) final { + Var v = GetRef(op); + return rename_.count(v) != 0 ? rename_.at(v) : v; + } + + Expr VisitExpr_(const LetNode* op) final { + Var v = Fresh(op->var); + return LetNode::make(v, VisitExpr(op->value), VisitExpr(op->body)); + } + + Type VisitType(const Type& t) final { + return t.defined() ? TypeMutator::VisitType(t) : t; + } + + Expr VisitExpr_(const FunctionNode* op) final { + tvm::Array type_params; + for (const TypeVar& type_param : op->type_params) { + type_params.push_back(Fresh(type_param)); + } + tvm::Array params; + for (const Var& param : op->params) { + params.push_back(Fresh(param)); + } + return FunctionNode::make(params, + VisitExpr(op->body), + VisitType(op->ret_type), + type_params, + op->attrs); + } + + Pattern VisitPattern(const Pattern& p) final { + return PatternMutator::VisitPattern(p); + } + + Pattern VisitPattern_(const PatternVarNode* op) final { + return PatternVarNode::make(Fresh(op->var)); + } + + Clause VisitClause(const Clause& c) final { + Pattern pat = VisitPattern(c->lhs); + return ClauseNode::make(pat, VisitExpr(c->rhs)); + } + + Type VisitType_(const TypeVarNode* op) final { + TypeVar v = GetRef(op); + return type_rename_.count(v) != 0 ? type_rename_.at(v) : v; + } + + Var VisitVar(const Var& v) final { + return Fresh(v); + } + + private: + std::unordered_map rename_; + std::unordered_map type_rename_; + }; + + Expr ret = DeDupMutator().VisitExpr(e); + CHECK_EQ(FreeVars(ret).size(), FreeVars(e).size()); + return ret; +} + +TVM_REGISTER_API("relay._ir_pass.dedup") +.set_body_typed(FreeVars); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/dependency_graph.h b/src/relay/pass/dependency_graph.h index 7f53918ebcb7f..5e2b08c352f09 100644 --- a/src/relay/pass/dependency_graph.h +++ b/src/relay/pass/dependency_graph.h @@ -20,7 +20,7 @@ /*! * Copyright (c) 2019 by Contributors. * \file tvm/relay/pass/dependency_graph.h - * \brief + * \brief create a dependency graph. */ #ifndef TVM_RELAY_PASS_DEPENDENCY_GRAPH_H_ #define TVM_RELAY_PASS_DEPENDENCY_GRAPH_H_ diff --git a/src/relay/pass/let_list.h b/src/relay/pass/let_list.h index 9f56b22fc13e9..1b422d2a878f0 100644 --- a/src/relay/pass/let_list.h +++ b/src/relay/pass/let_list.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,7 +18,7 @@ */ /*! - * Copyright (c) 2018 by Contributors + * Copyright (c) 2019 by Contributors * \file let_list.h * \brief LetList record let binding and insert let expression implicitly. * using it, one can treat AST as value instead of expression, @@ -46,6 +46,11 @@ namespace relay { */ class LetList { public: + ~LetList() { + if (lets_.size() > 0 && !used_) { + std::cout << "Warning: letlist not used" << std::endl; + } + } /*! * \brief insert a binding. * @@ -64,13 +69,13 @@ class LetList { /*! * \brief insert a binding. * - * \param ty the type of the binding. - * * \param expr the value of the binding. * + * \param ty the type of the binding. + * * \return a Var that hold the inserted expr. */ - Var Push(Type ty, Expr expr) { + Var Push(Expr expr, Type ty) { return Push(VarNode::make("x", ty), expr); } @@ -82,7 +87,7 @@ class LetList { * \return a Var that hold the inserted expr. */ Var Push(Expr expr) { - return Push(Type(), expr); + return Push(expr, Type()); } /*! @@ -129,6 +134,12 @@ class LetList { return ll.Get(f(&ll)); } + static Expr Let(const Expr& e, const std::function& f) { + return With([&](LetList* ll) { + return f(ll->Push(e)); + }); + } + private: std::vector > lets_; bool used_ = false; diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index e7edbb3153d85..fc356f473b94d 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -18,7 +18,7 @@ */ /*! - * Copyright (c) 2018 by Contributors + * Copyright (c) 2019 by Contributors * * \file partial_eval.cc * @@ -425,8 +425,6 @@ TVM_ADD_FILELINE) Expr StripWithFuncId(const Expr& e); -Expr DeDup(const Expr& e); - Function AsFunc(const Expr& e) { if (e.as()) { return Downcast(e); @@ -957,86 +955,6 @@ class PartialEvaluator : public ExprFunctor FInterpreter executor_ = CPUInterpreter(); }; -/*! \brief Use a fresh Id for every Var to make the result well-formed. */ -Expr DeDup(const Expr& e) { - class DeDupMutator : public TypeMutator, - public ExprMutator, - public PatternMutator { - public: - TypeVar Fresh(const TypeVar& tv) { - TypeVar ret = TypeVarNode::make(tv->var->name_hint, tv->kind); - type_rename_[tv] = ret; - return ret; - } - - Var Fresh(const Var& v) { - Var ret = VarNode::make(v->name_hint(), VisitType(v->type_annotation)); - rename_[v] = ret; - return ret; - } - - Expr VisitExpr(const Expr& e) final { - return ExprMutator::VisitExpr(e); - } - - Expr VisitExpr_(const VarNode* op) final { - Var v = GetRef(op); - return rename_.count(v) != 0 ? rename_.at(v) : v; - } - - Expr VisitExpr_(const LetNode* op) final { - Var v = Fresh(op->var); - return LetNode::make(v, VisitExpr(op->value), VisitExpr(op->body)); - } - - Type VisitType(const Type& t) final { - return t.defined() ? TypeMutator::VisitType(t) : t; - } - - Expr VisitExpr_(const FunctionNode* op) final { - tvm::Array type_params; - for (const TypeVar& type_param : op->type_params) { - type_params.push_back(Fresh(type_param)); - } - tvm::Array params; - for (const Var& param : op->params) { - params.push_back(Fresh(param)); - } - return FunctionNode::make(params, - VisitExpr(op->body), - VisitType(op->ret_type), - type_params, - op->attrs); - } - - Pattern VisitPattern(const Pattern& p) final { - return PatternMutator::VisitPattern(p); - } - - Clause VisitClause(const Clause& c) final { - Pattern pat = VisitPattern(c->lhs); - return ClauseNode::make(pat, VisitExpr(c->rhs)); - } - - Type VisitType_(const TypeVarNode* op) final { - TypeVar v = GetRef(op); - return type_rename_.count(v) != 0 ? type_rename_.at(v) : v; - } - - Var VisitVar(const Var& v) final { - return Fresh(v); - } - - private: - std::unordered_map rename_; - std::unordered_map type_rename_; - }; - - Expr ret = DeDupMutator().VisitExpr(e); - CHECK_EQ(FreeVars(ret).size(), FreeVars(e).size()); - return ret; -} - /*! \brief Remap multiple Var sharing the same Id into the same Var. */ Expr Remap(const Expr& e) { class RemapMutator : public ExprMutator, public PatternMutator { diff --git a/src/relay/pass/to_a_normal_form.cc b/src/relay/pass/to_a_normal_form.cc index b5a3f8552d8da..57395020b4ad2 100644 --- a/src/relay/pass/to_a_normal_form.cc +++ b/src/relay/pass/to_a_normal_form.cc @@ -18,9 +18,9 @@ */ /*! - * Copyright (c) 2018 by Contributors + * Copyright (c) 2019 by Contributors * - * \file to_anf.cc + * \file to_a_normal_form.cc * * \brief Turn implicit sharing into observable sharing. */ @@ -72,13 +72,16 @@ Scope LCA(Scope lhs, Scope rhs) { std::unordered_map CalcScope(const DependencyGraph& dg) { std::unordered_map expr_scope; + bool global_scope_used = false; Scope global_scope = std::make_shared(); for (auto it = dg.post_dfs_order.rbegin(); it != dg.post_dfs_order.rend(); ++it) { DependencyGraph::Node* n = *it; auto iit = n->parents.head; Scope s; if (iit == nullptr) { + CHECK(!global_scope_used); s = global_scope; + global_scope_used = true; } else { s = expr_scope.at(iit->value); iit = iit->next; @@ -88,13 +91,10 @@ std::unordered_map CalcScope(const DependencyGrap } expr_scope.insert({n, n->new_scope ? ChildScope(s) : s}); } + CHECK(global_scope_used); return expr_scope; } -bool IsPrimitiveFunction(const Expr& e) { - return e.as() && Downcast(e)->IsPrimitive(); -} - /* Special care is needed to handle local recursion. * Fill additionally take a (possibly null) Var argument, * If it is not null, Fill is required to bind the transformed result to that var. @@ -137,22 +137,26 @@ class Fill : ExprFunctor { Expr VisitExpr(const Expr& e, const Var& v) final { if (memo.count(e) == 0) { memo.insert({e, ExprFunctor::VisitExpr(e, v)}); + } else if (v.defined()) { + GetScope(e)->ll->Push(v, memo.at(e)); } - return memo.at(e); + auto ret = memo.at(e); + CHECK(IsAtomic(ret)); + return ret; } Expr VisitExpr(const Expr& e) { return this->VisitExpr(e, Var()); } - Expr Atomic(const Expr& orig, const Expr& now, const Var& v) { - return v.defined() ? GetScope(orig)->ll->Push(v, now) : now; + Expr Atomic(const Expr& e, const Var& v) { + return v.defined() ? GetScope(e)->ll->Push(v, e) : e; } Expr Compound(const Expr& orig, const Expr& now, const Var& v) { Var var = v.defined() ? v : - VarNode::make(std::string("x"), IncompleteTypeNode::make(Kind::kType)); + VarNode::make(std::string("x"), Type()); return GetScope(orig)->ll->Push(var, now); } @@ -205,7 +209,7 @@ class Fill : ExprFunctor { Expr VisitExpr_(const FunctionNode* f, const Var& v) final { Expr e = GetRef(f); Expr ret; - if (IsPrimitiveFunction(e)) { + if (f->IsPrimitive()) { ret = e; } else { ret = FunctionNode::make(f->params, @@ -231,22 +235,22 @@ class Fill : ExprFunctor { Expr VisitExpr_(const VarNode* vn, const Var& v) final { Expr e = GetRef(vn); - return Atomic(e, e, v); + return Atomic(e, v); } Expr VisitExpr_(const GlobalVarNode* gvn, const Var& v) final { GlobalVar gv = GetRef(gvn); - return Atomic(gv, gv, v); + return Atomic(gv, v); } Expr VisitExpr_(const OpNode* op, const Var& v) final { Expr e = GetRef(op); - return Atomic(e, e, v); + return Atomic(e, v); } Expr VisitExpr_(const ConstructorNode* c, const Var& v) final { Expr e = GetRef(c); - return Atomic(e, e, v); + return Atomic(e, v); } Expr VisitExpr_(const MatchNode* m, const Var& v) final { @@ -294,11 +298,12 @@ Module ToANormalForm(const Module& m) { tvm::Map updates; auto funcs = m->functions; for (const auto& it : funcs) { + CHECK_EQ(FreeVars(it.second).size(), 0); Expr ret = TransformF([&](const Expr& e) { return ToANormalFormAux(e); }, it.second); - CHECK_EQ(FreeVars(ret).size(), 0); + CHECK_EQ(FreeVars(ret).size(), 0) << AsText(ret) << "should not has free vars: " << FreeVars(ret); updates.Set(it.first, Downcast(ret)); } diff --git a/src/relay/pass/to_cps.cc b/src/relay/pass/to_cps.cc new file mode 100644 index 0000000000000..bd2e306578c54 --- /dev/null +++ b/src/relay/pass/to_cps.cc @@ -0,0 +1,397 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * + * \file to_cps.cc + * + * \brief Turn a program to continuation passing style. + * + * Given a fresh type variable 'answer', + * continuation passing style(CPS) convert every function of a -> b to a -> (b -> anwer) -> answer. + * + * That is, instead of returning the result directly, + * function will now call another function (called the continuation) + * and return that value as a result instead. + * + * Continuation passing style turn all function call into tail call, + * which bound the stack size, prevent stack from overflowing during recursion, + * and allow tail call optimization. + * + * In relay, as tensor operation is the bottleneck, + * CPS is currently intended to transform the program before partial eval (PE), + * as it reify the control flow and enable PE to handle control flow join more agressively. + * + * For example, given 'let a = if b then c else d in e', it will transform the code into + * 'let f a = e in if b then f c else f d'. + * This allow f to be optimized individually in both branch. + * + * We implement CPS conversion by higher order transform + * (see http://matt.might.net/articles/cps-conversion/). + * The basic idea is that we will recursively traverse the AST. + * During the traversal, there is an extra parameter, mcont, of expr -> expr. + * It is basically a continuation at the metalevel. + * All cases in the transform must return via the mcont, + * wheter directly invoking it, or indirectly by recursion. + */ +#include +#include +#include +#include "../ir/type_functor.h" +#include "let_list.h" +#include "pass_util.h" + +namespace tvm { +namespace relay { + +// we assume the data type has no closure - no idea how to look into datatype right now. + +Type Arrow(const Type& l, const Type& r) { + return FuncTypeNode::make({l}, r, {}, {}); +} + +Type CPSType(const Type& t, const TypeVar& answer); + +FuncType CPSFuncType(const FuncType& f, const TypeVar& answer) { + tvm::Array new_arg_types; + for (const Type& t : f->arg_types) { + new_arg_types.push_back(CPSType(t, answer)); + } + new_arg_types.push_back(Arrow(CPSType(f->ret_type, answer), answer)); + return FuncTypeNode::make(new_arg_types, answer, f->type_params, f->type_constraints); +} + +Type CPSType(const Type& t, const TypeVar& answer) { + struct CPSTypeMutator : TypeMutator { + explicit CPSTypeMutator(const TypeVar& answer) : answer(answer) { } + TypeVar answer; + Type VisitType_(const FuncTypeNode* t) final { + return CPSFuncType(GetRef(t), answer); + } + } mut(answer); + return mut(t); +} + +// transform global functions into cps form. +using CPSMap = std::unordered_map; + +// transform vars from the original program into new vars, so their type will be correct. +using VarMap = std::unordered_map; + +/* + * The meta continuation. + * There is 3 rules on the metacontinuation: + * 0: It can only use the argument once. + * The argument is code, and using it twice will duplicate code. + * Bound the argument via let instead. + * 1: If the size of the metacontinuation is unbounded, it can only be called once. + * It contain code, so calling it twice duplicate code. + * Reify the continuation and bound it instead. + * See the function 'reify' and the if case for more detail. + * 2: The argument must be effect free. + * It might reorder or drop the argument. + * Again, bound the argument via let instead. + * See the call case for more detail. + */ +using MCont = std::function; + +Function ToCPS(const Function& f, const Module& m, CPSMap* cm); + +Function ToCPS(const Function& f, const Module& m, CPSMap* cm, VarMap* vm, const TypeVar& answer) { + std::function remap = [&](const Var& v) { return vm->count(v) == 0 ? v : vm->at(v); }; + auto function_type = Downcast(f->checked_type()); + // Each MCont can be used at most once. + struct CPSFunctor : ExprFunctor, PatternMutator { + CPSFunctor(const std::function& remap, + const TypeVar& answer, + const Module& m, + VarMap* vm, + CPSMap* cm) : remap(remap), answer(answer), m(m), vm(vm), cm(cm) { } + const std::function& remap; + TypeVar answer; + Module m; + VarMap* vm; + CPSMap* cm; + + Expr VisitExpr_(const LetNode* op, const MCont& k) final { + return VisitExpr(op->value, [&](const Expr& v) { + return LetNode::make(remap(op->var), v, VisitExpr(op->body, k)); + }); + } + + Expr VisitExpr_(const FunctionNode* op, const MCont& k) final { + CHECK(!op->IsPrimitive()) << "primitive func not supported yet."; + return k(ToCPS(GetRef(op), m, cm, vm, answer)); + } + + Expr VisitExpr_(const ConstantNode* op, const MCont& k) final { + return k(GetRef(op)); + } + + Expr VisitExpr_(const VarNode* op, const MCont& k) final { + return k(remap(GetRef(op))); + } + + Pattern VisitPattern_(const PatternVarNode* op) final { + return PatternVarNode::make(remap(op->var)); + } + + Expr VisitExpr_(const GlobalVarNode* op, const MCont& k) final { + auto gv = GetRef(op); + if (cm->count(gv) == 0) { + auto cps_gv = GlobalVarNode::make(gv->name_hint + "_cps"); + cm->insert({gv, cps_gv}); + m->Add(cps_gv, ToCPS(m->Lookup(gv), m, cm)); + } + return k(cm->at(gv)); + } + + Expr VisitExpr_(const RefCreateNode* op, const MCont& k) final { + return VisitExpr(op->value, [&](const Expr& v) { return k(RefCreateNode::make(v)); }); + } + + Expr reify(const MCont& k) { + Var arg = VarNode::make("arg", Type()); + return FunctionNode::make({arg}, k(arg), Type(), {}, {}); + } + + Expr reify(const MCont& k, const std::function& cont) { + return LetList::Let(reify(k), + [&](const Var& f) { + return cont([&](const Expr& e) { return CallNode::make(f, {e}); }); + }); + } + + Expr VisitExpr_(const IfNode* op, const MCont& k) final { + return reify(k, [&](const MCont& kf) { + return VisitExpr(op->cond, + [&](const Expr& v) { + return IfNode::make(v, VisitExpr(op->true_branch, kf), VisitExpr(op->false_branch, kf)); + }); + }); + } + + Expr VisitExpr_(const MatchNode* op, const MCont& k) final { + return reify(k, [&](const MCont& kf) { + return VisitExpr(op->data, [&](const Expr& v) { + tvm::Array clauses; + for (const auto& c : op->clauses) { + clauses.push_back(ClauseNode::make(VisitPattern(c->lhs), VisitExpr(c->rhs, kf))); + } + return MatchNode::make(v, clauses); + }); + }); + } + + Expr VisitExpr_(const RefReadNode* op, const MCont& k) final { + return VisitExpr(op->ref, + [&](const Expr& r) { + return LetList::Let(RefReadNode::make(r), k); + }); + } + + Expr VisitExpr_(const RefWriteNode* op, const MCont& k) final { + return VisitExpr(op->ref, + [&](const Expr& r) { + return VisitExpr(op->value, + [&](const Expr& v) { + return LetList::Let(RefWriteNode::make(r, v), k); + }); + }); + } + + Expr VisitExpr_(const TupleNode* op, const MCont& k) final { + tvm::Array fields; + std::function next; + next = [&]() { + return (fields.size() == op->fields.size()) ? + k(TupleNode::make(fields)) : + VisitExpr(op->fields[fields.size()], [&](const Expr& v) { + fields.push_back(v); + return next(); + }); + }; + return next(); + } + + Expr VisitExpr_(const TupleGetItemNode* op, const MCont& k) final { + return VisitExpr(op->tuple, [&](const Expr& v) { + return k(TupleGetItemNode::make(v, op->index)); + }); + } + + Expr VisitExpr_(const CallNode* op, const MCont& k) final { + if (op->op.as() || op->op.as()) { + tvm::Array args; + std::function next; + next = [&]() { + if (args.size() == op->args.size()) { + return LetList::Let(CallNode::make(op->op, args, op->attrs, op->type_args), k); + } else { + return VisitExpr(op->args[args.size()], [&](const Expr& v) { + args.push_back(v); + return next(); + }); + } + }; + return next(); + } else { + Expr f; + tvm::Array args; + std::function next; + next = [&]() { + if (args.size() == op->args.size()) { + args.push_back(reify(k)); + return Expr(CallNode::make(f, args)); + } else { + return VisitExpr(op->args[args.size()], [&](const Expr& v) { + args.push_back(v); + return next(); + }); + } + }; + return VisitExpr(op->op, [&](const Expr& v) { + f = v; + return next(); + }); + } + } + } mut(remap, answer, m, vm, cm); + Var k = VarNode::make("k", Arrow(CPSType(function_type->ret_type, answer), answer)); + tvm::Array new_params; + for (const Var& v : f->params) { + new_params.push_back(remap(v)); + } + new_params.push_back(k); + return FunctionNode::make(new_params, + mut.VisitExpr(f->body, + [&](const Expr& e) { return CallNode::make(k, {e}); }), + answer, + f->type_params, + f->attrs); +} + +Function ToCPS(const Function& f, const Module& m, CPSMap* cm) { + TypeVar answer = TypeVarNode::make("answer", kType); + VarMap var; + struct Remapper : ExprVisitor, PatternVisitor { + Remapper(const TypeVar& answer, VarMap* vm) : answer(answer), vm(vm) { } + TypeVar answer; + VarMap* vm; + void VisitExpr_(const VarNode* vn) final { + Var v = GetRef(vn); + if (vm->count(v) == 0) { + auto ret = VarNode::make(v->name_hint(), CPSType(v->checked_type(), answer)); + vm->insert({v, ret}); + } + } + + void VisitPattern(const Pattern& p) final { + PatternVisitor::VisitPattern(p); + } + + void VisitPattern_(const PatternVarNode* op) final { + VisitExpr(op->var); + } + } remap(answer, &var); + remap.VisitExpr(f); + Function ret = ToCPS(f, m, cm, &var, answer); + auto new_type_params = ret->type_params; + new_type_params.push_back(answer); + return FunctionNode::make(ret->params, ret->body, ret->ret_type, new_type_params, ret->attrs); +} + +Function ToCPS(const Function& f, const Module& m) { + CPSMap cps; + return ToCPS(f, m, &cps); +} + +Function UnCPS(const Function& f) { + CHECK_GT(f->params.size(), 0); + std::vector new_params; + for (const auto& p : f->params) { + new_params.push_back(VarNode::make(p->name_hint(), p->checked_type())); + } + auto cont_type = Downcast(new_params.back()->type_annotation); + new_params.pop_back(); + CHECK_EQ(cont_type->arg_types.size(), 1); + auto new_ret_type = Type(cont_type->arg_types[0]); + std::vector new_type_params; + for (const auto& tp : f->type_params) { + new_type_params.push_back(TypeVarNode::make(tp->var->name_hint, tp->kind)); + } + auto answer_type = new_type_params.back(); + new_type_params.pop_back(); + // TODO(@M.K.): make alphaequal work on free term + // CHECK(AlphaEqual(cont_type, Arrow(new_ret_type, answer_type))); + auto x = VarNode::make("x", new_ret_type); + auto cont = FunctionNode::make({x}, x, new_ret_type, {}, {}); + tvm::Array args; + for (const auto& p : new_params) { + args.push_back(p); + } + args.push_back(cont); + tvm::Array type_args; + for (const auto& tp : new_type_params) { + type_args.push_back(tp); + } + type_args.push_back(new_ret_type); + return FunctionNode::make(new_params, + CallNode::make(f, args, {}, type_args), + new_ret_type, + new_type_params, + f->attrs); +} + +TVM_REGISTER_API("relay._ir_pass.to_cps") +.set_body_typed(static_cast(ToCPS)); + +TVM_REGISTER_API("relay._ir_pass.un_cps") +.set_body_typed(UnCPS); + +namespace transform { + +Pass ToCPS() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Function(ToCPS(f, m)); + }; + return CreateFunctionPass(pass_func, 1, "ToCPS", {}); +} + +TVM_REGISTER_API("relay._transform.ToCPS") +.set_body_typed(ToCPS); + + +Pass UnCPS() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Function(UnCPS(f)); + }; + return CreateFunctionPass(pass_func, 1, "UnCPS", {}); +} + +TVM_REGISTER_API("relay._transform.UnCPS") +.set_body_typed(ToCPS); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_pass_to_a_normal_form.py b/tests/python/relay/test_pass_to_a_normal_form.py index e74168141e63c..769c41e035f15 100644 --- a/tests/python/relay/test_pass_to_a_normal_form.py +++ b/tests/python/relay/test_pass_to_a_normal_form.py @@ -22,6 +22,7 @@ from tvm.relay.prelude import Prelude from tvm.relay.testing import add_nat_definitions, count from tvm.relay.feature import Feature +import tvm.relay.testing def check_eval(expr, expected_result, mod=None, rtol=1e-07): @@ -180,6 +181,20 @@ def test_function(): check_eval(anf_f(d), 8) +def test_gradient_if(): + x = relay.var("a", shape=(1, 16)) + y = relay.var("y", shape=(1, 16)) + cond = relay.var("cond", shape=(), dtype='uint1') + net = relay.If(cond, x, x) + net = relay.add(x, net) + net = relay.Function([cond,x,y], net) + net = relay.ir_pass.infer_type(net) + mod = relay.Module.from_expr(net) + mod = relay.transform.ToANormalForm()(mod) + mod[mod.entry_func] = relay.ir_pass.gradient(mod[mod.entry_func], mode='higher_order') + mod = relay.transform.ToANormalForm()(mod) + + if __name__ == '__main__': test_explicit_bound() test_order() @@ -189,3 +204,4 @@ def test_function(): test_let() test_nat_add() test_function() + test_gradient_if() diff --git a/tests/python/relay/test_pass_to_cps.py b/tests/python/relay/test_pass_to_cps.py new file mode 100644 index 0000000000000..384fb7b168803 --- /dev/null +++ b/tests/python/relay/test_pass_to_cps.py @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import tvm +from tvm import relay +from tvm.relay.ir_pass import alpha_equal, infer_type, detect_feature +from tvm.relay.ir_pass import to_cps, un_cps +from tvm.relay.feature import Feature +from tvm.relay.prelude import Prelude +from tvm.relay.testing import add_nat_definitions, make_nat_expr +from tvm.relay import create_executor +from tvm.relay import Function, transform + + +def rand(dtype='float32', *shape): + return tvm.nd.array(np.random.rand(*shape).astype(dtype)) + + +# make sure cps work for recursion. +def test_recursion(): + mod = relay.Module() + p = Prelude(mod) + add_nat_definitions(p) + shape = (10, 10) + dtype = 'float32' + t = relay.TensorType(shape, dtype) + x = relay.var("x", t) + double = relay.Function([x], x + x) + i = relay.var("i", t) + func = relay.Function([i], p.nat_iterate(double, make_nat_expr(p, 3))(i)) + func = infer_type(func, mod=mod) + cps_func = infer_type(un_cps(infer_type(to_cps(func, mod=mod), mod=mod)), mod=mod) + print(mod) + print(cps_func) + ex = create_executor(mod=mod) + i_nd = rand(dtype, *shape) + forward = ex.evaluate(cps_func)(i_nd) + tvm.testing.assert_allclose(forward.asnumpy(), 8 * i_nd.asnumpy()) + + +# This serve as an integration test. +# It test that, given a program with reference, +# cps and pe can completely eliminate the allocation of reference. +def test_cps_pe(): + def destroy_ref(x): + x = infer_type(x) + x = to_cps(x) + x = infer_type(x) + y = un_cps(x) + y = infer_type(y) + x = transform.OptimizeOnExpr(x, [transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)]) + assert Feature.fRefCreate not in detect_feature(x) + unit = relay.Function([], relay.const(0., dtype='float32')) + f_ref = relay.Var("f_ref") + + one = relay.const(1., dtype='float32') + two = relay.const(2., dtype='float32') + cond = relay.var(shape=(), dtype='uint1', name_hint='cond') + true_branch = relay.RefWrite(f_ref, relay.Function([], one)) + false_branch = relay.RefWrite(f_ref, relay.Function([], two)) + if_expr = relay.If(cond, true_branch, false_branch) + + stmt = relay.Let(f_ref, relay.RefCreate(unit), + relay.Let(relay.Var("x"), if_expr, + relay.Call(relay.RefRead(f_ref), []))) + + F = relay.Function([cond], stmt) + destroy_ref(F) + + G = relay.Function([cond], relay.If(cond, one, two)) + G = relay.ir_pass.gradient(G) + destroy_ref(G) + + x = relay.var("x", shape=(1, 16)) + y = relay.var("y", shape=(1, 16)) + z = relay.var("z", shape=(1, 16)) + cond = relay.var("cond", shape=(), dtype='uint1') + H = relay.If(cond, x, y) + H = relay.add(H, z) + H = relay.Function([cond,x,y,z], H) + H = relay.ir_pass.gradient(H) + destroy_ref(H) + + +if __name__ == '__main__': + test_recursion() + test_cps_pe()