Skip to content

Commit

Permalink
make resnet and vgg work
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics authored and jroesch committed Mar 26, 2019
1 parent d73204c commit 56b036c
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 33 deletions.
35 changes: 31 additions & 4 deletions python/tvm/relay/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ def tag(self):
def optimize(expr, mod=None):
# TODO: We need to move this optimization code into the optimizer/pass manager
ck_expr = ir_pass.infer_type(expr, mod=mod)
fused_expr = ir_pass.fuse_ops(ck_expr, mod=mod)
simplified_expr = ir_pass.simplify_inference(ck_expr)
simplified_expr = ir_pass.infer_type(simplified_expr, mod=mod)
fused_expr = ir_pass.fuse_ops(simplified_expr, mod=mod)
ck_fused = ir_pass.infer_type(fused_expr, mod=mod)
return ck_fused

Expand Down Expand Up @@ -64,22 +66,27 @@ def convert(args):

return cargs

def eval_vm(mod, ctx, *args):
def eval_vm(mod, ctx, *args, **kwargs):
"""
Evaluate a module on a given context with the provided arguments.
Parameters
----------
mod: relay.Module
The module to optimize, will execute its entry_func.
ctx: tvm.Context
The TVM context to execute on.
args: ...
args: List[tvm.NDArray, np.ndarray]
The arguments to evaluate.
kwargs: Dict[str, Union[tvm.NDArrray, np.ndarray]]
The keyword arguments to evaluate.
"""
main_func = mod[mod.entry_func]

if len(main_func.params) == 0 and isinstance(main_func.body, GlobalVar):
if not main_func.params and isinstance(main_func.body, GlobalVar):
main_func = eta_expand(main_func.body, mod)

assert isinstance(main_func, Function)
Expand All @@ -88,6 +95,26 @@ def eval_vm(mod, ctx, *args):

args = list(args)
assert isinstance(args, list)

params = main_func.params
if kwargs:
param_names = [parm.name_hint for param in params]
arg_count = len(args)

for i, name in enumerate(param_names):
if i < arg_count:
if kwargs.get(name):
raise Exception("Duplicate argument found in both inputs \
(at position: {0}) and keyword argument \
(with name: {1})".format(i, name))
else:
args.append(kwargs[name])

if len(args) != len(params):
raise Exception("Mismatch found between the expected and provided \
arguments, expected: {0], provided: \
{1}".format(len(args), len(params)))

cargs = convert(args)

result = _vm._evaluate_vm(mod, ctx.device_type, ctx.device_id, *cargs)
Expand Down
23 changes: 18 additions & 5 deletions src/relay/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@
#include <tvm/relay/error.h>
#include <tvm/relay/interpreter.h>
#include <tvm/relay/logging.h>
#include <tvm/relay/pass.h>
#include "../backend/compile_engine.h"
#include "../../runtime/naive_allocator.h"

#include <vector>
#include <iostream>
#include <unordered_map>
#include <unordered_set>

using namespace tvm::runtime;

Expand Down Expand Up @@ -82,6 +85,11 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
std::unordered_map<Var, size_t, NodeHash, NodeEqual> var_map;
size_t stack_index;
CompileEngine engine;

/*! \brief The functions that have been lowered. */
std::unordered_map<LoweredFunc, size_t, NodeHash, NodeEqual> seen_funcs;

/*! \brief Global shared meta data */
VMCompilerContext* context;

VMCompiler(VMCompilerContext* context) :
Expand Down Expand Up @@ -293,8 +301,15 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
auto cfunc = engine->Lower(key);
// TODO: support lowered funcs for multiple targets
CHECK(cfunc->funcs.size() == 1);
auto op_index = this->context->lowered_funcs.size();
this->context->lowered_funcs.push_back(cfunc->funcs[0]);
auto op_index = -1;
if (seen_funcs.find(cfunc->funcs[0]) == seen_funcs.end()) {
op_index = this->context->lowered_funcs.size();
this->context->lowered_funcs.push_back(cfunc->funcs[0]);
seen_funcs[cfunc->funcs[0]] = op_index;
LOG(INFO) << "lowered_funcs: " << cfunc->funcs[0].operator->()->name;
} else {
op_index = seen_funcs[cfunc->funcs[0]];
}

// If Tensor, 1
// If Tuple, size of tuple
Expand Down Expand Up @@ -486,13 +501,11 @@ void PopulateGlobalMap(GlobalMap* global_map, const Module& mod) {
}
}

// Verify

VirtualMachine CompileModule(const Module& mod_ref) {
Module mod = mod_ref;

// Run some optimizations first, this code should
// be moved to pass manager.

mod = OptimizeModule(mod);

VirtualMachine vm;
Expand Down
29 changes: 14 additions & 15 deletions src/relay/vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include <vector>
#include <iostream>
#include <chrono>

using namespace tvm::runtime;

Expand Down Expand Up @@ -389,24 +390,26 @@ VMObject VirtualMachine::Invoke(const VMFunction& func, const std::vector<VMObje
return stack.back();
}

VMObject VirtualMachine::Invoke(const GlobalVar& global, const std::vector<VMObject>& args) {
VMObject VirtualMachine::Invoke(const GlobalVar& global,
const std::vector<VMObject>& args) {
auto func_index = this->global_map[global];
RELAY_LOG(INFO) << "Invoke Global " << global << " at index " << func_index
<< std::endl;
return Invoke(this->functions[func_index], args);
}

void InvokePacked(const PackedFunc& func, size_t arg_count, size_t output_size, std::vector<VMObject>& stack) {
void InvokePacked(const PackedFunc& func, size_t arg_count, size_t output_size,
std::vector<VMObject>& stack) {
auto stack_end = stack.size() - 1;
RELAY_LOG(INFO) << "arg_count: " << arg_count;
RELAY_LOG(INFO) << "arg_count: " << arg_count << " output_size: " << output_size;
CHECK(arg_count <= stack.size());

std::vector<TVMValue> values(arg_count);
std::vector<int> codes(arg_count);
runtime::TVMArgsSetter setter(values.data(), codes.data());

auto argument_start = stack.size() - arg_count;
RELAY_LOG(INFO) << "ArgumentStart=" << argument_start << std::endl;
RELAY_LOG(INFO) << "argument_start = " << argument_start << std::endl;
for (size_t i = 0; i < arg_count; i++) {
NDArray data = ToNDArray(stack[argument_start + i]);
setter(i, data);
Expand All @@ -415,20 +418,10 @@ void InvokePacked(const PackedFunc& func, size_t arg_count, size_t output_size,
TVMRetValue rv;
func.CallPacked(TVMArgs(values.data(), codes.data(), arg_count), &rv);

// // Fix the object at return value position
// if (output_size == 1) {
// stack[stack.size() - 1] = stack[stack.size() - 2];
// } else {
// auto adt = std::dynamic_pointer_cast<VMDatatypeCell>(stack.back().ptr);
// for (size_t i = 0; i < output_size; ++i) {
// adt->fields[i] = stack[stack.size() - output_size - 1 + i];
// }
// }

// We can do this more efficiently by reverse laying out the arguments
// and just shrinking the stack.
stack[stack.size() - arg_count] = stack[stack_end];
RELAY_LOG(INFO) << "ShrinkBy=" << arg_count - output_size << std::endl;
RELAY_LOG(INFO) << "ShrinkBy = " << arg_count - output_size << std::endl;
stack.resize(stack.size() - (arg_count - output_size));
}

Expand Down Expand Up @@ -770,7 +763,13 @@ TVM_REGISTER_API("relay._vm._evaluate_vm")
vm_args.push_back(obj);
}

auto start = std::chrono::high_resolution_clock::now();
auto result = EvaluateModule(module, {ctx}, vm_args);
auto end = std::chrono::high_resolution_clock::now();
auto duration =
std::chrono::duration_cast<std::chrono::microseconds>(end - start)
.count();
LOG(INFO) << "Inference time: " << duration;
RELAY_LOG(INFO) << "Returning results\n";
*ret = VMToValue(std::get<1>(result), std::get<0>(result));
});
Expand Down
28 changes: 19 additions & 9 deletions tests/python/frontend/mxnet/benchmark_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@


def benchmark_execution(mx_symbol,
measure=False,
data_shape=(1, 3, 224, 224),
out_shape=(1, 1000),
dtype='float32'):
Expand Down Expand Up @@ -37,19 +38,26 @@ def get_tvm_output(symbol, x, args, auxs, target, ctx, dtype='float32'):

m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input("data", tvm.nd.array(x.astype(dtype)))
m.set_input("data", x)
m.set_input(**params)
m.run()
out = m.get_output(0, tvm.nd.empty(out_shape, dtype))

if measure:
print("Evaluate graph runtime inference time cost...")
ftimer = m.module.time_evaluator("run", ctx, number=1, repeat=20)
# Measure in millisecond.
prof_res = np.array(ftimer().results) * 1000
print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
(np.mean(prof_res), np.std(prof_res)))

return out.asnumpy()

def get_tvm_vm_output(symbol, x, args, auxs, target, ctx, dtype='float32'):
func, params = get_func_param(symbol, x, args, auxs)
params = [params[k] for k in params]
params = [x] + params
ex = relay.create_executor('vm', mod=relay.Module(), ctx=ctx)
result = ex.evaluate(func)(*params)
return result.asnumpy()
result = ex.evaluate(func)(x, **params)
return result.asnumpy().astype(dtype)

# random input
x = np.random.uniform(size=data_shape).astype(dtype)
Expand All @@ -58,8 +66,10 @@ def get_tvm_vm_output(symbol, x, args, auxs, target, ctx, dtype='float32'):

_, args, auxs = get_mxnet_output(mx_symbol, x, dtype)
assert "data" not in args
tvm_out = get_tvm_output(mx_symbol, x, args, auxs, target, ctx, dtype)
vm_out = get_tvm_vm_output(mx_symbol, x, args, auxs, target, ctx, dtype)
tvm_out = get_tvm_output(mx_symbol, tvm.nd.array(x.astype(dtype)), args,
auxs, target, ctx, dtype)
vm_out = get_tvm_vm_output(mx_symbol, tvm.nd.array(x.astype(dtype)), args,
auxs, target, ctx, dtype)
tvm.testing.assert_allclose(vm_out, tvm_out, rtol=1e-5, atol=1e-5)


Expand Down Expand Up @@ -126,8 +136,8 @@ def relay_compose(F, **kwargs):


if __name__ == '__main__':
test_mlp()
# test_resnet()
# test_mlp()
test_resnet()
# test_vgg()
# test_multi_outputs()
# test_dqn()
Expand Down

0 comments on commit 56b036c

Please sign in to comment.