diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 602e92759624..db98a9a9d3fd 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -334,15 +334,42 @@ struct VMCompiler : ExprFunctor { return Instruction::AllocTensor(last_register, dltype, NewRegister()); } - void EmitInvokePrimitive(const Function& func, std::vector args_registers, + void EmitInvokePrimitive(const Function& func, + const std::vector& args_registers, const Type& ret_type) { + std::vector unpacked_arg_regs; std::vector allocs; - size_t return_num = 0; + + // Arity calculation must flatten tuples. + size_t arity = 0; + CHECK_EQ(func->params.size(), args_registers.size()); + for (size_t i = 0; i < func->params.size(); i++) { + auto ty = func->params[i]->checked_type(); + if (ty.as()) { + unpacked_arg_regs.push_back(args_registers[i]); + arity += 1; + } else if (auto tuple_ty = ty.as()) { + for (size_t f = 0; f < tuple_ty->fields.size(); f++) { + const auto& field = tuple_ty->fields[f]; + CHECK(field.as()) + << "only supports non-nested tuples currently " + << "found " << field; + auto dst = NewRegister(); + Emit(Instruction::GetField(args_registers[i], f, dst)); + unpacked_arg_regs.push_back(dst); + } + arity += tuple_ty->fields.size(); + } else { + LOG(FATAL) << "unsupported parameter type " << ty; + } + } + + size_t return_val_count = 0; if (const TensorTypeNode* ttype = ret_type.as()) { // Allocate space for the return tensor. auto alloc = AllocTensorFromType(ttype); allocs.push_back(alloc); - return_num = 1; + return_val_count = 1; } else if (const TupleTypeNode* ttype = ret_type.as()) { std::vector fields_registers; @@ -352,14 +379,15 @@ struct VMCompiler : ExprFunctor { allocs.push_back(AllocTensorFromType(f_type)); fields_registers.push_back(allocs.back().dst); } - return_num = ttype->fields.size(); + return_val_count = ttype->fields.size(); } else { LOG(FATAL) << "Unsupported return value type"; } + arity += return_val_count; for (auto& alloc : allocs) { Emit(alloc); - args_registers.push_back(alloc.dst); + unpacked_arg_regs.push_back(alloc.dst); } // Next generate the invoke instruction. @@ -378,17 +406,15 @@ struct VMCompiler : ExprFunctor { op_index = seen_funcs[cfunc->funcs[0]]; } - // If Tensor, 1 - // If Tuple, size of tuple - size_t arity = func->params.size() + return_num; - Emit(Instruction::InvokePacked(op_index, arity, return_num, args_registers)); - if (return_num > 1) { + Emit(Instruction::InvokePacked(op_index, arity, return_val_count, unpacked_arg_regs)); + + if (return_val_count > 1) { // return value is a tuple, we need to create a tuple std::vector fields_registers; - for (size_t i = func->params.size(); i < arity; ++i) { - fields_registers.push_back(args_registers[i]); + for (size_t i = arity - return_val_count; i < arity; ++i) { + fields_registers.push_back(unpacked_arg_regs[i]); } - Emit(Instruction::AllocDatatype(0, return_num, fields_registers, NewRegister())); + Emit(Instruction::AllocDatatype(0, return_val_count, fields_registers, NewRegister())); } } diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index bc99418d5da4..d727e776cbcd 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -49,6 +49,17 @@ def test_split(): res = veval(f, x_data) tvm.testing.assert_allclose(res.asnumpy(), np.split(x_data, 3, axis=0)[0]) +def test_split_no_fuse(): + x = relay.var('x', shape=(12,)) + y = relay.split(x, 3, axis=0).astuple() + z = relay.concatenate([relay.TupleGetItem(y, 0)], axis=0) + z = relay.annotation.stop_fusion(z) + f = relay.Function([x], z) + x_data = np.random.rand(12,).astype('float32') + res = veval(f, x_data) + tvm.testing.assert_allclose(res.asnumpy(), np.split(x_data, 3, axis=0)[0]) + + def test_id(): x = relay.var('x', shape=(10, 10)) f = relay.Function([x], x) @@ -259,6 +270,8 @@ def test_closure(): test_tuple_second() test_let_scalar() test_let_tensor() + test_split() + test_split_no_fuse() # TODO(@jroesch): restore when match is supported # test_list_constructor() test_closure()