From 13f866cc56f6ffd6b271bf1f9c381a78df40f2db Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Mon, 2 May 2022 09:39:04 -0700 Subject: [PATCH] [FIX] Avoid stack overflow in TargetHookVisitor with large modules (#11135) Use MixedModeVisitor to not recursively visit let nodes. --- src/relay/transforms/target_hooks.cc | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/relay/transforms/target_hooks.cc b/src/relay/transforms/target_hooks.cc index 1662755ea472..b0ac883623d2 100644 --- a/src/relay/transforms/target_hooks.cc +++ b/src/relay/transforms/target_hooks.cc @@ -35,6 +35,7 @@ class TargetHookVisitor : public tvm::relay::MixedModeVisitor { std::vector pass_list_; /*! \brief Attribute map for all registered targets */ TargetKindAttrMap target_attr_map_; + using tvm::relay::MixedModeVisitor::VisitExpr_; public: TargetHookVisitor() : target_attr_map_(tvm::TargetKind::GetAttrMap("RelayToTIR")) {} @@ -48,6 +49,18 @@ class TargetHookVisitor : public tvm::relay::MixedModeVisitor { return pass_list_; } + void VisitExpr_(const LetNode* op) final { + auto pre_visit = [this](const LetNode* op) { + this->VisitExpr(op->var); + this->VisitExpr(op->value); + }; + auto post_visit = [this](const LetNode* op) { + this->VisitExpr(op->body); + this->visit_counter_[op] += 1; + }; + ExpandANormalForm(op, pre_visit, post_visit); + } + void VisitExpr_(const CallNode* call) override { // Descend the call tree for (auto arg : call->args) {