Skip to content

Commit

Permalink
[FIX][VM] Fix relay vm optimize (apache#6322)
Browse files Browse the repository at this point in the history
* [FIX][VM] Fix relay vm optimize

* retrigger ci
  • Loading branch information
zhiics authored and Trevor Morris committed Aug 26, 2020
1 parent e78d262 commit d5edac5
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 9 deletions.
10 changes: 8 additions & 2 deletions python/tvm/relay/backend/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def codegen(self):
"""Generate the kernel library."""
self._codegen()

def optimize(self, mod, target=None, params=None):
def optimize(self, mod, target=None, target_host=None, params=None):
"""Helper method that optimizes a Relay module via VM.
Parameters
Expand All @@ -149,6 +149,11 @@ def optimize(self, mod, target=None, params=None):
target : str, :any:`tvm.target.Target`, or dict of str (i.e.
device/context name) to str/tvm.target.Target, optional
target_host : str or :any:`tvm.target.Target`, optional
The compilation target for host.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
params : dict of str to NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.
Expand All @@ -162,9 +167,10 @@ def optimize(self, mod, target=None, params=None):
The parameters of the final module.
"""
target = self._update_target(target)
target_host = self._update_target_host(target, target_host)
if params:
self.set_params(params)
return self._optimize(mod, target), self.get_params()
return self._optimize(mod, target, target_host), self.get_params()

def get_exec(self):
"""Get the VM executable.
Expand Down
11 changes: 6 additions & 5 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -806,8 +806,8 @@ PackedFunc VMCompiler::GetFunction(const std::string& name, const ObjectPtr<Obje
});
} else if (name == "optimize") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 2);
*rv = this->OptimizeModule(args[0], args[1]);
CHECK_EQ(args.num_args, 3);
*rv = this->OptimizeModule(args[0], args[1], args[2]);
});
} else {
LOG(FATAL) << "Unknown packed function: " << name;
Expand Down Expand Up @@ -835,7 +835,7 @@ void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Targe
target_host_ = target_host;

// Run the optimizations necessary to target the VM.
context_.module = OptimizeModule(mod, targets_);
context_.module = OptimizeModule(mod, targets_, target_host_);

// Populate the global map.
//
Expand Down Expand Up @@ -923,7 +923,8 @@ transform::Sequential MemoryOpt(tvm::Target host_target) {
return transform::Sequential(pass_seqs);
}

IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targets) {
IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targets,
const Target& target_host) {
Array<Pass> pass_seqs;
Array<runtime::String> entry_functions{"main"};
pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
Expand Down Expand Up @@ -988,7 +989,7 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
// external codegen.
pass_seqs.push_back(transform::Inline());

pass_seqs.push_back(MemoryOpt(this->target_host_));
pass_seqs.push_back(MemoryOpt(target_host));

transform::Sequential seq(pass_seqs);
transform::PassContext pass_ctx = PassContext::Current();
Expand Down
3 changes: 2 additions & 1 deletion src/relay/backend/vm/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ class VMCompiler : public runtime::ModuleNode {
void Codegen();

protected:
IRModule OptimizeModule(const IRModule& mod, const TargetsMap& targets);
IRModule OptimizeModule(const IRModule& mod, const TargetsMap& targets,
const Target& target_host);

void PopulateGlobalMap();

Expand Down
12 changes: 11 additions & 1 deletion tests/python/relay/test_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,10 +593,20 @@ def test_add_op_broadcast():
mod["main"] = func
check_result([x_data, y_data], x_data + y_data, mod=mod)

def test_vm_optimize_dynamic():
dtype = 'float32'
x = relay.var('x', shape=(relay.Any(), relay.Any()), dtype=dtype)
y = relay.var('y', shape=(relay.Any(), relay.Any()), dtype=dtype)
mod = tvm.IRModule()
mod['main'] = relay.Function([x, y], relay.add(x, y))
comp = relay.vm.VMCompiler()
opt_mod, _ = comp.optimize(mod, target="llvm")
assert "shape_func" in opt_mod.astext(False)

def test_vm_optimize():
mod, params = testing.synthetic.get_workload()
comp = relay.vm.VMCompiler()
opt_mod, _ = comp.optimize(mod, "llvm", params)
opt_mod, _ = comp.optimize(mod, target="llvm", params=params)

def test_loop_free_var():
x = relay.var('x', shape=(), dtype='int32')
Expand Down

0 comments on commit d5edac5

Please sign in to comment.