From 3a929be76119354f861a7a0c6c456508e5991f7e Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Thu, 13 Feb 2020 23:01:10 -0800 Subject: [PATCH 01/30] save --- python/tvm/relay/std/gradient.rly | 36 ++++++++++++++++++ src/relay/transforms/gradient.cc | 8 ++++ tests/python/contrib/test_nnpack.py | 2 +- .../integration/test_winograd_nnpack.py | 4 +- tests/python/relay/test_ir_parser.py | 38 ++----------------- tests/python/relay/test_op_grad_level3.py | 2 +- tests/python/relay/test_op_grad_level4.py | 2 +- tests/python/relay/test_pass_lambda_lift.py | 3 +- tests/python/relay/test_pass_manager.py | 2 +- .../test_pass_remove_unused_functions.py | 2 +- 10 files changed, 56 insertions(+), 43 deletions(-) create mode 100644 python/tvm/relay/std/gradient.rly diff --git a/python/tvm/relay/std/gradient.rly b/python/tvm/relay/std/gradient.rly new file mode 100644 index 000000000000..c4a8a87057e9 --- /dev/null +++ b/python/tvm/relay/std/gradient.rly @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +v0.0.4 + +/* + * Store the Gradient Value of a Tensor of type T. + * Note that Gradient of T is stored inside a Ref(GradCell[T]) instead of GradCell[T]. + */ +type GradCell[T] { + Raw(T), + One(fn() -> T), + Zero(fn() -> T) +} + +def @AddGradCell[T](%add: fn(T, T) -> T, %l: GradCell[T], %r: GradCell[T]) -> GradCell[T] { + match ((%l, %r)) { + (Zero(_), _) => %r, + (_, Zero(_)) => %l + } +} diff --git a/src/relay/transforms/gradient.cc b/src/relay/transforms/gradient.cc index a3728e905922..aa6157f1acdb 100644 --- a/src/relay/transforms/gradient.cc +++ b/src/relay/transforms/gradient.cc @@ -92,6 +92,14 @@ Expr DeGlobal(const IRModule& mod, const Expr& e) { } } +std::string GradName(const Expr& e) { + if (const auto* x = e.as()) { + return x->name_hint + "_grad"; + } else { + return "temp_grad"; + } +} + /*! \brief A fragment of the program being built by the automatic differentation * pass. */ diff --git a/tests/python/contrib/test_nnpack.py b/tests/python/contrib/test_nnpack.py index 505199a55724..8c2197b94757 100644 --- a/tests/python/contrib/test_nnpack.py +++ b/tests/python/contrib/test_nnpack.py @@ -203,4 +203,4 @@ def verify(target="llvm", if __name__ == "__main__": - pytest.main() + pytest.main([__file__]) diff --git a/tests/python/integration/test_winograd_nnpack.py b/tests/python/integration/test_winograd_nnpack.py index 7dad2ca586d7..536ca5d042ea 100644 --- a/tests/python/integration/test_winograd_nnpack.py +++ b/tests/python/integration/test_winograd_nnpack.py @@ -25,6 +25,7 @@ import topi.testing from topi.util import get_const_tuple from pytest import skip +import pytest def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False, @@ -140,5 +141,4 @@ def test_conv2d_nchw(): if __name__ == "__main__": - import pytest - pytest.main() + pytest.main([__file__]) diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index ba1f8d884adc..7ba0a6de2780 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -18,7 +18,6 @@ from tvm import te from tvm import relay from tvm.relay.analysis import graph_equal, assert_graph_equal -from tvm.relay.analysis import alpha_equal, assert_alpha_equal import pytest from numpy import isclose from typing import Union @@ -868,38 +867,9 @@ def test_extern_adt_defn(): mod ) +def test_import_grad(): + mod = tvm.IRModule() + mod.import_from_std("gradient.rly") if __name__ == "__main__": - test_comments() - test_int_literal() - test_float_literal() - test_bool_literal() - test_negative() - test_bin_op() - test_parens() - test_op_assoc() - test_let() - test_seq() - test_graph() - test_tuple() - test_func() - test_defn() - test_recursive_call() - test_ifelse() - test_call() - test_incomplete_type() - test_builtin_types() - test_tensor_type() - test_function_type() - test_tuple_type() - test_adt_defn() - test_empty_adt_defn() - test_multiple_cons_defn() - test_multiple_type_param_defn() - test_match() - test_adt_cons_expr() - test_duplicate_adt_defn() - test_duplicate_adt_cons() - test_duplicate_adt_cons_defn() - test_duplicate_global_var() - test_extern_adt_defn() + pytest.main([__file__]) diff --git a/tests/python/relay/test_op_grad_level3.py b/tests/python/relay/test_op_grad_level3.py index d13687fbec72..cca730311751 100644 --- a/tests/python/relay/test_op_grad_level3.py +++ b/tests/python/relay/test_op_grad_level3.py @@ -65,4 +65,4 @@ def test_cast_grad(): check_grad(fwd_func) if __name__ == "__main__": - pytest.main() + pytest.main([__file__]) diff --git a/tests/python/relay/test_op_grad_level4.py b/tests/python/relay/test_op_grad_level4.py index f690a186ea41..7ec2c8609a97 100644 --- a/tests/python/relay/test_op_grad_level4.py +++ b/tests/python/relay/test_op_grad_level4.py @@ -46,4 +46,4 @@ def test_max_grad(): if __name__ == "__main__": - pytest.main() + pytest.main([__file__]) diff --git a/tests/python/relay/test_pass_lambda_lift.py b/tests/python/relay/test_pass_lambda_lift.py index e38887829551..ce7b597d07f6 100644 --- a/tests/python/relay/test_pass_lambda_lift.py +++ b/tests/python/relay/test_pass_lambda_lift.py @@ -75,5 +75,4 @@ def test_recursive(): if __name__ == "__main__": - pytest.main() - + pytest.main([__file__]) diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index aed026996a21..f39dfdc4dcb6 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -555,4 +555,4 @@ def test_print_debug_callback(): if __name__ == "__main__": - pytest.main() + pytest.main([__file__]) diff --git a/tests/python/relay/test_pass_remove_unused_functions.py b/tests/python/relay/test_pass_remove_unused_functions.py index 33816344f562..5774b93d0c5e 100644 --- a/tests/python/relay/test_pass_remove_unused_functions.py +++ b/tests/python/relay/test_pass_remove_unused_functions.py @@ -114,4 +114,4 @@ def get_mod(): if __name__ == '__main__': - pytest.main() + pytest.main([__file__]) From f68955e4d462f75a9f4dcc317cec74340caca144 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Thu, 27 Feb 2020 16:31:14 -0800 Subject: [PATCH 02/30] gradient.rly --- python/tvm/relay/std/gradient.rly | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/std/gradient.rly b/python/tvm/relay/std/gradient.rly index c4a8a87057e9..72fcc5c2fcdd 100644 --- a/python/tvm/relay/std/gradient.rly +++ b/python/tvm/relay/std/gradient.rly @@ -28,9 +28,18 @@ type GradCell[T] { Zero(fn() -> T) } +def @FromGradCell[T](%g: GradCell[T]) { + match (%g) { + Raw(%x) => %x, + One(%x) => %x(), + Zero(%x) => %x() + } +} + def @AddGradCell[T](%add: fn(T, T) -> T, %l: GradCell[T], %r: GradCell[T]) -> GradCell[T] { match ((%l, %r)) { (Zero(_), _) => %r, - (_, Zero(_)) => %l + (_, Zero(_)) => %l, + _ => %add(@FromGradCell(%l), @FromGradCell(%r)) } } From dfb00eed2800ab682810e01de59d06d63aee3b95 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Thu, 27 Feb 2020 16:46:36 -0800 Subject: [PATCH 03/30] fix --- python/tvm/relay/std/gradient.rly | 4 ++-- src/ir/error.cc | 10 +++++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/std/gradient.rly b/python/tvm/relay/std/gradient.rly index 72fcc5c2fcdd..4cb09bc0e2f8 100644 --- a/python/tvm/relay/std/gradient.rly +++ b/python/tvm/relay/std/gradient.rly @@ -28,7 +28,7 @@ type GradCell[T] { Zero(fn() -> T) } -def @FromGradCell[T](%g: GradCell[T]) { +def @FromGradCell[T](%g: GradCell[T]) -> T { match (%g) { Raw(%x) => %x, One(%x) => %x(), @@ -40,6 +40,6 @@ def @AddGradCell[T](%add: fn(T, T) -> T, %l: GradCell[T], %r: GradCell[T]) -> Gr match ((%l, %r)) { (Zero(_), _) => %r, (_, Zero(_)) => %l, - _ => %add(@FromGradCell(%l), @FromGradCell(%r)) + _ => Raw(%add(@FromGradCell(%l), @FromGradCell(%r))) } } diff --git a/src/ir/error.cc b/src/ir/error.cc index 9d498288d2ba..96c953e2d767 100644 --- a/src/ir/error.cc +++ b/src/ir/error.cc @@ -115,7 +115,9 @@ void ErrorReporter::RenderErrors(const IRModule& module, bool use_color) { auto it = err_map.find(expr); if (it != err_map.end()) { CHECK_NE(it->second.size(), 0); - return it->second; + std::string ret = it->second; + err_map.erase(it); + return ret; } else { return std::string(""); } @@ -128,6 +130,12 @@ void ErrorReporter::RenderErrors(const IRModule& module, bool use_color) { rang::setControlMode(rang::control::Auto); } + for (const auto& err_map : error_maps) { + for (const auto& str : err_map.second) { + annotated_prog << str.second << std::endl; + } + } + // Finally we report the error, currently we do so to LOG(FATAL), // it may be good to instead report it to std::cout. LOG(FATAL) << annotated_prog.str() << std::endl; From 75ca326da3b65a6aac541ee6df02f916848bba79 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Tue, 3 Mar 2020 13:29:39 -0800 Subject: [PATCH 04/30] NOT WORKING: gradient cell pass --- include/tvm/relay/transform.h | 2 + python/tvm/relay/std/gradient.rly | 10 ++ python/tvm/relay/transform/transform.py | 12 +++ src/relay/pass/gradient_cell.cc | 126 ++++++++++++++++++++++++ 4 files changed, 150 insertions(+) create mode 100644 src/relay/pass/gradient_cell.cc diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index d5626c80a6de..23358f0e2b4c 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -77,6 +77,8 @@ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc< */ TVM_DLL Pass DeadCodeElimination(bool inline_once = false); +TVM_DLL Pass GradientCell(); + /*! * \brief Fold constant expressions. * diff --git a/python/tvm/relay/std/gradient.rly b/python/tvm/relay/std/gradient.rly index 4cb09bc0e2f8..ed81e4b2d454 100644 --- a/python/tvm/relay/std/gradient.rly +++ b/python/tvm/relay/std/gradient.rly @@ -36,6 +36,16 @@ def @FromGradCell[T](%g: GradCell[T]) -> T { } } +def @MultiplyGradCell[T](%multiply: fn(T, T) -> T, %l: GradCell[T], %r: GradCell[T]) -> GradCell[T] { + match((%l, %r)) { + (Zero(_), _) => %l, + (_, Zero(_)) => %r, + (One(_), _) => %r, + (_, One(_)) => %l, + _ => Raw(%multiply(@FromGradCell(%l), @FromGradCell(%r))) + } +} + def @AddGradCell[T](%add: fn(T, T) -> T, %l: GradCell[T], %r: GradCell[T]) -> GradCell[T] { match ((%l, %r)) { (Zero(_), _) => %r, diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 43a116e64e5b..456df32531cc 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -219,6 +219,18 @@ def DeadCodeElimination(inline_once=False): """ return _ffi_api.DeadCodeElimination(inline_once) +def GradientCell(): + """Condense tensors with all 0s or 1s + + Parameters + ---------- + + Returns + ------- + ret: tvm.relay.Pass + The registered pass that condenses tensors with all 0s or 1s + """ + return _transform.GradientCell() def FoldConstant(): """Fold the constant expressions in a Relay program. diff --git a/src/relay/pass/gradient_cell.cc b/src/relay/pass/gradient_cell.cc new file mode 100644 index 000000000000..604a56dea009 --- /dev/null +++ b/src/relay/pass/gradient_cell.cc @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file gradient_node.cc + * + * \brief Convert all tensors to a Gradient Cell + * + * The algorithm is implemented by two visitor: + * CalcDep turn an expr into a dependency graph of expr, + * GenLet turn the dependency graph into a let list, taking only the used value. + */ + +#include +#include +#include +#include +#include "let_list.h" + +namespace tvm { +namespace relay { + +class GradientCellTransform: public ExprMutator, public TypeMutator { + public: + explicit GradientCellTransform(IRModule module): + module_(module) + {} + + Expr VisitExpr_(const CallNode* call_node) final { + if (auto* op = (call_node->op).as()) { + if (op->name.compare("add") == 0) { + const BaseFunc addFunc = module_->Lookup("AddGradCell"); + tvm::Array args; + + args.push_back(Op::Get("add")); + for (Expr expr: call_node->args) { + args.push_back(expr); + } + + return CallNode::make(addFunc, args); + } else if (op->name.compare("multiply") == 0) { + const BaseFunc multFunc = module_->Lookup("MultiplyGradCell"); + tvm::Array args; + + args.push_back(Op::Get("multiply")); + for (Expr expr: call_node->args) { + args.push_back(expr); + } + + return CallNode::make(multFunc, args); + } + const BaseFunc fromFunc = module_->Lookup("FromGradCell"); + GlobalTypeVar gradCellType = module_->GetGlobalTypeVar("GradCell"); + tvm::Array args; + // use FromGradCell to convert args to Tensor + for (Expr expr: call_node->args) { + tvm::Array fromGradArgs; + fromGradArgs.push_back(expr); + args.push_back(CallNode::make(fromFunc, fromGradArgs)); + } + + return CallNode::make(call_node->op, args); + } + + return GetRef(call_node); + } + + Type VisitType(const Type& t) final { + std::cout << "visittype called" << std::endl; + return TypeMutator::VisitType(t); + } + + Type VisitType_(const TensorTypeNode* op) { + std::cout << "TypeTensor" << std::endl; + GlobalTypeVar gradCell = module_->GetGlobalTypeVar("GradCell"); + tvm::Array args; + args.push_back(GetRef(op)); + + return TypeCall(gradCell, args); + } + + private: + // Module + IRModule module_; + + // memo which Expr visited + std::unordered_set visited_; +}; + +Expr GradientCell(const Expr& e, IRModule mod) { + return GradientCellTransform(mod).Mutate(e); +} + +namespace transform { +Pass GradientCell() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(GradientCell(f, m)); + }; + return CreateFunctionPass(pass_func, 2, "GradientCell", {}); +} + +TVM_REGISTER_GLOBAL("relay._transform.GradientCell") +.set_body_typed(GradientCell); + +} //namespace transform + +} //namespace relay +} //namespace tvm \ No newline at end of file From 7113720001c13c4d14256eb71805317a89b24581 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Tue, 3 Mar 2020 13:33:43 -0800 Subject: [PATCH 05/30] test gradient pass --- tests/python/relay/test_pass_gradient_cell.py | 71 +++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 tests/python/relay/test_pass_gradient_cell.py diff --git a/tests/python/relay/test_pass_gradient_cell.py b/tests/python/relay/test_pass_gradient_cell.py new file mode 100644 index 000000000000..dbbe5bee7612 --- /dev/null +++ b/tests/python/relay/test_pass_gradient_cell.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np + +import tvm +from tvm import te +from tvm import relay +from tvm.relay.analysis import free_vars, free_type_vars, assert_alpha_equal +from tvm.relay import create_executor, transform +from tvm.relay.build_module import optimize +from tvm.relay.transform import GradientCell +from tvm.relay.testing import rand, run_infer_type +from tvm.relay.op import add, multiply +from tvm.relay.prelude import Prelude, TensorArrayOps + +def test_zero_tensor(): + mod = tvm.IRModule() + mod.import_from_std("gradient.rly") + + shape = (10, 10) + dtype = 'float32' + + t = relay.TensorType(shape, dtype) + + x = relay.var("x", t) + + y = relay.Function([x], multiply(x,x)) + mod["main"] = y + + mod = transform.GradientCell()(mod) + + + # mod = transform.PrintIR(True)(mod) + + print("---------------------------") + + # gradcell = mod.get_global_type_var("GradCell")(t) + # x_np = np.zeros(shape, dtype) + + # gradcell = mod.get_global_type_var("GradCell") + # y = tvm.relay.TypeCall(gradcell, [t]) + # + # addcell = mod.get_global_var("AddGradCell") + # fromcell = mod.get_global_var("FromGradCell") + + # mod, params = optimize(mod, target="llvm", params={"x": x_nd}) + + #ex = create_executor(mod=mod) + #a = ex.evaluate(addFunc)(x_nd) + + + # mod = transform.InferType()(mod) + print("hi") + +if __name__ == "__main__": + test_zero_tensor() + From 16d63d7224c192b0be3d1d4b41fd9e2b6107b147 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Fri, 6 Mar 2020 00:07:12 -0800 Subject: [PATCH 06/30] fixed basic call ops --- src/ir/type_functor.cc | 1 + src/relay/pass/gradient_cell.cc | 59 +++++++++----- tests/python/relay/test_pass_gradient_cell.py | 80 ++++++++++++++----- 3 files changed, 100 insertions(+), 40 deletions(-) diff --git a/src/ir/type_functor.cc b/src/ir/type_functor.cc index cbd3538b066c..f60583ccc6f9 100644 --- a/src/ir/type_functor.cc +++ b/src/ir/type_functor.cc @@ -151,6 +151,7 @@ Type TypeMutator::VisitType_(const FuncTypeNode* op) { Array new_args = MutateArray(op->arg_types); changed = changed || !new_args.same_as(op->arg_types); + CHECK(new_args.size() == op->arg_types.size()); Type new_ret_type = VisitType(op->ret_type); changed = changed || !new_ret_type.same_as(op->ret_type); diff --git a/src/relay/pass/gradient_cell.cc b/src/relay/pass/gradient_cell.cc index 604a56dea009..c67199717fa8 100644 --- a/src/relay/pass/gradient_cell.cc +++ b/src/relay/pass/gradient_cell.cc @@ -43,52 +43,74 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { module_(module) {} + Expr VisitExpr_(const ConstantNode* op) final { + GlobalTypeVar gradCellType = module_->GetGlobalTypeVar("GradCell"); + Constructor toGradCell = Constructor("Raw", {op->checked_type()}, gradCellType); + + return CallNode::make(toGradCell, {GetRef(op)}); + } + Expr VisitExpr_(const CallNode* call_node) final { if (auto* op = (call_node->op).as()) { - if (op->name.compare("add") == 0) { - const BaseFunc addFunc = module_->Lookup("AddGradCell"); + if (op->name.compare("add") == 0 && call_node->args.size() == 2 && + AlphaEqual(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) { + const auto addFunc = module_->GetGlobalVar("AddGradCell"); tvm::Array args; - args.push_back(Op::Get("add")); + Type paramType = call_node->args[0]->checked_type(); + + tvm::Array params = {VarNode::make("lhs", paramType), VarNode::make("rhs", paramType)}; + Expr callAdd = CallNode::make(Op::Get("add"), {params[0], params[1]}); + + Expr addTensorsFunc = FunctionNode::make(params, callAdd, paramType, Array(), Attrs()); + + args.push_back(addTensorsFunc); for (Expr expr: call_node->args) { - args.push_back(expr); + args.push_back(VisitExpr(expr)); } - return CallNode::make(addFunc, args); - } else if (op->name.compare("multiply") == 0) { - const BaseFunc multFunc = module_->Lookup("MultiplyGradCell"); + } else if (op->name.compare("multiply") == 0 && call_node->args.size() == 2 && + AlphaEqual(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) { + const auto multFunc = module_->GetGlobalVar("MultiplyGradCell"); tvm::Array args; - args.push_back(Op::Get("multiply")); + Type paramType = call_node->args[0]->checked_type(); + + tvm::Array params = {VarNode::make("lhs", paramType), VarNode::make("rhs", paramType)}; + Expr callMultiply = CallNode::make(Op::Get("multiply"), {params[0], params[1]}); + + Expr multTensorsFunc = FunctionNode::make(params, callMultiply, paramType, Array(), Attrs()); + + args.push_back(multTensorsFunc); for (Expr expr: call_node->args) { - args.push_back(expr); + args.push_back(VisitExpr(expr)); } - return CallNode::make(multFunc, args); } - const BaseFunc fromFunc = module_->Lookup("FromGradCell"); + + const auto fromFunc = module_->GetGlobalVar("FromGradCell"); GlobalTypeVar gradCellType = module_->GetGlobalTypeVar("GradCell"); tvm::Array args; // use FromGradCell to convert args to Tensor for (Expr expr: call_node->args) { - tvm::Array fromGradArgs; - fromGradArgs.push_back(expr); - args.push_back(CallNode::make(fromFunc, fromGradArgs)); + args.push_back(CallNode::make(fromFunc, {VisitExpr(expr)}, Attrs(), {expr->checked_type()})); } + + const Expr tensorRes = CallNode::make(call_node->op, args); + + Constructor toGradCell = Constructor("Raw", {call_node->checked_type()}, gradCellType); - return CallNode::make(call_node->op, args); + return CallNode::make(toGradCell, {tensorRes}); } return GetRef(call_node); } Type VisitType(const Type& t) final { - std::cout << "visittype called" << std::endl; return TypeMutator::VisitType(t); } Type VisitType_(const TensorTypeNode* op) { - std::cout << "TypeTensor" << std::endl; GlobalTypeVar gradCell = module_->GetGlobalTypeVar("GradCell"); tvm::Array args; args.push_back(GetRef(op)); @@ -99,9 +121,6 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { private: // Module IRModule module_; - - // memo which Expr visited - std::unordered_set visited_; }; Expr GradientCell(const Expr& e, IRModule mod) { diff --git a/tests/python/relay/test_pass_gradient_cell.py b/tests/python/relay/test_pass_gradient_cell.py index dbbe5bee7612..f455bb0d8f8f 100644 --- a/tests/python/relay/test_pass_gradient_cell.py +++ b/tests/python/relay/test_pass_gradient_cell.py @@ -21,51 +21,91 @@ from tvm import relay from tvm.relay.analysis import free_vars, free_type_vars, assert_alpha_equal from tvm.relay import create_executor, transform -from tvm.relay.build_module import optimize -from tvm.relay.transform import GradientCell from tvm.relay.testing import rand, run_infer_type from tvm.relay.op import add, multiply from tvm.relay.prelude import Prelude, TensorArrayOps +import pytest -def test_zero_tensor(): +def grad_cell_type(mod, shape, dtype): + grad_type = mod.get_global_type_var("GradCell") + type_arg = relay.TensorType(shape, dtype) + return grad_type(type_arg) + +def test_add(): mod = tvm.IRModule() mod.import_from_std("gradient.rly") shape = (10, 10) dtype = 'float32' + t = relay.TensorType(shape, dtype) + + x = relay.var("x", t) + y = relay.Function([x], x+x) + + mod["main"] = y + mod = transform.GradientCell()(mod) + + new_type = grad_cell_type(mod, shape, dtype) + assert mod["main"].checked_type == relay.FuncType([new_type], new_type) +def test_mult(): + mod = tvm.IRModule() + mod.import_from_std("gradient.rly") + + shape = (15, 15) + dtype = 'float32' t = relay.TensorType(shape, dtype) x = relay.var("x", t) + y = relay.Function([x], x * x) - y = relay.Function([x], multiply(x,x)) mod["main"] = y + mod = transform.GradientCell()(mod) + + new_type = grad_cell_type(mod, shape, dtype) + assert mod["main"].checked_type == relay.FuncType([new_type], new_type) +def test_tc(): + mod = tvm.IRModule() + mod.import_from_std("gradient.rly") + + shape = (20, 20) + dtype = 'float32' + t = relay.TensorType(shape, dtype) + + x1 = relay.var("x1", t) + x2 = relay.var("x2", t) + + y = relay.Function([x1, x2], (x1 - x2) * x2) + + mod["main"] = y mod = transform.GradientCell()(mod) + new_type = grad_cell_type(mod, shape, dtype) + assert mod["main"].checked_type == relay.FuncType([new_type, new_type], new_type) - # mod = transform.PrintIR(True)(mod) +def test_reverse_ad(): + mod = tvm.IRModule() + mod.import_from_std("gradient.rly") - print("---------------------------") + shape = (10, 10) + dtype = 'float32' + t = relay.TensorType(shape, dtype) - # gradcell = mod.get_global_type_var("GradCell")(t) - # x_np = np.zeros(shape, dtype) + x = relay.var("x", t) - # gradcell = mod.get_global_type_var("GradCell") - # y = tvm.relay.TypeCall(gradcell, [t]) - # - # addcell = mod.get_global_var("AddGradCell") - # fromcell = mod.get_global_var("FromGradCell") + func = relay.Function([x], x) + func = run_infer_type(func) + back_func = transform.gradient(func) + back_func = run_infer_type(back_func) - # mod, params = optimize(mod, target="llvm", params={"x": x_nd}) + mod["main"] = back_func - #ex = create_executor(mod=mod) - #a = ex.evaluate(addFunc)(x_nd) + mod = transform.GradientCell()(mod) + # new_type = grad_cell_type(mod, shape, dtype) + # assert mod["main"].checked_type == relay.FuncType([new_type],) - # mod = transform.InferType()(mod) - print("hi") if __name__ == "__main__": - test_zero_tensor() - + pytest.main([__file__]) \ No newline at end of file From a0834505f1474c7d7f7880a484f2f62ef65ad018 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Tue, 10 Mar 2020 17:20:32 -0700 Subject: [PATCH 07/30] more tests --- src/ir/module.cc | 4 +- src/relay/pass/gradient_cell.cc | 20 +- tests/python/relay/test_pass_gradient_cell.py | 204 +++++++++++++++++- 3 files changed, 210 insertions(+), 18 deletions(-) diff --git a/src/ir/module.cc b/src/ir/module.cc index 45f39d5ade88..6c14b914105e 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -122,7 +122,7 @@ relay::Function RunTypeCheck(const IRModule& mod, auto fv = relay::FreeVars(func); auto ftv = relay::FreeTypeVars(func, mod); if (fv.size() != 0) { - LOG(WARNING) + CHECK(false) << "There are free variables: " << fv << " in function: " @@ -130,7 +130,7 @@ relay::Function RunTypeCheck(const IRModule& mod, << std::endl; } if (ftv.size() != 0) { - LOG(WARNING) + CHECK(false) << "There are free type variables: " << ftv << " in function: " diff --git a/src/relay/pass/gradient_cell.cc b/src/relay/pass/gradient_cell.cc index c67199717fa8..228340d00e85 100644 --- a/src/relay/pass/gradient_cell.cc +++ b/src/relay/pass/gradient_cell.cc @@ -19,13 +19,11 @@ /*! * - * \file gradient_node.cc + * \file gradient_cell.cc * * \brief Convert all tensors to a Gradient Cell - * - * The algorithm is implemented by two visitor: - * CalcDep turn an expr into a dependency graph of expr, - * GenLet turn the dependency graph into a let list, taking only the used value. + * + * This algorithm is implemented by one visitor */ #include @@ -47,7 +45,7 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { GlobalTypeVar gradCellType = module_->GetGlobalTypeVar("GradCell"); Constructor toGradCell = Constructor("Raw", {op->checked_type()}, gradCellType); - return CallNode::make(toGradCell, {GetRef(op)}); + return CallNode::make(toGradCell, {GetRef(op)}, Attrs(), {op->checked_type()}); } Expr VisitExpr_(const CallNode* call_node) final { @@ -68,7 +66,7 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { for (Expr expr: call_node->args) { args.push_back(VisitExpr(expr)); } - return CallNode::make(addFunc, args); + return CallNode::make(addFunc, args, Attrs(), {paramType}); } else if (op->name.compare("multiply") == 0 && call_node->args.size() == 2 && AlphaEqual(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) { const auto multFunc = module_->GetGlobalVar("MultiplyGradCell"); @@ -85,7 +83,7 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { for (Expr expr: call_node->args) { args.push_back(VisitExpr(expr)); } - return CallNode::make(multFunc, args); + return CallNode::make(multFunc, args, Attrs(), {paramType}); } const auto fromFunc = module_->GetGlobalVar("FromGradCell"); @@ -100,10 +98,10 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { Constructor toGradCell = Constructor("Raw", {call_node->checked_type()}, gradCellType); - return CallNode::make(toGradCell, {tensorRes}); + return CallNode::make(toGradCell, {tensorRes}, Attrs(), {call_node->checked_type()}); } - return GetRef(call_node); + return ExprMutator::VisitExpr_(call_node); } Type VisitType(const Type& t) final { @@ -121,6 +119,8 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { private: // Module IRModule module_; + + }; Expr GradientCell(const Expr& e, IRModule mod) { diff --git a/tests/python/relay/test_pass_gradient_cell.py b/tests/python/relay/test_pass_gradient_cell.py index f455bb0d8f8f..bc936d43143e 100644 --- a/tests/python/relay/test_pass_gradient_cell.py +++ b/tests/python/relay/test_pass_gradient_cell.py @@ -19,11 +19,12 @@ import tvm from tvm import te from tvm import relay -from tvm.relay.analysis import free_vars, free_type_vars, assert_alpha_equal from tvm.relay import create_executor, transform -from tvm.relay.testing import rand, run_infer_type +from tvm.relay.testing import rand, run_infer_type, check_grad +from tvm.relay.analysis import assert_alpha_equal from tvm.relay.op import add, multiply from tvm.relay.prelude import Prelude, TensorArrayOps +from tvm.testing import assert_allclose import pytest def grad_cell_type(mod, shape, dtype): @@ -48,6 +49,25 @@ def test_add(): new_type = grad_cell_type(mod, shape, dtype) assert mod["main"].checked_type == relay.FuncType([new_type], new_type) +def test_add_tuple(): + mod = tvm.IRModule() + mod.import_from_std("gradient.rly") + + shape = (10, 10) + dtype = 'float32' + t = relay.TensorType(shape, dtype) + + x1 = relay.var("x1", t) + x2 = relay.var("x2", t) + t1 = relay.Tuple([x1, x2]) + y = relay.Function([x1, x2], relay.TupleGetItem(t1,0) + relay.TupleGetItem(t1,1)) + + mod["main"] = y + mod = transform.GradientCell()(mod) + + new_type = grad_cell_type(mod, shape, dtype) + assert mod["main"].checked_type == relay.FuncType([new_type, new_type], new_type) + def test_mult(): mod = tvm.IRModule() mod.import_from_std("gradient.rly") @@ -84,7 +104,62 @@ def test_tc(): new_type = grad_cell_type(mod, shape, dtype) assert mod["main"].checked_type == relay.FuncType([new_type, new_type], new_type) -def test_reverse_ad(): +def test_ret_tuple(): + mod = tvm.IRModule() + mod.import_from_std("gradient.rly") + + shape = (10, 10) + dtype = 'float32' + t = relay.TensorType(shape, dtype) + + x = relay.var("x", t) + y = relay.RefCreate(x) + func = relay.Function([x], relay.Tuple([x,y])) + func = run_infer_type(func) + + mod["main"] = func + mod = transform.GradientCell()(mod) + + new_type = grad_cell_type(mod, shape, dtype) + assert mod["main"].checked_type == relay.FuncType([new_type], relay.TupleType([new_type, relay.RefType(new_type)])) + +def test_broadcast(): + mod = tvm.IRModule() + mod.import_from_std("gradient.rly") + + shape1 = (3, 4, 1) + shape2 = (1, 5) + dtype = 'float32' + t1 = relay.TensorType(shape1, dtype) + t2 = relay.TensorType(shape2, dtype) + + x1 = relay.var("x1", t1) + x2 = relay.var("x2", t2) + func = relay.Function([x1,x2], x1 + x2) + func = run_infer_type(func) + back_func = transform.gradient(func) + back_func = run_infer_type(back_func) + + mod["main"] = back_func + mod = transform.GradientCell()(mod) + + x1_np = rand(dtype, *shape1).asnumpy() + x2_np = rand(dtype, *shape2).asnumpy() + expected_forward = x1_np + x2_np + x1_type = grad_cell_type(mod, shape1, dtype) + x2_type = grad_cell_type(mod, shape2, dtype) + expected_forward_type = grad_cell_type(mod, expected_forward.shape, dtype) + assert mod["main"].checked_type == relay.FuncType([x1_type, x2_type], + relay.TupleType([expected_forward_type, relay.TupleType([x1_type, x2_type])])) + + ex = create_executor() + (forward), (grad_x1, grad_x2, ) = ex.evaluate(back_func)(x1_np, x2_np) + + assert_allclose(forward.asnumpy(), expected_forward) + assert_allclose(grad_x1.asnumpy(), np.ones_like(expected_forward).sum(axis=2, keepdims=True)) + assert_allclose(grad_x2.asnumpy(), np.ones_like(expected_forward).sum(axis=(0,1), keepdims=True).squeeze(axis=0)) + +def test_reverse_ad_identity(): mod = tvm.IRModule() mod.import_from_std("gradient.rly") @@ -103,9 +178,126 @@ def test_reverse_ad(): mod = transform.GradientCell()(mod) - # new_type = grad_cell_type(mod, shape, dtype) - # assert mod["main"].checked_type == relay.FuncType([new_type],) + new_type = grad_cell_type(mod, shape, dtype) + assert mod["main"].checked_type == relay.FuncType([new_type], + relay.TupleType([new_type, relay.TupleType([new_type])])) + + ex = create_executor() + x = rand(dtype, *shape) + (forward), (grad,) = ex.evaluate(back_func)(x) + assert_allclose(forward.asnumpy(), x.asnumpy()) + assert_allclose(grad.asnumpy(), np.ones_like(x.asnumpy())) + +def test_multivar_reverse_ad(): + mod = tvm.IRModule() + mod.import_from_std("gradient.rly") + + shape = (10, 10) + dtype = 'float32' + t = relay.TensorType(shape, dtype) + + x = relay.var("x", t) + y = relay.var("y", t) + + func = relay.Function([x, y], (x * y) * relay.const(np.ones(shape, dtype))) + func = run_infer_type(func) + back_func = transform.gradient(func) + back_func = run_infer_type(back_func) + + mod["main"] = back_func + + mod = transform.GradientCell()(mod) + + new_type = grad_cell_type(mod, shape, dtype) + assert mod["main"].checked_type == relay.FuncType([new_type, new_type], + relay.TupleType([new_type, relay.TupleType([new_type, new_type])])) + + ex = create_executor() + x = rand(dtype, *shape) + y = rand(dtype, *shape) + (forward), (grad_x, grad_y, ) = ex.evaluate(back_func)(x, y) + assert_allclose(forward.asnumpy(), x.asnumpy() * y.asnumpy()) + assert_allclose(grad_x.asnumpy(), y.asnumpy()) + assert_allclose(grad_y.asnumpy(), x.asnumpy()) + +def test_partial_eval_before(): + mod = tvm.IRModule() + mod.import_from_std("gradient.rly") + + shape = (10, 10) + dtype = 'float32' + t = relay.TensorType(shape, dtype) + + x = relay.var("x", t) + y = relay.var("y", t) + + func = relay.Function([x, y], (x * y) * relay.const(np.ones(shape, dtype))) + func = run_infer_type(func) + back_func = transform.gradient(func) + back_func = run_infer_type(back_func) + mod["main"] = back_func + + seq = transform.Sequential([ + transform.PartialEvaluate(), + transform.GradientCell(), + transform.DeadCodeElimination() + ]) + + mod = seq(mod) + + new_type = grad_cell_type(mod, shape, dtype) + assert mod["main"].checked_type == relay.FuncType([new_type, new_type], + relay.TupleType([new_type, relay.TupleType([new_type, new_type])])) + + ex = create_executor() + x = rand(dtype, *shape) + y = rand(dtype, *shape) + (forward), (grad_x, grad_y,) = ex.evaluate(back_func)(x, y) + assert_allclose(forward.asnumpy(), x.asnumpy() * y.asnumpy()) + assert_allclose(grad_x.asnumpy(), y.asnumpy()) + assert_allclose(grad_y.asnumpy(), x.asnumpy()) + +def test_partial_eval_after_multivar(): + mod = tvm.IRModule() + mod.import_from_std("gradient.rly") + + shape = (10, 10) + dtype = 'float32' + t = relay.TensorType(shape, dtype) + + x = relay.var("x", t) + y = relay.var("y", t) + + # func = relay.Function([x, y], (x * y) * relay.const(np.ones(shape, dtype))) + func = relay.Function([x, y], x + y) + func = run_infer_type(func) + back_func = transform.gradient(func) + back_func = run_infer_type(back_func) + + mod["main"] = back_func + + mod = transform.GradientCell()(mod) + mod = transform.PartialEvaluate()(mod) + + # seq = transform.Sequential([ + # transform.GradientCell(), + # transform.PartialEvaluate(), + # ]) + # + # mod = seq(mod) + + new_type = grad_cell_type(mod, shape, dtype) + # assert mod["main"].checked_type == relay.FuncType([new_type, new_type], + # relay.TupleType([new_type, relay.TupleType([new_type, new_type])])) + # + # ex = create_executor() + # x = rand(dtype, *shape) + # y = rand(dtype, *shape) + # (forward), (grad_x, grad_y,) = ex.evaluate(back_func)(x, y) + # assert_allclose(forward.asnumpy(), x.asnumpy() * y.asnumpy()) + # assert_allclose(grad_x.asnumpy(), y.asnumpy()) + # assert_allclose(grad_y.asnumpy(), x.asnumpy()) if __name__ == "__main__": - pytest.main([__file__]) \ No newline at end of file + pytest.main([__file__]) From 67a6b01f1b439fa16e52c608d31f605df95d295e Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Tue, 10 Mar 2020 18:12:47 -0700 Subject: [PATCH 08/30] fix bug --- src/relay/pass/gradient_cell.cc | 38 ++++++++++++++++------------ src/relay/transforms/partial_eval.cc | 5 +++- 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/src/relay/pass/gradient_cell.cc b/src/relay/pass/gradient_cell.cc index 228340d00e85..3b62e7a0e144 100644 --- a/src/relay/pass/gradient_cell.cc +++ b/src/relay/pass/gradient_cell.cc @@ -39,18 +39,25 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { public: explicit GradientCellTransform(IRModule module): module_(module) - {} + { + TypeData gradCell = module_->LookupTypeDef("GradCell"); + for (Constructor c: gradCell->constructors) { + if (c->name_hint.compare("Raw") == 0) { + rawConstructor = c; + return; + } + } - Expr VisitExpr_(const ConstantNode* op) final { - GlobalTypeVar gradCellType = module_->GetGlobalTypeVar("GradCell"); - Constructor toGradCell = Constructor("Raw", {op->checked_type()}, gradCellType); + CHECK(false) << "Raw Constructor missing from GradCell datatype"; + } - return CallNode::make(toGradCell, {GetRef(op)}, Attrs(), {op->checked_type()}); + Expr VisitExpr_(const ConstantNode* op) final { + return CallNode::make(rawConstructor, {GetRef(op)}, Attrs(), {op->checked_type()}); } Expr VisitExpr_(const CallNode* call_node) final { if (auto* op = (call_node->op).as()) { - if (op->name.compare("add") == 0 && call_node->args.size() == 2 && + if (op->name.compare("add") == 0 && call_node->args.size() == 2 && AlphaEqual(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) { const auto addFunc = module_->GetGlobalVar("AddGradCell"); tvm::Array args; @@ -59,7 +66,7 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { tvm::Array params = {VarNode::make("lhs", paramType), VarNode::make("rhs", paramType)}; Expr callAdd = CallNode::make(Op::Get("add"), {params[0], params[1]}); - + Expr addTensorsFunc = FunctionNode::make(params, callAdd, paramType, Array(), Attrs()); args.push_back(addTensorsFunc); @@ -67,7 +74,7 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { args.push_back(VisitExpr(expr)); } return CallNode::make(addFunc, args, Attrs(), {paramType}); - } else if (op->name.compare("multiply") == 0 && call_node->args.size() == 2 && + } else if (op->name.compare("multiply") == 0 && call_node->args.size() == 2 && AlphaEqual(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) { const auto multFunc = module_->GetGlobalVar("MultiplyGradCell"); tvm::Array args; @@ -76,7 +83,7 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { tvm::Array params = {VarNode::make("lhs", paramType), VarNode::make("rhs", paramType)}; Expr callMultiply = CallNode::make(Op::Get("multiply"), {params[0], params[1]}); - + Expr multTensorsFunc = FunctionNode::make(params, callMultiply, paramType, Array(), Attrs()); args.push_back(multTensorsFunc); @@ -94,11 +101,9 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { args.push_back(CallNode::make(fromFunc, {VisitExpr(expr)}, Attrs(), {expr->checked_type()})); } - const Expr tensorRes = CallNode::make(call_node->op, args); + const Expr tensorRes = CallNode::make(call_node->op, args); - Constructor toGradCell = Constructor("Raw", {call_node->checked_type()}, gradCellType); - - return CallNode::make(toGradCell, {tensorRes}, Attrs(), {call_node->checked_type()}); + return CallNode::make(rawConstructor, {tensorRes}, Attrs(), {call_node->checked_type()}); } return ExprMutator::VisitExpr_(call_node); @@ -112,14 +117,15 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { GlobalTypeVar gradCell = module_->GetGlobalTypeVar("GradCell"); tvm::Array args; args.push_back(GetRef(op)); - return TypeCall(gradCell, args); } - + private: // Module IRModule module_; + // Raw Constructor of GradCell datatype + Constructor rawConstructor; }; @@ -142,4 +148,4 @@ TVM_REGISTER_GLOBAL("relay._transform.GradientCell") } //namespace transform } //namespace relay -} //namespace tvm \ No newline at end of file +} //namespace tvm diff --git a/src/relay/transforms/partial_eval.cc b/src/relay/transforms/partial_eval.cc index cd1f40c28767..6c82bdc30882 100644 --- a/src/relay/transforms/partial_eval.cc +++ b/src/relay/transforms/partial_eval.cc @@ -204,7 +204,9 @@ struct SConstructorNode : StaticNode { Constructor constructor; std::vector fields; SConstructorNode(const Constructor& constructor, const std::vector& fields) : - constructor(constructor), fields(fields) { } + constructor(constructor), fields(fields) { + CHECK_NE(constructor->tag, -1); + } static constexpr const char* _type_key = "relay.SConstructor"; TVM_DECLARE_FINAL_OBJECT_INFO(SConstructorNode, StaticNode); }; @@ -1000,6 +1002,7 @@ class PartialEvaluator : public ExprFunctor } PStatic VisitExpr_(const ConstructorNode* op, LetList* ll) final { + CHECK_NE(op->tag, -1); Constructor c = GetRef(op); Func f = [=](const PStatic& self, const std::vector& pv, From 949aa2b5cd0d8b4fcb3ace124c2e7e4f23125d81 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Tue, 10 Mar 2020 21:09:23 -0700 Subject: [PATCH 09/30] transform calls to one ones_like zero zero_like --- src/relay/pass/gradient_cell.cc | 47 ++++--- tests/python/relay/test_pass_gradient_cell.py | 123 +++++++++++++++--- 2 files changed, 133 insertions(+), 37 deletions(-) diff --git a/src/relay/pass/gradient_cell.cc b/src/relay/pass/gradient_cell.cc index 3b62e7a0e144..42fee1f3e991 100644 --- a/src/relay/pass/gradient_cell.cc +++ b/src/relay/pass/gradient_cell.cc @@ -39,20 +39,10 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { public: explicit GradientCellTransform(IRModule module): module_(module) - { - TypeData gradCell = module_->LookupTypeDef("GradCell"); - for (Constructor c: gradCell->constructors) { - if (c->name_hint.compare("Raw") == 0) { - rawConstructor = c; - return; - } - } - - CHECK(false) << "Raw Constructor missing from GradCell datatype"; - } + {} Expr VisitExpr_(const ConstantNode* op) final { - return CallNode::make(rawConstructor, {GetRef(op)}, Attrs(), {op->checked_type()}); + return CallNode::make(getGradCellConstructor("Raw"), {GetRef(op)}, Attrs(), {op->checked_type()}); } Expr VisitExpr_(const CallNode* call_node) final { @@ -91,10 +81,15 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { args.push_back(VisitExpr(expr)); } return CallNode::make(multFunc, args, Attrs(), {paramType}); - } + } else if (op->name.compare("ones") == 0) { + Expr func = FunctionNode::make({}, {ExprMutator::VisitExpr_(call_node)}, {call_node->checked_type()}, {}, Attrs()); + return CallNode::make(getGradCellConstructor("One"), {func}, Attrs(), {call_node->checked_type()}); + } else if (op->name.compare("zeros") == 0) { + Expr func = FunctionNode::make({}, {ExprMutator::VisitExpr_(call_node)}, {call_node->checked_type()}, {}, Attrs()); + return CallNode::make(getGradCellConstructor("Zero"), {func}, Attrs(), {call_node->checked_type()}); + } const auto fromFunc = module_->GetGlobalVar("FromGradCell"); - GlobalTypeVar gradCellType = module_->GetGlobalTypeVar("GradCell"); tvm::Array args; // use FromGradCell to convert args to Tensor for (Expr expr: call_node->args) { @@ -103,7 +98,14 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { const Expr tensorRes = CallNode::make(call_node->op, args); - return CallNode::make(rawConstructor, {tensorRes}, Attrs(), {call_node->checked_type()}); + if (op->name.compare("ones_like") == 0) { + Expr onesFunction = FunctionNode::make({}, tensorRes, {call_node->checked_type()}, Array(), Attrs()); + return CallNode::make(getGradCellConstructor("One"), {onesFunction}, Attrs(), {call_node->checked_type()}); + } else if (op->name.compare("zeros_like") == 0) { + Expr zerosFunction = FunctionNode::make({}, tensorRes, {call_node->checked_type()}, Array(), Attrs()); + return CallNode::make(getGradCellConstructor("Zero"), {zerosFunction}, Attrs(), {call_node->checked_type()}); + } + return CallNode::make(getGradCellConstructor("Raw"), {tensorRes}, Attrs(), {call_node->checked_type()}); } return ExprMutator::VisitExpr_(call_node); @@ -123,10 +125,19 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { private: // Module IRModule module_; + // Constructors of gradCell datatype + std::unordered_map gradCellConstructors; + + Constructor getGradCellConstructor(std::string name_hint) { + TypeData gradCell = module_->LookupTypeDef("GradCell"); + for (Constructor c: gradCell->constructors) { + if (name_hint.compare(c->name_hint) == 0) { + return c; + } + } - // Raw Constructor of GradCell datatype - Constructor rawConstructor; - + CHECK(false) << "Constructor " << name_hint << "not found in GradCell datatype."; + } }; Expr GradientCell(const Expr& e, IRModule mod) { diff --git a/tests/python/relay/test_pass_gradient_cell.py b/tests/python/relay/test_pass_gradient_cell.py index bc936d43143e..842fafa15289 100644 --- a/tests/python/relay/test_pass_gradient_cell.py +++ b/tests/python/relay/test_pass_gradient_cell.py @@ -269,35 +269,120 @@ def test_partial_eval_after_multivar(): x = relay.var("x", t) y = relay.var("y", t) - # func = relay.Function([x, y], (x * y) * relay.const(np.ones(shape, dtype))) - func = relay.Function([x, y], x + y) + func = relay.Function([x, y], x * y) func = run_infer_type(func) back_func = transform.gradient(func) back_func = run_infer_type(back_func) mod["main"] = back_func + seq = transform.Sequential([ + transform.GradientCell(), + transform.PartialEvaluate(), + transform.DeadCodeElimination() + ]) + + mod = seq(mod) + + new_type = grad_cell_type(mod, shape, dtype) + assert mod["main"].checked_type == relay.FuncType([new_type, new_type], + relay.TupleType([new_type, relay.TupleType([new_type, new_type])])) + + ex = create_executor() + x = rand(dtype, *shape) + y = rand(dtype, *shape) + (forward), (grad_x, grad_y,) = ex.evaluate(back_func)(x, y) + assert_allclose(forward.asnumpy(), x.asnumpy() * y.asnumpy()) + assert_allclose(grad_x.asnumpy(), y.asnumpy()) + assert_allclose(grad_y.asnumpy(), x.asnumpy()) + +def test_zeros(): + mod = tvm.IRModule() + mod.import_from_std("gradient.rly") + + shape = (10, 10) + dtype = 'float32' + t = relay.TensorType(shape, dtype) + + x = relay.var("x", t) + y = relay.Function([x], x + relay.zeros(shape, dtype)) + + mod["main"] = y mod = transform.GradientCell()(mod) - mod = transform.PartialEvaluate()(mod) - # seq = transform.Sequential([ - # transform.GradientCell(), - # transform.PartialEvaluate(), - # ]) - # - # mod = seq(mod) + new_type = grad_cell_type(mod, shape, dtype) + assert mod["main"].checked_type == relay.FuncType([new_type], new_type) + + ex = create_executor() + x = rand(dtype, *shape) + y = ex.evaluate(y)(x) + assert_allclose(y.asnumpy(), x.asnumpy()) + +def test_ones(): + mod = tvm.IRModule() + mod.import_from_std("gradient.rly") + + shape = (10, 10) + dtype = 'float32' + t = relay.TensorType(shape, dtype) + + x = relay.var("x", t) + y = relay.Function([x], x + relay.ones(shape, dtype)) + + mod["main"] = y + mod = transform.GradientCell()(mod) + + new_type = grad_cell_type(mod, shape, dtype) + assert mod["main"].checked_type == relay.FuncType([new_type], new_type) + + ex = create_executor() + x = rand(dtype, *shape) + y = ex.evaluate(y)(x) + assert_allclose(y.asnumpy(), x.asnumpy() + np.ones_like(x.asnumpy())) + +def test_zeros(): + mod = tvm.IRModule() + mod.import_from_std("gradient.rly") + + shape = (10, 10) + dtype = 'float32' + t = relay.TensorType(shape, dtype) + + x = relay.var("x", t) + y = relay.Function([x], x + relay.zeros_like(x)) + + mod["main"] = y + mod = transform.GradientCell()(mod) new_type = grad_cell_type(mod, shape, dtype) - # assert mod["main"].checked_type == relay.FuncType([new_type, new_type], - # relay.TupleType([new_type, relay.TupleType([new_type, new_type])])) - # - # ex = create_executor() - # x = rand(dtype, *shape) - # y = rand(dtype, *shape) - # (forward), (grad_x, grad_y,) = ex.evaluate(back_func)(x, y) - # assert_allclose(forward.asnumpy(), x.asnumpy() * y.asnumpy()) - # assert_allclose(grad_x.asnumpy(), y.asnumpy()) - # assert_allclose(grad_y.asnumpy(), x.asnumpy()) + assert mod["main"].checked_type == relay.FuncType([new_type], new_type) + + ex = create_executor() + x = rand(dtype, *shape) + y = ex.evaluate(y)(x) + assert_allclose(y.asnumpy(), x.asnumpy()) + +def test_ones_like(): + mod = tvm.IRModule() + mod.import_from_std("gradient.rly") + + shape = (10, 10) + dtype = 'float32' + t = relay.TensorType(shape, dtype) + + x = relay.var("x", t) + y = relay.Function([x], x + relay.ones_like(x)) + + mod["main"] = y + mod = transform.GradientCell()(mod) + + new_type = grad_cell_type(mod, shape, dtype) + assert mod["main"].checked_type == relay.FuncType([new_type], new_type) + + ex = create_executor() + x = rand(dtype, *shape) + y = ex.evaluate(y)(x) + assert_allclose(y.asnumpy(), x.asnumpy() + np.ones_like(x.asnumpy())) if __name__ == "__main__": pytest.main([__file__]) From ca15729e6152d41aeb2945e83ca4907d5e898c43 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Tue, 10 Mar 2020 21:49:04 -0700 Subject: [PATCH 10/30] maintenance stuff --- include/tvm/relay/transform.h | 12 ++++++++++++ src/relay/pass/gradient_cell.cc | 16 +++++++++++++--- tests/python/relay/test_pass_gradient_cell.py | 6 +----- 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 23358f0e2b4c..f977cf924c93 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -77,6 +77,18 @@ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc< */ TVM_DLL Pass DeadCodeElimination(bool inline_once = false); +/*! +* \brief Convert all expressions of TensorType into GradCell, +* an algebraic data type defined in gradient.rly. +* +* This will delay or decrease memory usage. All calls to +* ones, ones_like, zeros, zeros_like will not immediately instantiate a tensor in memory, +* rather only instantiate if needed. It also defines + and * operation +* between GradCell types which can increase performance when using +* zero-filled or one-filled tensors, which is the case in gradient descent. +* +* \return the pass +*/ TVM_DLL Pass GradientCell(); /*! diff --git a/src/relay/pass/gradient_cell.cc b/src/relay/pass/gradient_cell.cc index 42fee1f3e991..16885a096925 100644 --- a/src/relay/pass/gradient_cell.cc +++ b/src/relay/pass/gradient_cell.cc @@ -23,7 +23,16 @@ * * \brief Convert all tensors to a Gradient Cell * - * This algorithm is implemented by one visitor + * This pass delays or removes memory allocation by converting tensors into + * GradCell, an algebraic data type defined in gradient.rly + * + * This will delay or decrease memory usage. All calls to + * ones, ones_like, zeros, zeros_like will call the One or Zero constructor + * of GradCell, which will not instantiate in memory until needed. All other cases result + * in using the Raw constructor which means the tensor is instantiated in memory. + * + * It also overloads + and * operation which can increase performance when doing + * operations involving zero-filled or one-filled tensors. */ #include @@ -49,6 +58,7 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { if (auto* op = (call_node->op).as()) { if (op->name.compare("add") == 0 && call_node->args.size() == 2 && AlphaEqual(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) { + // case: "add" between two tensors of the same size const auto addFunc = module_->GetGlobalVar("AddGradCell"); tvm::Array args; @@ -66,6 +76,7 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { return CallNode::make(addFunc, args, Attrs(), {paramType}); } else if (op->name.compare("multiply") == 0 && call_node->args.size() == 2 && AlphaEqual(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) { + // case: "multiply" between two tensors of the same size const auto multFunc = module_->GetGlobalVar("MultiplyGradCell"); tvm::Array args; @@ -125,9 +136,8 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { private: // Module IRModule module_; - // Constructors of gradCell datatype - std::unordered_map gradCellConstructors; + // get constructor of GradCell with name Constructor getGradCellConstructor(std::string name_hint) { TypeData gradCell = module_->LookupTypeDef("GradCell"); for (Constructor c: gradCell->constructors) { diff --git a/tests/python/relay/test_pass_gradient_cell.py b/tests/python/relay/test_pass_gradient_cell.py index 842fafa15289..326505273663 100644 --- a/tests/python/relay/test_pass_gradient_cell.py +++ b/tests/python/relay/test_pass_gradient_cell.py @@ -17,13 +17,9 @@ import numpy as np import tvm -from tvm import te from tvm import relay from tvm.relay import create_executor, transform -from tvm.relay.testing import rand, run_infer_type, check_grad -from tvm.relay.analysis import assert_alpha_equal -from tvm.relay.op import add, multiply -from tvm.relay.prelude import Prelude, TensorArrayOps +from tvm.relay.testing import rand, run_infer_type from tvm.testing import assert_allclose import pytest From 182fbdcc5bb698153713123e6b71fa970fd42d2d Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Tue, 10 Mar 2020 22:17:05 -0700 Subject: [PATCH 11/30] fix linting --- src/relay/pass/gradient_cell.cc | 80 +++++++++++++++++++-------------- 1 file changed, 47 insertions(+), 33 deletions(-) diff --git a/src/relay/pass/gradient_cell.cc b/src/relay/pass/gradient_cell.cc index 16885a096925..2934a7913d1b 100644 --- a/src/relay/pass/gradient_cell.cc +++ b/src/relay/pass/gradient_cell.cc @@ -51,7 +51,8 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { {} Expr VisitExpr_(const ConstantNode* op) final { - return CallNode::make(getGradCellConstructor("Raw"), {GetRef(op)}, Attrs(), {op->checked_type()}); + return CallNode::make(getGradCellConstructor("Raw"), + {GetRef(op)}, Attrs(), {op->checked_type()}); } Expr VisitExpr_(const CallNode* call_node) final { @@ -61,16 +62,17 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { // case: "add" between two tensors of the same size const auto addFunc = module_->GetGlobalVar("AddGradCell"); tvm::Array args; - + // create add function Type paramType = call_node->args[0]->checked_type(); - - tvm::Array params = {VarNode::make("lhs", paramType), VarNode::make("rhs", paramType)}; + tvm::Array params = {VarNode::make("lhs", paramType), + VarNode::make("rhs", paramType)}; Expr callAdd = CallNode::make(Op::Get("add"), {params[0], params[1]}); + Expr addTensorsFunc = FunctionNode::make(params, callAdd, paramType, + Array(), Attrs()); - Expr addTensorsFunc = FunctionNode::make(params, callAdd, paramType, Array(), Attrs()); - + // pass add function and tensors into arguments args.push_back(addTensorsFunc); - for (Expr expr: call_node->args) { + for (Expr expr : call_node->args) { args.push_back(VisitExpr(expr)); } return CallNode::make(addFunc, args, Attrs(), {paramType}); @@ -78,45 +80,57 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { AlphaEqual(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) { // case: "multiply" between two tensors of the same size const auto multFunc = module_->GetGlobalVar("MultiplyGradCell"); + // create multiply function tvm::Array args; - Type paramType = call_node->args[0]->checked_type(); - - tvm::Array params = {VarNode::make("lhs", paramType), VarNode::make("rhs", paramType)}; - Expr callMultiply = CallNode::make(Op::Get("multiply"), {params[0], params[1]}); - - Expr multTensorsFunc = FunctionNode::make(params, callMultiply, paramType, Array(), Attrs()); - + tvm::Array params = {VarNode::make("lhs", paramType), + VarNode::make("rhs", paramType)}; + Expr callMultiply = CallNode::make(Op::Get("multiply"), + {params[0], params[1]}); + Expr multTensorsFunc = FunctionNode::make(params, callMultiply, paramType, + Array(), Attrs()); + + // pass multiply function and tensors into arguments args.push_back(multTensorsFunc); - for (Expr expr: call_node->args) { + for (Expr expr : call_node->args) { args.push_back(VisitExpr(expr)); } return CallNode::make(multFunc, args, Attrs(), {paramType}); } else if (op->name.compare("ones") == 0) { - Expr func = FunctionNode::make({}, {ExprMutator::VisitExpr_(call_node)}, {call_node->checked_type()}, {}, Attrs()); - return CallNode::make(getGradCellConstructor("One"), {func}, Attrs(), {call_node->checked_type()}); + Expr func = FunctionNode::make({}, {ExprMutator::VisitExpr_(call_node)}, + {call_node->checked_type()}, {}, Attrs()); + return CallNode::make(getGradCellConstructor("One"), + {func}, Attrs(), {call_node->checked_type()}); } else if (op->name.compare("zeros") == 0) { - Expr func = FunctionNode::make({}, {ExprMutator::VisitExpr_(call_node)}, {call_node->checked_type()}, {}, Attrs()); - return CallNode::make(getGradCellConstructor("Zero"), {func}, Attrs(), {call_node->checked_type()}); - } + Expr func = FunctionNode::make({}, {ExprMutator::VisitExpr_(call_node)}, + {call_node->checked_type()}, {}, Attrs()); + return CallNode::make(getGradCellConstructor("Zero"), + {func}, Attrs(), {call_node->checked_type()}); + } const auto fromFunc = module_->GetGlobalVar("FromGradCell"); tvm::Array args; // use FromGradCell to convert args to Tensor - for (Expr expr: call_node->args) { - args.push_back(CallNode::make(fromFunc, {VisitExpr(expr)}, Attrs(), {expr->checked_type()})); + for (Expr expr : call_node->args) { + args.push_back(CallNode::make(fromFunc, + {VisitExpr(expr)}, Attrs(), {expr->checked_type()})); } const Expr tensorRes = CallNode::make(call_node->op, args); if (op->name.compare("ones_like") == 0) { - Expr onesFunction = FunctionNode::make({}, tensorRes, {call_node->checked_type()}, Array(), Attrs()); - return CallNode::make(getGradCellConstructor("One"), {onesFunction}, Attrs(), {call_node->checked_type()}); + Expr onesFunction = FunctionNode::make({}, tensorRes, + {call_node->checked_type()}, Array(), Attrs()); + return CallNode::make(getGradCellConstructor("One"), + {onesFunction}, Attrs(), {call_node->checked_type()}); } else if (op->name.compare("zeros_like") == 0) { - Expr zerosFunction = FunctionNode::make({}, tensorRes, {call_node->checked_type()}, Array(), Attrs()); - return CallNode::make(getGradCellConstructor("Zero"), {zerosFunction}, Attrs(), {call_node->checked_type()}); + Expr zerosFunction = FunctionNode::make({}, tensorRes, + {call_node->checked_type()}, Array(), Attrs()); + return CallNode::make(getGradCellConstructor("Zero"), + {zerosFunction}, Attrs(), {call_node->checked_type()}); } - return CallNode::make(getGradCellConstructor("Raw"), {tensorRes}, Attrs(), {call_node->checked_type()}); + return CallNode::make(getGradCellConstructor("Raw"), {tensorRes}, + Attrs(), {call_node->checked_type()}); } return ExprMutator::VisitExpr_(call_node); @@ -133,14 +147,14 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { return TypeCall(gradCell, args); } - private: + private: // Module IRModule module_; - // get constructor of GradCell with name + // get constructor of GradCell with name Constructor getGradCellConstructor(std::string name_hint) { TypeData gradCell = module_->LookupTypeDef("GradCell"); - for (Constructor c: gradCell->constructors) { + for (Constructor c : gradCell->constructors) { if (name_hint.compare(c->name_hint) == 0) { return c; } @@ -166,7 +180,7 @@ Pass GradientCell() { TVM_REGISTER_GLOBAL("relay._transform.GradientCell") .set_body_typed(GradientCell); -} //namespace transform +} //namespace transform -} //namespace relay -} //namespace tvm +} //namespace relay +} //namespace tvm From 109b288ae5b9a2ff36e94f0de78fcb2662ffc496 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Tue, 10 Mar 2020 22:26:11 -0700 Subject: [PATCH 12/30] linting --- src/relay/pass/gradient_cell.cc | 216 ++++++++++++++++---------------- 1 file changed, 108 insertions(+), 108 deletions(-) diff --git a/src/relay/pass/gradient_cell.cc b/src/relay/pass/gradient_cell.cc index 2934a7913d1b..fe0c79bceb84 100644 --- a/src/relay/pass/gradient_cell.cc +++ b/src/relay/pass/gradient_cell.cc @@ -45,123 +45,123 @@ namespace tvm { namespace relay { class GradientCellTransform: public ExprMutator, public TypeMutator { - public: - explicit GradientCellTransform(IRModule module): - module_(module) - {} - - Expr VisitExpr_(const ConstantNode* op) final { - return CallNode::make(getGradCellConstructor("Raw"), - {GetRef(op)}, Attrs(), {op->checked_type()}); - } - - Expr VisitExpr_(const CallNode* call_node) final { - if (auto* op = (call_node->op).as()) { - if (op->name.compare("add") == 0 && call_node->args.size() == 2 && - AlphaEqual(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) { - // case: "add" between two tensors of the same size - const auto addFunc = module_->GetGlobalVar("AddGradCell"); - tvm::Array args; - // create add function - Type paramType = call_node->args[0]->checked_type(); - tvm::Array params = {VarNode::make("lhs", paramType), - VarNode::make("rhs", paramType)}; - Expr callAdd = CallNode::make(Op::Get("add"), {params[0], params[1]}); - Expr addTensorsFunc = FunctionNode::make(params, callAdd, paramType, - Array(), Attrs()); - - // pass add function and tensors into arguments - args.push_back(addTensorsFunc); - for (Expr expr : call_node->args) { - args.push_back(VisitExpr(expr)); - } - return CallNode::make(addFunc, args, Attrs(), {paramType}); - } else if (op->name.compare("multiply") == 0 && call_node->args.size() == 2 && - AlphaEqual(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) { - // case: "multiply" between two tensors of the same size - const auto multFunc = module_->GetGlobalVar("MultiplyGradCell"); - // create multiply function - tvm::Array args; - Type paramType = call_node->args[0]->checked_type(); - tvm::Array params = {VarNode::make("lhs", paramType), - VarNode::make("rhs", paramType)}; - Expr callMultiply = CallNode::make(Op::Get("multiply"), - {params[0], params[1]}); - Expr multTensorsFunc = FunctionNode::make(params, callMultiply, paramType, - Array(), Attrs()); - - // pass multiply function and tensors into arguments - args.push_back(multTensorsFunc); - for (Expr expr : call_node->args) { - args.push_back(VisitExpr(expr)); - } - return CallNode::make(multFunc, args, Attrs(), {paramType}); - } else if (op->name.compare("ones") == 0) { - Expr func = FunctionNode::make({}, {ExprMutator::VisitExpr_(call_node)}, - {call_node->checked_type()}, {}, Attrs()); - return CallNode::make(getGradCellConstructor("One"), - {func}, Attrs(), {call_node->checked_type()}); - } else if (op->name.compare("zeros") == 0) { - Expr func = FunctionNode::make({}, {ExprMutator::VisitExpr_(call_node)}, - {call_node->checked_type()}, {}, Attrs()); - return CallNode::make(getGradCellConstructor("Zero"), - {func}, Attrs(), {call_node->checked_type()}); - } - - const auto fromFunc = module_->GetGlobalVar("FromGradCell"); + public: + explicit GradientCellTransform(IRModule module): + module_(module) + {} + + Expr VisitExpr_(const ConstantNode* op) final { + return CallNode::make(getGradCellConstructor("Raw"), + {GetRef(op)}, Attrs(), {op->checked_type()}); + } + + Expr VisitExpr_(const CallNode* call_node) final { + if (auto* op = (call_node->op).as()) { + if (op->name.compare("add") == 0 && call_node->args.size() == 2 && + AlphaEqual(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) { + // case: "add" between two tensors of the same size + const auto addFunc = module_->GetGlobalVar("AddGradCell"); tvm::Array args; - // use FromGradCell to convert args to Tensor + // create add function + Type paramType = call_node->args[0]->checked_type(); + tvm::Array params = {VarNode::make("lhs", paramType), + VarNode::make("rhs", paramType)}; + Expr callAdd = CallNode::make(Op::Get("add"), {params[0], params[1]}); + Expr addTensorsFunc = FunctionNode::make(params, callAdd, paramType, + Array(), Attrs()); + + // pass add function and tensors into arguments + args.push_back(addTensorsFunc); for (Expr expr : call_node->args) { - args.push_back(CallNode::make(fromFunc, - {VisitExpr(expr)}, Attrs(), {expr->checked_type()})); + args.push_back(VisitExpr(expr)); } + return CallNode::make(addFunc, args, Attrs(), {paramType}); + } else if (op->name.compare("multiply") == 0 && call_node->args.size() == 2 && + AlphaEqual(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) { + // case: "multiply" between two tensors of the same size + const auto multFunc = module_->GetGlobalVar("MultiplyGradCell"); + // create multiply function + tvm::Array args; + Type paramType = call_node->args[0]->checked_type(); + tvm::Array params = {VarNode::make("lhs", paramType), + VarNode::make("rhs", paramType)}; + Expr callMultiply = CallNode::make(Op::Get("multiply"), + {params[0], params[1]}); + Expr multTensorsFunc = FunctionNode::make(params, callMultiply, paramType, + Array(), Attrs()); - const Expr tensorRes = CallNode::make(call_node->op, args); - - if (op->name.compare("ones_like") == 0) { - Expr onesFunction = FunctionNode::make({}, tensorRes, - {call_node->checked_type()}, Array(), Attrs()); - return CallNode::make(getGradCellConstructor("One"), - {onesFunction}, Attrs(), {call_node->checked_type()}); - } else if (op->name.compare("zeros_like") == 0) { - Expr zerosFunction = FunctionNode::make({}, tensorRes, - {call_node->checked_type()}, Array(), Attrs()); - return CallNode::make(getGradCellConstructor("Zero"), - {zerosFunction}, Attrs(), {call_node->checked_type()}); + // pass multiply function and tensors into arguments + args.push_back(multTensorsFunc); + for (Expr expr : call_node->args) { + args.push_back(VisitExpr(expr)); } - return CallNode::make(getGradCellConstructor("Raw"), {tensorRes}, - Attrs(), {call_node->checked_type()}); + return CallNode::make(multFunc, args, Attrs(), {paramType}); + } else if (op->name.compare("ones") == 0) { + Expr func = FunctionNode::make({}, {ExprMutator::VisitExpr_(call_node)}, + {call_node->checked_type()}, {}, Attrs()); + return CallNode::make(getGradCellConstructor("One"), + {func}, Attrs(), {call_node->checked_type()}); + } else if (op->name.compare("zeros") == 0) { + Expr func = FunctionNode::make({}, {ExprMutator::VisitExpr_(call_node)}, + {call_node->checked_type()}, {}, Attrs()); + return CallNode::make(getGradCellConstructor("Zero"), + {func}, Attrs(), {call_node->checked_type()}); } - return ExprMutator::VisitExpr_(call_node); - } - - Type VisitType(const Type& t) final { - return TypeMutator::VisitType(t); - } + const auto fromFunc = module_->GetGlobalVar("FromGradCell"); + tvm::Array args; + // use FromGradCell to convert args to Tensor + for (Expr expr : call_node->args) { + args.push_back(CallNode::make(fromFunc, + {VisitExpr(expr)}, Attrs(), {expr->checked_type()})); + } - Type VisitType_(const TensorTypeNode* op) { - GlobalTypeVar gradCell = module_->GetGlobalTypeVar("GradCell"); - tvm::Array args; - args.push_back(GetRef(op)); - return TypeCall(gradCell, args); + const Expr tensorRes = CallNode::make(call_node->op, args); + + if (op->name.compare("ones_like") == 0) { + Expr onesFunction = FunctionNode::make({}, tensorRes, + {call_node->checked_type()}, Array(), Attrs()); + return CallNode::make(getGradCellConstructor("One"), + {onesFunction}, Attrs(), {call_node->checked_type()}); + } else if (op->name.compare("zeros_like") == 0) { + Expr zerosFunction = FunctionNode::make({}, tensorRes, + {call_node->checked_type()}, Array(), Attrs()); + return CallNode::make(getGradCellConstructor("Zero"), + {zerosFunction}, Attrs(), {call_node->checked_type()}); + } + return CallNode::make(getGradCellConstructor("Raw"), {tensorRes}, + Attrs(), {call_node->checked_type()}); } - private: - // Module - IRModule module_; - - // get constructor of GradCell with name - Constructor getGradCellConstructor(std::string name_hint) { - TypeData gradCell = module_->LookupTypeDef("GradCell"); - for (Constructor c : gradCell->constructors) { - if (name_hint.compare(c->name_hint) == 0) { - return c; - } + return ExprMutator::VisitExpr_(call_node); + } + + Type VisitType(const Type& t) final { + return TypeMutator::VisitType(t); + } + + Type VisitType_(const TensorTypeNode* op) { + GlobalTypeVar gradCell = module_->GetGlobalTypeVar("GradCell"); + tvm::Array args; + args.push_back(GetRef(op)); + return TypeCall(gradCell, args); + } + + private: + // Module + IRModule module_; + + // get constructor of GradCell with name + Constructor getGradCellConstructor(std::string name_hint) { + TypeData gradCell = module_->LookupTypeDef("GradCell"); + for (Constructor c : gradCell->constructors) { + if (name_hint.compare(c->name_hint) == 0) { + return c; } - - CHECK(false) << "Constructor " << name_hint << "not found in GradCell datatype."; } + + CHECK(false) << "Constructor " << name_hint << "not found in GradCell datatype."; + } }; Expr GradientCell(const Expr& e, IRModule mod) { @@ -180,7 +180,7 @@ Pass GradientCell() { TVM_REGISTER_GLOBAL("relay._transform.GradientCell") .set_body_typed(GradientCell); -} //namespace transform +} // namespace transform -} //namespace relay -} //namespace tvm +} // namespace relay +} // namespace tvm From 91157100fd6215983d81400219587baba6ef7cdc Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Tue, 10 Mar 2020 22:41:13 -0700 Subject: [PATCH 13/30] linting --- src/relay/pass/gradient_cell.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/relay/pass/gradient_cell.cc b/src/relay/pass/gradient_cell.cc index fe0c79bceb84..126cd305fc33 100644 --- a/src/relay/pass/gradient_cell.cc +++ b/src/relay/pass/gradient_cell.cc @@ -180,7 +180,7 @@ Pass GradientCell() { TVM_REGISTER_GLOBAL("relay._transform.GradientCell") .set_body_typed(GradientCell); -} // namespace transform +} // namespace transform -} // namespace relay -} // namespace tvm +} // namespace relay +} // namespace tvm From 8c7f4e8a45667aec4eda37dc2c420ce3ef26df2d Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Tue, 10 Mar 2020 23:40:27 -0700 Subject: [PATCH 14/30] throw default --- src/relay/pass/gradient_cell.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/relay/pass/gradient_cell.cc b/src/relay/pass/gradient_cell.cc index 126cd305fc33..5b4ec2159f07 100644 --- a/src/relay/pass/gradient_cell.cc +++ b/src/relay/pass/gradient_cell.cc @@ -160,7 +160,8 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { } } - CHECK(false) << "Constructor " << name_hint << "not found in GradCell datatype."; + CHECK(false) << "Constructor " << name_hint << "not found in GradCell typedata."; + throw std::runtime_error("Constructor not found in GradCell typedata"); } }; From bbd0a45e9f462e7f9fe5ef4fb342e0c5ab712fda Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Wed, 11 Mar 2020 21:28:38 -0700 Subject: [PATCH 15/30] remove unrelated changes --- src/ir/error.cc | 10 +--------- src/ir/module.cc | 4 ++-- src/ir/type_functor.cc | 1 - src/relay/transforms/gradient.cc | 10 +--------- src/relay/transforms/partial_eval.cc | 5 +---- tests/python/contrib/test_nnpack.py | 2 +- tests/python/integration/test_winograd_nnpack.py | 4 ++-- tests/python/relay/test_op_grad_level3.py | 2 +- tests/python/relay/test_op_grad_level4.py | 2 +- tests/python/relay/test_pass_lambda_lift.py | 3 ++- tests/python/relay/test_pass_manager.py | 2 +- .../python/relay/test_pass_remove_unused_functions.py | 2 +- 12 files changed, 14 insertions(+), 33 deletions(-) diff --git a/src/ir/error.cc b/src/ir/error.cc index 96c953e2d767..9d498288d2ba 100644 --- a/src/ir/error.cc +++ b/src/ir/error.cc @@ -115,9 +115,7 @@ void ErrorReporter::RenderErrors(const IRModule& module, bool use_color) { auto it = err_map.find(expr); if (it != err_map.end()) { CHECK_NE(it->second.size(), 0); - std::string ret = it->second; - err_map.erase(it); - return ret; + return it->second; } else { return std::string(""); } @@ -130,12 +128,6 @@ void ErrorReporter::RenderErrors(const IRModule& module, bool use_color) { rang::setControlMode(rang::control::Auto); } - for (const auto& err_map : error_maps) { - for (const auto& str : err_map.second) { - annotated_prog << str.second << std::endl; - } - } - // Finally we report the error, currently we do so to LOG(FATAL), // it may be good to instead report it to std::cout. LOG(FATAL) << annotated_prog.str() << std::endl; diff --git a/src/ir/module.cc b/src/ir/module.cc index 6c14b914105e..45f39d5ade88 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -122,7 +122,7 @@ relay::Function RunTypeCheck(const IRModule& mod, auto fv = relay::FreeVars(func); auto ftv = relay::FreeTypeVars(func, mod); if (fv.size() != 0) { - CHECK(false) + LOG(WARNING) << "There are free variables: " << fv << " in function: " @@ -130,7 +130,7 @@ relay::Function RunTypeCheck(const IRModule& mod, << std::endl; } if (ftv.size() != 0) { - CHECK(false) + LOG(WARNING) << "There are free type variables: " << ftv << " in function: " diff --git a/src/ir/type_functor.cc b/src/ir/type_functor.cc index f60583ccc6f9..cbd3538b066c 100644 --- a/src/ir/type_functor.cc +++ b/src/ir/type_functor.cc @@ -151,7 +151,6 @@ Type TypeMutator::VisitType_(const FuncTypeNode* op) { Array new_args = MutateArray(op->arg_types); changed = changed || !new_args.same_as(op->arg_types); - CHECK(new_args.size() == op->arg_types.size()); Type new_ret_type = VisitType(op->ret_type); changed = changed || !new_ret_type.same_as(op->ret_type); diff --git a/src/relay/transforms/gradient.cc b/src/relay/transforms/gradient.cc index aa6157f1acdb..eca85a3b9181 100644 --- a/src/relay/transforms/gradient.cc +++ b/src/relay/transforms/gradient.cc @@ -18,7 +18,7 @@ */ /*! - * \file gradient.cc + * \file ad.cc * \brief API for Automatic Differentiation for the Relay IR. */ #include @@ -92,14 +92,6 @@ Expr DeGlobal(const IRModule& mod, const Expr& e) { } } -std::string GradName(const Expr& e) { - if (const auto* x = e.as()) { - return x->name_hint + "_grad"; - } else { - return "temp_grad"; - } -} - /*! \brief A fragment of the program being built by the automatic differentation * pass. */ diff --git a/src/relay/transforms/partial_eval.cc b/src/relay/transforms/partial_eval.cc index 6c82bdc30882..cd1f40c28767 100644 --- a/src/relay/transforms/partial_eval.cc +++ b/src/relay/transforms/partial_eval.cc @@ -204,9 +204,7 @@ struct SConstructorNode : StaticNode { Constructor constructor; std::vector fields; SConstructorNode(const Constructor& constructor, const std::vector& fields) : - constructor(constructor), fields(fields) { - CHECK_NE(constructor->tag, -1); - } + constructor(constructor), fields(fields) { } static constexpr const char* _type_key = "relay.SConstructor"; TVM_DECLARE_FINAL_OBJECT_INFO(SConstructorNode, StaticNode); }; @@ -1002,7 +1000,6 @@ class PartialEvaluator : public ExprFunctor } PStatic VisitExpr_(const ConstructorNode* op, LetList* ll) final { - CHECK_NE(op->tag, -1); Constructor c = GetRef(op); Func f = [=](const PStatic& self, const std::vector& pv, diff --git a/tests/python/contrib/test_nnpack.py b/tests/python/contrib/test_nnpack.py index 8c2197b94757..505199a55724 100644 --- a/tests/python/contrib/test_nnpack.py +++ b/tests/python/contrib/test_nnpack.py @@ -203,4 +203,4 @@ def verify(target="llvm", if __name__ == "__main__": - pytest.main([__file__]) + pytest.main() diff --git a/tests/python/integration/test_winograd_nnpack.py b/tests/python/integration/test_winograd_nnpack.py index 536ca5d042ea..7dad2ca586d7 100644 --- a/tests/python/integration/test_winograd_nnpack.py +++ b/tests/python/integration/test_winograd_nnpack.py @@ -25,7 +25,6 @@ import topi.testing from topi.util import get_const_tuple from pytest import skip -import pytest def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False, @@ -141,4 +140,5 @@ def test_conv2d_nchw(): if __name__ == "__main__": - pytest.main([__file__]) + import pytest + pytest.main() diff --git a/tests/python/relay/test_op_grad_level3.py b/tests/python/relay/test_op_grad_level3.py index cca730311751..d13687fbec72 100644 --- a/tests/python/relay/test_op_grad_level3.py +++ b/tests/python/relay/test_op_grad_level3.py @@ -65,4 +65,4 @@ def test_cast_grad(): check_grad(fwd_func) if __name__ == "__main__": - pytest.main([__file__]) + pytest.main() diff --git a/tests/python/relay/test_op_grad_level4.py b/tests/python/relay/test_op_grad_level4.py index 7ec2c8609a97..f690a186ea41 100644 --- a/tests/python/relay/test_op_grad_level4.py +++ b/tests/python/relay/test_op_grad_level4.py @@ -46,4 +46,4 @@ def test_max_grad(): if __name__ == "__main__": - pytest.main([__file__]) + pytest.main() diff --git a/tests/python/relay/test_pass_lambda_lift.py b/tests/python/relay/test_pass_lambda_lift.py index ce7b597d07f6..e38887829551 100644 --- a/tests/python/relay/test_pass_lambda_lift.py +++ b/tests/python/relay/test_pass_lambda_lift.py @@ -75,4 +75,5 @@ def test_recursive(): if __name__ == "__main__": - pytest.main([__file__]) + pytest.main() + diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index f39dfdc4dcb6..aed026996a21 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -555,4 +555,4 @@ def test_print_debug_callback(): if __name__ == "__main__": - pytest.main([__file__]) + pytest.main() diff --git a/tests/python/relay/test_pass_remove_unused_functions.py b/tests/python/relay/test_pass_remove_unused_functions.py index 5774b93d0c5e..33816344f562 100644 --- a/tests/python/relay/test_pass_remove_unused_functions.py +++ b/tests/python/relay/test_pass_remove_unused_functions.py @@ -114,4 +114,4 @@ def get_mod(): if __name__ == '__main__': - pytest.main([__file__]) + pytest.main() From 921a03cb42ae4eef20cb0fd17355fe1b98bcd3e8 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Wed, 11 Mar 2020 21:38:03 -0700 Subject: [PATCH 16/30] import gradent.rly in pass --- python/tvm/relay/transform/transform.py | 4 +-- src/relay/pass/gradient_cell.cc | 4 ++- tests/python/relay/test_pass_gradient_cell.py | 34 ++++++------------- 3 files changed, 15 insertions(+), 27 deletions(-) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 456df32531cc..bd0f33d09928 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -220,7 +220,7 @@ def DeadCodeElimination(inline_once=False): return _ffi_api.DeadCodeElimination(inline_once) def GradientCell(): - """Condense tensors with all 0s or 1s + """Reduces memory usage of tensors with all 0s or 1s Parameters ---------- @@ -228,7 +228,7 @@ def GradientCell(): Returns ------- ret: tvm.relay.Pass - The registered pass that condenses tensors with all 0s or 1s + The registered pass that delays or reduces memory allocation """ return _transform.GradientCell() diff --git a/src/relay/pass/gradient_cell.cc b/src/relay/pass/gradient_cell.cc index 5b4ec2159f07..7a2bc9cc95fa 100644 --- a/src/relay/pass/gradient_cell.cc +++ b/src/relay/pass/gradient_cell.cc @@ -48,7 +48,9 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { public: explicit GradientCellTransform(IRModule module): module_(module) - {} + { + module_->ImportFromStd("gradient.rly"); + } Expr VisitExpr_(const ConstantNode* op) final { return CallNode::make(getGradCellConstructor("Raw"), diff --git a/tests/python/relay/test_pass_gradient_cell.py b/tests/python/relay/test_pass_gradient_cell.py index 326505273663..46afd2457b20 100644 --- a/tests/python/relay/test_pass_gradient_cell.py +++ b/tests/python/relay/test_pass_gradient_cell.py @@ -30,7 +30,6 @@ def grad_cell_type(mod, shape, dtype): def test_add(): mod = tvm.IRModule() - mod.import_from_std("gradient.rly") shape = (10, 10) dtype = 'float32' @@ -47,7 +46,6 @@ def test_add(): def test_add_tuple(): mod = tvm.IRModule() - mod.import_from_std("gradient.rly") shape = (10, 10) dtype = 'float32' @@ -66,7 +64,6 @@ def test_add_tuple(): def test_mult(): mod = tvm.IRModule() - mod.import_from_std("gradient.rly") shape = (15, 15) dtype = 'float32' @@ -83,7 +80,6 @@ def test_mult(): def test_tc(): mod = tvm.IRModule() - mod.import_from_std("gradient.rly") shape = (20, 20) dtype = 'float32' @@ -102,8 +98,7 @@ def test_tc(): def test_ret_tuple(): mod = tvm.IRModule() - mod.import_from_std("gradient.rly") - + shape = (10, 10) dtype = 'float32' t = relay.TensorType(shape, dtype) @@ -121,8 +116,7 @@ def test_ret_tuple(): def test_broadcast(): mod = tvm.IRModule() - mod.import_from_std("gradient.rly") - + shape1 = (3, 4, 1) shape2 = (1, 5) dtype = 'float32' @@ -157,8 +151,7 @@ def test_broadcast(): def test_reverse_ad_identity(): mod = tvm.IRModule() - mod.import_from_std("gradient.rly") - + shape = (10, 10) dtype = 'float32' t = relay.TensorType(shape, dtype) @@ -186,8 +179,7 @@ def test_reverse_ad_identity(): def test_multivar_reverse_ad(): mod = tvm.IRModule() - mod.import_from_std("gradient.rly") - + shape = (10, 10) dtype = 'float32' t = relay.TensorType(shape, dtype) @@ -218,8 +210,7 @@ def test_multivar_reverse_ad(): def test_partial_eval_before(): mod = tvm.IRModule() - mod.import_from_std("gradient.rly") - + shape = (10, 10) dtype = 'float32' t = relay.TensorType(shape, dtype) @@ -256,8 +247,7 @@ def test_partial_eval_before(): def test_partial_eval_after_multivar(): mod = tvm.IRModule() - mod.import_from_std("gradient.rly") - + shape = (10, 10) dtype = 'float32' t = relay.TensorType(shape, dtype) @@ -294,8 +284,7 @@ def test_partial_eval_after_multivar(): def test_zeros(): mod = tvm.IRModule() - mod.import_from_std("gradient.rly") - + shape = (10, 10) dtype = 'float32' t = relay.TensorType(shape, dtype) @@ -316,8 +305,7 @@ def test_zeros(): def test_ones(): mod = tvm.IRModule() - mod.import_from_std("gradient.rly") - + shape = (10, 10) dtype = 'float32' t = relay.TensorType(shape, dtype) @@ -338,8 +326,7 @@ def test_ones(): def test_zeros(): mod = tvm.IRModule() - mod.import_from_std("gradient.rly") - + shape = (10, 10) dtype = 'float32' t = relay.TensorType(shape, dtype) @@ -360,8 +347,7 @@ def test_zeros(): def test_ones_like(): mod = tvm.IRModule() - mod.import_from_std("gradient.rly") - + shape = (10, 10) dtype = 'float32' t = relay.TensorType(shape, dtype) From 02563fbf5f4274b221ef2a5089d5035f635bd815 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Wed, 11 Mar 2020 21:38:55 -0700 Subject: [PATCH 17/30] comment --- include/tvm/relay/transform.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index f977cf924c93..3f5be9694649 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -85,7 +85,7 @@ TVM_DLL Pass DeadCodeElimination(bool inline_once = false); * ones, ones_like, zeros, zeros_like will not immediately instantiate a tensor in memory, * rather only instantiate if needed. It also defines + and * operation * between GradCell types which can increase performance when using -* zero-filled or one-filled tensors, which is the case in gradient descent. +* zero-filled or one-filled tensors, which is the case in reverse mode ad. * * \return the pass */ From e81d0bd69d69123d66bda31efeb12bb680b8f92a Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Wed, 11 Mar 2020 21:41:40 -0700 Subject: [PATCH 18/30] linting --- src/relay/pass/gradient_cell.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/relay/pass/gradient_cell.cc b/src/relay/pass/gradient_cell.cc index 7a2bc9cc95fa..3b5bf53fb1ea 100644 --- a/src/relay/pass/gradient_cell.cc +++ b/src/relay/pass/gradient_cell.cc @@ -47,8 +47,7 @@ namespace relay { class GradientCellTransform: public ExprMutator, public TypeMutator { public: explicit GradientCellTransform(IRModule module): - module_(module) - { + module_(module) { module_->ImportFromStd("gradient.rly"); } From 2a7968c4f420fa9291a97b1f52af9364a7215511 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Wed, 11 Mar 2020 21:51:49 -0700 Subject: [PATCH 19/30] remove changes to test files --- tests/python/relay/test_ir_parser.py | 37 ++++++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 7ba0a6de2780..fbe521340930 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -18,6 +18,7 @@ from tvm import te from tvm import relay from tvm.relay.analysis import graph_equal, assert_graph_equal +from tvm.relay.analysis import alpha_equal, assert_alpha_equal import pytest from numpy import isclose from typing import Union @@ -866,10 +867,42 @@ def test_extern_adt_defn(): """, mod ) - def test_import_grad(): mod = tvm.IRModule() mod.import_from_std("gradient.rly") if __name__ == "__main__": - pytest.main([__file__]) + test_comments() + test_int_literal() + test_float_literal() + test_bool_literal() + test_negative() + test_bin_op() + test_parens() + test_op_assoc() + test_let() + test_seq() + test_graph() + test_tuple() + test_func() + test_defn() + test_recursive_call() + test_ifelse() + test_call() + test_incomplete_type() + test_builtin_types() + test_tensor_type() + test_function_type() + test_tuple_type() + test_adt_defn() + test_empty_adt_defn() + test_multiple_cons_defn() + test_multiple_type_param_defn() + test_match() + test_adt_cons_expr() + test_duplicate_adt_defn() + test_duplicate_adt_cons() + test_duplicate_adt_cons_defn() + test_duplicate_global_var() + test_extern_adt_defn() + test_import_grad() \ No newline at end of file From 4f504c1ff2cf419e413efcad47e3afe0b115b653 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Wed, 11 Mar 2020 23:56:59 -0700 Subject: [PATCH 20/30] move gradient_cell.cc to transforms --- src/relay/{pass => transforms}/gradient_cell.cc | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/relay/{pass => transforms}/gradient_cell.cc (100%) diff --git a/src/relay/pass/gradient_cell.cc b/src/relay/transforms/gradient_cell.cc similarity index 100% rename from src/relay/pass/gradient_cell.cc rename to src/relay/transforms/gradient_cell.cc From 0f148611a1d9b5980fe819bdfa4dc86577559c7d Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Fri, 20 Mar 2020 19:29:19 -0700 Subject: [PATCH 21/30] revert change --- src/relay/transforms/gradient.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/transforms/gradient.cc b/src/relay/transforms/gradient.cc index eca85a3b9181..a3728e905922 100644 --- a/src/relay/transforms/gradient.cc +++ b/src/relay/transforms/gradient.cc @@ -18,7 +18,7 @@ */ /*! - * \file ad.cc + * \file gradient.cc * \brief API for Automatic Differentiation for the Relay IR. */ #include From fda4fdfbb7f4a132e8ebddab4847e1385b16ab66 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Fri, 20 Mar 2020 20:04:46 -0700 Subject: [PATCH 22/30] update files with new commits --- python/tvm/relay/transform/transform.py | 2 +- src/relay/transforms/gradient_cell.cc | 24 ++++++++++++------------ 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index bd0f33d09928..3edc1bf39aa0 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -230,7 +230,7 @@ def GradientCell(): ret: tvm.relay.Pass The registered pass that delays or reduces memory allocation """ - return _transform.GradientCell() + return _ffi_api.GradientCell() def FoldConstant(): """Fold the constant expressions in a Relay program. diff --git a/src/relay/transforms/gradient_cell.cc b/src/relay/transforms/gradient_cell.cc index 3b5bf53fb1ea..d981ca9c113f 100644 --- a/src/relay/transforms/gradient_cell.cc +++ b/src/relay/transforms/gradient_cell.cc @@ -68,8 +68,8 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { tvm::Array params = {VarNode::make("lhs", paramType), VarNode::make("rhs", paramType)}; Expr callAdd = CallNode::make(Op::Get("add"), {params[0], params[1]}); - Expr addTensorsFunc = FunctionNode::make(params, callAdd, paramType, - Array(), Attrs()); + Expr addTensorsFunc = Function(params, callAdd, paramType, + Array()); // pass add function and tensors into arguments args.push_back(addTensorsFunc); @@ -88,8 +88,8 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { VarNode::make("rhs", paramType)}; Expr callMultiply = CallNode::make(Op::Get("multiply"), {params[0], params[1]}); - Expr multTensorsFunc = FunctionNode::make(params, callMultiply, paramType, - Array(), Attrs()); + Expr multTensorsFunc = Function(params, callMultiply, paramType, + Array()); // pass multiply function and tensors into arguments args.push_back(multTensorsFunc); @@ -98,13 +98,13 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { } return CallNode::make(multFunc, args, Attrs(), {paramType}); } else if (op->name.compare("ones") == 0) { - Expr func = FunctionNode::make({}, {ExprMutator::VisitExpr_(call_node)}, - {call_node->checked_type()}, {}, Attrs()); + Expr func = Function({}, {ExprMutator::VisitExpr_(call_node)}, + {call_node->checked_type()}, {}); return CallNode::make(getGradCellConstructor("One"), {func}, Attrs(), {call_node->checked_type()}); } else if (op->name.compare("zeros") == 0) { - Expr func = FunctionNode::make({}, {ExprMutator::VisitExpr_(call_node)}, - {call_node->checked_type()}, {}, Attrs()); + Expr func = Function({}, {ExprMutator::VisitExpr_(call_node)}, + {call_node->checked_type()}, {}); return CallNode::make(getGradCellConstructor("Zero"), {func}, Attrs(), {call_node->checked_type()}); } @@ -120,13 +120,13 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { const Expr tensorRes = CallNode::make(call_node->op, args); if (op->name.compare("ones_like") == 0) { - Expr onesFunction = FunctionNode::make({}, tensorRes, - {call_node->checked_type()}, Array(), Attrs()); + Expr onesFunction = Function({}, tensorRes, + {call_node->checked_type()}, Array()); return CallNode::make(getGradCellConstructor("One"), {onesFunction}, Attrs(), {call_node->checked_type()}); } else if (op->name.compare("zeros_like") == 0) { - Expr zerosFunction = FunctionNode::make({}, tensorRes, - {call_node->checked_type()}, Array(), Attrs()); + Expr zerosFunction = Function({}, tensorRes, + {call_node->checked_type()}, Array()); return CallNode::make(getGradCellConstructor("Zero"), {zerosFunction}, Attrs(), {call_node->checked_type()}); } From 88e6744c7828a55e6a95cc51f43862ed391cc63e Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Fri, 20 Mar 2020 20:44:06 -0700 Subject: [PATCH 23/30] type --- tests/python/relay/test_pass_gradient_cell.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relay/test_pass_gradient_cell.py b/tests/python/relay/test_pass_gradient_cell.py index 46afd2457b20..5fd1ac71fa11 100644 --- a/tests/python/relay/test_pass_gradient_cell.py +++ b/tests/python/relay/test_pass_gradient_cell.py @@ -324,7 +324,7 @@ def test_ones(): y = ex.evaluate(y)(x) assert_allclose(y.asnumpy(), x.asnumpy() + np.ones_like(x.asnumpy())) -def test_zeros(): +def test_zeros_like(): mod = tvm.IRModule() shape = (10, 10) From c61485796fb1db943dc2523b32f8686f04f96ca9 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Sun, 22 Mar 2020 16:27:15 -0700 Subject: [PATCH 24/30] wrapper function to main outermost function type --- src/relay/transforms/gradient_cell.cc | 180 ++++++++++++++--- tests/python/relay/test_pass_gradient_cell.py | 187 ++++++++++-------- 2 files changed, 267 insertions(+), 100 deletions(-) diff --git a/src/relay/transforms/gradient_cell.cc b/src/relay/transforms/gradient_cell.cc index d981ca9c113f..d8c3aeaa103d 100644 --- a/src/relay/transforms/gradient_cell.cc +++ b/src/relay/transforms/gradient_cell.cc @@ -32,7 +32,29 @@ * in using the Raw constructor which means the tensor is instantiated in memory. * * It also overloads + and * operation which can increase performance when doing - * operations involving zero-filled or one-filled tensors. + * operations involving tensors with values of only 0 or 1. + * + * Note: this pass can only be used with functions where the input/output types are + * a combination of TupleTypes and TensorTypes + * + * Specify optimize 6 ops: + * - add + * - multiply + * - ones + * - ones_like + * - zeros + * - zeros_like + * + * This pass makes use of three visitor. The most important one visits the entire function, + * one is used for wrap inputs and one to unwrap outputs. + * + * For example: + * fn: TensorType[(10,10), float32] -> TensorType[(10,10), float32] + * + * After this pass + * fn: GradCell[TensorType[(10,10), float32]] -> GradCell[TensorType[(10,10), float32]] + * + * Thus, it is necessary to wrap this outer function so that the input/output types remain the same */ #include @@ -44,6 +66,98 @@ namespace tvm { namespace relay { +namespace GradientCellPass { // avoid polluting namespace + +/*! +* \brief Get constructor of GradCell TypeDef with name_hint +* +* module must have TypeDefinition of GradCell (defined in gradient.rly) +*/ +Constructor getGradCellConstructor(IRModule module, std::string name_hint) { + TypeData gradCell = module->LookupTypeDef("GradCell"); + for (Constructor c : gradCell->constructors) { + if (name_hint.compare(c->name_hint) == 0) { + return c; + } + } + + LOG(FATAL) << "Constructor " << name_hint << "not found in GradCell typedata."; + throw std::runtime_error("Constructor not found in GradCell typedata"); +} + +/*! +* \brief Visitor to wrap inputs +*/ +class InputVisitor: public ExprFunctor { + public: + explicit InputVisitor(IRModule module): module_(module) {} + + Expr wrapExpr(const Expr expr, const Type& type) { + if (type.as()) { + return CallNode::make(getGradCellConstructor(module_, "Raw"), + {expr}, Attrs(), {type}); + } else if (auto* type_anno = type.as()) { + tvm::Array fields; + for (int i = 0; i < type_anno->fields.size(); i++) { + const Type& t = type_anno->fields[i]; + fields.push_back(this->VisitExpr(TupleGetItemNode::make(expr, i), t)); + } + Expr tuple = TupleNode::make(fields); + return tuple; + } + + return expr; + } + + Expr VisitExpr_(const VarNode* op, const Type& t) final { + std::cout << op->type_annotation << std::endl; + return wrapExpr(GetRef(op), op->type_annotation); + } + + Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final { + return wrapExpr(GetRef(op), t); + } + private: + IRModule module_; +}; + +/*! +* \brief Visitor to unwrap output +*/ +class OutputVisitor: public ExprFunctor { + public: + explicit OutputVisitor(IRModule module): module_(module) {} + + Expr unwrapExpr(const Expr expr, const Type& type) { + if (auto* type_call = type.as()) { + if (type_call->func.same_as(module_->GetGlobalTypeVar("GradCell"))) { + return CallNode::make(module_->GetGlobalVar("FromGradCell"), {expr}); + } + return expr; + } else if (auto* type_anno = type.as()) { + tvm::Array fields; + for (int i = 0; i < type_anno->fields.size(); i++) { + const Type& t = type_anno->fields[i]; + fields.push_back(this->VisitExpr(TupleGetItemNode::make(expr, i), t)); + } + Expr tuple = TupleNode::make(fields); + return tuple; + } + + return expr; + } + + Expr VisitExpr_(const CallNode* op, const Type& t) final { + return unwrapExpr(GetRef(op), t); + } + + Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final { + return unwrapExpr(GetRef(op), t); + } + private: + IRModule module_; +}; + class GradientCellTransform: public ExprMutator, public TypeMutator { public: explicit GradientCellTransform(IRModule module): @@ -51,12 +165,40 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { module_->ImportFromStd("gradient.rly"); } + /*! + * \brief apply GradientCell transformation and wrap function + * so that function type stays the same + * + * input/output types should only be a combination of TupleTypes and TensorTypes + */ + Expr transform(const Expr& e) { + auto* f = (e).as(); + auto* transformed = this->Mutate(e).as(); + + if (e.same_as(GetRef(transformed))) { + return GetRef(transformed); + } + + // wrap inputs of Tensor type using InputVisitor class + tvm::Array args; + for (Var var: f->params) { + Expr wrappedInput = InputVisitor(module_).VisitExpr(var, var->checked_type()); + args.push_back(wrappedInput); + } + Expr transformedExpr = CallNode::make(GetRef(transformed), args); + + // unwrap outputs of GradCell type into Tensor type using OutputVisitor class + Expr tensorOutput = OutputVisitor(module_).VisitExpr(transformedExpr, transformed->ret_type); + return Function(f->params, tensorOutput, f->ret_type, Array()); + } + Expr VisitExpr_(const ConstantNode* op) final { - return CallNode::make(getGradCellConstructor("Raw"), + return CallNode::make(getGradCellConstructor(module_, "Raw"), {GetRef(op)}, Attrs(), {op->checked_type()}); } Expr VisitExpr_(const CallNode* call_node) final { + // optimize operators if (auto* op = (call_node->op).as()) { if (op->name.compare("add") == 0 && call_node->args.size() == 2 && AlphaEqual(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) { @@ -98,17 +240,22 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { } return CallNode::make(multFunc, args, Attrs(), {paramType}); } else if (op->name.compare("ones") == 0) { + // ones operator, use One constructor of GradCell Expr func = Function({}, {ExprMutator::VisitExpr_(call_node)}, {call_node->checked_type()}, {}); - return CallNode::make(getGradCellConstructor("One"), + return CallNode::make(getGradCellConstructor(module_, "One"), {func}, Attrs(), {call_node->checked_type()}); } else if (op->name.compare("zeros") == 0) { + // zeros operator, use Zero constructor of GradCell Expr func = Function({}, {ExprMutator::VisitExpr_(call_node)}, {call_node->checked_type()}, {}); - return CallNode::make(getGradCellConstructor("Zero"), + return CallNode::make(getGradCellConstructor(module_, "Zero"), {func}, Attrs(), {call_node->checked_type()}); } + // handle other ops + zeros_like + ones_like + // we put zeros_like and ones_like here to make use of + // code converting the arguments of CallNode into Tensor const auto fromFunc = module_->GetGlobalVar("FromGradCell"); tvm::Array args; // use FromGradCell to convert args to Tensor @@ -122,18 +269,18 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { if (op->name.compare("ones_like") == 0) { Expr onesFunction = Function({}, tensorRes, {call_node->checked_type()}, Array()); - return CallNode::make(getGradCellConstructor("One"), + return CallNode::make(getGradCellConstructor(module_, "One"), {onesFunction}, Attrs(), {call_node->checked_type()}); } else if (op->name.compare("zeros_like") == 0) { Expr zerosFunction = Function({}, tensorRes, {call_node->checked_type()}, Array()); - return CallNode::make(getGradCellConstructor("Zero"), + return CallNode::make(getGradCellConstructor(module_, "Zero"), {zerosFunction}, Attrs(), {call_node->checked_type()}); } - return CallNode::make(getGradCellConstructor("Raw"), {tensorRes}, + return CallNode::make(getGradCellConstructor(module_, "Raw"), {tensorRes}, Attrs(), {call_node->checked_type()}); } - + // call-> op is not a relay op return ExprMutator::VisitExpr_(call_node); } @@ -151,23 +298,12 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { private: // Module IRModule module_; - - // get constructor of GradCell with name - Constructor getGradCellConstructor(std::string name_hint) { - TypeData gradCell = module_->LookupTypeDef("GradCell"); - for (Constructor c : gradCell->constructors) { - if (name_hint.compare(c->name_hint) == 0) { - return c; - } - } - - CHECK(false) << "Constructor " << name_hint << "not found in GradCell typedata."; - throw std::runtime_error("Constructor not found in GradCell typedata"); - } }; +} // namespace GradientCellPass + Expr GradientCell(const Expr& e, IRModule mod) { - return GradientCellTransform(mod).Mutate(e); + return GradientCellPass::GradientCellTransform(mod).transform(e); } namespace transform { diff --git a/tests/python/relay/test_pass_gradient_cell.py b/tests/python/relay/test_pass_gradient_cell.py index 5fd1ac71fa11..2055771ba9eb 100644 --- a/tests/python/relay/test_pass_gradient_cell.py +++ b/tests/python/relay/test_pass_gradient_cell.py @@ -23,80 +23,98 @@ from tvm.testing import assert_allclose import pytest -def grad_cell_type(mod, shape, dtype): - grad_type = mod.get_global_type_var("GradCell") - type_arg = relay.TensorType(shape, dtype) - return grad_type(type_arg) - -def test_add(): +def test_tc(): + # test typechecks mod = tvm.IRModule() - shape = (10, 10) + shape = (20, 20) dtype = 'float32' t = relay.TensorType(shape, dtype) - x = relay.var("x", t) - y = relay.Function([x], x+x) + x1 = relay.var("x1", t) + x2 = relay.var("x2", t) + # f(x1,x2) = (x1-x2)*x2 + y = relay.Function([x1, x2], (x1 - x2) * x2) mod["main"] = y mod = transform.GradientCell()(mod) - new_type = grad_cell_type(mod, shape, dtype) - assert mod["main"].checked_type == relay.FuncType([new_type], new_type) + # function input/output types should remain the same + assert mod["main"].checked_type == relay.FuncType([t, t], t) -def test_add_tuple(): +def test_add(): + # test simple add mod = tvm.IRModule() shape = (10, 10) dtype = 'float32' t = relay.TensorType(shape, dtype) - x1 = relay.var("x1", t) - x2 = relay.var("x2", t) - t1 = relay.Tuple([x1, x2]) - y = relay.Function([x1, x2], relay.TupleGetItem(t1,0) + relay.TupleGetItem(t1,1)) + x = relay.var("x", t) + # f(x) = x+x + y = relay.Function([x], x+x) mod["main"] = y mod = transform.GradientCell()(mod) + y = mod["main"] - new_type = grad_cell_type(mod, shape, dtype) - assert mod["main"].checked_type == relay.FuncType([new_type, new_type], new_type) + assert mod["main"].checked_type == relay.FuncType([t], t) -def test_mult(): + ex = create_executor(mod=mod) + x = rand(dtype, *shape) + y = ex.evaluate(y)(x) + assert_allclose(y.asnumpy(), x.asnumpy() + x.asnumpy()) + +def test_add_tuple(): + # test input tuple and add items mod = tvm.IRModule() - shape = (15, 15) + shape = (10, 10) dtype = 'float32' - t = relay.TensorType(shape, dtype) + tensor_type = relay.TensorType(shape, dtype) + t = relay.TupleType([tensor_type, tensor_type]) x = relay.var("x", t) - y = relay.Function([x], x * x) + # f((x1,x2)) = x1 + x2 + y = relay.Function([x], relay.TupleGetItem(x, 0) + relay.TupleGetItem(x, 1)) mod["main"] = y mod = transform.GradientCell()(mod) + mod = transform.PrintIR(show_meta_data=True)(mod) + y = mod["main"] - new_type = grad_cell_type(mod, shape, dtype) - assert mod["main"].checked_type == relay.FuncType([new_type], new_type) + assert mod["main"].checked_type == relay.FuncType([t], tensor_type) -def test_tc(): + ex = create_executor(mod=mod) + x = (rand(dtype, *shape), rand(dtype, *shape)) + y = ex.evaluate(y)(x) + assert_allclose(y.asnumpy(), x[0].asnumpy() + x[1].asnumpy()) + +def test_mult(): + # test simple add mod = tvm.IRModule() - shape = (20, 20) + shape = (15, 15) dtype = 'float32' t = relay.TensorType(shape, dtype) - x1 = relay.var("x1", t) - x2 = relay.var("x2", t) - - y = relay.Function([x1, x2], (x1 - x2) * x2) + x = relay.var("x", t) + # f(x) = x*x + y = relay.Function([x], x * x) mod["main"] = y mod = transform.GradientCell()(mod) + y = mod["main"] + + assert mod["main"].checked_type == relay.FuncType([t], t) - new_type = grad_cell_type(mod, shape, dtype) - assert mod["main"].checked_type == relay.FuncType([new_type, new_type], new_type) + ex = create_executor(mod=mod) + x = rand(dtype, *shape) + y = ex.evaluate(y)(x) + assert_allclose(y.asnumpy(), x.asnumpy() * x.asnumpy()) def test_ret_tuple(): + # test return tuple mod = tvm.IRModule() shape = (10, 10) @@ -104,17 +122,24 @@ def test_ret_tuple(): t = relay.TensorType(shape, dtype) x = relay.var("x", t) - y = relay.RefCreate(x) - func = relay.Function([x], relay.Tuple([x,y])) + # f(x) = (x,x) + func = relay.Function([x], relay.Tuple([x,x * relay.const(2.0)])) func = run_infer_type(func) mod["main"] = func mod = transform.GradientCell()(mod) + func = mod["main"] + + assert mod["main"].checked_type == relay.FuncType([t], relay.TupleType([t, t])) - new_type = grad_cell_type(mod, shape, dtype) - assert mod["main"].checked_type == relay.FuncType([new_type], relay.TupleType([new_type, relay.RefType(new_type)])) + ex = create_executor(mod=mod) + x = rand(dtype, *shape) + y = ex.evaluate(func)(x) + assert_allclose(y[0].asnumpy(), x.asnumpy()) + assert_allclose(y[1].asnumpy(), x.asnumpy() * 2.0) def test_broadcast(): + # test broadcast add mod = tvm.IRModule() shape1 = (3, 4, 1) @@ -132,17 +157,17 @@ def test_broadcast(): mod["main"] = back_func mod = transform.GradientCell()(mod) + back_func = mod["main"] x1_np = rand(dtype, *shape1).asnumpy() x2_np = rand(dtype, *shape2).asnumpy() - expected_forward = x1_np + x2_np - x1_type = grad_cell_type(mod, shape1, dtype) - x2_type = grad_cell_type(mod, shape2, dtype) - expected_forward_type = grad_cell_type(mod, expected_forward.shape, dtype) - assert mod["main"].checked_type == relay.FuncType([x1_type, x2_type], - relay.TupleType([expected_forward_type, relay.TupleType([x1_type, x2_type])])) - - ex = create_executor() + expected_forward = x1_np + x2_np + + expected_forward_type = relay.TensorType(expected_forward.shape, dtype) + assert mod["main"].checked_type == relay.FuncType([t1, t2], + relay.TupleType([expected_forward_type, relay.TupleType([t1, t2])])) + + ex = create_executor(mod=mod) (forward), (grad_x1, grad_x2, ) = ex.evaluate(back_func)(x1_np, x2_np) assert_allclose(forward.asnumpy(), expected_forward) @@ -150,6 +175,8 @@ def test_broadcast(): assert_allclose(grad_x2.asnumpy(), np.ones_like(expected_forward).sum(axis=(0,1), keepdims=True).squeeze(axis=0)) def test_reverse_ad_identity(): + # test correctness after reverse mode ad + # of f(x) = x mod = tvm.IRModule() shape = (10, 10) @@ -164,20 +191,21 @@ def test_reverse_ad_identity(): back_func = run_infer_type(back_func) mod["main"] = back_func - mod = transform.GradientCell()(mod) + back_func = mod["main"] - new_type = grad_cell_type(mod, shape, dtype) - assert mod["main"].checked_type == relay.FuncType([new_type], - relay.TupleType([new_type, relay.TupleType([new_type])])) + assert mod["main"].checked_type == relay.FuncType([t], + relay.TupleType([t, relay.TupleType([t])])) - ex = create_executor() + ex = create_executor(mod=mod) x = rand(dtype, *shape) (forward), (grad,) = ex.evaluate(back_func)(x) assert_allclose(forward.asnumpy(), x.asnumpy()) assert_allclose(grad.asnumpy(), np.ones_like(x.asnumpy())) def test_multivar_reverse_ad(): + # test correctness after reverse mode ad + # of multivariate function mod = tvm.IRModule() shape = (10, 10) @@ -193,14 +221,13 @@ def test_multivar_reverse_ad(): back_func = run_infer_type(back_func) mod["main"] = back_func - mod = transform.GradientCell()(mod) + back_func = mod["main"] - new_type = grad_cell_type(mod, shape, dtype) - assert mod["main"].checked_type == relay.FuncType([new_type, new_type], - relay.TupleType([new_type, relay.TupleType([new_type, new_type])])) + assert mod["main"].checked_type == relay.FuncType([t, t], + relay.TupleType([t, relay.TupleType([t, t])])) - ex = create_executor() + ex = create_executor(mod=mod) x = rand(dtype, *shape) y = rand(dtype, *shape) (forward), (grad_x, grad_y, ) = ex.evaluate(back_func)(x, y) @@ -208,7 +235,8 @@ def test_multivar_reverse_ad(): assert_allclose(grad_x.asnumpy(), y.asnumpy()) assert_allclose(grad_y.asnumpy(), x.asnumpy()) -def test_partial_eval_before(): +def test_after_partial_eval(): + # test GradientCell transformation after PartialEval mod = tvm.IRModule() shape = (10, 10) @@ -224,6 +252,7 @@ def test_partial_eval_before(): back_func = run_infer_type(back_func) mod["main"] = back_func + back_func = mod["main"] seq = transform.Sequential([ transform.PartialEvaluate(), @@ -233,11 +262,10 @@ def test_partial_eval_before(): mod = seq(mod) - new_type = grad_cell_type(mod, shape, dtype) - assert mod["main"].checked_type == relay.FuncType([new_type, new_type], - relay.TupleType([new_type, relay.TupleType([new_type, new_type])])) + assert mod["main"].checked_type == relay.FuncType([t, t], + relay.TupleType([t, relay.TupleType([t, t])])) - ex = create_executor() + ex = create_executor(mod=mod) x = rand(dtype, *shape) y = rand(dtype, *shape) (forward), (grad_x, grad_y,) = ex.evaluate(back_func)(x, y) @@ -245,7 +273,8 @@ def test_partial_eval_before(): assert_allclose(grad_x.asnumpy(), y.asnumpy()) assert_allclose(grad_y.asnumpy(), x.asnumpy()) -def test_partial_eval_after_multivar(): +def test_before_partial_eval(): + # test GradientCell transformation before PartialEval mod = tvm.IRModule() shape = (10, 10) @@ -261,20 +290,18 @@ def test_partial_eval_after_multivar(): back_func = run_infer_type(back_func) mod["main"] = back_func - seq = transform.Sequential([ transform.GradientCell(), transform.PartialEvaluate(), transform.DeadCodeElimination() ]) - mod = seq(mod) + back_func = mod["main"] - new_type = grad_cell_type(mod, shape, dtype) - assert mod["main"].checked_type == relay.FuncType([new_type, new_type], - relay.TupleType([new_type, relay.TupleType([new_type, new_type])])) + assert mod["main"].checked_type == relay.FuncType([t, t], + relay.TupleType([t, relay.TupleType([t, t])])) - ex = create_executor() + ex = create_executor(mod=mod) x = rand(dtype, *shape) y = rand(dtype, *shape) (forward), (grad_x, grad_y,) = ex.evaluate(back_func)(x, y) @@ -283,6 +310,7 @@ def test_partial_eval_after_multivar(): assert_allclose(grad_y.asnumpy(), x.asnumpy()) def test_zeros(): + # test with zeros operator mod = tvm.IRModule() shape = (10, 10) @@ -294,16 +322,17 @@ def test_zeros(): mod["main"] = y mod = transform.GradientCell()(mod) + y = mod["main"] - new_type = grad_cell_type(mod, shape, dtype) - assert mod["main"].checked_type == relay.FuncType([new_type], new_type) + assert mod["main"].checked_type == relay.FuncType([t], t) - ex = create_executor() + ex = create_executor(mod=mod) x = rand(dtype, *shape) y = ex.evaluate(y)(x) assert_allclose(y.asnumpy(), x.asnumpy()) def test_ones(): + # test with ones operator mod = tvm.IRModule() shape = (10, 10) @@ -315,16 +344,17 @@ def test_ones(): mod["main"] = y mod = transform.GradientCell()(mod) + y = mod["main"] - new_type = grad_cell_type(mod, shape, dtype) - assert mod["main"].checked_type == relay.FuncType([new_type], new_type) + assert mod["main"].checked_type == relay.FuncType([t], t) - ex = create_executor() + ex = create_executor(mod=mod) x = rand(dtype, *shape) y = ex.evaluate(y)(x) assert_allclose(y.asnumpy(), x.asnumpy() + np.ones_like(x.asnumpy())) def test_zeros_like(): + # test with zeros_like operator mod = tvm.IRModule() shape = (10, 10) @@ -336,16 +366,17 @@ def test_zeros_like(): mod["main"] = y mod = transform.GradientCell()(mod) + y = mod["main"] - new_type = grad_cell_type(mod, shape, dtype) - assert mod["main"].checked_type == relay.FuncType([new_type], new_type) + assert mod["main"].checked_type == relay.FuncType([t], t) - ex = create_executor() + ex = create_executor(mod=mod) x = rand(dtype, *shape) y = ex.evaluate(y)(x) assert_allclose(y.asnumpy(), x.asnumpy()) def test_ones_like(): + # test with ones_like operator mod = tvm.IRModule() shape = (10, 10) @@ -357,11 +388,11 @@ def test_ones_like(): mod["main"] = y mod = transform.GradientCell()(mod) + y = mod["main"] - new_type = grad_cell_type(mod, shape, dtype) - assert mod["main"].checked_type == relay.FuncType([new_type], new_type) + assert mod["main"].checked_type == relay.FuncType([t], t) - ex = create_executor() + ex = create_executor(mod=mod) x = rand(dtype, *shape) y = ex.evaluate(y)(x) assert_allclose(y.asnumpy(), x.asnumpy() + np.ones_like(x.asnumpy())) From 09556812b8ed75f25536a6fc8a7e63c553be4d2d Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Sun, 22 Mar 2020 16:34:11 -0700 Subject: [PATCH 25/30] fix linting --- src/relay/transforms/gradient_cell.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/relay/transforms/gradient_cell.cc b/src/relay/transforms/gradient_cell.cc index d8c3aeaa103d..b891e68c9e3c 100644 --- a/src/relay/transforms/gradient_cell.cc +++ b/src/relay/transforms/gradient_cell.cc @@ -66,7 +66,7 @@ namespace tvm { namespace relay { -namespace GradientCellPass { // avoid polluting namespace +namespace GradientCellPass { // avoid polluting namespace /*! * \brief Get constructor of GradCell TypeDef with name_hint @@ -178,10 +178,10 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { if (e.same_as(GetRef(transformed))) { return GetRef(transformed); } - + // wrap inputs of Tensor type using InputVisitor class tvm::Array args; - for (Var var: f->params) { + for (Var var : f->params) { Expr wrappedInput = InputVisitor(module_).VisitExpr(var, var->checked_type()); args.push_back(wrappedInput); } @@ -189,7 +189,7 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { // unwrap outputs of GradCell type into Tensor type using OutputVisitor class Expr tensorOutput = OutputVisitor(module_).VisitExpr(transformedExpr, transformed->ret_type); - return Function(f->params, tensorOutput, f->ret_type, Array()); + return Function(f->params, tensorOutput, f->ret_type, Array()); } Expr VisitExpr_(const ConstantNode* op) final { @@ -254,7 +254,7 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { } // handle other ops + zeros_like + ones_like - // we put zeros_like and ones_like here to make use of + // we put zeros_like and ones_like here to make use of // code converting the arguments of CallNode into Tensor const auto fromFunc = module_->GetGlobalVar("FromGradCell"); tvm::Array args; From 40e629c17a38882426b9ee643ddd2a9ef1e43e13 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Sun, 22 Mar 2020 16:44:20 -0700 Subject: [PATCH 26/30] fix unsigned and signed int comparison --- src/relay/transforms/gradient_cell.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relay/transforms/gradient_cell.cc b/src/relay/transforms/gradient_cell.cc index b891e68c9e3c..0a224908fd74 100644 --- a/src/relay/transforms/gradient_cell.cc +++ b/src/relay/transforms/gradient_cell.cc @@ -98,7 +98,7 @@ class InputVisitor: public ExprFunctor { {expr}, Attrs(), {type}); } else if (auto* type_anno = type.as()) { tvm::Array fields; - for (int i = 0; i < type_anno->fields.size(); i++) { + for (size_t i = 0; i < type_anno->fields.size(); i++) { const Type& t = type_anno->fields[i]; fields.push_back(this->VisitExpr(TupleGetItemNode::make(expr, i), t)); } @@ -136,7 +136,7 @@ class OutputVisitor: public ExprFunctor { return expr; } else if (auto* type_anno = type.as()) { tvm::Array fields; - for (int i = 0; i < type_anno->fields.size(); i++) { + for (size_t i = 0; i < type_anno->fields.size(); i++) { const Type& t = type_anno->fields[i]; fields.push_back(this->VisitExpr(TupleGetItemNode::make(expr, i), t)); } From 4d4b350695235380eba3e07500d11f2a85326570 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Sun, 22 Mar 2020 18:15:49 -0700 Subject: [PATCH 27/30] review --- src/relay/transforms/gradient_cell.cc | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/relay/transforms/gradient_cell.cc b/src/relay/transforms/gradient_cell.cc index 0a224908fd74..b5504b3ae7ee 100644 --- a/src/relay/transforms/gradient_cell.cc +++ b/src/relay/transforms/gradient_cell.cc @@ -37,7 +37,7 @@ * Note: this pass can only be used with functions where the input/output types are * a combination of TupleTypes and TensorTypes * - * Specify optimize 6 ops: + * This pass optimizes 6 ops: * - add * - multiply * - ones @@ -66,8 +66,6 @@ namespace tvm { namespace relay { -namespace GradientCellPass { // avoid polluting namespace - /*! * \brief Get constructor of GradCell TypeDef with name_hint * @@ -300,10 +298,8 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { IRModule module_; }; -} // namespace GradientCellPass - Expr GradientCell(const Expr& e, IRModule mod) { - return GradientCellPass::GradientCellTransform(mod).transform(e); + return GradientCellTransform(mod).transform(e); } namespace transform { From 953791b2d22c9471e20f7eb143efba8d554b0852 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Mon, 23 Mar 2020 11:31:11 -0700 Subject: [PATCH 28/30] GetConstructor definition in module and change op comparison --- include/tvm/ir/module.h | 8 +++++ src/ir/module.cc | 12 ++++++++ src/relay/transforms/gradient_cell.cc | 44 +++++++++------------------ 3 files changed, 34 insertions(+), 30 deletions(-) diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 1ee7c323336d..4613bec70633 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -162,6 +162,14 @@ class IRModuleNode : public Object { */ TVM_DLL Array GetGlobalTypeVars() const; + /*! + * \brief Find constructor of ADT using name + * \param adt name of the ADT the constructor belongs to + * \param cons name of the constructor + * \returns Constructor of ADT, error if not found + */ + TVM_DLL Constructor GetConstructor(const std::string& adt, const std::string& cons) const; + /*! * \brief Look up a global function by its variable. * \param var The global var to lookup. diff --git a/src/ir/module.cc b/src/ir/module.cc index 45f39d5ade88..a78a7525425c 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -96,6 +96,18 @@ GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const std::string& name) const { return (*it).second; } +Constructor IRModuleNode::GetConstructor(const std::string& adt, const std::string& cons) const { + TypeData typeDef = this->LookupTypeDef(adt); + for (Constructor c : typeDef->constructors) { + if (cons.compare(c->name_hint) == 0) { + return c; + } + } + + LOG(FATAL) << adt << " does not contain constructor " << cons; + throw std::runtime_error("Constructor Not Found."); +} + tvm::Array IRModuleNode::GetGlobalTypeVars() const { std::vector global_type_vars; for (const auto& pair : global_type_var_map_) { diff --git a/src/relay/transforms/gradient_cell.cc b/src/relay/transforms/gradient_cell.cc index b5504b3ae7ee..eb60176c1225 100644 --- a/src/relay/transforms/gradient_cell.cc +++ b/src/relay/transforms/gradient_cell.cc @@ -66,23 +66,6 @@ namespace tvm { namespace relay { -/*! -* \brief Get constructor of GradCell TypeDef with name_hint -* -* module must have TypeDefinition of GradCell (defined in gradient.rly) -*/ -Constructor getGradCellConstructor(IRModule module, std::string name_hint) { - TypeData gradCell = module->LookupTypeDef("GradCell"); - for (Constructor c : gradCell->constructors) { - if (name_hint.compare(c->name_hint) == 0) { - return c; - } - } - - LOG(FATAL) << "Constructor " << name_hint << "not found in GradCell typedata."; - throw std::runtime_error("Constructor not found in GradCell typedata"); -} - /*! * \brief Visitor to wrap inputs */ @@ -92,7 +75,7 @@ class InputVisitor: public ExprFunctor { Expr wrapExpr(const Expr expr, const Type& type) { if (type.as()) { - return CallNode::make(getGradCellConstructor(module_, "Raw"), + return CallNode::make(module_->GetConstructor("GradCell", "Raw"), {expr}, Attrs(), {type}); } else if (auto* type_anno = type.as()) { tvm::Array fields; @@ -191,14 +174,15 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { } Expr VisitExpr_(const ConstantNode* op) final { - return CallNode::make(getGradCellConstructor(module_, "Raw"), + return CallNode::make(module_->GetConstructor("GradCell", "Raw"), {GetRef(op)}, Attrs(), {op->checked_type()}); } Expr VisitExpr_(const CallNode* call_node) final { // optimize operators if (auto* op = (call_node->op).as()) { - if (op->name.compare("add") == 0 && call_node->args.size() == 2 && + Expr op_expr = GetRef(op); + if (op_expr == Op::Get("add") && call_node->args.size() == 2 && AlphaEqual(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) { // case: "add" between two tensors of the same size const auto addFunc = module_->GetGlobalVar("AddGradCell"); @@ -217,7 +201,7 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { args.push_back(VisitExpr(expr)); } return CallNode::make(addFunc, args, Attrs(), {paramType}); - } else if (op->name.compare("multiply") == 0 && call_node->args.size() == 2 && + } else if (op_expr == Op::Get("multiply") && call_node->args.size() == 2 && AlphaEqual(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) { // case: "multiply" between two tensors of the same size const auto multFunc = module_->GetGlobalVar("MultiplyGradCell"); @@ -237,17 +221,17 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { args.push_back(VisitExpr(expr)); } return CallNode::make(multFunc, args, Attrs(), {paramType}); - } else if (op->name.compare("ones") == 0) { + } else if (op_expr == Op::Get("ones")) { // ones operator, use One constructor of GradCell Expr func = Function({}, {ExprMutator::VisitExpr_(call_node)}, {call_node->checked_type()}, {}); - return CallNode::make(getGradCellConstructor(module_, "One"), + return CallNode::make(module_->GetConstructor("GradCell", "One"), {func}, Attrs(), {call_node->checked_type()}); - } else if (op->name.compare("zeros") == 0) { + } else if (op_expr == Op::Get("zeros")) { // zeros operator, use Zero constructor of GradCell Expr func = Function({}, {ExprMutator::VisitExpr_(call_node)}, {call_node->checked_type()}, {}); - return CallNode::make(getGradCellConstructor(module_, "Zero"), + return CallNode::make(module_->GetConstructor("GradCell", "Zero"), {func}, Attrs(), {call_node->checked_type()}); } @@ -264,18 +248,18 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { const Expr tensorRes = CallNode::make(call_node->op, args); - if (op->name.compare("ones_like") == 0) { + if (op_expr == Op::Get("ones_like")) { Expr onesFunction = Function({}, tensorRes, {call_node->checked_type()}, Array()); - return CallNode::make(getGradCellConstructor(module_, "One"), + return CallNode::make(module_->GetConstructor("GradCell", "One"), {onesFunction}, Attrs(), {call_node->checked_type()}); - } else if (op->name.compare("zeros_like") == 0) { + } else if (op_expr == Op::Get("zeros_like")) { Expr zerosFunction = Function({}, tensorRes, {call_node->checked_type()}, Array()); - return CallNode::make(getGradCellConstructor(module_, "Zero"), + return CallNode::make(module_->GetConstructor("GradCell", "Zero"), {zerosFunction}, Attrs(), {call_node->checked_type()}); } - return CallNode::make(getGradCellConstructor(module_, "Raw"), {tensorRes}, + return CallNode::make(module_->GetConstructor("GradCell", "Raw"), {tensorRes}, Attrs(), {call_node->checked_type()}); } // call-> op is not a relay op From 3e18cdecefbf34a9422362b060a338df1f531f9b Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Mon, 23 Mar 2020 11:41:51 -0700 Subject: [PATCH 29/30] update node instantiations --- src/relay/transforms/gradient_cell.cc | 48 +++++++++++++-------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/src/relay/transforms/gradient_cell.cc b/src/relay/transforms/gradient_cell.cc index eb60176c1225..2c21c751a0e7 100644 --- a/src/relay/transforms/gradient_cell.cc +++ b/src/relay/transforms/gradient_cell.cc @@ -75,15 +75,15 @@ class InputVisitor: public ExprFunctor { Expr wrapExpr(const Expr expr, const Type& type) { if (type.as()) { - return CallNode::make(module_->GetConstructor("GradCell", "Raw"), + return Call(module_->GetConstructor("GradCell", "Raw"), {expr}, Attrs(), {type}); } else if (auto* type_anno = type.as()) { tvm::Array fields; for (size_t i = 0; i < type_anno->fields.size(); i++) { const Type& t = type_anno->fields[i]; - fields.push_back(this->VisitExpr(TupleGetItemNode::make(expr, i), t)); + fields.push_back(this->VisitExpr(TupleGetItem(expr, i), t)); } - Expr tuple = TupleNode::make(fields); + Expr tuple = Tuple(fields); return tuple; } @@ -112,16 +112,16 @@ class OutputVisitor: public ExprFunctor { Expr unwrapExpr(const Expr expr, const Type& type) { if (auto* type_call = type.as()) { if (type_call->func.same_as(module_->GetGlobalTypeVar("GradCell"))) { - return CallNode::make(module_->GetGlobalVar("FromGradCell"), {expr}); + return Call(module_->GetGlobalVar("FromGradCell"), {expr}); } return expr; } else if (auto* type_anno = type.as()) { tvm::Array fields; for (size_t i = 0; i < type_anno->fields.size(); i++) { const Type& t = type_anno->fields[i]; - fields.push_back(this->VisitExpr(TupleGetItemNode::make(expr, i), t)); + fields.push_back(this->VisitExpr(TupleGetItem(expr, i), t)); } - Expr tuple = TupleNode::make(fields); + Expr tuple = Tuple(fields); return tuple; } @@ -166,7 +166,7 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { Expr wrappedInput = InputVisitor(module_).VisitExpr(var, var->checked_type()); args.push_back(wrappedInput); } - Expr transformedExpr = CallNode::make(GetRef(transformed), args); + Expr transformedExpr = Call(GetRef(transformed), args); // unwrap outputs of GradCell type into Tensor type using OutputVisitor class Expr tensorOutput = OutputVisitor(module_).VisitExpr(transformedExpr, transformed->ret_type); @@ -174,7 +174,7 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { } Expr VisitExpr_(const ConstantNode* op) final { - return CallNode::make(module_->GetConstructor("GradCell", "Raw"), + return Call(module_->GetConstructor("GradCell", "Raw"), {GetRef(op)}, Attrs(), {op->checked_type()}); } @@ -189,9 +189,9 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { tvm::Array args; // create add function Type paramType = call_node->args[0]->checked_type(); - tvm::Array params = {VarNode::make("lhs", paramType), - VarNode::make("rhs", paramType)}; - Expr callAdd = CallNode::make(Op::Get("add"), {params[0], params[1]}); + tvm::Array params = {Var("lhs", paramType), + Var("rhs", paramType)}; + Expr callAdd = Call(Op::Get("add"), {params[0], params[1]}); Expr addTensorsFunc = Function(params, callAdd, paramType, Array()); @@ -200,7 +200,7 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { for (Expr expr : call_node->args) { args.push_back(VisitExpr(expr)); } - return CallNode::make(addFunc, args, Attrs(), {paramType}); + return Call(addFunc, args, Attrs(), {paramType}); } else if (op_expr == Op::Get("multiply") && call_node->args.size() == 2 && AlphaEqual(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) { // case: "multiply" between two tensors of the same size @@ -208,9 +208,9 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { // create multiply function tvm::Array args; Type paramType = call_node->args[0]->checked_type(); - tvm::Array params = {VarNode::make("lhs", paramType), - VarNode::make("rhs", paramType)}; - Expr callMultiply = CallNode::make(Op::Get("multiply"), + tvm::Array params = {Var("lhs", paramType), + Var("rhs", paramType)}; + Expr callMultiply = Call(Op::Get("multiply"), {params[0], params[1]}); Expr multTensorsFunc = Function(params, callMultiply, paramType, Array()); @@ -220,18 +220,18 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { for (Expr expr : call_node->args) { args.push_back(VisitExpr(expr)); } - return CallNode::make(multFunc, args, Attrs(), {paramType}); + return Call(multFunc, args, Attrs(), {paramType}); } else if (op_expr == Op::Get("ones")) { // ones operator, use One constructor of GradCell Expr func = Function({}, {ExprMutator::VisitExpr_(call_node)}, {call_node->checked_type()}, {}); - return CallNode::make(module_->GetConstructor("GradCell", "One"), + return Call(module_->GetConstructor("GradCell", "One"), {func}, Attrs(), {call_node->checked_type()}); } else if (op_expr == Op::Get("zeros")) { // zeros operator, use Zero constructor of GradCell Expr func = Function({}, {ExprMutator::VisitExpr_(call_node)}, {call_node->checked_type()}, {}); - return CallNode::make(module_->GetConstructor("GradCell", "Zero"), + return Call(module_->GetConstructor("GradCell", "Zero"), {func}, Attrs(), {call_node->checked_type()}); } @@ -242,24 +242,24 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { tvm::Array args; // use FromGradCell to convert args to Tensor for (Expr expr : call_node->args) { - args.push_back(CallNode::make(fromFunc, - {VisitExpr(expr)}, Attrs(), {expr->checked_type()})); + args.push_back(Call(fromFunc, + {VisitExpr(expr)}, Attrs(), {expr->checked_type()})); } - const Expr tensorRes = CallNode::make(call_node->op, args); + const Expr tensorRes = Call(call_node->op, args); if (op_expr == Op::Get("ones_like")) { Expr onesFunction = Function({}, tensorRes, {call_node->checked_type()}, Array()); - return CallNode::make(module_->GetConstructor("GradCell", "One"), + return Call(module_->GetConstructor("GradCell", "One"), {onesFunction}, Attrs(), {call_node->checked_type()}); } else if (op_expr == Op::Get("zeros_like")) { Expr zerosFunction = Function({}, tensorRes, {call_node->checked_type()}, Array()); - return CallNode::make(module_->GetConstructor("GradCell", "Zero"), + return Call(module_->GetConstructor("GradCell", "Zero"), {zerosFunction}, Attrs(), {call_node->checked_type()}); } - return CallNode::make(module_->GetConstructor("GradCell", "Raw"), {tensorRes}, + return Call(module_->GetConstructor("GradCell", "Raw"), {tensorRes}, Attrs(), {call_node->checked_type()}); } // call-> op is not a relay op From 4c15056f67eebe3ef06ec4361d20484626b971d5 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Mon, 23 Mar 2020 20:09:20 -0700 Subject: [PATCH 30/30] increase code readability --- include/tvm/relay/transform.h | 2 +- python/tvm/relay/transform/transform.py | 9 +- ...gradient_cell.cc => lazy_gradient_init.cc} | 235 +++++++++--------- ...ell.py => test_pass_lazy_gradient_init.py} | 72 +++--- 4 files changed, 163 insertions(+), 155 deletions(-) rename src/relay/transforms/{gradient_cell.cc => lazy_gradient_init.cc} (57%) rename tests/python/relay/{test_pass_gradient_cell.py => test_pass_lazy_gradient_init.py} (85%) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 3f5be9694649..deb084c65d54 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -89,7 +89,7 @@ TVM_DLL Pass DeadCodeElimination(bool inline_once = false); * * \return the pass */ -TVM_DLL Pass GradientCell(); +TVM_DLL Pass LazyGradientInit(); /*! * \brief Fold constant expressions. diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 3edc1bf39aa0..aa17c7f3de1c 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -219,8 +219,8 @@ def DeadCodeElimination(inline_once=False): """ return _ffi_api.DeadCodeElimination(inline_once) -def GradientCell(): - """Reduces memory usage of tensors with all 0s or 1s +def LazyGradientInit(): + """Reduces memory usage of gradient tensors Parameters ---------- @@ -228,9 +228,10 @@ def GradientCell(): Returns ------- ret: tvm.relay.Pass - The registered pass that delays or reduces memory allocation + A pass which delays and/or reduces memory allocation, + by lazily allocating 0 or one filled tensors. """ - return _ffi_api.GradientCell() + return _ffi_api.LazyGradientInit() def FoldConstant(): """Fold the constant expressions in a Relay program. diff --git a/src/relay/transforms/gradient_cell.cc b/src/relay/transforms/lazy_gradient_init.cc similarity index 57% rename from src/relay/transforms/gradient_cell.cc rename to src/relay/transforms/lazy_gradient_init.cc index 2c21c751a0e7..ba6ca05663bb 100644 --- a/src/relay/transforms/gradient_cell.cc +++ b/src/relay/transforms/lazy_gradient_init.cc @@ -19,12 +19,14 @@ /*! * - * \file gradient_cell.cc + * \file lazy_gradient_init.cc * - * \brief Convert all tensors to a Gradient Cell + * \brief Lazily instantiate 0-filled or 1-filled tensors. + * This pass should be used after reverse-mode ad so that gradient tensors + * are not instantiated until after the forward pass. * * This pass delays or removes memory allocation by converting tensors into - * GradCell, an algebraic data type defined in gradient.rly + * GradCell, an algebraic data type defined in gradient.rly. * * This will delay or decrease memory usage. All calls to * ones, ones_like, zeros, zeros_like will call the One or Zero constructor @@ -67,13 +69,28 @@ namespace tvm { namespace relay { /*! -* \brief Visitor to wrap inputs +* \brief Visitor appropriately wraps tensors with Raw constructor +* +* Recursively looks at the type of the expression (TensorType or TupleType are only supported for now) +* and either call the GradCell constructor if TensorType +* or unfold and recursively visit if TupleType */ class InputVisitor: public ExprFunctor { public: explicit InputVisitor(IRModule module): module_(module) {} - Expr wrapExpr(const Expr expr, const Type& type) { + Expr VisitExpr_(const VarNode* op, const Type& t) final { + std::cout << op->type_annotation << std::endl; + return WrapExpr(GetRef(op), op->type_annotation); + } + + Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final { + return WrapExpr(GetRef(op), t); + } + private: + IRModule module_; + + Expr WrapExpr(const Expr expr, const Type& type) { if (type.as()) { return Call(module_->GetConstructor("GradCell", "Raw"), {expr}, Attrs(), {type}); @@ -89,27 +106,30 @@ class InputVisitor: public ExprFunctor { return expr; } - - Expr VisitExpr_(const VarNode* op, const Type& t) final { - std::cout << op->type_annotation << std::endl; - return wrapExpr(GetRef(op), op->type_annotation); - } - - Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final { - return wrapExpr(GetRef(op), t); - } - private: - IRModule module_; }; /*! -* \brief Visitor to unwrap output +* \brief Visitor appropriately unwraps expressions with GradCell type into Tensors +* +* Recursively looks at the type of the expression +* and either use the FromGradCell function if TypeCall to GradCell +* or unfold and recursively visit if TupleType */ class OutputVisitor: public ExprFunctor { public: explicit OutputVisitor(IRModule module): module_(module) {} - Expr unwrapExpr(const Expr expr, const Type& type) { + Expr VisitExpr_(const CallNode* op, const Type& t) final { + return UnwrapExpr(GetRef(op), t); + } + + Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final { + return UnwrapExpr(GetRef(op), t); + } + private: + IRModule module_; + + Expr UnwrapExpr(const Expr expr, const Type& type) { if (auto* type_call = type.as()) { if (type_call->func.same_as(module_->GetGlobalTypeVar("GradCell"))) { return Call(module_->GetGlobalVar("FromGradCell"), {expr}); @@ -127,32 +147,22 @@ class OutputVisitor: public ExprFunctor { return expr; } - - Expr VisitExpr_(const CallNode* op, const Type& t) final { - return unwrapExpr(GetRef(op), t); - } - - Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final { - return unwrapExpr(GetRef(op), t); - } - private: - IRModule module_; }; -class GradientCellTransform: public ExprMutator, public TypeMutator { +class LazyGradientInitializer: public ExprMutator, public TypeMutator { public: - explicit GradientCellTransform(IRModule module): + explicit LazyGradientInitializer(IRModule module): module_(module) { module_->ImportFromStd("gradient.rly"); } /*! - * \brief apply GradientCell transformation and wrap function + * \brief apply LazyGradientInit transformation and wrap function * so that function type stays the same * * input/output types should only be a combination of TupleTypes and TensorTypes */ - Expr transform(const Expr& e) { + Expr Transform(const Expr& e) { auto* f = (e).as(); auto* transformed = this->Mutate(e).as(); @@ -179,90 +189,46 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { } Expr VisitExpr_(const CallNode* call_node) final { - // optimize operators if (auto* op = (call_node->op).as()) { Expr op_expr = GetRef(op); - if (op_expr == Op::Get("add") && call_node->args.size() == 2 && - AlphaEqual(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) { - // case: "add" between two tensors of the same size - const auto addFunc = module_->GetGlobalVar("AddGradCell"); - tvm::Array args; - // create add function - Type paramType = call_node->args[0]->checked_type(); - tvm::Array params = {Var("lhs", paramType), - Var("rhs", paramType)}; - Expr callAdd = Call(Op::Get("add"), {params[0], params[1]}); - Expr addTensorsFunc = Function(params, callAdd, paramType, - Array()); - - // pass add function and tensors into arguments - args.push_back(addTensorsFunc); - for (Expr expr : call_node->args) { - args.push_back(VisitExpr(expr)); - } - return Call(addFunc, args, Attrs(), {paramType}); - } else if (op_expr == Op::Get("multiply") && call_node->args.size() == 2 && - AlphaEqual(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) { - // case: "multiply" between two tensors of the same size - const auto multFunc = module_->GetGlobalVar("MultiplyGradCell"); - // create multiply function - tvm::Array args; - Type paramType = call_node->args[0]->checked_type(); - tvm::Array params = {Var("lhs", paramType), - Var("rhs", paramType)}; - Expr callMultiply = Call(Op::Get("multiply"), - {params[0], params[1]}); - Expr multTensorsFunc = Function(params, callMultiply, paramType, - Array()); - - // pass multiply function and tensors into arguments - args.push_back(multTensorsFunc); - for (Expr expr : call_node->args) { - args.push_back(VisitExpr(expr)); - } - return Call(multFunc, args, Attrs(), {paramType}); - } else if (op_expr == Op::Get("ones")) { - // ones operator, use One constructor of GradCell - Expr func = Function({}, {ExprMutator::VisitExpr_(call_node)}, - {call_node->checked_type()}, {}); - return Call(module_->GetConstructor("GradCell", "One"), - {func}, Attrs(), {call_node->checked_type()}); - } else if (op_expr == Op::Get("zeros")) { - // zeros operator, use Zero constructor of GradCell - Expr func = Function({}, {ExprMutator::VisitExpr_(call_node)}, - {call_node->checked_type()}, {}); - return Call(module_->GetConstructor("GradCell", "Zero"), - {func}, Attrs(), {call_node->checked_type()}); + + if (op_expr == Op::Get("add")) { + return CallGradCellFunction(call_node, module_->GetGlobalVar("AddGradCell")); } - // handle other ops + zeros_like + ones_like - // we put zeros_like and ones_like here to make use of - // code converting the arguments of CallNode into Tensor - const auto fromFunc = module_->GetGlobalVar("FromGradCell"); - tvm::Array args; - // use FromGradCell to convert args to Tensor - for (Expr expr : call_node->args) { - args.push_back(Call(fromFunc, - {VisitExpr(expr)}, Attrs(), {expr->checked_type()})); + if (op_expr == Op::Get("multiply")) { + return CallGradCellFunction(call_node, module_->GetGlobalVar("MultiplyGradCell")); } - const Expr tensorRes = Call(call_node->op, args); + if (op_expr == Op::Get("ones") || op_expr == Op::Get("zeros")) { + // fn() -> T, function returns result of the operation + Expr func = Function({}, {ExprMutator::VisitExpr_(call_node)}, + {call_node->checked_type()}, {}); + // call appropriate GradCell constructor + std::string constructor_name = op_expr == Op::Get("ones") ? "One" : "Zero"; + return Call(module_->GetConstructor("GradCell", constructor_name), + {func}, Attrs(), {call_node->checked_type()}); + } - if (op_expr == Op::Get("ones_like")) { - Expr onesFunction = Function({}, tensorRes, + if (op_expr == Op::Get("ones_like") || op_expr == Op::Get("zeros_like")) { + // ones_like and zeros_like need TensorType input + Expr result = CallPrimitiveOp(call_node); + // fn() -> T, function returns result of operation + Expr func = Function({}, result, {call_node->checked_type()}, Array()); + // call appropriate GradCell constructor + std::string constructor_name = op_expr == Op::Get("ones_like") ? "One" : "Zero"; return Call(module_->GetConstructor("GradCell", "One"), - {onesFunction}, Attrs(), {call_node->checked_type()}); - } else if (op_expr == Op::Get("zeros_like")) { - Expr zerosFunction = Function({}, tensorRes, - {call_node->checked_type()}, Array()); - return Call(module_->GetConstructor("GradCell", "Zero"), - {zerosFunction}, Attrs(), {call_node->checked_type()}); + {func}, Attrs(), {call_node->checked_type()}); } - return Call(module_->GetConstructor("GradCell", "Raw"), {tensorRes}, + + // handle all other ops + Expr result = CallPrimitiveOp(call_node); + // wrap result with Raw constructor + return Call(module_->GetConstructor("GradCell", "Raw"), {result}, Attrs(), {call_node->checked_type()}); } - // call-> op is not a relay op + // not an op return ExprMutator::VisitExpr_(call_node); } @@ -280,23 +246,70 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { private: // Module IRModule module_; + + /*! + * \brief Convert call_node to add/multiply op to use overloaded functions for GradCell type + */ + Expr CallGradCellFunction(const CallNode* call_node, GlobalVar overloaded_op) { + // can only use overloaded functions if 2 arguments of same type + if (call_node->args.size() != 2 || + !AlphaEqual(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) { + Expr result = CallPrimitiveOp(call_node); + return Call(module_->GetConstructor("GradCell", "Raw"), {result}, + Attrs(), {call_node->checked_type()}); + } + + tvm::Array args; + // create "fallback" function for overloaded function + Type paramType = call_node->args[0]->checked_type(); + tvm::Array params = {Var("lhs", paramType), + Var("rhs", paramType)}; + // use primitive op in this case + Expr callOp = Call(call_node->op, {params[0], params[1]}); + Expr func = Function(params, callOp, paramType, + Array()); + + // pass "fallback" function and tensors as arguments + args.push_back(func); + for (Expr expr : call_node->args) { + args.push_back(VisitExpr(expr)); + } + // return new call to overloaded function + return Call(overloaded_op, args, Attrs(), {paramType}); + } + + /*! + * \brief Convert calls to other ops by converting args into TensorType + * \return call expr returning result of op + */ + Expr CallPrimitiveOp(const CallNode* call_node) { + const auto fromFunc = module_->GetGlobalVar("FromGradCell"); + tvm::Array args; + // use FromGradCell to convert args to Tensor + for (Expr expr : call_node->args) { + args.push_back(Call(fromFunc, + {VisitExpr(expr)}, Attrs(), {expr->checked_type()})); + } + // result of operation + return Call(call_node->op, args); + } }; -Expr GradientCell(const Expr& e, IRModule mod) { - return GradientCellTransform(mod).transform(e); +Expr LazyGradientInit(const Expr& e, IRModule mod) { + return LazyGradientInitializer(mod).Transform(e); } namespace transform { -Pass GradientCell() { +Pass LazyGradientInit() { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { - return Downcast(GradientCell(f, m)); + return Downcast(LazyGradientInit(f, m)); }; - return CreateFunctionPass(pass_func, 2, "GradientCell", {}); + return CreateFunctionPass(pass_func, 2, "LazyGradientInit", {}); } -TVM_REGISTER_GLOBAL("relay._transform.GradientCell") -.set_body_typed(GradientCell); +TVM_REGISTER_GLOBAL("relay._transform.LazyGradientInit") +.set_body_typed(LazyGradientInit); } // namespace transform diff --git a/tests/python/relay/test_pass_gradient_cell.py b/tests/python/relay/test_pass_lazy_gradient_init.py similarity index 85% rename from tests/python/relay/test_pass_gradient_cell.py rename to tests/python/relay/test_pass_lazy_gradient_init.py index 2055771ba9eb..f9c762e5f905 100644 --- a/tests/python/relay/test_pass_gradient_cell.py +++ b/tests/python/relay/test_pass_lazy_gradient_init.py @@ -24,7 +24,7 @@ import pytest def test_tc(): - # test typechecks + """Simple testcase, check that transformation typechecks.""" mod = tvm.IRModule() shape = (20, 20) @@ -37,13 +37,13 @@ def test_tc(): y = relay.Function([x1, x2], (x1 - x2) * x2) mod["main"] = y - mod = transform.GradientCell()(mod) + mod = transform.LazyGradientInit()(mod) # function input/output types should remain the same assert mod["main"].checked_type == relay.FuncType([t, t], t) def test_add(): - # test simple add + """Simple add testcase. Check types and semantic equivalence.""" mod = tvm.IRModule() shape = (10, 10) @@ -55,7 +55,7 @@ def test_add(): y = relay.Function([x], x+x) mod["main"] = y - mod = transform.GradientCell()(mod) + mod = transform.LazyGradientInit()(mod) y = mod["main"] assert mod["main"].checked_type == relay.FuncType([t], t) @@ -66,7 +66,7 @@ def test_add(): assert_allclose(y.asnumpy(), x.asnumpy() + x.asnumpy()) def test_add_tuple(): - # test input tuple and add items + """Add elements of tuple. Check types and semantic equivalence.""" mod = tvm.IRModule() shape = (10, 10) @@ -79,7 +79,7 @@ def test_add_tuple(): y = relay.Function([x], relay.TupleGetItem(x, 0) + relay.TupleGetItem(x, 1)) mod["main"] = y - mod = transform.GradientCell()(mod) + mod = transform.LazyGradientInit()(mod) mod = transform.PrintIR(show_meta_data=True)(mod) y = mod["main"] @@ -91,7 +91,7 @@ def test_add_tuple(): assert_allclose(y.asnumpy(), x[0].asnumpy() + x[1].asnumpy()) def test_mult(): - # test simple add + """Simple multiplication testcase. Check types and semantic equivalence.""" mod = tvm.IRModule() shape = (15, 15) @@ -103,7 +103,7 @@ def test_mult(): y = relay.Function([x], x * x) mod["main"] = y - mod = transform.GradientCell()(mod) + mod = transform.LazyGradientInit()(mod) y = mod["main"] assert mod["main"].checked_type == relay.FuncType([t], t) @@ -114,7 +114,7 @@ def test_mult(): assert_allclose(y.asnumpy(), x.asnumpy() * x.asnumpy()) def test_ret_tuple(): - # test return tuple + """Test tuple return type. Check types and semantic equivalence.""" mod = tvm.IRModule() shape = (10, 10) @@ -127,7 +127,7 @@ def test_ret_tuple(): func = run_infer_type(func) mod["main"] = func - mod = transform.GradientCell()(mod) + mod = transform.LazyGradientInit()(mod) func = mod["main"] assert mod["main"].checked_type == relay.FuncType([t], relay.TupleType([t, t])) @@ -138,8 +138,8 @@ def test_ret_tuple(): assert_allclose(y[0].asnumpy(), x.asnumpy()) assert_allclose(y[1].asnumpy(), x.asnumpy() * 2.0) -def test_broadcast(): - # test broadcast add +def test_add_broadcast(): + """Test adding matrices of different size. Check types and semantic equivalence.""" mod = tvm.IRModule() shape1 = (3, 4, 1) @@ -152,30 +152,25 @@ def test_broadcast(): x2 = relay.var("x2", t2) func = relay.Function([x1,x2], x1 + x2) func = run_infer_type(func) - back_func = transform.gradient(func) - back_func = run_infer_type(back_func) - mod["main"] = back_func - mod = transform.GradientCell()(mod) - back_func = mod["main"] + mod["main"] = func + mod = transform.LazyGradientInit()(mod) + func = mod["main"] x1_np = rand(dtype, *shape1).asnumpy() x2_np = rand(dtype, *shape2).asnumpy() expected_forward = x1_np + x2_np expected_forward_type = relay.TensorType(expected_forward.shape, dtype) - assert mod["main"].checked_type == relay.FuncType([t1, t2], - relay.TupleType([expected_forward_type, relay.TupleType([t1, t2])])) + assert mod["main"].checked_type == relay.FuncType([t1, t2], expected_forward_type) ex = create_executor(mod=mod) - (forward), (grad_x1, grad_x2, ) = ex.evaluate(back_func)(x1_np, x2_np) + forward = ex.evaluate(func)(x1_np, x2_np) assert_allclose(forward.asnumpy(), expected_forward) - assert_allclose(grad_x1.asnumpy(), np.ones_like(expected_forward).sum(axis=2, keepdims=True)) - assert_allclose(grad_x2.asnumpy(), np.ones_like(expected_forward).sum(axis=(0,1), keepdims=True).squeeze(axis=0)) def test_reverse_ad_identity(): - # test correctness after reverse mode ad + """Simple test with reverse mode ad.""" # of f(x) = x mod = tvm.IRModule() @@ -191,7 +186,7 @@ def test_reverse_ad_identity(): back_func = run_infer_type(back_func) mod["main"] = back_func - mod = transform.GradientCell()(mod) + mod = transform.LazyGradientInit()(mod) back_func = mod["main"] assert mod["main"].checked_type == relay.FuncType([t], @@ -204,8 +199,7 @@ def test_reverse_ad_identity(): assert_allclose(grad.asnumpy(), np.ones_like(x.asnumpy())) def test_multivar_reverse_ad(): - # test correctness after reverse mode ad - # of multivariate function + """Simple test with multivariate reverse mode ad.""" mod = tvm.IRModule() shape = (10, 10) @@ -221,7 +215,7 @@ def test_multivar_reverse_ad(): back_func = run_infer_type(back_func) mod["main"] = back_func - mod = transform.GradientCell()(mod) + mod = transform.LazyGradientInit()(mod) back_func = mod["main"] assert mod["main"].checked_type == relay.FuncType([t, t], @@ -236,7 +230,7 @@ def test_multivar_reverse_ad(): assert_allclose(grad_y.asnumpy(), x.asnumpy()) def test_after_partial_eval(): - # test GradientCell transformation after PartialEval + """Test transformation following reverse mode ad and PartialEval""" mod = tvm.IRModule() shape = (10, 10) @@ -256,7 +250,7 @@ def test_after_partial_eval(): seq = transform.Sequential([ transform.PartialEvaluate(), - transform.GradientCell(), + transform.LazyGradientInit(), transform.DeadCodeElimination() ]) @@ -274,7 +268,7 @@ def test_after_partial_eval(): assert_allclose(grad_y.asnumpy(), x.asnumpy()) def test_before_partial_eval(): - # test GradientCell transformation before PartialEval + """Test transformation before PartialEval""" mod = tvm.IRModule() shape = (10, 10) @@ -291,7 +285,7 @@ def test_before_partial_eval(): mod["main"] = back_func seq = transform.Sequential([ - transform.GradientCell(), + transform.LazyGradientInit(), transform.PartialEvaluate(), transform.DeadCodeElimination() ]) @@ -310,7 +304,7 @@ def test_before_partial_eval(): assert_allclose(grad_y.asnumpy(), x.asnumpy()) def test_zeros(): - # test with zeros operator + """Simple test using "zeros" op""" mod = tvm.IRModule() shape = (10, 10) @@ -321,7 +315,7 @@ def test_zeros(): y = relay.Function([x], x + relay.zeros(shape, dtype)) mod["main"] = y - mod = transform.GradientCell()(mod) + mod = transform.LazyGradientInit()(mod) y = mod["main"] assert mod["main"].checked_type == relay.FuncType([t], t) @@ -332,7 +326,7 @@ def test_zeros(): assert_allclose(y.asnumpy(), x.asnumpy()) def test_ones(): - # test with ones operator + """Simple test using "ones" op""" mod = tvm.IRModule() shape = (10, 10) @@ -343,7 +337,7 @@ def test_ones(): y = relay.Function([x], x + relay.ones(shape, dtype)) mod["main"] = y - mod = transform.GradientCell()(mod) + mod = transform.LazyGradientInit()(mod) y = mod["main"] assert mod["main"].checked_type == relay.FuncType([t], t) @@ -354,7 +348,7 @@ def test_ones(): assert_allclose(y.asnumpy(), x.asnumpy() + np.ones_like(x.asnumpy())) def test_zeros_like(): - # test with zeros_like operator + """Simple test using "zeros_like" op""" mod = tvm.IRModule() shape = (10, 10) @@ -365,7 +359,7 @@ def test_zeros_like(): y = relay.Function([x], x + relay.zeros_like(x)) mod["main"] = y - mod = transform.GradientCell()(mod) + mod = transform.LazyGradientInit()(mod) y = mod["main"] assert mod["main"].checked_type == relay.FuncType([t], t) @@ -376,7 +370,7 @@ def test_zeros_like(): assert_allclose(y.asnumpy(), x.asnumpy()) def test_ones_like(): - # test with ones_like operator + """Simple test using "ones_like" op""" mod = tvm.IRModule() shape = (10, 10) @@ -387,7 +381,7 @@ def test_ones_like(): y = relay.Function([x], x + relay.ones_like(x)) mod["main"] = y - mod = transform.GradientCell()(mod) + mod = transform.LazyGradientInit()(mod) y = mod["main"] assert mod["main"].checked_type == relay.FuncType([t], t)