From 3797ab68b518026a7b2001a4c309c3ba3a75fba6 Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Thu, 26 May 2022 14:14:53 -0700 Subject: [PATCH 1/5] [Relay] Odd's 'n ends changes to help Collage. - Complete the implementation of WithFields. (Unfortunately they appear to be without unit tests and I continue this tradition...) - InferTypeExpr for InferTypeLocal but return the expression rather than the type. - Remove python binding of InlineComposites since C++ impl was removed some time ago. - Make IndexedGraph more robust as stand-alone datastructure, and avoid unnecessary copies. This will become a fundamental datastructure in Collage rather than just a helper for DFPatternMatcher. - Extend IndexedGraph with a notion of 'basic block' on every dataflow node. Needed by Collage to avoid impossible partitions. --- include/tvm/relay/adt.h | 31 +- include/tvm/relay/expr.h | 173 ++---- include/tvm/relay/expr_functor.h | 2 + include/tvm/relay/function.h | 25 +- include/tvm/relay/transform.h | 5 + src/relay/ir/dataflow_matcher.cc | 84 +-- src/relay/ir/dataflow_matcher_impl.h | 20 +- src/relay/ir/expr.cc | 40 ++ src/relay/ir/indexed_graph.cc | 524 +++++++++++++------ src/relay/ir/indexed_graph.h | 284 ++++++++-- src/relay/op/dyn/tensor/transform.cc | 1 + src/relay/transforms/type_infer.cc | 5 + src/runtime/contrib/tensorrt/tensorrt_ops.cc | 2 +- tests/cpp/relay/ir/indexed_graph_test.cc | 184 +++++++ tests/python/relay/test_dataflow_pattern.py | 61 ++- 15 files changed, 1030 insertions(+), 411 deletions(-) create mode 100644 tests/cpp/relay/ir/indexed_graph_test.cc diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index 31dec2204146..cdb8e52d2359 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -264,17 +264,9 @@ class Clause : public ObjectRef { }; /*! - * \brief Returns the clause with given properties. A null property denotes 'no change'. - * Returns clause if all properties are unchanged. Otherwise, returns a copy with the new fields. - * \param clause The clause to copy. - * \param opt_lhs The (optional) lhs for the copied clause. If none, ret_clause->lhs = clause->lhs. - * \param opt_rhs The (optional) rhs for the copied clause. If none, - * ret_clause->rhs = clause->rhs. - * \return If all - * properties are null or the same as the property in the input clause (i.e., opt_lhs is null or - * opt_lhs.value() == clause->lhs, etc.), then we return clause. Otherwise, we return a copy of - * clause with the different fields overwritten. (i.e., if opt_lhs.value() != clause->lhs, then - * ret_clause->lhs = opt_lhs.value()). + * \brief Returns \p clause with the given properties. A null property denotes 'no change'. + * Returns \p clause if all properties are unchanged. Otherwise, returns a copy with the new + * fields. */ Clause WithFields(Clause clause, Optional opt_lhs = Optional(), Optional opt_rhs = Optional()); @@ -337,20 +329,9 @@ class Match : public Expr { }; /*! - * \brief Returns the match with given properties. A null property denotes 'no change'. - * Returns match if all properties are unchanged. Otherwise, returns a copy with the new fields. - * \param match The match to copy. - * \param opt_data The (optional) data for the copied match. If none, ret_match->data = match->data. - * \param opt_clauses The (optional) clauses for the copied match. If none, ret_match->clauses = - * match->clauses. - * \param opt_complete The (optional) complete for the copied match. If none, ret_match->complete = - * match->complete. - * \param opt_span The (optional) span for the copied match. If none, ret_match->span = match->span. - * \return If all properties are null or the same as the - * property in the input match (i.e., opt_clauses is null or opt_clauses.value() == match->clauses, - * etc.), then we return match. Otherwise, we return a copy of match with the different fields - * overwritten. (i.e., if opt_clauses.value() != match->clauses, then ret_match->clauses = - * opt_clauses.value()). + * \brief Returns \p match with the given properties. A null property denotes 'no change'. + * Returns \p match if all properties are unchanged. Otherwise, returns a copy with the new + * fields. */ Match WithFields(Match match, Optional opt_data = Optional(), Optional> opt_clauses = Optional>(), diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index fe570806922f..6b014c8478d8 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -39,6 +39,16 @@ #include "./type.h" namespace tvm { + +/*! + * \brief Returns \p global_var with the given properties. A null property denotes 'no change'. + * Returns \p global_var if all properties are unchanged. Otherwise, returns a copy with the new + * fields. + */ +GlobalVar WithFields(GlobalVar global_var, Optional opt_name_hint = {}, + Optional opt_type = {}, Optional opt_virtual_device = {}, + Optional opt_span = {}); + namespace relay { using Expr = tvm::RelayExpr; @@ -97,8 +107,17 @@ class Constant : public Expr { TVM_DLL explicit Constant(runtime::NDArray data, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Constant, RelayExpr, ConstantNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ConstantNode); }; +/*! + * \brief Returns \p constant with the given properties. A null property denotes 'no change'. + * Returns \p constant if all properties are unchanged. Otherwise, returns a copy with the new + * fields. + */ +Constant WithFields(Constant constant, Optional opt_data = {}, + Optional opt_virtual_device = {}, Optional opt_span = {}); + /*! \brief Tuple of multiple Exprs */ class Tuple; /*! \brief Tuple container */ @@ -149,15 +168,9 @@ class Tuple : public Expr { }; /*! - * \brief Returns the tuple with given properties. A null property denotes 'no change'. - * Returns this if all properties are unchanged. Otherwise, returns a copy with the new fields. - * \param tuple The tuple to copy - * \param opt_fields The (optional) fields for the copied tuple. If none, ret_tuple->fields = - * tuple->fields. - * \param opt_virtual_device The (optional) virtual_device for the copied tuple. If none, - * ret_tuple->virtual_device = tuple->virtual_device. - * \param opt_span The (optional) span for the copied tuple. If none, - * ret_tuple->span = tuple->span. + * \brief Returns \p tuple with the given properties. A null property denotes 'no change'. + * Returns \p tuple if all properties are unchanged. Otherwise, returns a copy with the new + * fields. */ Tuple WithFields(Tuple tuple, Optional> opt_fields = Optional>(), Optional opt_virtual_device = Optional(), @@ -251,19 +264,9 @@ class Var : public Expr { }; /*! - * \brief Returns the var with given properties. A null property denotes 'no change'. - * Returns var if all properties are unchanged. Otherwise, returns a copy with the new fields. - * \param var The var to copy. - * \param opt_vid The (optional) vid for the copied var. If none, ret_var->vid = var->vid. - * \param opt_type_annotation The (optional) type_annotation for the copied var. If none, - * ret_var->type_annotation = var->type_annotation. - * \param opt_virtual_device The (optional) virtual_device for the copied tuple. If none, - * ret_tuple->virtual_device = tuple->virtual_device. - * \param opt_span The (optional) span for the copied var. If none, ret_var->span = var->span. - * \return If all properties are null or the same as the property in the input var (i.e., opt_vid is - * null or opt_vid.value() == var->vid, etc.), then we return var. Otherwise, we return a copy of - * call with the different fields overwritten. (i.e., if opt_vid.value() != var->vid, then - * ret_var->vid = opt_.value()). + * \brief Returns \p vor with the given properties. A null property denotes 'no change'. + * Returns \p var if all properties are unchanged. Otherwise, returns a copy with the new + * fields. */ Var WithFields(Var var, Optional opt_vid = Optional(), Optional opt_type_annotation = Optional(), @@ -374,22 +377,9 @@ class Call : public Expr { }; /*! - * \brief Returns the call with given properties. A null property denotes 'no change'. - * Returns call if all properties are unchanged. Otherwise, returns a copy with the new fields. - * \param call The call to copy. - * \param opt_op The (optional) op for the copied call. If none, ret_call->op = call->op. - * \param opt_args The (optional) args for the copied call. If none, ret_call->args = call->args. - * \param opt_attrs The (optional) attrs for the copied call. If none, ret_call->attrs = - * call->attrs. - * \param opt_type_args The (optional) type args for the copied call. If none, - * ret_call->type_args = call->type_args. - * \param opt_virtual_device The (optional) virtual_device for the copied call. If none, - * ret_call->virtual_device = call->virtual_device. - * \param opt_span The (optional) span for the copied call. If none, ret_call->span = call->span. - * \return If all properties are null or the same as the property in the input call (i.e., opt_op is - * null or opt_op.value() == call->op, etc.), then we return call. Otherwise, we return a copy of - * call with the different fields overwritten. (i.e., if opt_op.value() != call->op, then - * ret_call->op = opt_op.value()). + * \brief Returns \p call with the given properties. A null property denotes 'no change'. + * Returns \p call if all properties are unchanged. Otherwise, returns a copy with the new + * fields. */ Call WithFields(Call call, Optional opt_op = Optional(), Optional> opt_args = Optional>(), @@ -475,19 +465,9 @@ class Let : public Expr { }; /*! - * \brief Returns the let with given properties. A null property denotes 'no change'. - * Returns let if all properties are unchanged. Otherwise, returns a copy with the new fields. - * \param let The let to copy. - * \param opt_var The (optional) var for the copied let. If none, ret_let->op = let->op. - * \param opt_value The (optional) value for the copied let. If none, ret_let->args = let->args. - * \param opt_body The (optional) body for the copied let. If none, ret_let->attrs = let->attrs. - * \param opt_virtual_device The (optional) virtual_device for the copied let. If none, - * ret_let->virtual_device = let->virtual_device. - * \param opt_span The (optional) span for the copied let. If none, ret_let->span = let->span. - * \return If all properties are null or the same as the property in the input let (i.e., opt_var is - * null or opt_var.value() == let->var, etc.), then we return let. Otherwise, we return a copy of - * let with the different fields overwritten. (i.e., if opt_var.value() != let->var, then - * ret_let->var = opt_var.value()). + * \brief Returns \p let with the given properties. A null property denotes 'no change'. + * Returns \p let if all properties are unchanged. Otherwise, returns a copy with the new + * fields. */ Let WithFields(Let let, Optional opt_var = Optional(), Optional opt_value = Optional(), @@ -559,23 +539,9 @@ class If : public Expr { }; /*! - * \brief Returns the if_expr with given properties. A null property denotes 'no change'. - * Returns if_expr if all properties are unchanged. Otherwise, returns a copy with the new fields. - * \param if_expr The if expression to copy. - * \param opt_cond The (optional) cond for the copied if_expr. If none, ret_if->cond = - * if_expr->cond. - * \param opt_true_branch The (optional) true_branch for the copied if_expr. If none, - * ret_if->true_branch = ret_if->false_branch. - * \param opt_false_branch The (optional) false_branch - * for the copied if_expr. If none, ret_if->false_branch = if_expr->false_branch. - * \param opt_virtual_device The (optional) virtual_device for the copied if_expr. If none, - * ret_if->virtual_device = if_expr->virtual_device. - * \param opt_span The (optional) span for the copied if_expr. If none, - * ret_if->span = if_expr->span. - * \return If all properties are null or the same as the property in - * the input if_expr (i.e., opt_cond is null or opt_cond.value() == if_expr->cond, etc.), then we - * return if_expr. Otherwise, we return a copy of if_expr with the different fields overwritten. - * (i.e., if opt_cond.value() != if_expr->cond, then ret_if->cond = opt_cond.value()). + * \brief Returns \p if_expr with the given properties. A null property denotes 'no change'. + * Returns \p if_expr if all properties are unchanged. Otherwise, returns a copy with the new + * fields. */ If WithFields(If if_expr, Optional opt_cond = Optional(), Optional opt_true_branch = Optional(), @@ -628,22 +594,9 @@ class TupleGetItem : public Expr { }; /*! - * \brief Returns the tuple_get_item with given properties. A null property denotes 'no change'. - * Returns if_expr if all properties are unchanged. Otherwise, returns a copy with the new fields. - * \param tuple_get_item The tuple_get_item to copy. - * \param opt_tuple The (optional) tuple for the copied tuple_get_item. If none, - * ret_tuple_get_item->tuple = tuple_get_item->tuple. - * \param opt_index The (optional) index for the copied tuple_get_item. If none, - * ret_tuple_get_item->index = tuple_get_item->index. - * \param opt_virtual_device The (optional) virtual_device for the copied tuple_get_item. - * If none, ret_tuple_get_item->virtual_device = tuple_get_item->virtual_device. - * \param opt_span The (optional) span for the copied tuple_get_item. If none, - * ret_tuple_get_item->span = tuple_get_item->span. - * \return If all properties are null or the same as the property in the input tuple_get_item - * (i.e., opt_tuple is null or opt_tuple.value() == tuple_get_item->tuple, etc.), then we return - * tuple_get_item. Otherwise, we return a copy of tuple_get_item with the different fields - * overwritten. (i.e., if opt_tuple.value() != tuple_get_item->tuple, then - * ret_tuple_get_item->tuple = opt_tuple.value()). + * \brief Returns \p tuple_get_item with the given properties. A null property denotes 'no change'. + * Returns \p tuple_get_item if all properties are unchanged. Otherwise, returns a copy with the new + * fields. */ TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple = Optional(), Optional opt_index = Optional(), @@ -692,21 +645,9 @@ class RefCreate : public Expr { }; /*! - * \brief Returns the ref create with given properties. A null property denotes 'no change'. - * Returns ref_create if all properties are unchanged. Otherwise, returns a copy with the new + * \brief Returns \p ref_create with the given properties. A null property denotes 'no change'. + * Returns \p ref_crete if all properties are unchanged. Otherwise, returns a copy with the new * fields. - * \param ref_create The ref_create to copy. - * \param opt_value The (optional) value for the copied ref_create. If none, - * ret_ref_create->value = ref_create->value. - * \param opt_virtual_device The (optional) virtual_device for the copied ref_create. If none, - * ret_ref_create->virtual_device = ref_create->virtual_device. - * \param opt_span The (optional) span for the copied ref_create. If none, - * ret_ref_create->span = ref_create->span. - * \return If all properties are null or the same as the property in the input ref_create - * (i.e., opt_value is null or opt_value.value() == ref_create->value, etc.), then we return - * ref_create. Otherwise, we return a copy of ref_create with the different fields overwritten. - * (i.e., if opt_value.value() != ref_create->value, then - * ret_ref_create->value = opt_value.value()). */ RefCreate WithFields(RefCreate ref_create, Optional opt_value = Optional(), Optional opt_virtual_device = Optional(), @@ -754,20 +695,9 @@ class RefRead : public Expr { }; /*! - * \brief Returns the ref read with given properties. A null property denotes 'no change'. - * Returns ref_read if all properties are unchanged. Otherwise, returns a copy with the new fields. - * \param ref_read The ref_read to copy. - * \param opt_ref The (optional) ref for the copied ref_read. If none, ret_ref_read->ref = - * ref_read->ref. - * \param opt_virtual_device - * The (optional) virtual_device for the copied ref_read. If none, ret_ref_read->virtual_device = - * ref_read->virtual_device. - * \param opt_span The (optional) span for the copied ref_read. If none, ret_ref_read->span = - * ref_read->span. - * \return If all properties are null or the same as the property in the input - * ref_read (i.e., opt_ref is null or opt_ref.value() == ref_read->ref, etc.), then we return - * ref_read. Otherwise, we return a copy of ref_read with the different fields overwritten. (i.e., - * if opt_ref.value() != ref_read->ref, then ret_ref_read->ref = opt_ref.value()). + * \brief Returns \p ref_read with the given properties. A null property denotes 'no change'. + * Returns \p ref_read if all properties are unchanged. Otherwise, returns a copy with the new + * fields. */ RefRead WithFields(RefRead ref_read, Optional opt_ref = Optional(), Optional opt_virtual_device = Optional(), @@ -820,22 +750,9 @@ class RefWrite : public Expr { }; /*! - * \brief Returns the ref write with given properties. A null property denotes 'no change'. - * Returns ref_write if all properties are unchanged. Otherwise, returns a copy with the new fields. - * \param ref_write The ref_write to copy. - * \param opt_ref The (optional) ref for the copied ref_write. If none, - * ret_ref_write->ref = ref_write->ref. - * \param opt_value The (optional) value for the copied ref_write. If none, - * ret_ref_write->value = ref_write->value. - * \param opt_virtual_device - * The (optional) virtual_device for the copied ref_write. If none, ret_ref_write->virtual_device = - * ref_write->virtual_device. - * \param opt_span The (optional) span for the copied ref_write. If none, ret_ref_write->span = - * ref_write->span. - * \return If all properties are null or the same as the property in the input ref_write (i.e., - * opt_ref is null or opt_ref.value() == ref_write->ref, etc.), then we return ref_write. Otherwise, - * we return a copy of ref_write with the different fields overwritten. (i.e., if ref_write.value() - * != ref_write->ref, then ret_ref_write->ref = opt_ref.value()). + * \brief Returns \p ref_write with the given properties. A null property denotes 'no change'. + * Returns \p ref_write if all properties are unchanged. Otherwise, returns a copy with the new + * fields. */ RefWrite WithFields(RefWrite ref_write, Optional opt_ref = Optional(), Optional opt_value = Optional(), diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index d8f575dfdf48..280a1f8a6c29 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -240,6 +240,8 @@ class MixedModeVisitor : public ::tvm::relay::ExprVisitor { */ explicit MixedModeVisitor(int visit_limit = 1); + using ExprVisitor::VisitExpr_; + /*! * \brief VisitExpr is finalized to preserve call expansion of dataflow regions */ diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h index 052d04fe2411..874d4f233416 100644 --- a/include/tvm/relay/function.h +++ b/include/tvm/relay/function.h @@ -121,28 +121,9 @@ class Function : public BaseFunc { }; /*! - * \brief Returns the function with given properties. A null property denotes 'no change'. - * Returns function if all properties are unchanged. Otherwise, returns a copy with the new fields. - * \param function The function to copy. - * \param opt_params The (optional) params for the copied function. If none, - * ret_function->params = function->params. - * \param opt_body The (optional) body for the copied function. If none, - * ret_function->body = function->body. - * \param opt_ret_type The (optional) return type for the copied function. If none, - * ret_function->ret_type = function->ret_type. - * \param opt_ty_params The (optional) type params for the copied function. If none, - * ret_function->type_params = function->type_params. - * \param opt_attrs - * The (optional) attributes for the copied function. If none, - * ret_function->attrs = function->attrs. - * \param opt_virtual_device The (optional) virtual_device for the copied function. If none, - * ret_function->virtual_device = function->virtual_device. - * \param opt_span The (optional) span for the copied function. If none, - * ret_function->span = function->span. - * \return If all properties are null or the same as the property in the input function - * (i.e., opt_params is null or opt_params.value() == function->params, etc.), then we return - * function. Otherwise, we return a copy of function with the different fields overwritten. (i.e., - * if opt_params.value() != function->params, then ret_function->params = opt_params.value()). + * \brief Returns \p function with the given properties. A null property denotes 'no change'. + * Returns \p function if all properties are unchanged. Otherwise, returns a copy with the new + * fields. */ Function WithFields(Function function, Optional> opt_params = Optional>(), Optional opt_body = Optional(), diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 6e3bddf9adf5..fad033ee3f1d 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -281,6 +281,11 @@ TVM_DLL Pass InferType(); */ TVM_DLL Type InferTypeLocal(const Expr& expr); +/*! + * \brief Infer the types of all sub-expression of expr. + */ +TVM_DLL Expr InferTypeExpr(const Expr& expr); + /*! * \brief Search and eliminate common subexpression. For example, if there are * two expressions evaluated to an identical value, a single variable is created diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 8d7ed163a197..c537756d0987 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -36,6 +36,7 @@ namespace relay { // Pattern Matcher bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) { + VLOG(1) << "Match " << PrettyPrint(pattern) << " in:" << std::endl << PrettyPrint(expr); memo_.clear(); matched_nodes_.clear(); return VisitDFPattern(pattern, expr); @@ -58,6 +59,7 @@ bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr if (out) { memo_[pattern].push_back(expr); matched_nodes_.push_back(pattern); + VLOG(1) << "Matched " << PrettyPrint(pattern) << " at:" << std::endl << PrettyPrint(expr); } else { ClearMap(watermark); } @@ -124,7 +126,6 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons if (!matches) { return matches; } - VLOG(1) << "considering AttrPatternNode at:\n" << PrettyPrint(expr); auto attributes = attr_pattern->attrs.as()->dict; if (const auto* op_node = expr.as()) { Op op = GetRef(op_node); @@ -299,14 +300,18 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex // Recursively find the Dominator parent along all inputs paths. bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) { auto call_node = expr.as(); - for (auto node : expr_graph_.node_map_.at(expr)->inputs_) { - if (!(call_node && node->ref_ == call_node->op)) { + auto index_node = expr_to_node(expr); + for (auto node : index_node->inputs_) { + if (!(call_node && node->ref() == call_node->op)) { memoize_ = true; - if (VisitDFPattern(op->parent, node->ref_)) { + if (VisitDFPattern(op->parent, node->ref())) { return true; } else { memoize_ = false; - if (!VisitDFPattern(op->path, node->ref_) || !MatchesPath(op, node->ref_)) { + if (!VisitDFPattern(op->path, node->ref())) { + return false; + } + if (!MatchesPath(op, node->ref())) { return false; } } @@ -318,19 +323,19 @@ bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& e // Iteratively ensure that the parent is dominated somewhere by the child or the path bool DFPatternMatcher::DominatesParent(const DominatorPatternNode* op, const Expr& expr) { std::stack stack; - std::unordered_set visited; + std::unordered_set visited; stack.push(expr); while (!stack.empty()) { Expr current = stack.top(); stack.pop(); - for (auto node : expr_graph_.node_map_.at(current)->dominator_children_) { - if (visited.count(node->ref_) == 0) { - if (VisitDFPattern(op->parent, node->ref_)) { + for (auto node : expr_to_node(current)->dominator_children_) { + if (visited.count(node->node_ref_) == 0) { + if (VisitDFPattern(op->parent, node->ref())) { return true; } else { - stack.push(node->ref_); + stack.push(node->ref()); } - visited.insert(node->ref_); + visited.insert(node->node_ref_); } } } @@ -500,7 +505,8 @@ bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr } bool MatchPattern(DFPattern pattern, Expr expr) { - return DFPatternMatcher(expr).Match(pattern, expr); + std::unique_ptr> expr_graph = CreateIndexedGraph(expr); + return DFPatternMatcher(expr_graph.get()).Match(pattern, expr); } TVM_REGISTER_GLOBAL("relay.dataflow_pattern.match").set_body_typed(MatchPattern); @@ -575,7 +581,8 @@ const std::unordered_map& PatternGrouper::GroupMatch pattern_ = pattern; pattern_graph_ = CreateIndexedGraph(pattern_); - auto matcher = DFPatternMatcher(pre); + std::unique_ptr> expr_graph = CreateIndexedGraph(pre); + DFPatternMatcher matcher(expr_graph.get()); matcher_ = &matcher; this->VisitExprs(); return this->groups_; @@ -583,9 +590,9 @@ const std::unordered_map& PatternGrouper::GroupMatch void PatternGrouper::VisitExprs() { std::unordered_set pre_partitioned; - for (size_t i = matcher_->expr_graph_.topological_order_.size(); i != 0; --i) { - size_t index = i - 1; - Expr current = matcher_->expr_graph_.topological_order_.at(index)->ref_; + for (PostDfsIndex i = matcher_->size(); i != 0; --i) { + PostDfsIndex index = i - 1; + const auto current = matcher_->index_to_node(index)->ref(); if (gid_assignments_.count(current) == 0) { // Don't visit nodes we've already grouped if (auto op = current.as()) { if (op->attrs.defined() && op->attrs->dict.count(attr::kPartitionedFromPattern) != 0) { @@ -607,9 +614,10 @@ void PatternGrouper::CreateGroup(const Expr& expr) { auto node_map = matcher_->GetMemo(); // Get fuzzy patterns std::unordered_set fuzzy_matches; - for (auto node : pattern_graph_.topological_order_) { + for (PostDfsIndex index = 0; index < pattern_graph_->size(); ++index) { + auto node = pattern_graph_->index_to_node(index); // Don't treat fuzzy Dominator patterns input variables for partition - if (auto op = node->ref_.as()) { + if (auto op = node->ref().as()) { for (auto fuzzy_op : {op->parent, op->path}) { for (auto match : node_map[fuzzy_op]) { fuzzy_matches.insert(match); @@ -617,12 +625,13 @@ void PatternGrouper::CreateGroup(const Expr& expr) { } } // Don't treat Function params or body as input variables for partition - if (node->ref_.as()) { - auto matches = node_map[node->ref_]; + if (node->ref().as()) { + auto matches = node_map[node->ref()]; for (auto match : matches) { - auto graph = CreateIndexedGraph(match.as()->body); - for (auto node : graph.topological_order_) { - fuzzy_matches.insert(node->ref_); + auto sub_graph = CreateIndexedGraph(match.as()->body); + for (PostDfsIndex sub_index = 0; sub_index < sub_graph->size(); ++sub_index) { + auto sub_node = sub_graph->index_to_node(sub_index); + fuzzy_matches.insert(sub_node->ref()); } } } @@ -636,10 +645,11 @@ void PatternGrouper::CreateGroup(const Expr& expr) { std::unordered_map inputs; Array params; - for (auto node : pattern_graph_.topological_order_) { + for (PostDfsIndex index = 0; index < pattern_graph_->size(); ++index) { + auto node = pattern_graph_->index_to_node(index); auto make_input = [&](const Expr& input) { if (fuzzy_matches.count(input) == 0 && input.as() == nullptr && - input.as() == nullptr && !EmbedConst(input, node->ref_)) { + input.as() == nullptr && !EmbedConst(input, node->ref())) { inputs[input] = Var("FunctionVar_" + std::to_string(graph_number_) + "_" + std::to_string(var_number), NullValue()); @@ -648,11 +658,11 @@ void PatternGrouper::CreateGroup(const Expr& expr) { var_number++; } }; - auto tuple = node->ref_.as(); - auto call = node->ref_.as(); + auto tuple = node->ref().as(); + auto call = node->ref().as(); if (tuple && !tuple->fields.defined()) { - if (node_map.count(node->ref_)) { - auto matches = node_map[node->ref_]; + if (node_map.count(node->ref())) { + auto matches = node_map[node->ref()]; for (auto match : matches) { for (auto input : match.as()->fields) { make_input(input); @@ -660,8 +670,8 @@ void PatternGrouper::CreateGroup(const Expr& expr) { } } } else if (call && !call->args.defined()) { - if (node_map.count(node->ref_)) { - auto matches = node_map[node->ref_]; + if (node_map.count(node->ref())) { + auto matches = node_map[node->ref()]; for (auto match : matches) { for (auto input : match.as()->args) { make_input(input); @@ -669,8 +679,8 @@ void PatternGrouper::CreateGroup(const Expr& expr) { } } } else if (node->inputs_.size() == 0) { - if (node_map.count(node->ref_)) { - auto matches = node_map[node->ref_]; + if (node_map.count(node->ref())) { + auto matches = node_map[node->ref()]; for (auto match : matches) { make_input(match); } @@ -708,13 +718,15 @@ void PatternGrouper::CreateGroup(const Expr& expr) { return; } else if (kv.second != body) { // if the node isn't the output of the group - auto node = matcher_->expr_graph_.node_map_.at(kv.first); + auto node = matcher_->expr_to_node(kv.first); for (auto* output : node->outputs_) { // and the node is used by nodes outside of the group - if (memo.count(output->ref_) == 0 && - !matcher_->expr_graph_.node_map_.at(expr)->Dominates(output)) { + if (memo.count(output->ref()) == 0) { + // TODO(mbs): The dominates relation is backwards here, and will always fail. + // So all we're doing is failing if an internal node connects to an outside node. // Exit because nodes in this pattern's body are used outside the pattern // fusing it would be invalid + ICHECK(!matcher_->expr_to_node(expr)->Dominates(output)); return; } } diff --git a/src/relay/ir/dataflow_matcher_impl.h b/src/relay/ir/dataflow_matcher_impl.h index d993d4720e4e..ae5168a84d0a 100644 --- a/src/relay/ir/dataflow_matcher_impl.h +++ b/src/relay/ir/dataflow_matcher_impl.h @@ -27,7 +27,9 @@ #include #include #include +#include +#include #include #include #include @@ -39,10 +41,21 @@ namespace relay { class DFPatternMatcher : public DFPatternFunctor { public: - explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {} + explicit DFPatternMatcher(const IndexedGraph* expr_graph) : expr_graph_(expr_graph) {} bool Match(const DFPattern& pattern, const Expr& expr); Map> GetMemo() { return Map>(memo_); } - const IndexedGraph expr_graph_; + + const IndexedGraph::Node* expr_to_node(const Expr& expr) const { + return expr_graph_->item_to_node(expr); + } + const IndexedGraph::Node* index_to_node(size_t index) const { + return expr_graph_->index_to_node(index); + } + size_t size() const { return expr_graph_->size(); } + const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& memo() const { + return memo_; + } + const IndexedGraph& dataflow_graph() const { return *expr_graph_; } protected: bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override; @@ -67,6 +80,7 @@ class DFPatternMatcher : public DFPatternFunctor* expr_graph_; std::unordered_map, ObjectPtrHash, ObjectPtrEqual> memo_; std::vector matched_nodes_; bool memoize_ = true; @@ -131,7 +145,7 @@ class PatternGrouper { std::unordered_map groups_; std::unordered_map gid_assignments_; DFPatternMatcher* matcher_ = nullptr; - IndexedGraph pattern_graph_; + std::unique_ptr> pattern_graph_; int gid_ = 0; int graph_number_ = 0; }; diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index fc76577bd7c0..150226fae1a2 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -27,6 +27,26 @@ namespace tvm { +GlobalVar WithFields(GlobalVar global_var, Optional opt_name_hint, Optional opt_type, + Optional opt_virtual_device, Optional opt_span) { + String name_hint = opt_name_hint.value_or(global_var->name_hint); + Type type = opt_type.value_or(global_var->checked_type()); + VirtualDevice virtual_device = opt_virtual_device.value_or(global_var->virtual_device()); + Span span = opt_span.value_or(global_var->span); + bool all_fields_unchanged = + name_hint.same_as(global_var->name_hint) && type.same_as(global_var->checked_type()) && + virtual_device.same_as(global_var->virtual_device()) && span.same_as(global_var->span); + if (!all_fields_unchanged) { + GlobalVarNode* cow_global_var_node = global_var.CopyOnWrite(); + cow_global_var_node->name_hint = name_hint; + cow_global_var_node->checked_type_ = type; + cow_global_var_node->virtual_device_ = virtual_device; + cow_global_var_node->span = span; + } + + return global_var; +} + VirtualDevice RelayExprNode::virtual_device() const { if (!this->virtual_device_.defined()) { // virtual_device_ should always be defined, unless we imported this node from JSON using an old @@ -77,6 +97,25 @@ TensorType ConstantNode::tensor_type() const { return TensorType(shape, dtype); } +Constant WithFields(Constant constant, Optional opt_data, + Optional opt_virtual_device, Optional opt_span) { + runtime::NDArray data = opt_data.value_or(constant->data); + VirtualDevice virtual_device = opt_virtual_device.value_or(constant->virtual_device()); + Span span = opt_span.value_or(constant->span); + + bool all_fields_unchanged = data.same_as(constant->data) && + virtual_device.same_as(constant->virtual_device()) && + span.same_as(constant->span); + + if (!all_fields_unchanged) { + ConstantNode* cow_constant_node = constant.CopyOnWrite(); + cow_constant_node->data = data; + cow_constant_node->virtual_device_ = virtual_device; + cow_constant_node->span = span; + } + return constant; +} + Tuple::Tuple(tvm::Array fields, Span span) { ObjectPtr n = make_object(); n->fields = std::move(fields); @@ -90,6 +129,7 @@ TVM_REGISTER_NODE_TYPE(TupleNode); TVM_REGISTER_GLOBAL("relay.ir.Tuple").set_body_typed([](tvm::Array fields, Span span) { return Tuple(fields, span); }); + Tuple WithFields(Tuple tuple, Optional> opt_fields, Optional opt_virtual_device, Optional opt_span) { Array fields = opt_fields.value_or(tuple->fields); diff --git a/src/relay/ir/indexed_graph.cc b/src/relay/ir/indexed_graph.cc index 4efe57b491db..150cfca18b48 100644 --- a/src/relay/ir/indexed_graph.cc +++ b/src/relay/ir/indexed_graph.cc @@ -19,195 +19,391 @@ /*! * \file src/relay/ir/indexed_graph.cc - * \brief Utilties for Creating Indexed Graphs. + * \brief A graph representation of the dataflow in a Relay expression or Relay dataflow + * pattern. Nodes in this graph capture generic dataflow inputs, outputs, dominators and control + * flow scopes. */ #include "indexed_graph.h" #include #include #include -#include +#include + +#include namespace tvm { namespace relay { -// IndexedGraph +std::string RefToSummary(const Expr& expr) { + class Visitor : public ExprFunctor { + std::string VisitExpr_(const VarNode* op) final { return "%" + op->name_hint(); } + std::string VisitExpr_(const GlobalVarNode* op) final { return "@" + op->name_hint; } + std::string VisitExpr_(const ConstantNode* op) final { return "const"; } + std::string VisitExpr_(const TupleNode* op) final { + return "tuple(" + std::to_string(op->fields.size()) + ")"; + } + std::string VisitExpr_(const FunctionNode* op) final { return "fn"; } + std::string VisitExpr_(const CallNode* op) final { + return VisitExpr(op->op) + "(" + std::to_string(op->args.size()) + ")"; + } + std::string VisitExpr_(const LetNode* op) final { return "let"; } + std::string VisitExpr_(const IfNode* op) final { return "if"; } + std::string VisitExpr_(const OpNode* op) final { return op->name; } + std::string VisitExpr_(const TupleGetItemNode* op) final { + return "." + std::to_string(op->index); + } + std::string VisitExpr_(const RefCreateNode* op) final { return "ref_create"; } + std::string VisitExpr_(const RefReadNode* op) final { return "ref_read"; } + std::string VisitExpr_(const RefWriteNode* op) final { return "ref_write"; } + std::string VisitExpr_(const ConstructorNode* op) final { return "ctor"; } + std::string VisitExpr_(const MatchNode* op) final { return "match"; } + }; + return Visitor().VisitExpr(expr); +} -IndexedGraph CreateIndexedGraph(const Expr& expr) { - using NodePtr = std::shared_ptr::Node>; - /*! \brief Creator Creates an IndexedGraph and determintes Topological order */ +std::string RefToSummary(const DFPattern& pattern) { return ""; } + +std::unique_ptr> CreateIndexedGraph(const Expr& expr) { + /*! + * \brief Adds indexed graph nodes in post-dfs order, and discovers which let-bound vars are to + * recursive functions. + */ class Creator : public MixedModeVisitor { public: - IndexedGraph CreateGraph(const Expr& expr) { + std::pair>, + std::unique_ptr>> + CreateGraph(const Expr& expr) { VisitExpr(expr); - graph_.node_map_[expr]->is_external_ = true; - return std::move(graph_); + // Last visited node is implicitly used 'externally'. + graph_->item_to_node(expr)->is_external_ = true; + return {std::move(graph_), std::move(rec_calls_)}; } protected: using MixedModeVisitor::VisitExpr_; + // By the default the MixedModeVisitor will place + // - callee and arguments before a call + // - tuple fields before a tuple + // - tuple before a tuple projection void VisitLeaf(const Expr& expr) override { + if (const auto* var_node = expr.as()) { + if (var_node == current_let_bound_var_) { + // Don't visit occurrences of let-rec bound vars in the recursive function body. + // Instead, wait for them to be visited at call sites outside of the function. + VLOG(1) << "Ignore let-rec var '" << var_node->name_hint() << "'"; + return; + } + } + MixedModeVisitor::VisitLeaf(expr); - auto node = std::make_shared::Node>(expr, index_++); - graph_.node_map_[expr] = node; - graph_.topological_order_.push_back(node); + graph_->AddNode(expr); + + if (const auto* call_node = expr.as()) { + if (const auto* var_node = call_node->op.as()) { + if (var_node == current_let_bound_var_) { + // Remember this is a recursive call to the let-rec bound function. + // The Annotator functor below will not record any dependency from the let-rec bound + // var to the expression so that the dataflow graph is always a DAG. + VLOG(1) << "Remembering recursive call to '" << var_node->name_hint() << "'"; + rec_calls_->emplace(call_node); + } + } + } } - void VisitExpr_(const LetNode* let) override { + void VisitExpr_(const LetNode* let_node) override { auto pre_visit = [&](const LetNode* op) { - this->VisitSpan(op->span); - this->VisitExpr(op->value); - this->VisitExpr(op->var); + // Let-bound values come before their let-bound variable. + const VarNode* prev_let_bound_var = current_let_bound_var_; + current_let_bound_var_ = op->var.get(); + VisitExpr(op->value); + current_let_bound_var_ = prev_let_bound_var; + VisitExpr(op->var); }; auto post_visit = [&](const LetNode* op) { - this->VisitExpr(op->body); - if (let != op) { - Expr expr = GetRef(op); + VisitExpr(op->body); + if (let_node != op) { + // Replicate VisitLeaf, which we are effectively bypassing. visit_counter_[op]++; - auto node = std::make_shared::Node>(expr, index_++); - graph_.node_map_[expr] = node; - graph_.topological_order_.push_back(node); + graph_->AddNode(GetRef(op)); } }; - ExpandANormalForm(let, pre_visit, post_visit); + ExpandANormalForm(let_node, pre_visit, post_visit); + } + + class PatternCreator : public PatternVisitor { + public: + explicit PatternCreator(Creator* creator) : creator_(creator) {} + + private: + void VisitPattern_(const PatternVarNode* pattern_var_node) final { + creator_->VisitLeaf(pattern_var_node->var); + } + + Creator* creator_; + }; + + void VisitExpr_(const MatchNode* match_node) override { + // Matched data comes before match-bound vars then match rhs, in match order. + VisitExpr(match_node->data); + for (const Clause& c : match_node->clauses) { + PatternCreator pattern_creator(this); + pattern_creator.VisitPattern(c->lhs); + VisitExpr(c->rhs); + } } - IndexedGraph graph_; - size_t index_ = 0; + /*! \brief Graph we are accumulated nodes into. */ + std::unique_ptr> graph_ = std::make_unique>(); + /*! \brief Variable the currently visited expression is to be let-bound to, if any. */ + const VarNode* current_let_bound_var_ = nullptr; + /*! \brief Accumulated calls to recursive functions. */ + std::unique_ptr> rec_calls_ = + std::make_unique>(); }; - /*! \brief Annotator takes an IndexedGraph, fills it's forward outputs, and does dominator tree - * analysis. + + /*! + * \brief Fills in the inputs and outputs for all nodes, then does dominator analysis. * - * Annotator use ExprFunctor to visit nodes, but iterates over them in pre-determined - * topological order instead of recursing. + * Thought we use the ExprFunctor to visit nodes, we never recurse and instead just inspect + * each sub-expression's immediate sub-sub-expressions to accumulate inputs and outputs. */ - class Annotator : public ExprFunctor { + class Annotator : public ExprFunctor { public: - Annotator(const IndexedGraph& graph) : graph_(graph) {} - IndexedGraph Annotate() { + explicit Annotator(std::pair>, + std::unique_ptr>> + args) + : graph_(std::move(args.first)), rec_calls_(std::move(args.second)) {} + + std::unique_ptr> Annotate() { // Visit all of the nodes in topological order to get forward outputs - for (const auto& node : graph_.topological_order_) { - ExprFunctor::VisitExpr(node->ref_, nullptr); + for (PostDfsIndex index = 0; index < graph_->size(); ++index) { + VisitExpr(graph_->index_to_node(index)->ref()); } // do the dominator analysis - graph_.PostDom(); + graph_->PostDom(); return std::move(graph_); } - /*! Default visitation pushes the parent to the child's outputs and the child to the parent's - * inputs*/ - void VisitExpr(const Expr& expr, NodePtr parent) override { - auto current = graph_.node_map_[expr]; - if (parent) { - current->outputs_.push_back(parent.get()); - parent->inputs_.push_back(current.get()); - } + /*! + * \brief Add \p parent as a possible output of the node corresponding to \p expr. + */ + void AddOutput(const Expr& expr, IndexedGraph::Node* parent) { + auto current = graph_->item_to_node(expr); + current->outputs_.push_back(parent); + parent->inputs_.push_back(current); } protected: - IndexedGraph graph_; - void VisitExpr_(const VarNode* op, NodePtr parent) override { - if (op->type_annotation.defined()) { - this->VisitType(op->type_annotation); - } - } + void VisitExpr_(const VarNode* var_node) override {} - void VisitExpr_(const GlobalVarNode* op, NodePtr parent) override {} + void VisitExpr_(const GlobalVarNode* global_var_node) override {} - void VisitExpr_(const ConstantNode* op, NodePtr parent) override {} + void VisitExpr_(const ConstantNode* constant_node) override {} - void VisitExpr_(const TupleNode* op, NodePtr parent) override { - for (auto field : op->fields) { - this->VisitExpr(field, graph_.node_map_[GetRef(op)]); + void VisitExpr_(const TupleNode* tuple_node) override { + auto node = graph_->item_to_node(GetRef(tuple_node)); + for (auto field : tuple_node->fields) { + AddOutput(field, node); } } - void VisitExpr_(const FunctionNode* op, NodePtr parent) override { - for (auto param : op->params) { - this->VisitExpr(param, graph_.node_map_[GetRef(op)]); + void VisitExpr_(const FunctionNode* function_node) override { + auto node = graph_->item_to_node(GetRef(function_node)); + // Nothing to do for parameters -- each use of a parameter will contribute to its outputs. + AddOutput(function_node->body, node); + } + + void VisitExpr_(const CallNode* call_node) override { + auto node = graph_->item_to_node(GetRef(call_node)); + if (rec_calls_->count(call_node)) { + // We want the dataflow graph to be a dag, so don't consider a call to a let-rec bound + // function from inside the function to depend on the let-rec bound var. + VLOG(1) << "Ignoring op in call " << RefToSummary(GetRef(call_node)); + } else { + AddOutput(call_node->op, node); + } + for (auto arg : call_node->args) { + AddOutput(arg, node); } + } + + void VisitExpr_(const LetNode* let_node) override { + auto node = graph_->item_to_node(GetRef(let_node)); + auto let_var_node = graph_->item_to_node(let_node->var); + AddOutput(let_node->value, let_var_node); + // Nothing to do for the let-bound variable -- each use of that variable in the let-body + // will contribute to its outputs. + AddOutput(let_node->body, node); + } + + void VisitExpr_(const IfNode* if_node) override { + auto node = graph_->item_to_node(GetRef(if_node)); + AddOutput(if_node->cond, node); + AddOutput(if_node->true_branch, node); + AddOutput(if_node->false_branch, node); + } + + void VisitExpr_(const OpNode* op_node) override {} - this->VisitExpr(op->body, graph_.node_map_[GetRef(op)]); + void VisitExpr_(const TupleGetItemNode* tuple_get_item_node) override { + auto node = graph_->item_to_node(GetRef(tuple_get_item_node)); + AddOutput(tuple_get_item_node->tuple, node); } - void VisitExpr_(const CallNode* op, NodePtr parent) override { - this->VisitExpr(op->op, graph_.node_map_[GetRef(op)]); + void VisitExpr_(const RefCreateNode* ref_create_node) override { + auto node = graph_->item_to_node(GetRef(ref_create_node)); + AddOutput(ref_create_node->value, node); + } + + void VisitExpr_(const RefReadNode* ref_read_node) override { + auto node = graph_->item_to_node(GetRef(ref_read_node)); + AddOutput(ref_read_node->ref, node); + } + + void VisitExpr_(const RefWriteNode* ref_write_node) override { + auto node = graph_->item_to_node(GetRef(ref_write_node)); + AddOutput(ref_write_node->ref, node); + AddOutput(ref_write_node->value, node); + } + + void VisitExpr_(const ConstructorNode* constructor_node) override {} + + class PatternAnnotator : public PatternVisitor { + public: + PatternAnnotator(Annotator* annotator, const ExprNode* adt_node) + : annotator_(annotator), adt_node_(adt_node) {} - for (auto ty_arg : op->type_args) { - this->VisitType(ty_arg); + private: + void VisitPattern_(const PatternVarNode* pattern_var_node) final { + auto node = annotator_->graph_->item_to_node(pattern_var_node->var); + annotator_->AddOutput(GetRef(adt_node_), node); } - for (auto arg : op->args) { - this->VisitExpr(arg, graph_.node_map_[GetRef(op)]); + Annotator* annotator_; + const ExprNode* adt_node_; + }; + + void VisitExpr_(const MatchNode* match_node) override { + // Data flows from the match data to pattern vars into match arms and out into overall + // match. + auto node = graph_->item_to_node(GetRef(match_node)); + for (const Clause& c : match_node->clauses) { + PatternAnnotator pattern_annotator(this, match_node->data.get()); + pattern_annotator.VisitPattern(c->lhs); + AddOutput(c->rhs, node); } } - void VisitExpr_(const LetNode* op, NodePtr parent) override { - this->VisitExpr(op->value, graph_.node_map_[GetRef(op)]); - this->VisitExpr(op->var, graph_.node_map_[GetRef(op)]); - this->VisitExpr(op->body, graph_.node_map_[GetRef(op)]); - } + std::unique_ptr> graph_; + /*! \brief Accumulated calls to recursive functions. */ + std::unique_ptr> rec_calls_; + }; + + /*! \brief Fills in the basic blocks for all nodes. */ + class Blocker : public MixedModeVisitor { + public: + explicit Blocker(std::unique_ptr> graph) : graph_(std::move(graph)) {} - void VisitExpr_(const IfNode* op, NodePtr parent) override { - this->VisitExpr(op->cond, graph_.node_map_[GetRef(op)]); - this->VisitExpr(op->true_branch, graph_.node_map_[GetRef(op)]); - this->VisitExpr(op->false_branch, graph_.node_map_[GetRef(op)]); + std::unique_ptr> Scope(const Expr& expr) { + VisitExpr(expr); + return std::move(graph_); } - void VisitExpr_(const OpNode* op, NodePtr parent) override { return; } + private: + using MixedModeVisitor::VisitExpr_; - void VisitExpr_(const TupleGetItemNode* op, NodePtr parent) override { - this->VisitExpr(op->tuple, graph_.node_map_[GetRef(op)]); + void VisitLeaf(const Expr& expr) override { + MixedModeVisitor::VisitLeaf(expr); + SetScope(expr); } - void VisitExpr_(const RefCreateNode* op, NodePtr parent) override { - this->VisitExpr(op->value, graph_.node_map_[GetRef(op)]); + void VisitExpr_(const FunctionNode* function_node) override { + auto node = graph_->item_to_node(GetRef(function_node)); + basic_block_stack_.push_back(node); + ExprVisitor::VisitExpr_(function_node); + basic_block_stack_.pop_back(); } - void VisitExpr_(const RefReadNode* op, NodePtr parent) override { - this->VisitExpr(op->ref, graph_.node_map_[GetRef(op)]); + void VisitExpr_(const IfNode* if_node) override { + VisitExpr(if_node->cond); + auto node = graph_->item_to_node(GetRef(if_node)); + basic_block_stack_.push_back(node); + VisitExpr(if_node->true_branch); + VisitExpr(if_node->false_branch); + basic_block_stack_.pop_back(); } - void VisitExpr_(const RefWriteNode* op, NodePtr parent) override { - this->VisitExpr(op->ref, graph_.node_map_[GetRef(op)]); - this->VisitExpr(op->value, graph_.node_map_[GetRef(op)]); + void VisitExpr_(const LetNode* let_node) override { + auto pre_visit = [&](const LetNode* op) { + VisitExpr(op->value); + VisitExpr(op->var); + }; + auto post_visit = [&](const LetNode* op) { + VisitExpr(op->body); + if (let_node != op) { + visit_counter_[op]++; + SetScope(GetRef(op)); + } + }; + ExpandANormalForm(let_node, pre_visit, post_visit); } - void VisitExpr_(const ConstructorNode* op, NodePtr parent) override { - for (const Type& t : op->inputs) { - this->VisitType(t); + class PatternBlocker : public PatternVisitor { + public: + explicit PatternBlocker(Blocker* scoper) : scoper_(scoper) {} + + private: + void VisitPattern_(const PatternVarNode* pattern_var_node) final { + scoper_->SetScope(pattern_var_node->var); } - this->VisitType(op->belong_to); - } - void VisitExpr_(const MatchNode* op, NodePtr parent) override { - this->VisitExpr(op->data, graph_.node_map_[GetRef(op)]); - for (const Clause& c : op->clauses) { - this->VisitClause(c, graph_.node_map_[GetRef(op)]); + Blocker* scoper_; + }; + + void VisitExpr_(const MatchNode* match_node) override { + VisitExpr(match_node->data); + auto node = graph_->item_to_node(GetRef(match_node)); + basic_block_stack_.push_back(node); + for (const Clause& c : match_node->clauses) { + PatternBlocker pattern_scoper(this); + pattern_scoper.VisitPattern(c->lhs); + VisitExpr(c->rhs); } + basic_block_stack_.pop_back(); } - void VisitClause(const Clause& op, NodePtr parent) { - this->VisitPattern(op->lhs); - this->VisitExpr(op->rhs, parent); + void SetScope(const Expr& expr) { + auto node = graph_->item_to_node(expr); + if (!basic_block_stack_.empty()) { + node->basic_block_ = basic_block_stack_.back(); + } } - void VisitPattern(const Pattern& p) { return; } - - void VisitType(const Type& t) { return; } + std::unique_ptr> graph_; + std::vector::Node*> basic_block_stack_; }; - return Annotator(Creator().CreateGraph(expr)).Annotate(); + + VLOG(1) << "CreateIndexedGraph:" << std::endl << PrettyPrint(expr); + std::unique_ptr> graph = + Blocker(Annotator(Creator().CreateGraph(expr)).Annotate()).Scope(expr); + VLOG(1) << "graph:" << std::endl << graph->ToString(); +#if TVM_LOG_DEBUG + graph->CheckValid(); +#endif + return graph; } -IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { - using NodePtr = std::shared_ptr::Node>; - /*! \brief Creator Creates an IndexedGraph and determintes Toplogical order */ +std::unique_ptr> CreateIndexedGraph(const DFPattern& pattern) { + /*! \brief Creates an IndexedGraph and determines topological order */ class Creator : public DFPatternVisitor { public: - IndexedGraph CreateGraph(const DFPattern& pattern) { + std::unique_ptr> CreateGraph(const DFPattern& pattern) { + graph_ = std::make_unique>(); VisitDFPattern(pattern); - graph_.node_map_[pattern]->is_external_ = true; + graph_->item_to_node(pattern)->is_external_ = true; return std::move(graph_); } @@ -215,121 +411,135 @@ IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { void VisitDFPattern(const DFPattern& pattern) override { if (this->visited_.count(pattern.get()) == 0) { DFPatternVisitor::VisitDFPattern(pattern); - auto node = std::make_shared::Node>(pattern, index_++); - graph_.node_map_[pattern] = node; - graph_.topological_order_.push_back(node); + graph_->AddNode(pattern); } } - IndexedGraph graph_; - size_t index_ = 0; + + std::unique_ptr> graph_; }; + /*! \brief Annotator takes an IndexedGraph, fills it's forward outputs, and does domiantor tree * analysis. * * Annotator use ExprFunctor to visit nodes, but iterates over them in pre-determined * topological order instead of recursing. */ - class Annotator : public DFPatternFunctor { + class Annotator : public DFPatternFunctor { public: - Annotator(const IndexedGraph& graph) : graph_(graph) {} - IndexedGraph Annotate() { + Annotator(std::unique_ptr> graph) : graph_(std::move(graph)) {} + + std::unique_ptr> Annotate() { // Visit all of the nodes in topological order to get forward outputs - for (const auto& node : graph_.topological_order_) { - DFPatternFunctor::VisitDFPattern(node->ref_, nullptr); + for (PostDfsIndex index = 0; index < graph_->size(); ++index) { + VisitDFPattern(graph_->index_to_node(index)->ref()); } - graph_.PostDom(); // do the dominator analysis + graph_->PostDom(); return std::move(graph_); } /*! Default visitation pushes the parent to the child's outputs */ - void VisitDFPattern(const DFPattern& pattern, NodePtr parent) override { - auto current = graph_.node_map_[pattern]; + void AddOutput(const DFPattern& pattern, IndexedGraph::Node* parent) { + auto current = graph_->item_to_node(pattern); if (parent) { - current->outputs_.push_back(parent.get()); - parent->inputs_.push_back(current.get()); + current->outputs_.push_back(parent); + parent->inputs_.push_back(current); } } protected: - IndexedGraph graph_; - void VisitDFPattern_(const AltPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->left, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->right, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const AltPatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->left, node); + AddOutput(op->right, node); } - void VisitDFPattern_(const AttrPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const AttrPatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->pattern, node); } - void VisitDFPattern_(const CallPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->op, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const CallPatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->op, node); if (op->args.defined()) { for (auto arg : op->args) { - VisitDFPattern(arg, graph_.node_map_[GetRef(op)]); + AddOutput(arg, node); } } } - void VisitDFPattern_(const ConstantPatternNode* op, NodePtr parent) override {} + void VisitDFPattern_(const ConstantPatternNode* op) override {} - void VisitDFPattern_(const DataTypePatternNode* op, NodePtr parent) override { - VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const DataTypePatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->pattern, node); } - void VisitDFPattern_(const DominatorPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->parent, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->path, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->child, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const DominatorPatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->parent, node); + AddOutput(op->path, node); + AddOutput(op->child, node); } - void VisitDFPattern_(const ExprPatternNode* op, NodePtr parent) override {} + void VisitDFPattern_(const ExprPatternNode* op) override {} - void VisitDFPattern_(const FunctionPatternNode* op, NodePtr parent) override { + void VisitDFPattern_(const FunctionPatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); if (op->params.defined()) { for (auto param : op->params) { - VisitDFPattern(param, graph_.node_map_[GetRef(op)]); + AddOutput(param, node); } } - VisitDFPattern(op->body, graph_.node_map_[GetRef(op)]); + AddOutput(op->body, node); } - void VisitDFPattern_(const ShapePatternNode* op, NodePtr parent) override { - VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const ShapePatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->pattern, node); } - void VisitDFPattern_(const TupleGetItemPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->tuple, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const TupleGetItemPatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->tuple, node); } - void VisitDFPattern_(const TuplePatternNode* op, NodePtr parent) override { + void VisitDFPattern_(const TuplePatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); if (op->fields.defined()) { for (auto field : op->fields) { - VisitDFPattern(field, graph_.node_map_[GetRef(op)]); + AddOutput(field, node); } } } - void VisitDFPattern_(const IfPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->cond, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->true_branch, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->false_branch, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const IfPatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->cond, node); + AddOutput(op->true_branch, node); + AddOutput(op->false_branch, node); } - void VisitDFPattern_(const LetPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->var, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->value, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->body, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const LetPatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->var, node); + AddOutput(op->value, node); + AddOutput(op->body, node); } - void VisitDFPattern_(const TypePatternNode* op, NodePtr parent) override { - VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const TypePatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->pattern, node); } - void VisitDFPattern_(const VarPatternNode* op, NodePtr parent) override {} + void VisitDFPattern_(const VarPatternNode* op) override {} - void VisitDFPattern_(const WildcardPatternNode* op, NodePtr parent) override {} + void VisitDFPattern_(const WildcardPatternNode* op) override {} + + std::unique_ptr> graph_; }; + return Annotator(Creator().CreateGraph(pattern)).Annotate(); } diff --git a/src/relay/ir/indexed_graph.h b/src/relay/ir/indexed_graph.h index d073bcaeea5c..d623029ffb70 100644 --- a/src/relay/ir/indexed_graph.h +++ b/src/relay/ir/indexed_graph.h @@ -19,7 +19,11 @@ /*! * \file src/relay/ir/indexed_graph.h - * \brief A pattern matcher for matching dataflow properties. + * \brief A graph representation of the dataflow in a Relay expression or Relay dataflow + * pattern. Nodes in this graph capture generic dataflow inputs, outputs, dominators and control + * flow scopes. + * + * TODO(mbs): Rename to 'DataflowGraph'. */ #ifndef TVM_RELAY_IR_INDEXED_GRAPH_H_ #define TVM_RELAY_IR_INDEXED_GRAPH_H_ @@ -28,6 +32,7 @@ #include #include +#include #include #include #include @@ -36,47 +41,109 @@ namespace tvm { namespace relay { +/*! \brief The index of a node in the post-dfs traversal of overall expression. */ +using PostDfsIndex = size_t; + +/*! + * \brief Returns a brief summary of the 'reference' expression or pattern. Only used for + * debugging by IndexedGraph::ToString(). + */ +std::string RefToSummary(const Expr& expr); +std::string RefToSummary(const DFPattern& pattern); + /*! - * \brief A Wrapper around a templated graph type - * Holds a forward-backward indexed representation of the graph and a dominator tree representation - * of the graph + * \brief Represents the dataflow of an expression (or dataflow pattern) as a graph which is + * overlaid on the underlying expression (or dataflow pattern) graph. + * + * Each indexed graph node references the corresponding sub-expression (or dataflow sub-pattern) + * node, and captures: + * - dataflow inputs + * - dataflow outputs (or a flag indicating the node is an implied output) + * - dominator parent + * - dominator children + * - basic block * - * This class is templated and the implementaiton is in the header file so we can analyze both - * DFPattern and Expr with the same infrastructure. + * This class is templated so we can analyze both DFPatterns and Exprs with the same infrastructure. * - * IndexedGraph should be instantiated through the CreateIndexedGraph utilities. + * IndexedGraph should be instantiated through the CreateIndexedGraph utilities. */ template class IndexedGraph { public: - /*! \brief A Node that wraps the input type and represents the indexed graph and dominator tree */ + using TNode = typename T::ContainerType; + + /*! \brief A Node in the dataflow graph. */ struct Node { /*! \brief Node Constructor - * \param ref The input graph node - * \param index The index of the node in toplogical order + * \param ref The expression or dataflow pattern node this indexed graph node is augmenting. + * \param index The index of this node in the topological order */ - Node(const T& ref, const size_t index) : ref_(ref), index_(index) {} + Node(const TNode* ref, PostDfsIndex index) : node_ref_(ref), index_(index) {} + + /*! \brief The underlying expression or pattern node. */ + const TNode* node_ref_; - /*! \brief The input node */ - const T ref_; - /*! \brief The topological order index */ - const size_t index_; + T ref() const { + ICHECK(node_ref_ != nullptr); + return GetRef(node_ref_); + } + + /*! + * \brief The index of this node in post-dfs order. If left.index_ > right.index_ then + * left does not flow into right. If left.index_ = right.index_ then left and right are + * the same node. + */ + const PostDfsIndex index_; - /*! \brief A boolean to determine if this node is external to the graph */ + /*! \brief If true this node has implicit outputs, for example as the result of a function. */ bool is_external_ = false; - /*! \brief The forward inputs of the node */ + /*! \brief Immediate dataflow inputs to this node. */ std::vector inputs_; - /*! \brief The forward outputs/users of the node */ + /*! \brief Immediate dataflow outputs of this node -- may be empty if is_external_ is true. */ std::vector outputs_; - /*! \brief The depth of the node in the dominator tree */ + /*! + * \brief The node representing the 'basic block' containing this node: + * - Function bodies start a new basic block for their bodies. + * - The true and false branches of an if start their own blocks. + * - The arms of a match each have their own blocks. + */ + Node* basic_block_ = nullptr; + + /*! \brief The depth of this node in the dominator tree */ size_t depth_ = 0; - /*! \brief The dominator parent/final user of the outputs of this node */ - Node* dominator_parent_; - /*! \brief The nodes this node dominates */ + /*! + * \brief The dominator parent of this node. This is the node N with least index such that + * all possible dataflows from this node pass through N. + */ + Node* dominator_parent_ = nullptr; + /*! \brief The nodes this node dominates. */ std::vector dominator_children_; - bool Dominates(const Node* other) { + /*! + * Add to \p nodes all the nodes which are strictly downstream of \p this, ie can be + * reached by following output paths. + */ + void AccumulateDownstreamNodes(std::unordered_set* nodes) const { + std::stack stack; + stack.push(this); + while (!stack.empty()) { + const Node* current = stack.top(); + stack.pop(); + for (auto node : current->outputs_) { + if (nodes->count(node) == 0) { + stack.push(node); + nodes->insert(node); + } + } + } + } + + /*! + * \brief Returns true if \p this is a dominator of \p other. Ie all dataflow paths from \p + * other pass through \p this. + */ + bool Dominates(const Node* other) const { std::stack stack; std::unordered_set visited; stack.push(this); @@ -97,10 +164,11 @@ class IndexedGraph { return false; } }; + /*! \brief Construct the domination tree inside IndexedGraph */ void PostDom() { - for (size_t i = topological_order_.size(); i != 0; --i) { - size_t index = i - 1; + for (PostDfsIndex i = topological_order_.size(); i != 0; --i) { + PostDfsIndex index = i - 1; auto* current = topological_order_[index].get(); if (current->is_external_) { current->depth_ = 1; @@ -109,16 +177,127 @@ class IndexedGraph { auto parent = LeastCommonAncestor(current->outputs_); current->depth_ = parent ? parent->depth_ + 1 : 1; current->dominator_parent_ = parent; - parent->dominator_children_.push_back(current); + if (parent) { + parent->dominator_children_.push_back(current); + } } } } - /*! \brief Map of input nodes to IndexedGraph Nodes */ - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> node_map_; - /*! \brief Topological IndexedGraph Nodes */ - std::vector> topological_order_; - protected: + PostDfsIndex size() const { return topological_order_.size(); } + + Node* item_to_node(const T& item) { return item_to_node(item.get()); } + const Node* item_to_node(const T& item) const { return item_to_node(item.get()); } + + Node* item_to_node(const TNode* item) { + auto itr = node_map_.find(item); + ICHECK(itr != node_map_.end()) << PrettyPrint(GetRef(item)); + return itr->second; + } + + const Node* item_to_node(const TNode* item) const { + auto itr = node_map_.find(item); + ICHECK(itr != node_map_.end()) << PrettyPrint(GetRef(item)); + return itr->second; + } + + Node* index_to_node(PostDfsIndex index) { + ICHECK_LT(index, topological_order_.size()) << index; + return topological_order_[index].get(); + } + + const Node* index_to_node(PostDfsIndex index) const { + ICHECK_LT(index, topological_order_.size()) << index; + return topological_order_[index].get(); + } + + /*! + * \brief (For debugging only) Returns description of indexed graph with hints as to the + * sub-expressions or sub-patterns corresponding to each indexed graph node. + */ + std::string ToString() const { + std::ostringstream os; + os << "IndexedGraph(size = " << topological_order_.size() << ") {" << std::endl; + for (PostDfsIndex index = 0; index < topological_order_.size(); ++index) { + const Node* node = topological_order_[index].get(); + ICHECK_EQ(index, node->index_); + os << " " << index << " (" << RefToSummary(node->ref()) << "): inputs=["; + for (const auto* sub_node : node->inputs_) { + os << sub_node->index_ << ","; + } + os << "], outputs=["; + for (const auto* sub_node : node->outputs_) { + os << sub_node->index_ << ","; + } + os << "]"; + if (node->is_external_) { + os << ", external"; + } + if (node->basic_block_) { + os << ", basic_block=" << node->basic_block_->index_; + } + if (node->depth_ > 0) { + os << ", depth=" << node->depth_; + } + if (node->dominator_parent_) { + os << ", dom_parent=" << node->dominator_parent_->index_; + } + os << ", dom_children=["; + for (const auto* sub_node : node->dominator_children_) { + os << sub_node->index_ << ","; + } + os << "]" << std::endl; + } + os << "}"; + return os.str(); + } + + /*! + * Check-fails if the graph is ill-formed. For debugging only. + */ + void CheckValid() const { + ICHECK_GT(topological_order_.size(), 0); + for (PostDfsIndex index = 0; index < topological_order_.size(); ++index) { + const Node* node = topological_order_[index].get(); + // We have a node. + ICHECK(node); + // Bijections with post-dfs indexes and expressions/patterns are correct. + ICHECK_EQ(node->index_, index); + ICHECK(node->node_ref_); + auto itr = node_map_.find(node->node_ref_); + ICHECK(itr != node_map_.end()); + ICHECK_EQ(itr->second, node) << "at index " << index << " in:" << std::endl << ToString(); + // Inputs come before. + for (size_t i = 0; i < node->inputs_.size(); ++i) { + const Node* input = node->inputs_[i]; + ICHECK(input); + ICHECK_LT(input->index_, index); + ICHECK(std::find(input->outputs_.begin(), input->outputs_.end(), node) != + input->outputs_.end()); + } + // Outputs come after. + for (size_t i = 0; i < node->outputs_.size(); ++i) { + const Node* output = node->outputs_[i]; + ICHECK(output); + ICHECK_GT(output->index_, index); + ICHECK(std::find(output->inputs_.begin(), output->inputs_.end(), node) != + output->inputs_.end()); + } + ICHECK_GT(node->depth_, 0); + // Dominator children come before. + for (size_t i = 0; i < node->dominator_children_.size(); ++i) { + const Node* child = node->dominator_children_[i]; + ICHECK(child); + ICHECK_LT(child->index_, index); + } + if (node->dominator_parent_) { + // Dominator comes after. + ICHECK_GT(node->dominator_parent_->index_, index); + } + } + } + + private: /*! \brief Find the least common ancestor of all outputs of a node */ Node* LeastCommonAncestor(const std::vector& outputs) { if (outputs.size() == 0) { @@ -136,9 +315,11 @@ class IndexedGraph { if (lhs == nullptr || rhs == nullptr) { return nullptr; } + PostDfsIndex lhs_index = lhs->index_; + PostDfsIndex rhs_index = rhs->index_; while (lhs != rhs) { - ICHECK(lhs); - ICHECK(rhs); + ICHECK(lhs && rhs) << "LCA(" << lhs_index << ", " << rhs_index << ") on graph:" << std::endl + << ToString(); if (lhs->depth_ < rhs->depth_) { rhs = rhs->dominator_parent_; } else if (lhs->depth_ > rhs->depth_) { @@ -150,13 +331,42 @@ class IndexedGraph { } return lhs; } + + /*! + * \brief Appends a node corresponding to \p ref, and maintains the sub-expression/sub-pattern to + * node bijection. The insertion index will be the node's PostDfsIndex. All other node properties + * are accumulated in-place. + */ + void AddNode(const T& ref) { + PostDfsIndex index = topological_order_.size(); + auto node = std::make_unique(ref.get(), index); + node_map_[ref.get()] = node.get(); + topological_order_.emplace_back(std::move(node)); + } + + /*! + * \brief Map from underlying sub-expressions or dataflow sub-pattern graph nodes to their + * indexed graph nodes. + */ + std::unordered_map node_map_; + /*! \brief All nodes in increasing post-dfs index order. This vector owns all the nodes. */ + std::vector> topological_order_; + + friend std::unique_ptr> CreateIndexedGraph(const Expr& expr); + friend std::unique_ptr> CreateIndexedGraph(const DFPattern& pattern); }; -/*! \brief Create an Indexed Graph based on an Expr */ -IndexedGraph CreateIndexedGraph(const Expr& expr); -/*! \brief Create an Indexed Graph based on an DFPattern */ -IndexedGraph CreateIndexedGraph(const DFPattern& pattern); +/*! \brief Returns an Indexed Graph for \p expr, which much outlive the result. */ +std::unique_ptr> CreateIndexedGraph(const Expr& expr); + +/*! + * \brief Returns an Indexed Graph for \p pattern, which must outlive the result. + * The dataflow for a pattern mimics the dataflow for the expression which would match + * that pattern. + */ +std::unique_ptr> CreateIndexedGraph(const DFPattern& pattern); } // namespace relay } // namespace tvm + #endif // TVM_RELAY_IR_INDEXED_GRAPH_H_ diff --git a/src/relay/op/dyn/tensor/transform.cc b/src/relay/op/dyn/tensor/transform.cc index f7045305e90d..d5cc6608662b 100644 --- a/src/relay/op/dyn/tensor/transform.cc +++ b/src/relay/op/dyn/tensor/transform.cc @@ -258,6 +258,7 @@ RELAY_REGISTER_OP("dyn.broadcast_to") .describe(R"code(Broadcast the first input to match the shape argument. )code" TVM_ADD_FILELINE) .set_num_inputs(2) + .set_attrs_type() .add_argument("data", "Tensor", "The input tensor.") .add_argument("shape", "Tensor", "Target shape.") .set_support_level(4) diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 9c01c40517f4..132bcaa82922 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -949,6 +949,11 @@ TVM_REGISTER_GLOBAL("relay._transform.InferTypeLocal").set_body_typed([](const E return InferTypeLocal(expr); }); +Expr InferTypeExpr(const Expr& expr) { + InferTypeLocal(expr); + return expr; +} + Pass InferType() { auto pass_info = PassInfo(0, "InferType", {}); return tvm::transform::CreateModulePass( diff --git a/src/runtime/contrib/tensorrt/tensorrt_ops.cc b/src/runtime/contrib/tensorrt/tensorrt_ops.cc index 3971081bf8f8..83b0c7cffe86 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_ops.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_ops.cc @@ -67,7 +67,7 @@ nvinfer1::ITensor* TensorRTOpConverter::Transpose(TensorRTOpConverterParams* par // Batch dimension cannot be modified. ICHECK_EQ(input->getDimensions().nbDims, order.size() - 1); ICHECK_EQ(order[0], 0); - for (size_t i = 0; i < order.size(); ++i) { + for (size_t i = 0; i + 1 < order.size(); ++i) { perm.order[i] = order[i + 1] - 1; } } else { diff --git a/tests/cpp/relay/ir/indexed_graph_test.cc b/tests/cpp/relay/ir/indexed_graph_test.cc new file mode 100644 index 000000000000..ce383c52909c --- /dev/null +++ b/tests/cpp/relay/ir/indexed_graph_test.cc @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "../../../src/relay/ir/indexed_graph.h" + +#include +#include +#include +#include + +namespace tvm { +namespace relay { +namespace { + +// A module stolen from onnx/test_forward.py::test_loop which combines functions, recursion, +// control flow, tuples as well as the usual operator calls. +// We include the known post-dfs indexes in comments to help write the tests. +IRModule TestRecursiveIRModule() { + Device device = {kDLCPU, 0}; + Constant const0(runtime::NDArray::Empty(ShapeTuple({1}), DataType::Int(64), device)); + Constant const1(runtime::NDArray::Empty(ShapeTuple({0, 1}), DataType::Float(32), device)); + Map> metadata; + metadata.Set("relay.Constant", Array({const0, const1})); + constexpr const char* kModel = R"( + #[version = "0.0.5"] + def @main(%trip_count: int64, // 0 + %cond: bool, // 1 + %y: Tensor[(1), float32]) // 2 + -> (Tensor[(1), float32], Tensor[(?, ?), float32]) { + %17 = ( + let %while_loop = fn (%iter_count: int64, // 3 + %max_count: int64, // 4 + %cond_in: bool, // 5 + %y_in: Tensor[(1), float32], // 6 + %scan_out: Tensor[(?, ?), float32]) // 7 + -> (int64, int64, bool, Tensor[(1), float32], Tensor[(?, ?), float32]) { + %0 = equal(%cond_in, True); // 11 + %1 = less(%iter_count, %max_count); // 13 + %2 = logical_and(%0, %1); // 14 + if (%2) { + %3 = cast(%iter_count, dtype="float32"); // 20 + %4 = add(%y_in, %3); // 21 + %5 = less(%4, 5f); // 23 + %6 = squeeze(%5); // 24 + %7 = reshape(%iter_count, newshape=[1]); // 29 + %8 = (%7, meta[relay.Constant][0]); // 31 + %9 = concatenate(%8); // 32 + %10 = copy(%4); // 36 + %11 = dyn.broadcast_to(%scan_out, %9, shape=None); // 33 + %12 = expand_dims(%10, axis=0); // 37 + %13 = (%11, %12); // 38 + %14 = add(%iter_count, 1i64); // 17 + %15 = cast(%6, dtype="bool"); // 25 + %16 = concatenate(%13); // 39 + %while_loop(%14, %max_count, %15, %4, %16) // 40 + } else { + (%iter_count, %max_count, %cond_in, %y_in, %scan_out) // 41 + } // 42 + }; // 43 + %while_loop // 44 + ); // 45 + %18 = %17(0i64, %trip_count, %cond, %y, meta[relay.Constant][1]); // 48 + %19 = %18.3; // 49 + %20 = %18.4; // 50 + (%19, %20) // 51 + } // 52 + )"; + return parser::ParseModule("string", kModel, /*init_module=*/{}, metadata); +} + +TEST(IndexedGraph, RecursiveExprRegression) { + IRModule ir_mod = TestRecursiveIRModule(); + auto main = Downcast(ir_mod->Lookup("main")); + auto graph = CreateIndexedGraph(main); + graph->CheckValid(); + + { + // Dataflow node properties for %4 + auto node = graph->index_to_node(21); + const auto* call_node = node->ref().as(); + ASSERT_NE(call_node, nullptr); + const auto* op_node = call_node->op.as(); + ASSERT_NE(op_node, nullptr); + ASSERT_EQ(op_node->name, "add"); + + // 3 inputs (the op itself is an input) + ASSERT_EQ(node->inputs_.size(), 3); + ASSERT_EQ(node->inputs_[0]->index_, 15); // the add op + ASSERT_EQ(node->inputs_[1]->index_, 6); // %y_in + ASSERT_EQ(node->inputs_[2]->index_, 20); // %3 + + // 3 outputs + ASSERT_EQ(node->outputs_.size(), 3); + ASSERT_EQ(node->outputs_[0]->index_, 23); // %5 + ASSERT_EQ(node->outputs_[1]->index_, 36); // %10 + ASSERT_EQ(node->outputs_[2]->index_, 40); // recursive %while_loop call + + // In the 'if' basic block + ASSERT_EQ(node->basic_block_->index_, 42); + + // Dominator 'parent' is recursive call + ASSERT_EQ(node->dominator_parent_->index_, 40); + + // One dominator child from %3 + ASSERT_EQ(node->dominator_children_.size(), 1); + ASSERT_EQ(node->dominator_children_[0]->index_, 20); + } + + { + // The recursive call to %while_loop does not depend on %while_loop + auto node = graph->index_to_node(40); + const auto* call_node = node->ref().as(); + ASSERT_NE(call_node, nullptr); + const auto* var_node = call_node->op.as(); + ASSERT_NE(var_node, nullptr); + ASSERT_EQ(var_node->name_hint(), "while_loop"); + + ASSERT_EQ(node->inputs_.size(), 5); + ASSERT_EQ(node->inputs_[0]->index_, 17); // %14 + ASSERT_EQ(node->inputs_[1]->index_, 4); // %max_count + ASSERT_EQ(node->inputs_[2]->index_, 25); // %15 + ASSERT_EQ(node->inputs_[3]->index_, 21); // %4 + ASSERT_EQ(node->inputs_[4]->index_, 39); // %16 + } +} + +// A module with unused let-bound function. The 'add' operator should have no dominator +// since it is used both in the unused function and in the main body. +IRModule TestUnusedLetBoundIRModule() { + constexpr const char* kModel = R"( + #[version = "0.0.5"] + def @main(%x: int64) -> int64 { // 0 + let %f = fn ( // 5 + %y: int64 // 1 + ) { + add(%x, %y) // 3 + }; + if (less(%x, 5i64)) { + add(%x, 3i64) // 10 + } else { + %x + } + } + )"; + return parser::ParseModule("string", kModel); +} + +TEST(IndexedGraph, UnusedLetVars) { + IRModule ir_mod = TestUnusedLetBoundIRModule(); + auto main = Downcast(ir_mod->Lookup("main")); + auto graph = CreateIndexedGraph(main); + graph->CheckValid(); + + { + auto node = graph->index_to_node(2); + const auto* op_node = node->ref().as(); + ICHECK(op_node); + ICHECK_EQ(op_node->name, "add"); + ICHECK_EQ(node->outputs_.size(), 2); + ICHECK_EQ(node->outputs_[0]->index_, 3); + ICHECK_EQ(node->outputs_[1]->index_, 10); + ICHECK(node->dominator_parent_ == nullptr); + } +} + +} // namespace +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index 74e03f6a9755..682ed143aaff 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -16,7 +16,6 @@ # under the License. # pylint: disable=unused-wildcard-import import numpy as np -import pytest import tvm from tvm import relay @@ -601,6 +600,64 @@ def test_match_fake_diamond(): assert not diamond.match(out) +def test_at_most_one_parent(): + # Pattern + P = is_op("nn.conv2d")(wildcard(), wildcard()) # 'parent' + I = is_op("nn.relu")(wildcard()) # 'intermediate' ('path' in the code) + C = is_op("add")(wildcard(), wildcard()) # 'child' + pattern = dominates(P, I, C) + + # n6(P) + # / \ + # n7 \ + # / \ + # n8(P) n10(I) + # \ / + # n9(I) / + # \ / + # n11(C) + + x = relay.var("x") + w = relay.var("w") + n6 = relay.op.nn.conv2d(x, w) # matches P + n7 = relay.op.tanh(n6) # does not match I + n8 = relay.op.nn.conv2d(n7, w) # matches P + n9 = relay.op.nn.relu(n8) # matches I + n10 = relay.op.nn.relu(n6) # matches I + n11 = relay.add(n9, n10) # matches C + + # Does not match: Can't match the parent pattern P at both 8 and 6. + # Note that if we did allow P to be used twice the implementation would + # need to be changed to not 'jump over' n7. + assert not pattern.match(n11) + + +def test_parallel_injective(): + # Pattern + P = is_op("add")(wildcard(), wildcard()) # 'parent' + I = is_op("squeeze")(wildcard()) | is_op("transpose")( + wildcard() + ) # 'intermediate' ('path' in the code) + C = is_op("left_shift")(wildcard(), wildcard()) # 'child' + pattern = dominates(P, I, C) + + # + # n5(P) + # / \ + # n6(I) n8(I) + # \ / + # n9(C) + # + + x = relay.var("x", shape=(10, 20)) + n5 = relay.add(x, relay.const(1, "float32")) + n6 = relay.squeeze(n5) + n8 = relay.transpose(n5, axes=[0, 1]) + n9 = relay.left_shift(n6, n8) + + assert pattern.match(n9) + + def test_match_dominator(): # Pattern is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()) @@ -1760,4 +1817,4 @@ def callback(self, pre, post, node_map): if __name__ == "__main__": - pytest.main([__file__]) + tvm.testing.main() From b2bf71052e292c30dd0452cbd5423d85170d0a86 Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Mon, 6 Jun 2022 12:28:18 -0700 Subject: [PATCH 2/5] - Revert non IndexedGraph changes. --- include/tvm/relay/adt.h | 31 +++- include/tvm/relay/expr.h | 173 ++++++++++++++----- include/tvm/relay/expr_functor.h | 2 - include/tvm/relay/function.h | 25 ++- include/tvm/relay/transform.h | 5 - src/relay/ir/expr.cc | 40 ----- src/relay/op/dyn/tensor/transform.cc | 1 - src/relay/transforms/type_infer.cc | 5 - src/runtime/contrib/tensorrt/tensorrt_ops.cc | 2 +- 9 files changed, 176 insertions(+), 108 deletions(-) diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index cdb8e52d2359..31dec2204146 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -264,9 +264,17 @@ class Clause : public ObjectRef { }; /*! - * \brief Returns \p clause with the given properties. A null property denotes 'no change'. - * Returns \p clause if all properties are unchanged. Otherwise, returns a copy with the new - * fields. + * \brief Returns the clause with given properties. A null property denotes 'no change'. + * Returns clause if all properties are unchanged. Otherwise, returns a copy with the new fields. + * \param clause The clause to copy. + * \param opt_lhs The (optional) lhs for the copied clause. If none, ret_clause->lhs = clause->lhs. + * \param opt_rhs The (optional) rhs for the copied clause. If none, + * ret_clause->rhs = clause->rhs. + * \return If all + * properties are null or the same as the property in the input clause (i.e., opt_lhs is null or + * opt_lhs.value() == clause->lhs, etc.), then we return clause. Otherwise, we return a copy of + * clause with the different fields overwritten. (i.e., if opt_lhs.value() != clause->lhs, then + * ret_clause->lhs = opt_lhs.value()). */ Clause WithFields(Clause clause, Optional opt_lhs = Optional(), Optional opt_rhs = Optional()); @@ -329,9 +337,20 @@ class Match : public Expr { }; /*! - * \brief Returns \p match with the given properties. A null property denotes 'no change'. - * Returns \p match if all properties are unchanged. Otherwise, returns a copy with the new - * fields. + * \brief Returns the match with given properties. A null property denotes 'no change'. + * Returns match if all properties are unchanged. Otherwise, returns a copy with the new fields. + * \param match The match to copy. + * \param opt_data The (optional) data for the copied match. If none, ret_match->data = match->data. + * \param opt_clauses The (optional) clauses for the copied match. If none, ret_match->clauses = + * match->clauses. + * \param opt_complete The (optional) complete for the copied match. If none, ret_match->complete = + * match->complete. + * \param opt_span The (optional) span for the copied match. If none, ret_match->span = match->span. + * \return If all properties are null or the same as the + * property in the input match (i.e., opt_clauses is null or opt_clauses.value() == match->clauses, + * etc.), then we return match. Otherwise, we return a copy of match with the different fields + * overwritten. (i.e., if opt_clauses.value() != match->clauses, then ret_match->clauses = + * opt_clauses.value()). */ Match WithFields(Match match, Optional opt_data = Optional(), Optional> opt_clauses = Optional>(), diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 6b014c8478d8..fe570806922f 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -39,16 +39,6 @@ #include "./type.h" namespace tvm { - -/*! - * \brief Returns \p global_var with the given properties. A null property denotes 'no change'. - * Returns \p global_var if all properties are unchanged. Otherwise, returns a copy with the new - * fields. - */ -GlobalVar WithFields(GlobalVar global_var, Optional opt_name_hint = {}, - Optional opt_type = {}, Optional opt_virtual_device = {}, - Optional opt_span = {}); - namespace relay { using Expr = tvm::RelayExpr; @@ -107,17 +97,8 @@ class Constant : public Expr { TVM_DLL explicit Constant(runtime::NDArray data, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Constant, RelayExpr, ConstantNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(ConstantNode); }; -/*! - * \brief Returns \p constant with the given properties. A null property denotes 'no change'. - * Returns \p constant if all properties are unchanged. Otherwise, returns a copy with the new - * fields. - */ -Constant WithFields(Constant constant, Optional opt_data = {}, - Optional opt_virtual_device = {}, Optional opt_span = {}); - /*! \brief Tuple of multiple Exprs */ class Tuple; /*! \brief Tuple container */ @@ -168,9 +149,15 @@ class Tuple : public Expr { }; /*! - * \brief Returns \p tuple with the given properties. A null property denotes 'no change'. - * Returns \p tuple if all properties are unchanged. Otherwise, returns a copy with the new - * fields. + * \brief Returns the tuple with given properties. A null property denotes 'no change'. + * Returns this if all properties are unchanged. Otherwise, returns a copy with the new fields. + * \param tuple The tuple to copy + * \param opt_fields The (optional) fields for the copied tuple. If none, ret_tuple->fields = + * tuple->fields. + * \param opt_virtual_device The (optional) virtual_device for the copied tuple. If none, + * ret_tuple->virtual_device = tuple->virtual_device. + * \param opt_span The (optional) span for the copied tuple. If none, + * ret_tuple->span = tuple->span. */ Tuple WithFields(Tuple tuple, Optional> opt_fields = Optional>(), Optional opt_virtual_device = Optional(), @@ -264,9 +251,19 @@ class Var : public Expr { }; /*! - * \brief Returns \p vor with the given properties. A null property denotes 'no change'. - * Returns \p var if all properties are unchanged. Otherwise, returns a copy with the new - * fields. + * \brief Returns the var with given properties. A null property denotes 'no change'. + * Returns var if all properties are unchanged. Otherwise, returns a copy with the new fields. + * \param var The var to copy. + * \param opt_vid The (optional) vid for the copied var. If none, ret_var->vid = var->vid. + * \param opt_type_annotation The (optional) type_annotation for the copied var. If none, + * ret_var->type_annotation = var->type_annotation. + * \param opt_virtual_device The (optional) virtual_device for the copied tuple. If none, + * ret_tuple->virtual_device = tuple->virtual_device. + * \param opt_span The (optional) span for the copied var. If none, ret_var->span = var->span. + * \return If all properties are null or the same as the property in the input var (i.e., opt_vid is + * null or opt_vid.value() == var->vid, etc.), then we return var. Otherwise, we return a copy of + * call with the different fields overwritten. (i.e., if opt_vid.value() != var->vid, then + * ret_var->vid = opt_.value()). */ Var WithFields(Var var, Optional opt_vid = Optional(), Optional opt_type_annotation = Optional(), @@ -377,9 +374,22 @@ class Call : public Expr { }; /*! - * \brief Returns \p call with the given properties. A null property denotes 'no change'. - * Returns \p call if all properties are unchanged. Otherwise, returns a copy with the new - * fields. + * \brief Returns the call with given properties. A null property denotes 'no change'. + * Returns call if all properties are unchanged. Otherwise, returns a copy with the new fields. + * \param call The call to copy. + * \param opt_op The (optional) op for the copied call. If none, ret_call->op = call->op. + * \param opt_args The (optional) args for the copied call. If none, ret_call->args = call->args. + * \param opt_attrs The (optional) attrs for the copied call. If none, ret_call->attrs = + * call->attrs. + * \param opt_type_args The (optional) type args for the copied call. If none, + * ret_call->type_args = call->type_args. + * \param opt_virtual_device The (optional) virtual_device for the copied call. If none, + * ret_call->virtual_device = call->virtual_device. + * \param opt_span The (optional) span for the copied call. If none, ret_call->span = call->span. + * \return If all properties are null or the same as the property in the input call (i.e., opt_op is + * null or opt_op.value() == call->op, etc.), then we return call. Otherwise, we return a copy of + * call with the different fields overwritten. (i.e., if opt_op.value() != call->op, then + * ret_call->op = opt_op.value()). */ Call WithFields(Call call, Optional opt_op = Optional(), Optional> opt_args = Optional>(), @@ -465,9 +475,19 @@ class Let : public Expr { }; /*! - * \brief Returns \p let with the given properties. A null property denotes 'no change'. - * Returns \p let if all properties are unchanged. Otherwise, returns a copy with the new - * fields. + * \brief Returns the let with given properties. A null property denotes 'no change'. + * Returns let if all properties are unchanged. Otherwise, returns a copy with the new fields. + * \param let The let to copy. + * \param opt_var The (optional) var for the copied let. If none, ret_let->op = let->op. + * \param opt_value The (optional) value for the copied let. If none, ret_let->args = let->args. + * \param opt_body The (optional) body for the copied let. If none, ret_let->attrs = let->attrs. + * \param opt_virtual_device The (optional) virtual_device for the copied let. If none, + * ret_let->virtual_device = let->virtual_device. + * \param opt_span The (optional) span for the copied let. If none, ret_let->span = let->span. + * \return If all properties are null or the same as the property in the input let (i.e., opt_var is + * null or opt_var.value() == let->var, etc.), then we return let. Otherwise, we return a copy of + * let with the different fields overwritten. (i.e., if opt_var.value() != let->var, then + * ret_let->var = opt_var.value()). */ Let WithFields(Let let, Optional opt_var = Optional(), Optional opt_value = Optional(), @@ -539,9 +559,23 @@ class If : public Expr { }; /*! - * \brief Returns \p if_expr with the given properties. A null property denotes 'no change'. - * Returns \p if_expr if all properties are unchanged. Otherwise, returns a copy with the new - * fields. + * \brief Returns the if_expr with given properties. A null property denotes 'no change'. + * Returns if_expr if all properties are unchanged. Otherwise, returns a copy with the new fields. + * \param if_expr The if expression to copy. + * \param opt_cond The (optional) cond for the copied if_expr. If none, ret_if->cond = + * if_expr->cond. + * \param opt_true_branch The (optional) true_branch for the copied if_expr. If none, + * ret_if->true_branch = ret_if->false_branch. + * \param opt_false_branch The (optional) false_branch + * for the copied if_expr. If none, ret_if->false_branch = if_expr->false_branch. + * \param opt_virtual_device The (optional) virtual_device for the copied if_expr. If none, + * ret_if->virtual_device = if_expr->virtual_device. + * \param opt_span The (optional) span for the copied if_expr. If none, + * ret_if->span = if_expr->span. + * \return If all properties are null or the same as the property in + * the input if_expr (i.e., opt_cond is null or opt_cond.value() == if_expr->cond, etc.), then we + * return if_expr. Otherwise, we return a copy of if_expr with the different fields overwritten. + * (i.e., if opt_cond.value() != if_expr->cond, then ret_if->cond = opt_cond.value()). */ If WithFields(If if_expr, Optional opt_cond = Optional(), Optional opt_true_branch = Optional(), @@ -594,9 +628,22 @@ class TupleGetItem : public Expr { }; /*! - * \brief Returns \p tuple_get_item with the given properties. A null property denotes 'no change'. - * Returns \p tuple_get_item if all properties are unchanged. Otherwise, returns a copy with the new - * fields. + * \brief Returns the tuple_get_item with given properties. A null property denotes 'no change'. + * Returns if_expr if all properties are unchanged. Otherwise, returns a copy with the new fields. + * \param tuple_get_item The tuple_get_item to copy. + * \param opt_tuple The (optional) tuple for the copied tuple_get_item. If none, + * ret_tuple_get_item->tuple = tuple_get_item->tuple. + * \param opt_index The (optional) index for the copied tuple_get_item. If none, + * ret_tuple_get_item->index = tuple_get_item->index. + * \param opt_virtual_device The (optional) virtual_device for the copied tuple_get_item. + * If none, ret_tuple_get_item->virtual_device = tuple_get_item->virtual_device. + * \param opt_span The (optional) span for the copied tuple_get_item. If none, + * ret_tuple_get_item->span = tuple_get_item->span. + * \return If all properties are null or the same as the property in the input tuple_get_item + * (i.e., opt_tuple is null or opt_tuple.value() == tuple_get_item->tuple, etc.), then we return + * tuple_get_item. Otherwise, we return a copy of tuple_get_item with the different fields + * overwritten. (i.e., if opt_tuple.value() != tuple_get_item->tuple, then + * ret_tuple_get_item->tuple = opt_tuple.value()). */ TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple = Optional(), Optional opt_index = Optional(), @@ -645,9 +692,21 @@ class RefCreate : public Expr { }; /*! - * \brief Returns \p ref_create with the given properties. A null property denotes 'no change'. - * Returns \p ref_crete if all properties are unchanged. Otherwise, returns a copy with the new + * \brief Returns the ref create with given properties. A null property denotes 'no change'. + * Returns ref_create if all properties are unchanged. Otherwise, returns a copy with the new * fields. + * \param ref_create The ref_create to copy. + * \param opt_value The (optional) value for the copied ref_create. If none, + * ret_ref_create->value = ref_create->value. + * \param opt_virtual_device The (optional) virtual_device for the copied ref_create. If none, + * ret_ref_create->virtual_device = ref_create->virtual_device. + * \param opt_span The (optional) span for the copied ref_create. If none, + * ret_ref_create->span = ref_create->span. + * \return If all properties are null or the same as the property in the input ref_create + * (i.e., opt_value is null or opt_value.value() == ref_create->value, etc.), then we return + * ref_create. Otherwise, we return a copy of ref_create with the different fields overwritten. + * (i.e., if opt_value.value() != ref_create->value, then + * ret_ref_create->value = opt_value.value()). */ RefCreate WithFields(RefCreate ref_create, Optional opt_value = Optional(), Optional opt_virtual_device = Optional(), @@ -695,9 +754,20 @@ class RefRead : public Expr { }; /*! - * \brief Returns \p ref_read with the given properties. A null property denotes 'no change'. - * Returns \p ref_read if all properties are unchanged. Otherwise, returns a copy with the new - * fields. + * \brief Returns the ref read with given properties. A null property denotes 'no change'. + * Returns ref_read if all properties are unchanged. Otherwise, returns a copy with the new fields. + * \param ref_read The ref_read to copy. + * \param opt_ref The (optional) ref for the copied ref_read. If none, ret_ref_read->ref = + * ref_read->ref. + * \param opt_virtual_device + * The (optional) virtual_device for the copied ref_read. If none, ret_ref_read->virtual_device = + * ref_read->virtual_device. + * \param opt_span The (optional) span for the copied ref_read. If none, ret_ref_read->span = + * ref_read->span. + * \return If all properties are null or the same as the property in the input + * ref_read (i.e., opt_ref is null or opt_ref.value() == ref_read->ref, etc.), then we return + * ref_read. Otherwise, we return a copy of ref_read with the different fields overwritten. (i.e., + * if opt_ref.value() != ref_read->ref, then ret_ref_read->ref = opt_ref.value()). */ RefRead WithFields(RefRead ref_read, Optional opt_ref = Optional(), Optional opt_virtual_device = Optional(), @@ -750,9 +820,22 @@ class RefWrite : public Expr { }; /*! - * \brief Returns \p ref_write with the given properties. A null property denotes 'no change'. - * Returns \p ref_write if all properties are unchanged. Otherwise, returns a copy with the new - * fields. + * \brief Returns the ref write with given properties. A null property denotes 'no change'. + * Returns ref_write if all properties are unchanged. Otherwise, returns a copy with the new fields. + * \param ref_write The ref_write to copy. + * \param opt_ref The (optional) ref for the copied ref_write. If none, + * ret_ref_write->ref = ref_write->ref. + * \param opt_value The (optional) value for the copied ref_write. If none, + * ret_ref_write->value = ref_write->value. + * \param opt_virtual_device + * The (optional) virtual_device for the copied ref_write. If none, ret_ref_write->virtual_device = + * ref_write->virtual_device. + * \param opt_span The (optional) span for the copied ref_write. If none, ret_ref_write->span = + * ref_write->span. + * \return If all properties are null or the same as the property in the input ref_write (i.e., + * opt_ref is null or opt_ref.value() == ref_write->ref, etc.), then we return ref_write. Otherwise, + * we return a copy of ref_write with the different fields overwritten. (i.e., if ref_write.value() + * != ref_write->ref, then ret_ref_write->ref = opt_ref.value()). */ RefWrite WithFields(RefWrite ref_write, Optional opt_ref = Optional(), Optional opt_value = Optional(), diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 280a1f8a6c29..d8f575dfdf48 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -240,8 +240,6 @@ class MixedModeVisitor : public ::tvm::relay::ExprVisitor { */ explicit MixedModeVisitor(int visit_limit = 1); - using ExprVisitor::VisitExpr_; - /*! * \brief VisitExpr is finalized to preserve call expansion of dataflow regions */ diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h index 874d4f233416..052d04fe2411 100644 --- a/include/tvm/relay/function.h +++ b/include/tvm/relay/function.h @@ -121,9 +121,28 @@ class Function : public BaseFunc { }; /*! - * \brief Returns \p function with the given properties. A null property denotes 'no change'. - * Returns \p function if all properties are unchanged. Otherwise, returns a copy with the new - * fields. + * \brief Returns the function with given properties. A null property denotes 'no change'. + * Returns function if all properties are unchanged. Otherwise, returns a copy with the new fields. + * \param function The function to copy. + * \param opt_params The (optional) params for the copied function. If none, + * ret_function->params = function->params. + * \param opt_body The (optional) body for the copied function. If none, + * ret_function->body = function->body. + * \param opt_ret_type The (optional) return type for the copied function. If none, + * ret_function->ret_type = function->ret_type. + * \param opt_ty_params The (optional) type params for the copied function. If none, + * ret_function->type_params = function->type_params. + * \param opt_attrs + * The (optional) attributes for the copied function. If none, + * ret_function->attrs = function->attrs. + * \param opt_virtual_device The (optional) virtual_device for the copied function. If none, + * ret_function->virtual_device = function->virtual_device. + * \param opt_span The (optional) span for the copied function. If none, + * ret_function->span = function->span. + * \return If all properties are null or the same as the property in the input function + * (i.e., opt_params is null or opt_params.value() == function->params, etc.), then we return + * function. Otherwise, we return a copy of function with the different fields overwritten. (i.e., + * if opt_params.value() != function->params, then ret_function->params = opt_params.value()). */ Function WithFields(Function function, Optional> opt_params = Optional>(), Optional opt_body = Optional(), diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index fad033ee3f1d..6e3bddf9adf5 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -281,11 +281,6 @@ TVM_DLL Pass InferType(); */ TVM_DLL Type InferTypeLocal(const Expr& expr); -/*! - * \brief Infer the types of all sub-expression of expr. - */ -TVM_DLL Expr InferTypeExpr(const Expr& expr); - /*! * \brief Search and eliminate common subexpression. For example, if there are * two expressions evaluated to an identical value, a single variable is created diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 150226fae1a2..fc76577bd7c0 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -27,26 +27,6 @@ namespace tvm { -GlobalVar WithFields(GlobalVar global_var, Optional opt_name_hint, Optional opt_type, - Optional opt_virtual_device, Optional opt_span) { - String name_hint = opt_name_hint.value_or(global_var->name_hint); - Type type = opt_type.value_or(global_var->checked_type()); - VirtualDevice virtual_device = opt_virtual_device.value_or(global_var->virtual_device()); - Span span = opt_span.value_or(global_var->span); - bool all_fields_unchanged = - name_hint.same_as(global_var->name_hint) && type.same_as(global_var->checked_type()) && - virtual_device.same_as(global_var->virtual_device()) && span.same_as(global_var->span); - if (!all_fields_unchanged) { - GlobalVarNode* cow_global_var_node = global_var.CopyOnWrite(); - cow_global_var_node->name_hint = name_hint; - cow_global_var_node->checked_type_ = type; - cow_global_var_node->virtual_device_ = virtual_device; - cow_global_var_node->span = span; - } - - return global_var; -} - VirtualDevice RelayExprNode::virtual_device() const { if (!this->virtual_device_.defined()) { // virtual_device_ should always be defined, unless we imported this node from JSON using an old @@ -97,25 +77,6 @@ TensorType ConstantNode::tensor_type() const { return TensorType(shape, dtype); } -Constant WithFields(Constant constant, Optional opt_data, - Optional opt_virtual_device, Optional opt_span) { - runtime::NDArray data = opt_data.value_or(constant->data); - VirtualDevice virtual_device = opt_virtual_device.value_or(constant->virtual_device()); - Span span = opt_span.value_or(constant->span); - - bool all_fields_unchanged = data.same_as(constant->data) && - virtual_device.same_as(constant->virtual_device()) && - span.same_as(constant->span); - - if (!all_fields_unchanged) { - ConstantNode* cow_constant_node = constant.CopyOnWrite(); - cow_constant_node->data = data; - cow_constant_node->virtual_device_ = virtual_device; - cow_constant_node->span = span; - } - return constant; -} - Tuple::Tuple(tvm::Array fields, Span span) { ObjectPtr n = make_object(); n->fields = std::move(fields); @@ -129,7 +90,6 @@ TVM_REGISTER_NODE_TYPE(TupleNode); TVM_REGISTER_GLOBAL("relay.ir.Tuple").set_body_typed([](tvm::Array fields, Span span) { return Tuple(fields, span); }); - Tuple WithFields(Tuple tuple, Optional> opt_fields, Optional opt_virtual_device, Optional opt_span) { Array fields = opt_fields.value_or(tuple->fields); diff --git a/src/relay/op/dyn/tensor/transform.cc b/src/relay/op/dyn/tensor/transform.cc index d5cc6608662b..f7045305e90d 100644 --- a/src/relay/op/dyn/tensor/transform.cc +++ b/src/relay/op/dyn/tensor/transform.cc @@ -258,7 +258,6 @@ RELAY_REGISTER_OP("dyn.broadcast_to") .describe(R"code(Broadcast the first input to match the shape argument. )code" TVM_ADD_FILELINE) .set_num_inputs(2) - .set_attrs_type() .add_argument("data", "Tensor", "The input tensor.") .add_argument("shape", "Tensor", "Target shape.") .set_support_level(4) diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 132bcaa82922..9c01c40517f4 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -949,11 +949,6 @@ TVM_REGISTER_GLOBAL("relay._transform.InferTypeLocal").set_body_typed([](const E return InferTypeLocal(expr); }); -Expr InferTypeExpr(const Expr& expr) { - InferTypeLocal(expr); - return expr; -} - Pass InferType() { auto pass_info = PassInfo(0, "InferType", {}); return tvm::transform::CreateModulePass( diff --git a/src/runtime/contrib/tensorrt/tensorrt_ops.cc b/src/runtime/contrib/tensorrt/tensorrt_ops.cc index 83b0c7cffe86..3971081bf8f8 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_ops.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_ops.cc @@ -67,7 +67,7 @@ nvinfer1::ITensor* TensorRTOpConverter::Transpose(TensorRTOpConverterParams* par // Batch dimension cannot be modified. ICHECK_EQ(input->getDimensions().nbDims, order.size() - 1); ICHECK_EQ(order[0], 0); - for (size_t i = 0; i + 1 < order.size(); ++i) { + for (size_t i = 0; i < order.size(); ++i) { perm.order[i] = order[i + 1] - 1; } } else { From cfb0ce6998d311eac1d7c3cdaebd72306efb3746 Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Mon, 6 Jun 2022 13:22:10 -0700 Subject: [PATCH 3/5] - Stick to 'Indexed graph' terminology - More tests --- src/relay/ir/dataflow_matcher.cc | 10 ++-- src/relay/ir/dataflow_matcher_impl.h | 1 - src/relay/ir/indexed_graph.cc | 14 +++-- src/relay/ir/indexed_graph.h | 75 ++++++++++++------------ tests/cpp/relay/ir/indexed_graph_test.cc | 21 +++++++ 5 files changed, 72 insertions(+), 49 deletions(-) diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index c537756d0987..df896cb690eb 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -722,11 +722,13 @@ void PatternGrouper::CreateGroup(const Expr& expr) { for (auto* output : node->outputs_) { // and the node is used by nodes outside of the group if (memo.count(output->ref()) == 0) { - // TODO(mbs): The dominates relation is backwards here, and will always fail. - // So all we're doing is failing if an internal node connects to an outside node. - // Exit because nodes in this pattern's body are used outside the pattern - // fusing it would be invalid + // TODO(mbs): This condition used to also include the following test, which since + // the dominators relation is used back-to-front was always vacuously true. So the + // code is just rejecting the match if a strictly internal node happened to connect + // to an outside node. ICHECK(!matcher_->expr_to_node(expr)->Dominates(output)); + // Exit because nodes in this pattern's body are used outside the pattern, fusing it + // would be invalid return; } } diff --git a/src/relay/ir/dataflow_matcher_impl.h b/src/relay/ir/dataflow_matcher_impl.h index ae5168a84d0a..f04190f72e40 100644 --- a/src/relay/ir/dataflow_matcher_impl.h +++ b/src/relay/ir/dataflow_matcher_impl.h @@ -55,7 +55,6 @@ class DFPatternMatcher : public DFPatternFunctor, ObjectPtrHash, ObjectPtrEqual>& memo() const { return memo_; } - const IndexedGraph& dataflow_graph() const { return *expr_graph_; } protected: bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override; diff --git a/src/relay/ir/indexed_graph.cc b/src/relay/ir/indexed_graph.cc index 150cfca18b48..f39ff4850eae 100644 --- a/src/relay/ir/indexed_graph.cc +++ b/src/relay/ir/indexed_graph.cc @@ -19,9 +19,8 @@ /*! * \file src/relay/ir/indexed_graph.cc - * \brief A graph representation of the dataflow in a Relay expression or Relay dataflow - * pattern. Nodes in this graph capture generic dataflow inputs, outputs, dominators and control - * flow scopes. + * \brief A graph representation of the dataflow in a Relay expression or Relay (dataflow) + * pattern. */ #include "indexed_graph.h" @@ -62,7 +61,10 @@ std::string RefToSummary(const Expr& expr) { return Visitor().VisitExpr(expr); } -std::string RefToSummary(const DFPattern& pattern) { return ""; } +std::string RefToSummary(const DFPattern& pattern) { + // TODO(mbs): Implement as debugging requires. + return ""; +} std::unique_ptr> CreateIndexedGraph(const Expr& expr) { /*! @@ -105,7 +107,7 @@ std::unique_ptr> CreateIndexedGraph(const Expr& expr) { if (var_node == current_let_bound_var_) { // Remember this is a recursive call to the let-rec bound function. // The Annotator functor below will not record any dependency from the let-rec bound - // var to the expression so that the dataflow graph is always a DAG. + // var to the expression so that the indexed graph is always a DAG. VLOG(1) << "Remembering recursive call to '" << var_node->name_hint() << "'"; rec_calls_->emplace(call_node); } @@ -219,7 +221,7 @@ std::unique_ptr> CreateIndexedGraph(const Expr& expr) { void VisitExpr_(const CallNode* call_node) override { auto node = graph_->item_to_node(GetRef(call_node)); if (rec_calls_->count(call_node)) { - // We want the dataflow graph to be a dag, so don't consider a call to a let-rec bound + // We want the indexed graph to be a DAG, so don't consider a call to a let-rec bound // function from inside the function to depend on the let-rec bound var. VLOG(1) << "Ignoring op in call " << RefToSummary(GetRef(call_node)); } else { diff --git a/src/relay/ir/indexed_graph.h b/src/relay/ir/indexed_graph.h index d623029ffb70..c1ce53f40da3 100644 --- a/src/relay/ir/indexed_graph.h +++ b/src/relay/ir/indexed_graph.h @@ -19,11 +19,12 @@ /*! * \file src/relay/ir/indexed_graph.h - * \brief A graph representation of the dataflow in a Relay expression or Relay dataflow - * pattern. Nodes in this graph capture generic dataflow inputs, outputs, dominators and control - * flow scopes. + * \brief A graph representation of the dataflow in a Relay expression or Relay (dataflow) + * pattern. Each 'indexed graph' node is 1:1 with an expression/pattern 'node', hence the + * term 'IndexedGraph'. Dataflow is captured in a generic representation which is convenient + * for analysis, particularly pattern matching and partitioning. * - * TODO(mbs): Rename to 'DataflowGraph'. + * TODO(mbs): Copied from fuse_ops.cc, consider refactoring to share implementation. */ #ifndef TVM_RELAY_IR_INDEXED_GRAPH_H_ #define TVM_RELAY_IR_INDEXED_GRAPH_H_ @@ -45,34 +46,33 @@ namespace relay { using PostDfsIndex = size_t; /*! - * \brief Returns a brief summary of the 'reference' expression or pattern. Only used for - * debugging by IndexedGraph::ToString(). + * \brief Returns a brief summary of the 'reference' expression or pattern. Only used by + * IndexedGraph::ToString() for debugging. */ std::string RefToSummary(const Expr& expr); std::string RefToSummary(const DFPattern& pattern); /*! - * \brief Represents the dataflow of an expression (or dataflow pattern) as a graph which is - * overlaid on the underlying expression (or dataflow pattern) graph. + * \brief Represents the implied dataflow of an expression or (dataflow) pattern as a DAG who's + * nodes are 1:1 with those in the underlying expression/pattern. * - * Each indexed graph node references the corresponding sub-expression (or dataflow sub-pattern) - * node, and captures: - * - dataflow inputs - * - dataflow outputs (or a flag indicating the node is an implied output) - * - dominator parent - * - dominator children - * - basic block + * Each indexed graph node captures: + * - Dataflow inputs. + * - Dataflow outputs (or a flag indicating the node is an implied output). + * - Dominator parent (ie closest node at which all outputs of the current node re-combine). + * - Dominator children (inverse of above). + * - Basic block (ie node representing the body of a function, arm of an if, etc). * * This class is templated so we can analyze both DFPatterns and Exprs with the same infrastructure. * - * IndexedGraph should be instantiated through the CreateIndexedGraph utilities. + * IndexedGraph should be instantiated through the CreateIndexedGraph utilities below. */ template class IndexedGraph { public: using TNode = typename T::ContainerType; - /*! \brief A Node in the dataflow graph. */ + /*! \brief A Node in the graph. */ struct Node { /*! \brief Node Constructor * \param ref The expression or dataflow pattern node this indexed graph node is augmenting. @@ -165,25 +165,6 @@ class IndexedGraph { } }; - /*! \brief Construct the domination tree inside IndexedGraph */ - void PostDom() { - for (PostDfsIndex i = topological_order_.size(); i != 0; --i) { - PostDfsIndex index = i - 1; - auto* current = topological_order_[index].get(); - if (current->is_external_) { - current->depth_ = 1; - current->dominator_parent_ = nullptr; - } else { - auto parent = LeastCommonAncestor(current->outputs_); - current->depth_ = parent ? parent->depth_ + 1 : 1; - current->dominator_parent_ = parent; - if (parent) { - parent->dominator_children_.push_back(current); - } - } - } - } - PostDfsIndex size() const { return topological_order_.size(); } Node* item_to_node(const T& item) { return item_to_node(item.get()); } @@ -298,6 +279,25 @@ class IndexedGraph { } private: + /*! \brief Construct the domination tree inside IndexedGraph */ + void PostDom() { + for (PostDfsIndex i = topological_order_.size(); i != 0; --i) { + PostDfsIndex index = i - 1; + auto* current = topological_order_[index].get(); + if (current->is_external_) { + current->depth_ = 1; + current->dominator_parent_ = nullptr; + } else { + auto parent = LeastCommonAncestor(current->outputs_); + current->depth_ = parent ? parent->depth_ + 1 : 1; + current->dominator_parent_ = parent; + if (parent) { + parent->dominator_children_.push_back(current); + } + } + } + } + /*! \brief Find the least common ancestor of all outputs of a node */ Node* LeastCommonAncestor(const std::vector& outputs) { if (outputs.size() == 0) { @@ -345,8 +345,7 @@ class IndexedGraph { } /*! - * \brief Map from underlying sub-expressions or dataflow sub-pattern graph nodes to their - * indexed graph nodes. + * \brief Map from underlying sub-expression or sub-pattern nodes to their indexed graph nodes. */ std::unordered_map node_map_; /*! \brief All nodes in increasing post-dfs index order. This vector owns all the nodes. */ diff --git a/tests/cpp/relay/ir/indexed_graph_test.cc b/tests/cpp/relay/ir/indexed_graph_test.cc index ce383c52909c..7caa8584d04b 100644 --- a/tests/cpp/relay/ir/indexed_graph_test.cc +++ b/tests/cpp/relay/ir/indexed_graph_test.cc @@ -138,6 +138,27 @@ TEST(IndexedGraph, RecursiveExprRegression) { ASSERT_EQ(node->inputs_[3]->index_, 21); // %4 ASSERT_EQ(node->inputs_[4]->index_, 39); // %16 } + + { + // Downstream nodes of %18 + auto node = graph->index_to_node(48); + std::unordered_set::Node*> downstreams; + node->AccumulateDownstreamNodes(&downstreams); + ASSERT_EQ(downstreams.size(), 4); + for (const auto* downstream : downstreams) { + ASSERT_TRUE(downstream->index_ > 49 && downstream->index_ < 52); + } + } + + { + // Dominates relation for %4 + auto upstream = graph->index_to_node(21); + // Path 1: 21->23->24->25->40 + // Path 2: 21->36->37->38->39->40 + // Then 40->43 + auto downstream = graph->index_to_node(43); + ASSERT_TRUE(downstream->Dominates(upstream)); + } } // A module with unused let-bound function. The 'add' operator should have no dominator From 530f25b5f9eb44b4f27d38967999161f49cf7ad0 Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Mon, 6 Jun 2022 13:40:27 -0700 Subject: [PATCH 4/5] - Stick to 'Indexed graph' terminology - More tests --- src/relay/op/dyn/tensor/transform.cc | 1 + tests/cpp/relay/ir/indexed_graph_test.cc | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/relay/op/dyn/tensor/transform.cc b/src/relay/op/dyn/tensor/transform.cc index f7045305e90d..d5cc6608662b 100644 --- a/src/relay/op/dyn/tensor/transform.cc +++ b/src/relay/op/dyn/tensor/transform.cc @@ -258,6 +258,7 @@ RELAY_REGISTER_OP("dyn.broadcast_to") .describe(R"code(Broadcast the first input to match the shape argument. )code" TVM_ADD_FILELINE) .set_num_inputs(2) + .set_attrs_type() .add_argument("data", "Tensor", "The input tensor.") .add_argument("shape", "Tensor", "Target shape.") .set_support_level(4) diff --git a/tests/cpp/relay/ir/indexed_graph_test.cc b/tests/cpp/relay/ir/indexed_graph_test.cc index 7caa8584d04b..17ec68261684 100644 --- a/tests/cpp/relay/ir/indexed_graph_test.cc +++ b/tests/cpp/relay/ir/indexed_graph_test.cc @@ -146,7 +146,7 @@ TEST(IndexedGraph, RecursiveExprRegression) { node->AccumulateDownstreamNodes(&downstreams); ASSERT_EQ(downstreams.size(), 4); for (const auto* downstream : downstreams) { - ASSERT_TRUE(downstream->index_ > 49 && downstream->index_ < 52); + ASSERT_TRUE(downstream->index_ >= 49 && downstream->index_ <= 52); } } From bd7200c152ff16753a4bb3feb7ba40ba0039116d Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Mon, 6 Jun 2022 13:56:19 -0700 Subject: [PATCH 5/5] - Remove silly unit test --- tests/python/relay/test_dataflow_pattern.py | 26 --------------------- 1 file changed, 26 deletions(-) diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index 682ed143aaff..f0474c911273 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -632,32 +632,6 @@ def test_at_most_one_parent(): assert not pattern.match(n11) -def test_parallel_injective(): - # Pattern - P = is_op("add")(wildcard(), wildcard()) # 'parent' - I = is_op("squeeze")(wildcard()) | is_op("transpose")( - wildcard() - ) # 'intermediate' ('path' in the code) - C = is_op("left_shift")(wildcard(), wildcard()) # 'child' - pattern = dominates(P, I, C) - - # - # n5(P) - # / \ - # n6(I) n8(I) - # \ / - # n9(C) - # - - x = relay.var("x", shape=(10, 20)) - n5 = relay.add(x, relay.const(1, "float32")) - n6 = relay.squeeze(n5) - n8 = relay.transpose(n5, axes=[0, 1]) - n9 = relay.left_shift(n6, n8) - - assert pattern.match(n9) - - def test_match_dominator(): # Pattern is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard())