Skip to content

Commit

Permalink
Interpreter call in FoldConstant now always uses graph executor with …
Browse files Browse the repository at this point in the history
…link-params=0

Addressed issue apache#10390

Change-Id: I1a6b2dd27845f9292f1e07f9da1b9be722481f46
  • Loading branch information
d-smirnov committed Mar 9, 2022
1 parent 060d9d2 commit c0e19f4
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/relay/transforms/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/executor.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/interpreter.h>
#include <tvm/relay/op.h>
Expand Down Expand Up @@ -254,8 +255,13 @@ class ConstantFolder : public MixedModeMutator {
// needed for both execution and creation(due to JIT)
With<transform::PassContext> fresh_build_ctx(transform::PassContext::Create());

Map<String, ObjectRef> dict =
(module_->attrs.defined()) ? module_->attrs->dict : Map<String, ObjectRef>();
Map<String, ObjectRef> dict = (module_->attrs.defined())
? Map<String, ObjectRef>(module_->attrs.CopyOnWrite()->dict)
: Map<String, ObjectRef>();

// always use graph executor with no link-params
dict.Set(tvm::attr::kExecutor,
relay::Executor::Create("graph", {{"link-params", Bool(false)}}));
Expr result = ObjectToExpr(Eval(expr, module_->type_definitions, module_->Imports(),
eval_cpu_dev_, eval_cpu_target_, dict));
VLOG(1) << "Evaluated to constant:" << std::endl << PrettyPrint(result);
Expand Down
20 changes: 20 additions & 0 deletions tests/python/relay/test_pass_fold_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
import tvm
from tvm import relay
from tvm.relay.backend import Executor
from tvm.relay import transform
from tvm.relay.build_module import bind_params_by_name
from tvm.relay.testing import run_infer_type, create_workload
Expand Down Expand Up @@ -369,6 +370,25 @@ def before():
tvm.ir.assert_structural_equal(run_infer_type(before_mod["main"]), after_mod["main"])


def test_pass_link_params():
"""
This test checks ensures that proper executor is passed to interpreter instance
The test will fail if FoldConstant does not override the executor due to "int8"
is not supported in ScheduleBuilder
"""
def expr():
z = relay.const(10, dtype="int8")
return relay.cast(z, dtype="int32")



mod = tvm.IRModule.from_expr(expr())
mod = tvm.relay.transform.InferType()(mod)
# Add executor with link-params
mod = mod.with_attr("executor", Executor('aot', {'link-params': True}))
mod = tvm.relay.transform.FoldConstant()(mod)


if __name__ == "__main__":
import sys
import pytest
Expand Down

0 comments on commit c0e19f4

Please sign in to comment.