Skip to content

Commit

Permalink
support nested tuple in CallNode's return type
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Nov 3, 2020
1 parent 33085d1 commit 66225ed
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 15 deletions.
4 changes: 1 addition & 3 deletions src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,10 +270,8 @@ class ScheduleGetter : public backend::MemoizedExprTranslator<Array<te::Tensor>>
Array<te::Tensor> VisitExpr_(const TupleNode* op) final {
Array<te::Tensor> fields;
for (Expr field : op->fields) {
ICHECK(field->checked_type().as<TensorTypeNode>()) << "Only allow Tuple of Tensor";
Array<te::Tensor> res = VisitExpr(field);
ICHECK_EQ(res.size(), 1);
fields.push_back(res[0]);
fields.insert(fields.end(), res.begin(), res.end());
}
return fields;
}
Expand Down
7 changes: 5 additions & 2 deletions src/relay/backend/graph_plan_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <tvm/tir/op.h>

#include "../../support/arena.h"
#include "../op/memory/utils.h"

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -145,8 +146,10 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor {
std::vector<StorageToken*> tokens;
int device_type =
node_device_map_.count(GetRef<Expr>(op)) ? node_device_map_[GetRef<Expr>(op)]->value : 0;
if (const auto* tuple_type = op->checked_type().as<TupleTypeNode>()) {
for (Type t : tuple_type->fields) {
const Type checked_type = op->checked_type();
if (checked_type.as<TupleTypeNode>()) {
std::vector<TensorType> fields = FlattenTupleType(checked_type);
for (TensorType t : fields) {
const auto* ttype = t.as<TensorTypeNode>();
ICHECK(ttype);
StorageToken* token = arena_->make<StorageToken>();
Expand Down
12 changes: 5 additions & 7 deletions src/relay/backend/graph_runtime_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <string>
#include <vector>

#include "../op/memory/utils.h"
#include "compile_engine.h"
#include "utils.h"

Expand Down Expand Up @@ -273,14 +274,11 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator<std::vector<G
std::vector<GraphNodeRef> ret;
ShapeVector shape;
std::vector<std::string> dtype;
std::vector<TensorType> fields = FlattenTupleType(checked_type);
for (size_t i = 0; i < tuple_type->fields.size(); ++i) {
if (const auto* typ = tuple_type->fields[i].as<TensorTypeNode>()) {
ret.push_back(GraphNodeRef(node_id, i));
shape.emplace_back(_ShapeToJSON(typ->shape));
dtype.emplace_back(DType2String(typ->dtype));
} else {
LOG(FATAL) << "type " << checked_type->GetTypeKey() << " not supported";
}
ret.push_back(GraphNodeRef(node_id, i));
shape.emplace_back(_ShapeToJSON(fields[i]->shape));
dtype.emplace_back(DType2String(fields[i]->dtype));
}
ICHECK_EQ(node->Type(), kGraphOpNode);
auto op_nd = std::dynamic_pointer_cast<GraphOpNode>(node);
Expand Down
1 change: 1 addition & 0 deletions src/relay/op/memory/memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "../../transforms/infer_layout_utils.h"
#include "../op_common.h"
#include "../type_relations.h"
#include "utils.h"

namespace tvm {
namespace relay {
Expand Down
39 changes: 39 additions & 0 deletions src/relay/op/memory/utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/relay/op/memory/memory.h
* \brief Utilities related to memory allocation
*/

#ifndef TVM_RELAY_OP_MEMORY_UTILS_H_
#define TVM_RELAY_OP_MEMORY_UTILS_H_

#include <tvm/relay/type.h>

#include <vector>

namespace tvm {
namespace relay {

std::vector<TensorType> FlattenTupleType(const Type& type);

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_MEMORY_UTILS_H_
41 changes: 38 additions & 3 deletions tests/python/relay/test_backend_graph_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
from tvm import relay
from tvm.contrib import graph_runtime
from tvm.relay.op import add
from tvm.relay import transform
import tvm.testing


# @tq, @jr should we put this in testing ns?
def check_rts(expr, args, expected_result, mod=None):
"""
Expand Down Expand Up @@ -184,7 +186,7 @@ def unit_numpy(X, W):
tvm.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5)


def test_compile_nested_tuples():
def test_return_nested_tuples():
x = relay.var("x", shape=(10,))
x1 = x + relay.const(1.0)
x2 = x1 + relay.const(1.0)
Expand All @@ -193,7 +195,9 @@ def test_compile_nested_tuples():
out = relay.Tuple([x1, relay.Tuple([relay.Tuple([x2, x3]), x4])])
func = relay.Function([x], out)

graph, lib, _ = relay.build(tvm.IRModule.from_expr(func), "llvm")
with tvm.transform.PassContext(opt_level=3):
graph, lib, _ = relay.build(tvm.IRModule.from_expr(func), "llvm")

mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))

x_data = np.random.uniform(size=(10,)).astype(np.float32)
Expand All @@ -209,11 +213,42 @@ def test_compile_nested_tuples():
ref = ref + 1


def test_compile_nested_tuples_call_output():
mod = tvm.IRModule()
x = relay.var("x", shape=(10, 10))
a_split = relay.split(x, 2)
a_split_0 = relay.TupleGetItem(a_split.astuple(), 0)
a_split_1 = relay.TupleGetItem(a_split.astuple(), 1)
tuple_out = relay.Tuple((a_split_0, relay.Tuple([a_split_1])))
func0 = relay.Function([x], tuple_out)
func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))

data = relay.var("x", shape=(10, 10))
call = relay.Call(func0, [data])
mod["main"] = relay.Function([data], call)

with tvm.transform.PassContext(opt_level=3):
graph, lib, _ = relay.build(mod, "llvm")

mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
x_data = np.random.uniform(size=(10, 10)).astype(np.float32)
mod.set_input(x=x_data)
mod.run()

assert mod.get_num_outputs() == 2

ref = np.split(x_data, 2)
for i in range(mod.get_num_outputs()):
out = mod.get_output(i).asnumpy()
tvm.testing.assert_allclose(out, ref[i], rtol=1e-5, atol=1e-5)


if __name__ == "__main__":
test_plan_memory()
test_with_params()
test_add_op_scalar()
test_add_op_tensor()
test_add_op_broadcast()
test_gru_like()
test_compile_nested_tuples()
test_return_nested_tuples()
test_compile_nested_tuples_call_output()

0 comments on commit 66225ed

Please sign in to comment.