Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[relay][tensor_array] test tensor_array in vm #4608

Merged
merged 2 commits into from
Jan 3, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions python/tvm/relay/backend/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tvm import autotvm
from tvm.relay import expr as _expr
from tvm._ffi.runtime_ctypes import TVMByteArray
from tvm._ffi import base as _base
from . import _vm
from . import vmobj as _obj
from .interpreter import Executor
Expand All @@ -34,7 +35,9 @@
ADT = _obj.ADT

def _convert(arg, cargs):
if isinstance(arg, _obj.Object):
if isinstance(arg, _expr.Constant):
cargs.append(_obj.Tensor(arg.data))
elif isinstance(arg, _obj.Object):
cargs.append(arg)
elif isinstance(arg, (np.ndarray, tvm.nd.NDArray)):
cargs.append(_obj.Tensor(arg))
Expand All @@ -43,8 +46,12 @@ def _convert(arg, cargs):
for field in arg:
_convert(field, field_args)
cargs.append(_obj.tuple_object(field_args))
elif isinstance(arg, (_base.numeric_types, bool)):
dtype = "int32" if isinstance(arg, (int, bool)) else "float32"
value = _obj.Tensor(np.array(arg, dtype=dtype))
cargs.append(value)
else:
raise "Unsupported type: %s" % (type(arg))
raise TypeError("Unsupported type: %s" % (type(arg)))


def convert(args):
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ def __init__(self, prelude, dtype):
self.dtype = dtype

def get_name(self, canonical):
"""Get name corresponding to the caninical name"""
"""Get name corresponding to the canonical name"""
return self.prelude.get_name(canonical, self.dtype)

def get_var(self, canonical):
"""Get var corresponding to the caninical name"""
"""Get var corresponding to the canonical name"""
return self.prelude.get_var(canonical, self.dtype)

def define_tensor_adt(self):
Expand Down
36 changes: 12 additions & 24 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,13 @@
#include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h>
#include <tvm/relay/attrs/memory.h>
#include <topi/tags.h>
#include <algorithm>
#include <iostream>
#include <memory>
#include <set>
#include <string>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "../../../runtime/vm/naive_allocator.h"
#include "../../backend/compile_engine.h"
#include "../../pass/pass_util.h"
#include "../../op/op_common.h"
Expand Down Expand Up @@ -73,8 +69,6 @@ using namespace relay::transform;
// (@jroesch): VM passes, eventually declare as passes.
bool IsClosure(const Function& func);

void InstructionPrint(std::ostream& os, const Instruction& instr);

// Represent a runtime object that's going to be matched by pattern match expressions
struct MatchValue {
virtual ~MatchValue() {}
Expand Down Expand Up @@ -156,12 +150,10 @@ TreeObjectPtr BuildDecisionTreeFromPattern(MatchValuePtr data,
if (pattern.as<PatternWildcardNode>()) {
// We ignore wildcard binding since it's not producing new vars
return then_branch;
} else if (pattern.as<PatternVarNode>()) {
auto pat = pattern.as<PatternVarNode>();
auto pattern = GetRef<PatternVar>(pat);
auto cond = std::make_shared<VarBinding>(pattern->var, data);
} else if (const auto* pvn = pattern.as<PatternVarNode>()) {
auto cond = std::make_shared<VarBinding>(pvn->var, data);
return TreeBranchNode::Make(cond, then_branch, else_branch);
} else if (auto pcn = pattern.as<PatternConstructorNode>()) {
} else if (const auto* pcn = pattern.as<PatternConstructorNode>()) {
auto tag = pcn->constructor->tag;

size_t field_index = 0;
Expand All @@ -173,13 +165,12 @@ TreeObjectPtr BuildDecisionTreeFromPattern(MatchValuePtr data,
auto cond = std::make_shared<TagCompare>(data, tag);
return TreeBranchNode::Make(cond, then_branch, else_branch);
} else {
auto pt = pattern.as<PatternTupleNode>();
CHECK(pt) << "unhandled case: " << pattern;
const auto* pt = pattern.as<PatternTupleNode>();
CHECK(pt) << "unhandled case: " << AsText(pattern, false);
size_t field_index = 0;
for (auto& p : pt->patterns) {
auto d = std::make_shared<AccessField>(data, field_index);
auto d = std::make_shared<AccessField>(data, field_index++);
then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch);
field_index++;
}
return then_branch;
}
Expand Down Expand Up @@ -633,7 +624,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
// and emit a call to allocate the data structure.
auto constructor = GetRef<Constructor>(constructor_node);
Emit(Instruction::AllocADT(constructor->tag, call_node->args.size(), args_registers,
NewRegister()));
NewRegister()));
} else if (auto var_node = op.as<VarNode>()) {
// If we are calling a variable, it must be the case that it is a closure so we
// emit invoke closure here.
Expand Down Expand Up @@ -675,16 +666,13 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
}

void CompileTreeNode(TreeObjectPtr tree) {
if (std::dynamic_pointer_cast<TreeLeafNode>(tree)) {
auto node = std::dynamic_pointer_cast<TreeLeafNode>(tree);
if (auto node = std::dynamic_pointer_cast<TreeLeafNode>(tree)) {
VisitExpr(node->body);
} else if (std::dynamic_pointer_cast<TreeLeafFatalNode>(tree)) {
Emit(Instruction::Fatal());
} else if (std::dynamic_pointer_cast<TreeBranchNode>(tree)) {
auto node = std::dynamic_pointer_cast<TreeBranchNode>(tree);
if (std::dynamic_pointer_cast<TagCompare>(node->cond)) {
} else if (auto node = std::dynamic_pointer_cast<TreeBranchNode>(tree)) {
if (auto cond = std::dynamic_pointer_cast<TagCompare>(node->cond)) {
// For Tag compariton, generate branches
auto cond = std::dynamic_pointer_cast<TagCompare>(node->cond);
auto r = CompileMatchValue(cond->obj);
Emit(Instruction::GetTag(r, NewRegister()));
auto operand1 = last_register_;
Expand All @@ -707,8 +695,8 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
instructions_[goto_offset].pc_offset = else_offset - goto_offset + 1;
} else {
// For other non-branch conditions, move to then_branch directly
auto cond = std::dynamic_pointer_cast<VarBinding>(node->cond);
var_register_map_[cond->var] = CompileMatchValue(cond->val);
auto var_bind = std::dynamic_pointer_cast<VarBinding>(node->cond);
var_register_map_[var_bind->var] = CompileMatchValue(var_bind->val);
CompileTreeNode(node->then_branch);
}
}
Expand Down
18 changes: 11 additions & 7 deletions src/runtime/vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -583,9 +583,9 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) {
break;
}
case Opcode::AllocStorage: {
os << "alloc_storage " <<
instr.dst << " " <<
instr.alloc_storage.allocation_size << " " <<
os << "alloc_storage $" <<
instr.dst << " $" <<
instr.alloc_storage.allocation_size << " $" <<
instr.alloc_storage.alignment << " " <<
TVMType2String(instr.alloc_storage.dtype_hint);
break;
Expand Down Expand Up @@ -771,12 +771,14 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func,
for (size_t fi = 0; fi < dt_cell->size; ++fi) {
auto obj = (*dt_cell)[fi];
const auto* tensor = obj.as<TensorObj>();
CHECK(tensor != nullptr);
CHECK(tensor != nullptr) << "Expect tensor object, but received: "
<< obj->GetTypeKey();
setter(idx++, tensor->data);
}
} else {
const auto* tensor = args[i].as<TensorObj>();
CHECK(tensor != nullptr);
CHECK(tensor != nullptr) << "Expect tensor object, but received: "
<< args[i]->GetTypeKey();
setter(idx++, tensor->data);
}
}
Expand Down Expand Up @@ -823,7 +825,8 @@ inline int32_t VirtualMachine::LoadScalarInt(Index r) const {
int32_t result;
const auto& obj = ReadRegister(r);
const auto* tensor = obj.as<TensorObj>();
CHECK(tensor != nullptr);
CHECK(tensor != nullptr) << "Expect tensor object, but received: "
<< obj->GetTypeKey();
NDArray array = tensor->data.CopyTo({kDLCPU, 0});

if (array->dtype.bits <= 8) {
Expand Down Expand Up @@ -984,7 +987,8 @@ void VirtualMachine::RunLoop() {
cpu_ctx.device_id = 0;
auto shape_tensor_obj = ReadRegister(instr.alloc_tensor_reg.shape_register);
const auto* tensor = shape_tensor_obj.as<TensorObj>();
CHECK(tensor != nullptr);
CHECK(tensor != nullptr) << "Expect tensor object, but received: "
<< shape_tensor_obj->GetTypeKey();
NDArray shape_tensor = tensor->data.CopyTo(cpu_ctx);
const DLTensor* dl_tensor = shape_tensor.operator->();
CHECK_EQ(dl_tensor->dtype.code, 0u);
Expand Down
Loading