Skip to content

Commit

Permalink
Cleanup type pack and unpack for tuples.
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch committed Apr 13, 2020
1 parent 5958d60 commit c9b82e3
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 17 deletions.
4 changes: 2 additions & 2 deletions include/tvm/relay/attrs/device_copy.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ struct DeviceCopyAttrs : public tvm::AttrsNode<DeviceCopyAttrs> {
TVM_DECLARE_ATTRS(DeviceCopyAttrs, "relay.attrs.DeviceCopyAttrs") {
TVM_ATTR_FIELD(src_dev_type)
.describe(
"The virutal device/context type where the op copies data from.")
"The virtual device/context type where the op copies data from.")
.set_default(0);
TVM_ATTR_FIELD(dst_dev_type)
.describe(
"The virutal device/context type where the op copies data to.")
"The virtual device/context type where the op copies data to.")
.set_default(0);
}
};
Expand Down
24 changes: 24 additions & 0 deletions include/tvm/relay/attrs/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,30 @@
namespace tvm {
namespace relay {

/*!
* \brief Options for allocating storage.
*/
struct AllocStorageAttrs : public tvm::AttrsNode<AllocStorageAttrs> {
DataType dtype;
int device_id;
int device_type;

TVM_DECLARE_ATTRS(AllocStorageAttrs, "relay.attrs.AllocStorageAttrs") {
TVM_ATTR_FIELD(dtype)
.describe(
"The dtype of the tensor to allocate.")
.set_default(DataType::Float(32, 1));
TVM_ATTR_FIELD(device_id)
.describe(
"The device id on which to allocate memory."
);
TVM_ATTR_FIELD(device_type)
.describe(
"The device type on which to allocate memory."
);
}
};

/*!
* \brief Options for allocating tensors.
*/
Expand Down
16 changes: 14 additions & 2 deletions python/tvm/relay/op/memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ def invoke_tvm_op(func, inputs, outputs):
Parameters
----------
func : tvm.relay.Expr
The input expr.
inputs : tvm.relay.Expr
A tuple of the inputs to pass to the TVM function.
Expand Down Expand Up @@ -59,7 +62,7 @@ def alloc_tensor(storage, shape, dtype='float32', assert_shape=None):
"""
return _make.alloc_tensor(storage, shape, dtype, assert_shape)

def alloc_storage(size, alignment, dtype_hint='float32'):
def alloc_storage(size, alignment, ctx, dtype_hint='float32'):
"""Allocate a piece of tensor storage.
Parameters
Expand All @@ -76,7 +79,7 @@ def alloc_storage(size, alignment, dtype_hint='float32'):
result : tvm.relay.Expr
The alloc_storage expression.
"""
return _make.alloc_storage(size, alignment, dtype_hint)
return _make.alloc_storage(size, alignment, dtype_hint, ctx)

def shape_func(func, inputs, outputs, dependent=False):
"""Invoke the shape function of the passed function.
Expand All @@ -96,3 +99,12 @@ def shape_func(func, inputs, outputs, dependent=False):
The shape function expression.
"""
return _make.shape_func(func, inputs, outputs, dependent)

def flatten_tuple_type(ty):
return _make.FlattenTupleType(ty)

def from_tuple_type(ty, expr):
return _make.FromTupleType(ty, expr)

def to_tuple_type(ty, exprs):
return _make.ToTupleType(ty, exprs)
103 changes: 90 additions & 13 deletions src/relay/op/memory/memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,19 @@
namespace tvm {
namespace relay {

TVM_REGISTER_NODE_TYPE(AllocStorageAttrs);
TVM_REGISTER_NODE_TYPE(AllocTensorAttrs);
TVM_REGISTER_NODE_TYPE(ShapeFuncAttrs);

// The passing value in attrs and args doesn't seem super great.
// We should consider a better solution, i.e the type relation
// being able to see the arguments as well?
TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_storage")
.set_body_typed([](Expr size, Expr alignment, DataType dtype) {
auto attrs = make_object<AllocTensorAttrs>();
.set_body_typed([](Expr size, Expr alignment, DataType dtype, TVMContext ctx) {
auto attrs = make_object<AllocStorageAttrs>();
attrs->dtype = dtype;
attrs->device_id = ctx.device_id;
attrs->device_type = ctx.device_type;
static const Op& op = Op::Get("memory.alloc_storage");
return Call(op, {size, alignment}, Attrs(attrs), {});
});
Expand Down Expand Up @@ -209,10 +212,20 @@ bool InvokeTVMOPRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
}

TVM_REGISTER_GLOBAL("relay.op.memory._make.invoke_tvm_op")
.set_body_typed(
[](Expr func, Expr inputs, Expr outputs) {
return Call(Op::Get("memory.invoke_tvm_op"), {func, inputs, outputs}, Attrs());
});
.set_body_typed(
[](Expr func, Expr inputs, Expr outputs) {
Attrs attrs;
// Record the attribute of the input expression. The attribute of the master
// op in a fused function is used.
if (const auto* fn = func.as<FunctionNode>()) {
if (const auto* cn = fn->body.as<CallNode>()) {
attrs = cn->attrs;
}
} else if (const auto* cn = func.as<CallNode>()) {
attrs = cn->attrs;
}
return Call(Op::Get("memory.invoke_tvm_op"), {func, inputs, outputs}, attrs);
});

RELAY_REGISTER_OP("memory.invoke_tvm_op")
.describe(R"code(Invoke an operation compiled by TVM.)code" TVM_ADD_FILELINE)
Expand Down Expand Up @@ -265,29 +278,93 @@ TVM_REGISTER_GLOBAL("relay.op.memory._make.shape_func")
return Call(op, {func, inputs, outputs}, Attrs(attrs), {});
});

static void FlattenTypeAux(const Type& type, std::vector<TensorType>* out) {
// # TODO(@jroesch): port to c++ and unify with existing code
// class LinearizeRetType:
// """A linear view of a Relay type, handles a linear order
// for nested tuples, and tensor types.
// """

// def __init__(self, typ):
// """Initialize the linearizer."""
// self.typ = typ

// def unpack(self):
// """Return the linear representation of the type."""
// def _unpack(typ, out):
// # TODO(@jroesch): replace with new flattening pass
// if isinstance(typ, ty.TensorType):
// out.append(typ)
// elif isinstance(typ, ty.TupleType):
// for field_ty in typ.fields:
// _unpack(field_ty, out)
// else:
// raise Exception("unsupported Relay type: {0}".format(typ))

// output = []
// _unpack(self.typ, output)
// return output

// def pack(self, seq):
// """Repack a linear type as a nested type."""
// def _pack(value, typ, out):
// if isinstance(typ, ty.TensorType):
// out.append(value)
// elif isinstance(typ, ty.TupleType):
// tuple_out = []
// for i, field_ty in enumerate(typ.fields):
// _pack(value[i], field_ty, tuple_out)
// out.append(expr.Tuple(tuple_out))
// else:
// raise Exception("unsupported Relay type: {0}".format(typ))

// if len(seq) == 1:
// return seq[0]
// else:
// out = []
// _pack(seq, self.typ, out)
// assert len(out) == 1, "must return fully packed type"
// return out[0]

static void FlattenTupleTypeAux(const Type& type, std::vector<TensorType>* out) {
if (auto tt = type.as<TensorTypeNode>()) {
out->push_back(GetRef<TensorType>(tt));
} else if (auto tuple_ty = type.as<TupleTypeNode>()) {
for (auto field : tuple_ty->fields) {
FlattenTypeAux(field, out);
FlattenTupleTypeAux(field, out);
}
} else {
LOG(FATAL) << "unsupported " << type;
}
}

std::vector<TensorType> FlattenType(const Type& type) {
std::vector<TensorType> FlattenTupleType(const Type& type) {
std::vector<TensorType> out;
FlattenTypeAux(type, &out);
FlattenTupleTypeAux(type, &out);
return out;
}

Expr PackByType(const Type& t, const Array<Expr>& exprs) {
Array<Expr> FromTupleType(const Type& type, const Expr& expr) {
LOG(FATAL) << "NYI";
}

// Pack the sequence of expressions according to the provided TupleType.
Expr ToTupleType(const Type& t, const Array<Expr>& exprs) {
LOG(FATAL) << "NYI";
return Expr();
}

TVM_REGISTER_GLOBAL("relay.op.memory._make.FlattenTupleType")
.set_body_typed([](Type type) {
auto types = FlattenTupleType(type);
return Array<Type>(types.begin(), types.end());
});

TVM_REGISTER_GLOBAL("relay.op.memory._make.FromTupleType")
.set_body_typed(FromTupleType);

TVM_REGISTER_GLOBAL("relay.op.memory._make.ToTupleType")
.set_body_typed(ToTupleType);

bool ShapeFuncRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4u);
Expand All @@ -298,8 +375,8 @@ bool ShapeFuncRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
CHECK(func_type != nullptr);

auto tuple = TupleType(func_type->arg_types);
auto in_types = FlattenType(tuple);
auto out_types = FlattenType(func_type->ret_type);
auto in_types = FlattenTupleType(tuple);
auto out_types = FlattenTupleType(func_type->ret_type);

Array<Type> shape_func_ins, shape_func_outs;
for (size_t i = 0; i < in_types.size(); i++) {
Expand Down

0 comments on commit c9b82e3

Please sign in to comment.