Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][VM] Fix code generation for packed functions + tuples #3287

Merged
merged 5 commits into from
Jun 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 39 additions & 13 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -334,15 +334,42 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
return Instruction::AllocTensor(last_register, dltype, NewRegister());
}

void EmitInvokePrimitive(const Function& func, std::vector<Index> args_registers,
void EmitInvokePrimitive(const Function& func,
const std::vector<Index>& args_registers,
const Type& ret_type) {
std::vector<Index> unpacked_arg_regs;
std::vector<Instruction> allocs;
size_t return_num = 0;

// Arity calculation must flatten tuples.
size_t arity = 0;
CHECK_EQ(func->params.size(), args_registers.size());
for (size_t i = 0; i < func->params.size(); i++) {
auto ty = func->params[i]->checked_type();
if (ty.as<TensorTypeNode>()) {
unpacked_arg_regs.push_back(args_registers[i]);
arity += 1;
} else if (auto tuple_ty = ty.as<TupleTypeNode>()) {
for (size_t f = 0; f < tuple_ty->fields.size(); f++) {
const auto& field = tuple_ty->fields[f];
CHECK(field.as<TensorTypeNode>())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you extend it and allow recursive tuple? it is more uniform this way.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really want to do it right now, recursive is a pain given the way the code generator is written right now. This is mostly to fix CI breakage, I would like to make some updates to compiler likely after FCRC tutorial.

<< "only supports non-nested tuples currently "
<< "found " << field;
auto dst = NewRegister();
Emit(Instruction::GetField(args_registers[i], f, dst));
unpacked_arg_regs.push_back(dst);
}
arity += tuple_ty->fields.size();
} else {
LOG(FATAL) << "unsupported parameter type " << ty;
}
}

size_t return_val_count = 0;
if (const TensorTypeNode* ttype = ret_type.as<TensorTypeNode>()) {
// Allocate space for the return tensor.
auto alloc = AllocTensorFromType(ttype);
allocs.push_back(alloc);
return_num = 1;
return_val_count = 1;
} else if (const TupleTypeNode* ttype = ret_type.as<TupleTypeNode>()) {
std::vector<Index> fields_registers;

Expand All @@ -352,14 +379,15 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
allocs.push_back(AllocTensorFromType(f_type));
fields_registers.push_back(allocs.back().dst);
}
return_num = ttype->fields.size();
return_val_count = ttype->fields.size();
} else {
LOG(FATAL) << "Unsupported return value type";
}

arity += return_val_count;
for (auto& alloc : allocs) {
Emit(alloc);
args_registers.push_back(alloc.dst);
unpacked_arg_regs.push_back(alloc.dst);
}

// Next generate the invoke instruction.
Expand All @@ -378,17 +406,15 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
op_index = seen_funcs[cfunc->funcs[0]];
}

// If Tensor, 1
// If Tuple, size of tuple
size_t arity = func->params.size() + return_num;
Emit(Instruction::InvokePacked(op_index, arity, return_num, args_registers));
if (return_num > 1) {
Emit(Instruction::InvokePacked(op_index, arity, return_val_count, unpacked_arg_regs));

if (return_val_count > 1) {
// return value is a tuple, we need to create a tuple
std::vector<Index> fields_registers;
for (size_t i = func->params.size(); i < arity; ++i) {
fields_registers.push_back(args_registers[i]);
for (size_t i = arity - return_val_count; i < arity; ++i) {
fields_registers.push_back(unpacked_arg_regs[i]);
}
Emit(Instruction::AllocDatatype(0, return_num, fields_registers, NewRegister()));
Emit(Instruction::AllocDatatype(0, return_val_count, fields_registers, NewRegister()));
}
}

Expand Down
13 changes: 13 additions & 0 deletions tests/python/relay/test_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,17 @@ def test_split():
res = veval(f, x_data)
tvm.testing.assert_allclose(res.asnumpy(), np.split(x_data, 3, axis=0)[0])

def test_split_no_fuse():
x = relay.var('x', shape=(12,))
y = relay.split(x, 3, axis=0).astuple()
z = relay.concatenate([relay.TupleGetItem(y, 0)], axis=0)
z = relay.annotation.stop_fusion(z)
f = relay.Function([x], z)
x_data = np.random.rand(12,).astype('float32')
res = veval(f, x_data)
tvm.testing.assert_allclose(res.asnumpy(), np.split(x_data, 3, axis=0)[0])


def test_id():
x = relay.var('x', shape=(10, 10))
f = relay.Function([x], x)
Expand Down Expand Up @@ -259,6 +270,8 @@ def test_closure():
test_tuple_second()
test_let_scalar()
test_let_tensor()
test_split()
test_split_no_fuse()
# TODO(@jroesch): restore when match is supported
# test_list_constructor()
test_closure()