diff --git a/src/relay/transforms/label_ops.cc b/src/relay/transforms/label_ops.cc index 861342b03a76..b23b1d92e77a 100644 --- a/src/relay/transforms/label_ops.cc +++ b/src/relay/transforms/label_ops.cc @@ -77,6 +77,25 @@ class LabelOpsMutator : public MixedModeMutator { } return std::move(f); } + Expr VisitExpr_(const LetNode* op) final { + auto pre_visit = [this](const LetNode* op) { + this->Mutate(op->var); + this->Mutate(op->value); + }; + auto post_visit = [this](const LetNode* op) { + Var var = Downcast(this->Mutate(op->var)); + auto value = this->Mutate(op->value); + auto body = this->Mutate(op->body); + auto expr = GetRef(op); + if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { + this->memo_[expr] = expr; + } else { + this->memo_[expr] = Let(var, value, body); + } + }; + ExpandANormalForm(op, pre_visit, post_visit); + return memo_[GetRef(op)]; + } Expr Rewrite_(const CallNode* op, const Expr& post) final { auto updated = MixedModeMutator::Rewrite_(op, post);