Skip to content

Commit

Permalink
[Unity] Relax op: datatype (#13986)
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterJH5574 authored and tqchen committed Mar 4, 2023
1 parent 7489379 commit c8869b7
Show file tree
Hide file tree
Showing 9 changed files with 358 additions and 0 deletions.
44 changes: 44 additions & 0 deletions include/tvm/relax/attrs/datatype.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relax/attrs/datatype.h
* \brief Attributes for datatype operators.
*/
#ifndef TVM_RELAX_ATTRS_DATATYPE_H_
#define TVM_RELAX_ATTRS_DATATYPE_H_

#include <tvm/relax/expr.h>

namespace tvm {
namespace relax {

/*! \brief Attributes used in astype operator */
struct AstypeAttrs : public tvm::AttrsNode<AstypeAttrs> {
DataType dtype;

TVM_DECLARE_ATTRS(AstypeAttrs, "relax.attrs.AstypeAttrs") {
TVM_ATTR_FIELD(dtype).describe("Target data type");
}
}; // struct AstypeAttrs.

} // namespace relax
} // namespace tvm

#endif // TVM_RELAX_ATTRS_DATATYPE_H_
1 change: 1 addition & 0 deletions python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# Operators
from .base import *
from .binary import *
from .datatype import *
from .index import *
from .manipulate import *
from .op_attrs import *
Expand Down
42 changes: 42 additions & 0 deletions python/tvm/relax/op/datatype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Datatype operators."""
from typing import Union

from tvm import DataType

from . import _ffi_api
from ..expr import Expr


def astype(x: Expr, dtype: Union[str, DataType]) -> Expr:
"""Cast input tensor to the given data type.
Parameters
----------
x : relax.Expr
The input data to the operator.
dtype: Union[str, DataType]
The target data type
Returns
-------
result : relax.Expr
The casted result.
"""
return _ffi_api.astype(x, dtype) # type: ignore
5 changes: 5 additions & 0 deletions python/tvm/relax/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
import tvm._ffi


@tvm._ffi.register_object("relax.attrs.AstypeAttrs")
class AstypeAttrs(Attrs):
"""Attributes used in astype operator"""


@tvm._ffi.register_object("relax.attrs.TakeAttrs")
class TakeAttrs(Attrs):
"""Attributes used in take operator"""
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from tvm.relax.op import (
add,
assert_op,
astype,
builtin,
call_builtin_with_ctx,
call_tir,
Expand Down Expand Up @@ -403,6 +404,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"add",
"arg",
"assert_op",
"astype",
"builtin",
"call_packed",
"call_tir",
Expand Down
60 changes: 60 additions & 0 deletions src/relax/op/tensor/datatype.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file datatype.cc
* \brief Datatype operators.
*/

#include "datatype.h"

#include <utility>

namespace tvm {
namespace relax {

/* relax.astype */
TVM_REGISTER_NODE_TYPE(AstypeAttrs);

Expr astype(Expr x, DataType dtype) {
ObjectPtr<AstypeAttrs> attrs = make_object<AstypeAttrs>();
attrs->dtype = dtype;

static const Op& op = Op::Get("relax.astype");
return Call(op, {std::move(x)}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relax.op.astype").set_body_typed(astype);

StructInfo InferStructInfoAstype(const Call& call, const BlockBuilder& ctx) {
TensorStructInfo sinfo = GetUnaryInputTensorStructInfo(call, ctx);
const auto* attrs = call->attrs.as<AstypeAttrs>();
ObjectPtr<TensorStructInfoNode> new_sinfo = make_object<TensorStructInfoNode>(*sinfo.get());
new_sinfo->dtype = attrs->dtype;
return TensorStructInfo(new_sinfo);
}

TVM_REGISTER_OP("relax.astype")
.set_attrs_type<AstypeAttrs>()
.set_num_inputs(1)
.add_argument("x", "Tensor", "The input tensor")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoAstype);

} // namespace relax
} // namespace tvm
45 changes: 45 additions & 0 deletions src/relax/op/tensor/datatype.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file datatype.h
* \brief The functions to make Relax datatype operator calls.
*/
#ifndef TVM_RELAX_OP_TENSOR_DATATYPE_H_
#define TVM_RELAX_OP_TENSOR_DATATYPE_H_

#include <tvm/relax/attrs/datatype.h>

#include "../op_common.h"

namespace tvm {
namespace relax {

/*!
* \brief Cast input tensor to the given data type.
* \param x The input data to the operator.
* \param dtype The target data type
* \return The casted result.
*/
Expr astype(Expr x, DataType dtype);

} // namespace relax
} // namespace tvm

#endif // TVM_RELAX_OP_TENSOR_DATATYPE_H_
105 changes: 105 additions & 0 deletions tests/python/relax/test_op_datatype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import tvm
import tvm.testing
from tvm import relax, tir
from tvm import TVMError
from tvm.ir import Op
from tvm.script import relax as R


def test_op_correctness():
x = relax.Var("x", R.Tensor((2, 3), "float32"))
assert relax.op.astype(x, "float16").op == Op.get("relax.astype")


def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo):
ret = bb.normalize(call)
tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo)


def test_astype_infer_struct_info():
bb = relax.BlockBuilder()
x0 = relax.Var("x", R.Tensor((2, 3), "float32"))
x1 = relax.Var("x", R.Tensor("float32", ndim=2))
x2 = relax.Var("x", R.Tensor("float32"))
x3 = relax.Var("x", R.Tensor((2, 3)))
x4 = relax.Var("x", R.Tensor(ndim=2))
x5 = relax.Var("x", R.Tensor())

_check_inference(bb, relax.op.astype(x0, "float16"), relax.TensorStructInfo((2, 3), "float16"))
_check_inference(
bb, relax.op.astype(x1, "float16"), relax.TensorStructInfo(dtype="float16", ndim=2)
)
_check_inference(bb, relax.op.astype(x2, "float16"), relax.TensorStructInfo(dtype="float16"))
_check_inference(bb, relax.op.astype(x3, "float16"), relax.TensorStructInfo((2, 3), "float16"))
_check_inference(
bb, relax.op.astype(x4, "float16"), relax.TensorStructInfo(dtype="float16", ndim=2)
)
_check_inference(bb, relax.op.astype(x5, "float16"), relax.TensorStructInfo(dtype="float16"))


def test_astype_infer_struct_info_shape_symbolic():
bb = relax.BlockBuilder()
m = tir.Var("m", "int64")
n = tir.Var("n", "int64")
x0 = relax.Var("x", R.Tensor((m, n), "float32"))
x1 = relax.Var("x", R.Tensor((m, n)))

_check_inference(bb, relax.op.astype(x0, "float16"), relax.TensorStructInfo((m, n), "float16"))
_check_inference(bb, relax.op.astype(x1, "float16"), relax.TensorStructInfo((m, n), "float16"))


def test_astype_infer_struct_info_shape_var():
bb = relax.BlockBuilder()
s0 = relax.Var("s", relax.ShapeStructInfo((2, 3)))
s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2))
s2 = relax.Var("s", relax.ShapeStructInfo())
x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32"))

_check_inference(bb, relax.op.astype(x0, "float16"), relax.TensorStructInfo(s0, "float16"))
_check_inference(bb, relax.op.astype(x1, "float16"), relax.TensorStructInfo(s1, "float16"))
_check_inference(bb, relax.op.astype(x2, "float16"), relax.TensorStructInfo(s2, "float16"))


def test_astype_infer_struct_info_more_input_dtype():
bb = relax.BlockBuilder()
x0 = relax.Var("x", R.Tensor((2, 3), "float16"))
x1 = relax.Var("x", R.Tensor((2, 3), "int8"))
x2 = relax.Var("x", R.Tensor((2, 3), "int32"))

_check_inference(bb, relax.op.astype(x0, "float32"), relax.TensorStructInfo((2, 3), "float32"))
_check_inference(bb, relax.op.astype(x1, "int32"), relax.TensorStructInfo((2, 3), "int32"))
_check_inference(bb, relax.op.astype(x2, "int8"), relax.TensorStructInfo((2, 3), "int8"))


def test_astype_infer_struct_info_wrong_input_type():
bb = relax.BlockBuilder()
x0 = relax.Var("x", relax.ShapeStructInfo((2, 3)))
x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32")))

with pytest.raises(TVMError):
bb.normalize(relax.op.astype(x0, "float16"))
with pytest.raises(TVMError):
bb.normalize(relax.op.astype(x1, "float16"))


if __name__ == "__main__":
tvm.testing.main()
54 changes: 54 additions & 0 deletions tests/python/relax/test_tvmscript_parser_op_datatype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from typing import Optional, Union

import tvm
import tvm.script
import tvm.testing
from tvm import IRModule, relax
from tvm.script import relax as R


def _check(
parsed: Union[relax.Function, IRModule],
expect: Optional[Union[relax.Function, IRModule]],
):
test = parsed.script(show_meta=True)
roundtrip_mod = tvm.script.from_source(test)
tvm.ir.assert_structural_equal(parsed, roundtrip_mod)
if expect:
tvm.ir.assert_structural_equal(parsed, expect)


def test_astype():
@R.function
def expected(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 3, 4), "float16"):
gv: R.Tensor((2, 3, 4), "float16") = R.astype(x, "float16")
return gv

x = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
bb = relax.BlockBuilder()
with bb.function("main", [x]):
gv = bb.emit(relax.op.astype(x, "float16"))
bb.emit_func_output(gv)

_check(expected, bb.get()["main"])


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit c8869b7

Please sign in to comment.