Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][VM] Add ReshapeTensor instruction in the VM to replace the reshape op #6089

Merged
merged 5 commits into from
Jul 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions include/tvm/relay/attrs/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@ struct ShapeFuncAttrs : public tvm::AttrsNode<ShapeFuncAttrs> {
}
};

/*!
* \brief Attributes for VM reshape_tensor operator.
*/
struct ReshapeTensorAttrs : public tvm::AttrsNode<ReshapeTensorAttrs> {
Array<PrimExpr> newshape;

TVM_DECLARE_ATTRS(ReshapeTensorAttrs, "relay.attrs.ReshapeTensorAttrs") {
TVM_ATTR_FIELD(newshape).describe("The new shape of output tensor");
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_VM_H_
14 changes: 14 additions & 0 deletions include/tvm/runtime/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ enum class Opcode {
Fatal = 15U,
AllocStorage = 16U,
ShapeOf = 17U,
ReshapeTensor = 18U,
};

/*! \brief A single virtual machine instruction.
Expand Down Expand Up @@ -249,6 +250,10 @@ struct Instruction {
struct /* ShapeOf Operands */ {
RegName tensor;
} shape_of;
struct /* ReshapeTensor Operands */ {
RegName tensor;
RegName newshape;
} reshape_tensor;
};

/*!
Expand Down Expand Up @@ -401,6 +406,15 @@ struct Instruction {
*/
static Instruction ShapeOf(RegName tensor, RegName dst);

/*!
* \brief Reshape the tensor given the new shape.
* \param tensor The input tensor.
* \param newshape The shape tensor.
* \param dst The destination to store the output tensor with new shape.
* \return The reshape tensor instruction.
*/
static Instruction ReshapeTensor(RegName tensor, RegName newshape, RegName dst);

Instruction();
Instruction(const Instruction& instr);
Instruction& operator=(const Instruction& instr);
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/backend/compile_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,9 @@ def lower_call(call, inputs, target):
new_fields.append(field)
ret_type = _ty.TupleType(new_fields)

is_dyn = _ty.type_has_any(call.checked_type)
is_dyn = _ty.is_dynamic(call.checked_type)
for arg in call.args:
is_dyn = is_dyn or _ty.type_has_any(arg.checked_type)
is_dyn = is_dyn or _ty.is_dynamic(arg.checked_type)

# check if in the AutoTVM tracing mode, and disable if op is not in wanted list
env = autotvm.task.TaskExtractEnv.current
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/backend/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import tvm.runtime.vm as vm_rt
from tvm import autotvm
from tvm.relay import expr as _expr
from tvm.relay.ty import type_has_any
from tvm.relay.ty import is_dynamic
from tvm.relay.backend.interpreter import Executor
from . import _vm

Expand Down Expand Up @@ -257,7 +257,7 @@ def _make_executor(self, expr=None):
def _vm_wrapper(*args, **kwargs):
args = self._convert_args(main, args, kwargs)
ret_type = self.mod["main"].checked_type.ret_type
if type_has_any(ret_type) and "llvm" not in str(self.target) and "arm" not in str(
if is_dynamic(ret_type) and "llvm" not in str(self.target) and "arm" not in str(
self.target):
raise ValueError(
"Virtual Machine only supports dynamic graphs on CPU, got output type",
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def _make_executor(self, expr=None):
if expr:
self.mod["main"] = expr
ret_type = self.mod["main"].checked_type.ret_type
if _ty.type_has_any(ret_type):
if _ty.is_dynamic(ret_type):
raise ValueError("Graph Runtime only supports static graphs, got output type",
ret_type)
num_outputs = len(ret_type.fields) if isinstance(ret_type, _ty.TupleType) else 1
Expand Down
17 changes: 17 additions & 0 deletions python/tvm/relay/op/vm/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,20 @@ def shape_func(func, inputs, outputs, is_inputs):
The shape function expression.
"""
return _ffi_api.shape_func(func, inputs, outputs, is_inputs)


def reshape_tensor(data, shape, newshape):
"""Invoke the VM ReshapeTensor instruction.

Parameters
----------
data : tvm.relay.Expr
The input data.

shape : tvm.relay.Expr
The newshape tensor.

newshape : List[tvm.ir.PrimExpr]
The new shape.
"""
return _ffi_api.reshape_tensor(data, shape, newshape)
80 changes: 62 additions & 18 deletions python/tvm/relay/transform/memory_alloc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
A pass for manifesting explicit memory allocations.
"""
import numpy as np
from ..expr_functor import ExprMutator
from ..expr_functor import ExprVisitor, ExprMutator
from ..scope_builder import ScopeBuilder
from . import transform
from .. import op
Expand All @@ -38,13 +38,39 @@ def is_primitive(call):
return hasattr(call, 'op') and hasattr(call.op, 'attrs') and \
hasattr(call.op.attrs, 'Primitive') and int(call.op.attrs.Primitive) == 1


class CheckReshapeOnly(ExprVisitor):
"""A pass to check if the fused op contains only reshape ops."""
def __init__(self):
super().__init__()
self._reshape_ops = [op.get("reshape"), op.get("contrib_reverse_reshape"),
op.get("dyn.reshape")]
self.reshape_only = True

def visit_call(self, call):
if not self.reshape_only:
return
if call.op not in self._reshape_ops:
self.reshape_only = False
for arg in call.args:
self.visit(arg)


def is_reshape_only(func):
"""Check if the primitive function contains only reshape ops."""
check = CheckReshapeOnly()
check.visit(func)
return check.reshape_only


class ManifestAllocPass(ExprMutator):
"""A pass for explicitly manifesting all memory allocations in Relay."""

def __init__(self, target_host):
self.invoke_tvm = op.vm.invoke_tvm_op
self.shape_func = op.vm.shape_func
self.shape_of = op.vm.shape_of
self.reshape_tensor = op.vm.reshape_tensor
self.scopes = [ScopeBuilder()]
self.target_host = target_host
self.default_context = cpu(0)
Expand Down Expand Up @@ -121,8 +147,8 @@ def visit_let(self, let):

return scope.get()

def dynamic_invoke(self, scope, func, ins, new_args, out_types, ret_type):
"""Generate the code for invoking a TVM op with a dynamic shape."""
def emit_shape_func(self, scope, func, new_args):
"""Insert the shape function given a primitive function."""
shape_func_ins = []
engine = compile_engine.get()
cfunc = engine.lower_shape_func(func, self.target_host)
Expand Down Expand Up @@ -165,9 +191,14 @@ def dynamic_invoke(self, scope, func, ins, new_args, out_types, ret_type):
expr.Tuple(out_shapes), is_inputs)

scope.let("shape_func", shape_call)
return out_shapes

def dynamic_invoke(self, scope, func, ins, new_args, out_types, ret_type):
"""Generate the code for invoking a TVM op with a dynamic shape."""
out_shapes = self.emit_shape_func(scope, func, new_args)

storages = []
for out_shape, out_type in zip(out_shapes, out_types):
for i, (out_shape, out_type) in enumerate(zip(out_shapes, out_types)):
size = self.compute_storage_in_relay(
out_shape, out_type.dtype)
alignment = self.compute_alignment(out_type.dtype)
Expand All @@ -191,8 +222,18 @@ def dynamic_invoke(self, scope, func, ins, new_args, out_types, ret_type):
scope.let("", invoke)
return to_tuple_type(ret_type, tuple_outs.fields)

def emit_reshape_tensor(self, scope, func, new_args, ret_type):
if self.is_dynamic(ret_type):
out_shapes = self.emit_shape_func(scope, func, new_args)
shape_expr = out_shapes[0]
else:
# constant output shape
shape = [int(dim) for dim in ret_type.shape]
shape_expr = expr.const(shape, dtype=self.compute_dtype)
return self.reshape_tensor(new_args[0], shape_expr, ret_type.shape)

def is_dynamic(self, ret_type):
is_dynamic = ty.type_has_any(ret_type)
is_dynamic = ty.is_dynamic(ret_type)
# TODO(@jroesch): restore this code, more complex then it seems
# for arg in call.args:
# is_dynamic = is_dynamic or arg.checked_type.is_dynamic()
Expand All @@ -208,22 +249,25 @@ def visit_call(self, call):
ret_type = call.checked_type
out_types = flatten_tuple_type(ret_type)

if is_reshape_only(call.op):
# Handle fused op that only contains reshape op
return self.emit_reshape_tensor(scope, call.op, new_args, ret_type)

if self.is_dynamic(ret_type):
# Handle dynamic case.
return self.dynamic_invoke(scope, call.op, ins, new_args, out_types, ret_type)
else:
# Handle static case.
outs = []
for i, out_ty in enumerate(out_types):
out = self.make_static_allocation(scope, out_ty, i)
outs.append(out)

output = expr.Tuple(outs)
invoke = self.invoke_tvm(call.op, ins, output)
scope.let("", invoke)
return to_tuple_type(ret_type, output.fields)
else:
return super().visit_call(call)

# Handle static case.
outs = []
for i, out_ty in enumerate(out_types):
out = self.make_static_allocation(scope, out_ty, i)
outs.append(out)

output = expr.Tuple(outs)
invoke = self.invoke_tvm(call.op, ins, output)
scope.let("", invoke)
return to_tuple_type(ret_type, output.fields)
return super().visit_call(call)


@transform.function_pass(opt_level=0)
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/ty.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@

Any = _ffi_api.Any

def type_has_any(tensor_type):
"""Check whether type has any as a shape.
def is_dynamic(tensor_type):
"""Check whether type has any or symbolic variables as a shape.

tensor_type : Type
The type to be inspected
Expand Down
2 changes: 1 addition & 1 deletion src/relay/analysis/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ struct IsDynamicVisitor : public TypeVisitor {
bool is_dyn{false};
void VisitType_(const TensorTypeNode* tt) {
for (auto dim : tt->shape) {
if (dim.as<AnyNode>()) {
if (dim.as<tir::IntImmNode>() == nullptr) {
is_dyn = true;
break;
}
Expand Down
10 changes: 10 additions & 0 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
case Opcode::AllocClosure:
case Opcode::AllocStorage:
case Opcode::ShapeOf:
case Opcode::ReshapeTensor:
case Opcode::Move:
case Opcode::InvokeClosure:
last_register_ = instr.dst;
Expand Down Expand Up @@ -601,6 +602,15 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
this->VisitExpr(args[0]);
Emit(Instruction::ShapeOf(last_register_, NewRegister()));
})
.Match("vm.reshape_tensor",
[this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
CHECK_EQ(args.size(), 2u);
this->VisitExpr(args[0]);
auto tensor_reg = last_register_;
this->VisitExpr(args[1]);
auto shape_reg = last_register_;
Emit(Instruction::ReshapeTensor(tensor_reg, shape_reg, NewRegister()));
})
.Match("memory.kill",
[](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
LOG(FATAL) << "memory.kill is not yet supported";
Expand Down
2 changes: 2 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,8 @@ bool ReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
infer_dim = indexdiv(infer_dim, oshape[i]);
}
}
arith::Analyzer ana;
infer_dim = ana.Simplify(infer_dim);
oshape.Set(infer_idx, infer_dim);
}

Expand Down
37 changes: 37 additions & 0 deletions src/relay/op/vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
namespace tvm {
namespace relay {

// vm.shape_func
TVM_REGISTER_NODE_TYPE(ShapeFuncAttrs);

RELAY_REGISTER_OP("vm.shape_of")
Expand Down Expand Up @@ -133,6 +134,7 @@ RELAY_REGISTER_OP("vm.shape_func")
return {topi::identity(inputs[0])};
});

// vm.invoke_tvm_op
bool InvokeTVMOpRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4u);
Expand Down Expand Up @@ -181,5 +183,40 @@ RELAY_REGISTER_OP("vm.invoke_tvm_op")
return {topi::identity(inputs[0])};
});

// vm.reshape
TVM_REGISTER_NODE_TYPE(ReshapeTensorAttrs);

bool ReshapeTensorRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3u);
auto reshape_attrs = attrs.as<ReshapeTensorAttrs>();
CHECK(reshape_attrs);
auto tt = types[0].as<TensorTypeNode>();
CHECK(tt) << "input must be tensor type";
reporter->Assign(types[2], TensorType(reshape_attrs->newshape, tt->dtype));
return true;
}

RELAY_REGISTER_OP("vm.reshape_tensor")
.describe(R"code(Use VM reshape_tensor instruction to reshape the tensor.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor")
.add_argument("shape", "Tensor", "The output shape tensor")
.add_type_rel("ReshapeTensor", ReshapeTensorRel)
.set_support_level(10)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);

TVM_REGISTER_GLOBAL("relay.op.vm.reshape_tensor")
.set_body_typed([](Expr data, Expr shape, Array<PrimExpr> newshape) {
static const Op& op = Op::Get("vm.reshape_tensor");
auto attrs = make_object<ReshapeTensorAttrs>();
attrs->newshape = std::move(newshape);
return Call(op, {data, shape}, Attrs(attrs), {});
});

} // namespace relay
} // namespace tvm
10 changes: 10 additions & 0 deletions src/runtime/vm/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,11 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) {
fields.assign({instr.shape_of.tensor, instr.dst});
break;
}
case Opcode::ReshapeTensor: {
// Number of fields = 3
fields.assign({instr.reshape_tensor.tensor, instr.reshape_tensor.newshape, instr.dst});
break;
}
default:
LOG(FATAL) << "Invalid opcode" << static_cast<int>(instr.op);
break;
Expand Down Expand Up @@ -693,6 +698,11 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) {
DCHECK_EQ(instr.fields.size(), 2U);
return Instruction::ShapeOf(instr.fields[0], instr.fields[1]);
}
case Opcode::ReshapeTensor: {
// Number of fields = 3
DCHECK_EQ(instr.fields.size(), 3U);
return Instruction::ReshapeTensor(instr.fields[0], instr.fields[1], instr.fields[2]);
}
default:
LOG(FATAL) << "Invalid opcode" << instr.opcode;
return Instruction();
Expand Down
Loading