Skip to content

Commit

Permalink
[Relay][VM] Fix code generation for packed functions + tuples (apache…
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch authored and Wei Chen committed Jun 26, 2019
1 parent 1dbe83d commit cd3248e
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 13 deletions.
52 changes: 39 additions & 13 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -334,15 +334,42 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
return Instruction::AllocTensor(last_register, dltype, NewRegister());
}

void EmitInvokePrimitive(const Function& func, std::vector<Index> args_registers,
void EmitInvokePrimitive(const Function& func,
const std::vector<Index>& args_registers,
const Type& ret_type) {
std::vector<Index> unpacked_arg_regs;
std::vector<Instruction> allocs;
size_t return_num = 0;

// Arity calculation must flatten tuples.
size_t arity = 0;
CHECK_EQ(func->params.size(), args_registers.size());
for (size_t i = 0; i < func->params.size(); i++) {
auto ty = func->params[i]->checked_type();
if (ty.as<TensorTypeNode>()) {
unpacked_arg_regs.push_back(args_registers[i]);
arity += 1;
} else if (auto tuple_ty = ty.as<TupleTypeNode>()) {
for (size_t f = 0; f < tuple_ty->fields.size(); f++) {
const auto& field = tuple_ty->fields[f];
CHECK(field.as<TensorTypeNode>())
<< "only supports non-nested tuples currently "
<< "found " << field;
auto dst = NewRegister();
Emit(Instruction::GetField(args_registers[i], f, dst));
unpacked_arg_regs.push_back(dst);
}
arity += tuple_ty->fields.size();
} else {
LOG(FATAL) << "unsupported parameter type " << ty;
}
}

size_t return_val_count = 0;
if (const TensorTypeNode* ttype = ret_type.as<TensorTypeNode>()) {
// Allocate space for the return tensor.
auto alloc = AllocTensorFromType(ttype);
allocs.push_back(alloc);
return_num = 1;
return_val_count = 1;
} else if (const TupleTypeNode* ttype = ret_type.as<TupleTypeNode>()) {
std::vector<Index> fields_registers;

Expand All @@ -352,14 +379,15 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
allocs.push_back(AllocTensorFromType(f_type));
fields_registers.push_back(allocs.back().dst);
}
return_num = ttype->fields.size();
return_val_count = ttype->fields.size();
} else {
LOG(FATAL) << "Unsupported return value type";
}

arity += return_val_count;
for (auto& alloc : allocs) {
Emit(alloc);
args_registers.push_back(alloc.dst);
unpacked_arg_regs.push_back(alloc.dst);
}

// Next generate the invoke instruction.
Expand All @@ -378,17 +406,15 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
op_index = seen_funcs[cfunc->funcs[0]];
}

// If Tensor, 1
// If Tuple, size of tuple
size_t arity = func->params.size() + return_num;
Emit(Instruction::InvokePacked(op_index, arity, return_num, args_registers));
if (return_num > 1) {
Emit(Instruction::InvokePacked(op_index, arity, return_val_count, unpacked_arg_regs));

if (return_val_count > 1) {
// return value is a tuple, we need to create a tuple
std::vector<Index> fields_registers;
for (size_t i = func->params.size(); i < arity; ++i) {
fields_registers.push_back(args_registers[i]);
for (size_t i = arity - return_val_count; i < arity; ++i) {
fields_registers.push_back(unpacked_arg_regs[i]);
}
Emit(Instruction::AllocDatatype(0, return_num, fields_registers, NewRegister()));
Emit(Instruction::AllocDatatype(0, return_val_count, fields_registers, NewRegister()));
}
}

Expand Down
13 changes: 13 additions & 0 deletions tests/python/relay/test_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,17 @@ def test_split():
res = veval(f, x_data)
tvm.testing.assert_allclose(res.asnumpy(), np.split(x_data, 3, axis=0)[0])

def test_split_no_fuse():
x = relay.var('x', shape=(12,))
y = relay.split(x, 3, axis=0).astuple()
z = relay.concatenate([relay.TupleGetItem(y, 0)], axis=0)
z = relay.annotation.stop_fusion(z)
f = relay.Function([x], z)
x_data = np.random.rand(12,).astype('float32')
res = veval(f, x_data)
tvm.testing.assert_allclose(res.asnumpy(), np.split(x_data, 3, axis=0)[0])


def test_id():
x = relay.var('x', shape=(10, 10))
f = relay.Function([x], x)
Expand Down Expand Up @@ -259,6 +270,8 @@ def test_closure():
test_tuple_second()
test_let_scalar()
test_let_tensor()
test_split()
test_split_no_fuse()
# TODO(@jroesch): restore when match is supported
# test_list_constructor()
test_closure()

0 comments on commit cd3248e

Please sign in to comment.