Skip to content

Commit

Permalink
Change call_tir convention; Unify shape/type deduction rule (mlc-ai#94)
Browse files Browse the repository at this point in the history
* Change call_tir convention and fix shape/type deduction.

* test

* output shape as 3rd arg.

* address comments.

* lint
  • Loading branch information
YuchenJin authored and junrushao committed Feb 9, 2023
1 parent 5b939a9 commit 8df16c6
Show file tree
Hide file tree
Showing 23 changed files with 419 additions and 189 deletions.
24 changes: 18 additions & 6 deletions include/tvm/relax/attrs/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,26 @@

namespace tvm {
namespace relax {

/*!
* \brief Attributes for allocating tensor.
*/
struct AllocTensorAttrs : public tvm::AttrsNode<AllocTensorAttrs> {
DataType dtype;

TVM_DECLARE_ATTRS(AllocTensorAttrs, "relax.attrs.AllocTensorAttrs") {
TVM_ATTR_FIELD(dtype).describe("The datatype of the tensor to be allocated.");
}
};

/*!
* \brief Attributes for allocating storage.
* \brief Attributes for allocating storage on Relax VM.
*/
struct AllocStorageAttrs : public tvm::AttrsNode<AllocStorageAttrs> {
struct VMAllocStorageAttrs : public tvm::AttrsNode<VMAllocStorageAttrs> {
int device_type;
DataType dtype;

TVM_DECLARE_ATTRS(AllocStorageAttrs, "relax.attrs.AllocStorageAttrs") {
TVM_DECLARE_ATTRS(VMAllocStorageAttrs, "relax.attrs.VMAllocStorageAttrs") {
TVM_ATTR_FIELD(device_type).describe("The device type on which to allocate memory.");
TVM_ATTR_FIELD(dtype)
.describe("The dtype of the tensor to allocate.")
Expand All @@ -44,13 +56,13 @@ struct AllocStorageAttrs : public tvm::AttrsNode<AllocStorageAttrs> {
};

/*!
* \brief Attributes for allocating tensors.
* \brief Attributes for allocating tensor on Relax VM.
*/
struct AllocTensorAttrs : public tvm::AttrsNode<AllocTensorAttrs> {
struct VMAllocTensorAttrs : public tvm::AttrsNode<VMAllocTensorAttrs> {
int offset;
DataType dtype;

TVM_DECLARE_ATTRS(AllocTensorAttrs, "relax.attrs.AllocTensorAttrs") {
TVM_DECLARE_ATTRS(VMAllocTensorAttrs, "relax.attrs.VMAllocTensorAttrs") {
TVM_ATTR_FIELD(offset).describe("Storage offset to allocate the tensor.").set_default(0);
TVM_ATTR_FIELD(dtype)
.describe("The dtype of the tensor to allocate.")
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/relax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
ShapeType = ty.ShapeType
DynTensorType = ty.DynTensorType
DimType = ty.DimType
TupleType = ty.TupleType

# VM
ExecBuilder = exec_builder.ExecBuilder
Expand All @@ -65,7 +66,7 @@

# Operator
from .op.base import call_tir
from .op.op_attrs import AllocStorageAttrs, AllocTensorAttrs
from .op.op_attrs import VMAllocStorageAttrs, VMAllocTensorAttrs

# IRBuilder
BlockBuilder = block_builder.BlockBuilder
17 changes: 12 additions & 5 deletions python/tvm/relax/block_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def te_func(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle,
@R.function
def rx_func(x: Tensor[(n, m), "float32"], y: Tensor[(n, m), "float32"]) -> Tensor:
# block 0
gv = relax.call_tir((128, 128), "te_func", (x, y))
gv = relax.call_tir("te_func", (x, y), (128, 128), dtype="float32")
return gv
Example
Expand Down Expand Up @@ -411,8 +411,7 @@ def te_func(var_rxplaceholder: T.handle, var_compute: T.handle, n: T.int64) -> N
def rx_func(x: Tensor[(n,), "float32"], y: Tensor[((n + 1),), "float32"])
-> Tensor[_, "float32"]:
# block 0
gv: Tensor[((n + 1),), "float32"]
= relax.call_tir(((n + 1),), te_func, (y,), (n,))
gv = relax.call_tir(te_func, (y,), ((n + 1),), (n,), dtype="float32")
return gv
"""
primfunc_name_hint = kwargs.pop("primfunc_name_hint", None)
Expand Down Expand Up @@ -442,16 +441,24 @@ def rx_func(x: Tensor[(n,), "float32"], y: Tensor[((n + 1),), "float32"])
gvar = self.add_func(tir_func, func.__name__)

call_args = [x.op.value for x in te_args]

output_shape = (
outs[0].shape
if isinstance(te_out, tvm.te.tensor.Tensor)
else Tuple([ShapeExpr(x.shape) for x in outs])
)

output_dtype = (
te_out.dtype if isinstance(te_out, tvm.te.tensor.Tensor) else [x.dtype for x in outs]
)

# add arguments for extra parameters from unbound var
if len(unbound_tir_vars) > 0:
call = call_tir(output_shape, gvar, call_args, tir_vars=ShapeExpr(unbound_tir_vars))
call = call_tir(
gvar, call_args, output_shape, output_dtype, tir_vars=ShapeExpr(unbound_tir_vars)
)
else:
call = call_tir(output_shape, gvar, call_args)
call = call_tir(gvar, call_args, output_shape, output_dtype)
return self.emit(call)

def match_shape(self, value: Expr, pattern: List[PrimExpr]) -> Var:
Expand Down
39 changes: 28 additions & 11 deletions python/tvm/relax/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,39 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
"""The base Relax operators."""
from typing import Union, List
from typing import Union, List, Optional

from . import _ffi_api
from ..expr import Expr, ShapeExpr, Tuple, Call
from ..ty import DynTensorType, TupleType
from ...ir import Array


def call_tir(
shape: Union[Tuple, ShapeExpr, List[int]],
func: Expr,
args: Union[Tuple, List[Expr]],
tir_vars: ShapeExpr = None,
shape: Union[Tuple, ShapeExpr, List[int]],
dtype: Union[str, List[str]],
tir_vars: Optional[ShapeExpr] = None,
) -> Call:
"""
Call a destination-passing-style function and return the output.
Parameters
----------
shape: Tuple[ShapeExpr] or ShapeExpr
The output shape. Tuple[ShapeExpr] if multiple outputs, ShapeExpr is single output.
func : Expr
The destination-passing-style function, can be ExternFunc or PrimFunc.
func : ExternFunc or PrimFunc
The destination-passing-style function.
args : Tuple[Expr]
args : Union[Tuple, List[Expr]]
The input arguments.
tir_vars : ShapeExpr
shape: Union[Tuple, ShapeExpr, List[int]]
The output shape. Tuple[ShapeExpr] if multiple outputs, ShapeExpr if single output.
dtype: Union[str, List[str]]
The output dtype. List[str] if multiple outputs, str if single output.
tir_vars : ShapeExpr, optional
ShapeExpr representing a tuple of integers to unpack when calling func. Is null if not used
Returns
Expand All @@ -50,6 +56,17 @@ def call_tir(
"""
if isinstance(shape, (list, tuple, Array)):
shape = ShapeExpr(shape)

if isinstance(args, (list, tuple)):
args = Tuple(args)
return _ffi_api.call_tir(shape, func, args, tir_vars)

if isinstance(dtype, str):
output_type = DynTensorType(len(shape), dtype)
elif isinstance(dtype, (list, tuple)):
if len(shape) != len(dtype):
raise ValueError("The number of output_shape and output_dtype of call_tir mismatch")
output_type = TupleType([DynTensorType(len(x), y) for x, y in zip(shape, dtype)])
else:
raise TypeError("Not supported dtype for call_tir: " + str(type(dtype)))

return _ffi_api.call_tir(func, args, shape, output_type, tir_vars)
15 changes: 10 additions & 5 deletions python/tvm/relax/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,16 @@
import tvm._ffi


@tvm._ffi.register_object("relax.attrs.AllocStorageAttrs")
class AllocStorageAttrs(Attrs):
"""Attributes used in alloc_storage operators"""


@tvm._ffi.register_object("relax.attrs.AllocTensorAttrs")
class AllocTensorAttrs(Attrs):
"""Attributes used in alloc_tensor operators"""


@tvm._ffi.register_object("relax.attrs.VMAllocStorageAttrs")
class VMAllocStorageAttrs(Attrs):
"""Attributes used in VM alloc_storage operators"""


@tvm._ffi.register_object("relax.attrs.VMAllocTensorAttrs")
class VMAllocTensorAttrs(Attrs):
"""Attributes used in VM alloc_tensor operators"""
2 changes: 1 addition & 1 deletion python/tvm/relax/ty.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# pylint: disable=invalid-name, unused-import
"""The type nodes of the Relax language."""
import tvm._ffi
from tvm.ir import Type, TensorType, Span
from tvm.ir import Type, TensorType, TupleType, Span

from . import _ffi_api

Expand Down
79 changes: 72 additions & 7 deletions python/tvm/script/relax/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, no-else-return
# pylint: disable=invalid-name, no-else-return, too-many-nested-blocks
# pylint: disable=inconsistent-return-statements, ungrouped-imports
"""TVM Script Parser For Relax"""
from __future__ import annotations
Expand Down Expand Up @@ -973,6 +973,7 @@ def parse_call(self, expr: ast.Call) -> Union[tir.PrimExpr, relax.Expr]:
return self.transform_Subscript(expr)

op = self.transform_expr(expr.func_name)
type_args = None

if op == SpecialOp.CALL_PACKED:
extern_func = expr.params[0]
Expand Down Expand Up @@ -1015,18 +1016,73 @@ def parse_call(self, expr: ast.Call) -> Union[tir.PrimExpr, relax.Expr]:
# check call arity eagerly
if op.name == "relax.call_tir":
# call_tir is special case because last argument is optional
if len(args) != 3 and len(args) != 4:
if len(args) != op.num_inputs and len(args) != op.num_inputs - 1:
self.report_error(
f"{op.name} expects {op.num_inputs} arguments but got {len(args)}",
f"""{op.name} expects {op.num_inputs} or {op.num_inputs - 1}
arguments but got {len(args)}""",
expr.span,
)

if len(expr.keyword_params) != 1:
self.report_error(
f"""{op.name} expects exact one keyword argument with dtype as the key but
got {len(expr.keyword_params)} keyword arguments""",
expr.span,
)

if isinstance(args[0], str):
# extern function call case: rewrite identifier to an ExternFunc
args[0] = relax.ExternFunc(args[0], self.to_tvm_span(expr.params[1].span))

for key, val in expr.keyword_params.items():
assert isinstance(key, ast.Constant) and isinstance(key.value, str)
if key.value == "dtype":
val = self.transform_expr(val)
# single output case
if isinstance(val, str):
if not isinstance(args[2], relax.ShapeExpr):
self.report_error(
(
f"The number of output_shape and output_dtype of "
f"call_tir mismatch"
),
expr.span,
)
type_args = [relax.DynTensorType(rank=len(args[2].values), dtype=val)]
elif isinstance(val, Tuple):
# multiple outputs case
if not isinstance(args[2], Tuple) and len(args[2]) != len(val):
self.report_error(
(
f"The number of output_shape and output_dtype of "
f"call_tir mismatch"
),
expr.span,
)
types = []
for i in range(len(args[2])):
types.append(
relax.DynTensorType(rank=len(args[2][i].values), dtype=val[i])
)
type_args = [relax.TupleType(types)]
else:
self.report_error(
f"call_tir expects the output_dtype to be a string or a tuple",
expr.span,
)
else:
self.report_error(
(
f"{op.name} expects one keyword argument with dtype as the key but "
f"got {len(key.value)} as the key"
),
expr.span,
)

elif op.num_inputs != -1 and len(args) != op.num_inputs:
self.report_error(
f"{op.name} expects {op.num_inputs} arguments but got {len(args)}", expr.span
)
if op.name == "relax.call_tir" and isinstance(args[1], str):
# extern function call case: rewrite identifier to an ExternFunc
args[1] = relax.ExternFunc(args[1], self.to_tvm_span(expr.params[1].span))

elif isinstance(op, relay.Expr):
args = [self.transform_expr(arg) for arg in expr.params]
Expand Down Expand Up @@ -1054,7 +1110,13 @@ def parse_call(self, expr: ast.Call) -> Union[tir.PrimExpr, relax.Expr]:
attrs = None
if kwargs or not is_default:
attrs = tvm.ir.attrs.make_node(attrs_type_key, **kwargs)
return relay.Call(op, args, attrs=attrs, span=self.to_tvm_span(expr.span))

if type_args:
return relay.Call(
op, args, attrs=attrs, type_args=type_args, span=self.to_tvm_span(expr.span)
)
else:
return relay.Call(op, args, attrs=attrs, span=self.to_tvm_span(expr.span))

# Exprs:
# - ArrayLiteral
Expand Down Expand Up @@ -1090,6 +1152,9 @@ def transform_expr(self, expr: ast.Expr) -> relax.Expr:
elif isinstance(expr, ast.Tuple):
fields = [self.transform_expr(field) for field in expr.values]

if all([isinstance(f, str) for f in fields]):
return tuple(fields)

# TODO(@altanh): this check might be too weak; we really only accept integral PrimExprs
# (e.g. int constants, dim vars, and integer operations on these)

Expand Down
8 changes: 4 additions & 4 deletions src/relax/backend/vm/codegen_vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,8 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
}

// Handle attrs of the call
auto alloc_attrs = call_node->attrs.as<AllocStorageAttrs>();
ICHECK(alloc_attrs != nullptr) << "must be AllocStorageAttrs";
auto alloc_attrs = call_node->attrs.as<VMAllocStorageAttrs>();
ICHECK(alloc_attrs != nullptr) << "must be VMAllocStorageAttrs";
int device_type = alloc_attrs->device_type;
args.push_back(Instruction::Arg(Instruction::kImmediate, device_type));
DataType dtype = alloc_attrs->dtype;
Expand All @@ -254,8 +254,8 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
// Handle `self`
args.push_back(ConvertArg(call_node->args[0]));
// Handle `offset`
auto alloc_attrs = call_node->attrs.as<AllocTensorAttrs>();
ICHECK(alloc_attrs != nullptr) << "must be AllocTensorAttrs";
auto alloc_attrs = call_node->attrs.as<VMAllocTensorAttrs>();
ICHECK(alloc_attrs != nullptr) << "must be VMAllocTensorAttrs";
int offset = alloc_attrs->offset;
args.push_back(Instruction::Arg(Instruction::kImmediate, offset));
// Handle `shape`
Expand Down
25 changes: 12 additions & 13 deletions src/relax/backend/vm/vm_memory_lower.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,13 @@ namespace relax {
// Example:
// x = relax.builtin.alloc_tensor((m, n))
// -->
// gv0 = relax.call_packed("relax.vm.builtin.alloc_storage", (m * n), relax.attrs.AllocStorageAttrs)
// gv0 = relax.call_packed("relax.vm.builtin.alloc_storage", (m * n),
// relax.attrs.VMAllocStorageAttrs)
// gv1 = relax.call_packed("relax.vm.builtin.alloc_tensor", gv0, (m, n),
// relax.attrs.AllocTensorAttrs)
// relax.attrs.VMAllocTensorAttrs)

class VMMemLowerMutator : public ExprMutator {
Expr ComputeStorageSize(const Expr& shape, const Type& type) const {
DynTensorType tensor_type = Downcast<DynTensorType>(type);
DataType dtype = DataType(tensor_type->dtype);
Expr ComputeStorageSize(const Expr& shape, const DataType& dtype) const {
// Question: what if the dtype of tensor_type is unknown?
// Symbolic/static shape case
if (auto* shape_expr = shape.as<ShapeExprNode>()) {
Expand Down Expand Up @@ -79,19 +78,19 @@ class VMMemLowerMutator : public ExprMutator {
// TODO(@yuchen): memory planning
if (call->op == alloc_tensor_op) {
ShapeExpr output_shape = Downcast<ShapeExpr>(call->args[0]);

// TODO(@yuchen): Get the type of input x, options: add an attr to relax.builtin.alloc_tensor
Type tensor_type = DynTensorType(output_shape->values.size(), DataType::Float(32));
Expr storage_size = ComputeStorageSize(output_shape, tensor_type);
auto storage_attr = make_object<AllocStorageAttrs>();
storage_attr->dtype = DataType::Float(32);
auto alloc_attrs = call->attrs.as<AllocTensorAttrs>();
ICHECK(alloc_attrs != nullptr) << "must be AllocTensorAttrs";
DataType dtype = alloc_attrs->dtype;
Expr storage_size = ComputeStorageSize(output_shape, dtype);
auto storage_attr = make_object<VMAllocStorageAttrs>();
storage_attr->dtype = dtype;
storage_attr->device_type = 1;

Var storage =
builder_->Emit(Call(vm_alloc_storage_op, {storage_size}, Attrs(storage_attr)), "storage");
auto tensor_attr = make_object<AllocTensorAttrs>();
auto tensor_attr = make_object<VMAllocTensorAttrs>();
tensor_attr->offset = 0;
tensor_attr->dtype = DataType::Float(32);
tensor_attr->dtype = dtype;
Expr shape = call->args[0];
Var tensor =
builder_->Emit(Call(vm_alloc_tensor_op, {storage, shape}, Attrs(tensor_attr)), "tensor");
Expand Down
Loading

0 comments on commit 8df16c6

Please sign in to comment.