diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 1d2fa5472993f..18ec944f54f52 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -184,6 +184,26 @@ class VarNode : public ExprNode { RELAY_DEFINE_NODE_REF(Var, VarNode, Expr); +/*! \brief Hash Var by it's id. + * Different VarNode might has same vid, and they are considered to be the same var in such case. + * Use VarHash to hash Var by id. + */ +struct VarHash { + size_t operator()(const Var& v) const { + return v->vid.hash(); + } +}; + +/*! \brief Compare Var by it's id. + * Different VarNode might has same vid, and they are considered to be the same var in such case. + * Use VarEqual to compare Var by id. + */ +struct VarEqual { + bool operator()(const Var& l, const Var& r) const { + return l->vid.get() == r->vid.get(); + } +}; + /*! * \brief Global variable that leaves in the top-level module. * This is used to enable recursive calls between function. @@ -521,7 +541,7 @@ RELAY_DEFINE_NODE_REF(RefWrite, RefWriteNode, Expr); * rewriting pass such as layout or type transformation. * * Subclass TempExprNode allows us to pattern match on - * specific kind TempExpr and use them for expression rewriting. + * specific kind of TempExpr and use them for expression rewriting. * * TempExpr should only be used within a pass, */ @@ -539,6 +559,25 @@ class TempExprNode : public ExprNode { RELAY_DEFINE_NODE_REF(TempExpr, TempExprNode, Expr); +class Annotate; +class AnnotateNode : public ExprNode { + public: + Expr expr; + NodeRef annotation; + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("expr", &expr); + v->Visit("annotation", &annotation); + v->Visit("_checked_type_", &checked_type_); + } + + TVM_DLL static Annotate make(Expr expr, NodeRef annotation); + + static constexpr const char* _type_key = "relay.AnnotateNode"; + TVM_DECLARE_NODE_TYPE_INFO(AnnotateNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(Annotate, AnnotateNode, Expr); + // implementataions inline const Type& ExprNode::checked_type() const { CHECK(checked_type_.defined()) << "internal error: the type checker has " diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 3b179f8e53300..d3154d28bb272 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -116,6 +116,7 @@ class ExprFunctor { virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const AnnotateNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Node* op, Args...) { throw Error(std::string("Do not have a default for ") + op->type_key()); } @@ -140,6 +141,7 @@ class ExprFunctor { RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode); RELAY_EXPR_FUNCTOR_DISPATCH(ConstructorNode); RELAY_EXPR_FUNCTOR_DISPATCH(MatchNode); + RELAY_EXPR_FUNCTOR_DISPATCH(AnnotateNode); return vtable; } }; @@ -170,6 +172,7 @@ class ExprVisitor void VisitExpr_(const RefWriteNode* op) override; void VisitExpr_(const ConstructorNode* op) override; void VisitExpr_(const MatchNode* op) override; + void VisitExpr_(const AnnotateNode* op) override; virtual void VisitType(const Type& t); virtual void VisitClause(const Clause& c); virtual void VisitPattern(const Pattern& c); @@ -212,6 +215,7 @@ class ExprMutator Expr VisitExpr_(const RefWriteNode* op) override; Expr VisitExpr_(const ConstructorNode* op) override; Expr VisitExpr_(const MatchNode* op) override; + Expr VisitExpr_(const AnnotateNode* op) override; /*! * \brief Used to visit the types inside of expressions. diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 5a47b1d42ed31..f1e3942743dbb 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -74,7 +74,7 @@ def schedule_batch_matmul(attrs, outputs, target): with target: return topi.generic.schedule_batch_matmul(outputs) -reg.register_pattern("nn.batch_matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE) +reg.register_pattern("nn.batch_matmul", reg.OpPattern.OPAQUE) # conv2d diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 3108bc2501fed..422163758a2f2 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -232,8 +232,7 @@ TVM_REGISTER_API("relay._make.Call") TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const CallNode* node, tvm::IRPrinter* p) { - p->stream << "CallNode(" << node->op << ", " << node->args << ", " - << node->attrs << ", " << node->type_args << ")"; + p->stream << "CallNode(" << node->op << ")"; }); Let LetNode::make(Var var, Expr value, Expr body) { @@ -349,5 +348,17 @@ TVM_REGISTER_API("relay._expr.TempExprRealize") *ret = temp->Realize(); }); +Annotate AnnotateNode::make(Expr expr, NodeRef annotation) { + NodePtr n = make_node(); + n->expr = std::move(expr); + n->annotation = std::move(annotation); + return Annotate(n); +} + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const AnnotateNode* node, tvm::IRPrinter* p) { + p->stream << "AnnotateNode(" << node->expr << ")"; + }); + } // namespace relay } // namespace tvm diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index d0cd30adda29f..aaaf34d261a17 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -221,6 +221,10 @@ Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; } Type ExprMutator::VisitType(const Type& t) { return t; } +Expr ExprMutator::VisitExpr_(const AnnotateNode* op) { + return AnnotateNode::make(VisitExpr(op->expr), op->annotation); +} + void ExprVisitor::VisitExpr(const Expr& expr) { auto it = visit_counter_.find(expr.get()); if (it != visit_counter_.end()) { @@ -315,6 +319,10 @@ void ExprVisitor::VisitExpr_(const MatchNode* op) { } } +void ExprVisitor::VisitExpr_(const AnnotateNode* op) { + this->VisitExpr(op->expr); +} + void ExprVisitor::VisitClause(const Clause& op) { this->VisitPattern(op->lhs); this->VisitExpr(op->rhs); diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index fb0d919b46c38..718bad63693b1 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -113,6 +113,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) }); TypeCall TypeCallNode::make(Type func, tvm::Array args) { + CHECK(func.as()); NodePtr n = make_node(); n->func = std::move(func); n->args = std::move(args); diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index d24431347f808..c2f24e2179bef 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -683,7 +683,7 @@ bool BatchMatmulRel(const Array& types, const auto* x = types[0].as(); const auto* y = types[1].as(); if (x == nullptr || y == nullptr) return false; - if (x->shape.size() != 3 || y->shape.size() != 3) return false; + CHECK (x->shape.size() == 3 && y->shape.size() == 3); CHECK(reporter->AssertEQ(x->shape[0], y->shape[0])) << "BatchDot: batch dimension doesn't match, " << " x shape=" << x->shape diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index f86156bdbddcd..0f83f2cf194f2 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -729,9 +729,19 @@ bool TakeRel(const Array& types, // `types` contains: [data, indices, result] CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); - CHECK(data != nullptr); + if (data == nullptr) { + CHECK(types[0].as()) + << "must be tensor type or incomplete type"; + return false; + } + const auto* indices = types[1].as(); - CHECK(indices != nullptr); + if (indices == nullptr) { + CHECK(types[1].as()) + << "must be tensor type or incomplete type"; + return true; + } + const auto param = attrs.as(); CHECK(param != nullptr); diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index f6283d380176a..63ad11d2e1182 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -1,22 +1,3 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - /*! * Copyright (c) 2018 by Contributors * @@ -104,6 +85,7 @@ #include #include #include +#include "../ir/type_functor.h" #include "pass_util.h" #include "let_list.h" @@ -112,26 +94,7 @@ namespace relay { using namespace runtime; -/*! \brief Hash Var by it's id. - * Different VarNode might has same vid, and they are considered to be the same var in such case. - * Use VarHash to hash Var by id. - */ -struct VarHash { - size_t operator()(const Var& v) const { - return v->vid.hash(); - } -}; - -/*! \brief Compare Var by it's id. - * Different VarNode might has same vid, and they are considered to be the same var in such case. - * Use VarEqual to compare Var by id. - */ -struct VarEqual { - bool operator()(const Var& l, const Var& r) const { - return l->vid.get() == r->vid.get(); - } -}; - +Expr PostProcess(const Expr&); /*! \brief The base container type of Relay values. */ class StaticNode : public RelayNode { public: @@ -150,10 +113,20 @@ class Static : public NodeRef { using ContainerType = StaticNode; }; +using Time = size_t; + struct PStaticNode : Node { + static Time time() { + static Time time_ = 0; + Time ret = time_; + time_++; + return ret; + } Static pstatic; // may be null Expr dynamic; - PStaticNode(const Static& pstatic, const Expr& dynamic) : pstatic(pstatic), dynamic(dynamic) { } + Time created_time; + PStaticNode(const Static& pstatic, const Expr& dynamic) : + pstatic(pstatic), dynamic(dynamic), created_time(time()) { } explicit PStaticNode(const Expr& dynamic) : PStaticNode(Static(), dynamic) { } TVM_DECLARE_NODE_TYPE_INFO(PStaticNode, Node); }; @@ -261,7 +234,7 @@ class Environment { } ++rit; } - LOG(FATAL) << "Unknown Variable: " << v; + LOG(FATAL) << "Unknown Variable: " << v << v.as(); throw; } @@ -341,6 +314,7 @@ class Store { }; PStatic HasStatic(const Static& stat, const Expr& dynamic) { + CHECK(stat.defined()); return PStatic(make_node(stat, dynamic)); } @@ -383,15 +357,61 @@ FInterpreter CPUInterpreter() { return CreateInterpreter(Module(nullptr), CPUContext(), target); } -class PartialEvaluator : public ExprFunctor, +bool IsAtomic(const Expr& e) { + return e.as() || e.as() || e.as() || e.as(); +} + +using FuncId = size_t; + +struct WithFuncId; + +struct WithFuncIdNode : Node { + FuncId fid; + WithFuncIdNode(FuncId fid) : fid(fid) { } + static constexpr const char* _type_key = "relay.WithFuncId"; + TVM_DECLARE_NODE_TYPE_INFO(WithFuncIdNode, Node); +}; + +RELAY_DEFINE_NODE_REF(WithFuncId, WithFuncIdNode, NodeRef); + +Annotate MkWithFuncId(const Expr& expr, FuncId fid) { + return AnnotateNode::make(expr, WithFuncId(make_node(fid))); +} + +Expr StripWithFuncId(const Expr& e); + +Expr DeDup(const Expr& e); + +Function AsFunc(const Expr& e) { + if (e.as()) { + return Downcast(e); + } else if (const AnnotateNode* a = e.as()) { + CHECK(a->annotation.as()); + return AsFunc(a->expr); + } else { + LOG(FATAL) << "Unknown case"; + throw; + } +} + +class PartialEvaluator : public ExprFunctor, public PatternFunctor { public: - PartialEvaluator(const tvm::Array& free_vars) { + PartialEvaluator(const tvm::Array& free_vars, + const Module& mod) : + mod_(mod) { for (const Var& v : free_vars) { env_.Insert(v, NoStatic(v)); } } + size_t depth = 0; + PStatic VisitExpr(const Expr& e, LetList* ll) final { + PStatic ret = ExprFunctor::VisitExpr(e, ll); + CHECK(IsAtomic(ret->dynamic)) << ret->dynamic; + return ret; + } + PStatic VisitExpr_(const ConstantNode* op, LetList* ll) final { return HasStatic(MkSTensor(op->data.CopyTo(context_)), ll->Push(GetRef(op))); } @@ -421,7 +441,20 @@ class PartialEvaluator : public ExprFunctor } PStatic VisitExpr_(const GlobalVarNode* op, LetList* ll) final { - return NoStatic(GetRef(op)); + GlobalVar gv = GetRef(op); + if (gv_map_.count(gv) == 0) { + if (mod_.defined()) { + Function func = mod_->Lookup(gv); + InitializeFuncId(func); + Func f = VisitFuncStatic(func, gv); + gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)}); + func = AsFunc(PostProcess(VisitFuncDynamic(func, f))); + mod_->Update(gv, func); + } else { + gv_map_.insert({gv, NoStatic(gv)}); + } + } + return gv_map_.at(gv); } PStatic VisitExpr_(const LetNode* op, LetList* ll) final { @@ -501,19 +534,45 @@ class PartialEvaluator : public ExprFunctor } } - PStatic VisitExpr_(const FunctionNode* op, LetList* ll) final { - Function func = GetRef(op); + PStatic VisitExpr_(const AnnotateNode* op, LetList* ll) final { + CHECK(op->annotation.as()); + return VisitExpr(op->expr, ll); + } + + struct TimeFrame { + PartialEvaluator* pe_; + FuncId fid_; + std::vector