Skip to content

Commit

Permalink
[RELAY][OP] Move computes to cxx, enable concat as injective (apache#…
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored and AWS Neo committed Feb 20, 2019
1 parent 783599a commit c9e42c0
Show file tree
Hide file tree
Showing 8 changed files with 145 additions and 271 deletions.
16 changes: 6 additions & 10 deletions python/tvm/relay/backend/graph_runtime_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,18 +236,14 @@ def visit_call(self, call):
self.lowered_funcs.add(loweredf)

inputs = []
tuple_arg_count = 0
# flatten tuple in the call.
for arg in call.args:
res = self.visit(arg)
if isinstance(arg.checked_type, TupleType):
tuple_arg_count += 1
inputs.append(self.visit(arg))
# We need to specially handle tuple inputs and
# tuple output cases.
# Tuple input function(e.g. concat)
if tuple_arg_count:
assert len(call.args) == 1
assert isinstance(inputs[0], tuple)
inputs = list(inputs[0])
assert isinstance(res, tuple)
inputs += res
else:
inputs.append(res)

inputs = [x.to_json() for x in inputs]
op_name = cached_func.func_name
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,11 +589,11 @@ def from_mxnet(symbol,
shape, dtype = _update_shape_dtype(shape, dtype, params)
sym = _from_mxnet_impl(symbol, shape, dtype)
elif isinstance(symbol, mx.gluon.HybridBlock):
if args_params is not None or aux_params is not None:
if arg_params is not None or aux_params is not None:
raise ValueError("arg_params and aux_params ae not used when importing HybridBlock")
params = {}
for k, v in symbol.collect_params().items():
params[k] = tvm.nd.array(v.data().asnumpy())
params[k] = _nd.array(v.data().asnumpy())
data = mx.sym.Variable("data")
sym = symbol(data)
shape, dtype = _update_shape_dtype(shape, dtype, params)
Expand Down
193 changes: 3 additions & 190 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,223 +5,37 @@
from .op import register_compute, register_schedule, register_pattern
from .op import schedule_injective, OpPattern


schedule_broadcast = schedule_injective
schedule_elemwise = schedule_injective

# log
@register_compute("log")
def log_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.log(inputs[0])]

register_schedule("log", schedule_broadcast)

# exp
@register_compute("exp")
def exp_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.exp(inputs[0])]

register_schedule("exp", schedule_broadcast)

# sqrt
@register_compute("sqrt")
def sqrt_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.sqrt(inputs[0])]

register_schedule("sqrt", schedule_broadcast)

# sigmoid
@register_compute("sigmoid")
def sigmoid_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.sigmoid(inputs[0])]

register_schedule("sigmoid", schedule_broadcast)

# floor
@register_compute("floor")
def floor_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.floor(inputs[0])]

register_schedule("floor", schedule_broadcast)

# ceil
@register_compute("ceil")
def ceil_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.ceil(inputs[0])]

register_schedule("ceil", schedule_broadcast)

# trunc
@register_compute("trunc")
def trunc_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.trunc(inputs[0])]

register_schedule("trunc", schedule_broadcast)

# round
@register_compute("round")
def round_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.round(inputs[0])]

register_schedule("round", schedule_broadcast)

# abs
@register_compute("abs")
def abs_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.abs(inputs[0])]

register_schedule("abs", schedule_broadcast)

# tanh
@register_compute("tanh")
def tanh_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.tanh(inputs[0])]

register_schedule("tanh", schedule_broadcast)

# negative
@register_compute("negative")
def negative_compute(attrs, inputs, output_type, target):
assert len(inputs) == 1
return [topi.negative(inputs[0])]

register_schedule("negative", schedule_broadcast)

# add
@register_compute("add")
def add_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.add(inputs[0], inputs[1])]

register_schedule("add", schedule_injective)

# subtract
@register_compute("subtract")
def subtract_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.subtract(inputs[0], inputs[1])]

register_schedule("add", schedule_broadcast)
register_schedule("subtract", schedule_broadcast)

# multiply
@register_compute("multiply")
def multiply_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.multiply(inputs[0], inputs[1])]

register_schedule("multiply", schedule_broadcast)

# divide
@register_compute("divide")
def divide_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.divide(inputs[0], inputs[1])]

register_schedule("divide", schedule_broadcast)

# power
@register_compute("power")
def power_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.power(inputs[0], inputs[1])]

register_schedule("power", schedule_injective)

# mod
@register_compute("mod")
def mod_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.mod(inputs[0], inputs[1])]

register_schedule("mod", schedule_broadcast)

# equal
@register_compute("equal")
def equal_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.equal(inputs[0], inputs[1])]

register_schedule("equal", schedule_broadcast)

# not_equal
@register_compute("not_equal")
def not_equal_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.not_equal(inputs[0], inputs[1])]

register_schedule("not_equal", schedule_broadcast)

# less
@register_compute("less")
def less_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.less(inputs[0], inputs[1])]

register_schedule("less", schedule_broadcast)

# less equal
@register_compute("less_equal")
def less_equal_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.less_equal(inputs[0], inputs[1])]

register_schedule("less_equal", schedule_broadcast)

# greater
@register_compute("greater")
def greater_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.greater(inputs[0], inputs[1])]

register_schedule("greater", schedule_broadcast)

# greater equal
@register_compute("greater_equal")
def greater_equal_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.greater_equal(inputs[0], inputs[1])]

register_schedule("greater_equal", schedule_broadcast)

# maximum
@register_compute("maximum")
def maximum_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.maximum(inputs[0], inputs[1])]

register_schedule("maximum_compute", schedule_injective)

# minimum
@register_compute("minimum")
def minimum_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.minimum(inputs[0], inputs[1])]

register_schedule("minimum", schedule_injective)

# right shift
@register_compute("right_shift")
def right_shift_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.right_shift(inputs[0], inputs[1])]

register_schedule("right_shift", schedule_injective)

# left shift
@register_compute("left_shift")
def left_shift_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.left_shift(inputs[0], inputs[1])]

register_schedule("left_shift", schedule_injective)

# zeros
Expand Down Expand Up @@ -273,5 +87,4 @@ def concatenate_compute(attrs, inputs, output_type, target):
return [topi.concatenate(inputs, axis=attrs.axis)]

register_schedule("concatenate", schedule_injective)
# TODO(tqchen): renable concat as injective
register_pattern("concatenate", OpPattern.OPAQUE)
register_pattern("concatenate", OpPattern.INJECTIVE)
37 changes: 17 additions & 20 deletions src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,30 +56,26 @@ class ScheduleGetter :
Op::GetAttr<FTVMSchedule>("FTVMSchedule");
auto cache_node = make_node<CachedFuncNode>();
cache_node->target = target_;

if (prim_func->params.size() == 1 &&
prim_func->params[0]->checked_type().as<TupleTypeNode>()) {
// Handle tuple input type by flattening them.
// This is the current calling convention of tuple input.
for (Var param : prim_func->params) {
Array<tvm::Tensor> inputs;
for (Type field : prim_func->params[0]->type_as<TupleTypeNode>()->fields) {
const auto* ttype = field.as<TensorTypeNode>();
CHECK(ttype != nullptr);
if (const auto* ttype = param->checked_type().as<TensorTypeNode>()) {
tvm::Tensor tensor = tvm::placeholder(
GetShape(ttype->shape), ttype->dtype);
cache_node->inputs.push_back(tensor);
inputs.push_back(tensor);
} else {
// flatten tuple of tensor type.
const auto* tuple_type = param->type_as<TupleTypeNode>();
for (Type field : tuple_type->fields) {
const auto* ttype = field.as<TensorTypeNode>();
CHECK(ttype != nullptr);
tvm::Tensor tensor = tvm::placeholder(
GetShape(ttype->shape), ttype->dtype);
cache_node->inputs.push_back(tensor);
inputs.push_back(tensor);
}
}
memo_[prim_func->params[0]] = inputs;

} else {
for (Var param : prim_func->params) {
const auto* ttype = param->type_as<TensorTypeNode>();
tvm::Tensor tensor = tvm::placeholder(
GetShape(ttype->shape), ttype->dtype);
cache_node->inputs.push_back(tensor);
memo_[param] = Array<Tensor>({tensor});
}
memo_[param] = inputs;
}
readable_name_stream_ << "fused";
cache_node->outputs = this->VisitExpr(prim_func->body);
Expand Down Expand Up @@ -161,8 +157,9 @@ class ScheduleGetter :

int op_pattern = fpattern[op];
if (op_pattern >= kCommReduce) {
CHECK(!master_op_.defined())
<< "Two complicated op in a primitive function";
CHECK(!master_op_.defined() || master_op_patetrn_ < kCommReduce)
<< "Two complicated op in a primitive function "
<< " master=" << master_op_ << " current=" << op;
}
if (op_pattern >= master_op_patetrn_) {
master_op_ = op;
Expand Down
27 changes: 12 additions & 15 deletions src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ class Interpreter :
// Marshal the arguments.
// Handle tuple input/output by flattening them.
size_t arg_len = 0;
for (size_t i = 0; i < args.size(); i++) {
for (size_t i = 0; i < args.size(); ++i) {
if (args[i].as<TensorValueNode>()) {
++arg_len;
} else {
Expand Down Expand Up @@ -242,22 +242,19 @@ class Interpreter :
<< context_ << ", but get " << arg_ctx;
};

if (func->params.size() == 1 &&
func->params[0]->checked_type().as<TupleTypeNode>()) {
// handle tuple input.
const TupleValueNode* tuple = args[0].as<TupleValueNode>();
CHECK(tuple);
for (size_t i = 0; i < tuple->fields.size(); ++i) {
fset_input(i, tuple->fields[i]);
}
} else {
CHECK_EQ(num_inputs, args.size());
// Decide the target context.
// Primitive functions always sit in the same context.
for (size_t i = 0; i < args.size(); i++) {
fset_input(i, args[i]);
int arg_counter = 0;
for (Value arg : args) {
if (arg.as<TensorValueNode>()) {
fset_input(arg_counter++, arg);
} else {
const TupleValueNode* tuple = arg.as<TupleValueNode>();
CHECK(tuple != nullptr);
for (size_t i = 0; i < tuple->fields.size(); ++i) {
fset_input(arg_counter++, tuple->fields[i]);
}
}
}

// TVM's calling convention is that the final argument is the output
// buffer. To preserve the illusion of being a functional language
// we need to allocate space for the output buffer based on the
Expand Down
Loading

0 comments on commit c9e42c0

Please sign in to comment.