From 7425128c45f1ccf1f8d556d80bbd9f9966c09ed7 Mon Sep 17 00:00:00 2001 From: Hongyi Jin <3231950289@qq.com> Date: Thu, 28 Jul 2022 23:07:23 +0800 Subject: [PATCH] Add High-level Op Support (#5) * high-level-op support * format * format * follow relay convention * format * fix --- include/tvm/relax/op_attr_types.h | 28 ++++++ python/tvm/relax/op/__init__.py | 3 +- python/tvm/relax/op/nn/__init__.py | 19 ++++ python/tvm/relax/op/nn/_make.py | 20 ++++ python/tvm/relax/op/nn/nn.py | 48 ++++++++++ python/tvm/relax/op/tensor.py | 5 +- src/relax/op/nn/convolution.cc | 50 ++++++++++ src/relax/op/nn/convolution.h | 71 ++++++++++++++ src/relax/op/nn/nn.cc | 56 +++++++++++ src/relax/op/nn/nn.h | 123 +++++++++++++++++++++++++ src/relax/op/nn/pooling.cc | 47 ++++++++++ src/relax/op/nn/pooling.h | 62 +++++++++++++ src/relax/op/op_common.h | 10 ++ src/relax/op/tensor/binary.cc | 76 +++++++++++++++ src/relax/op/tensor/binary.h | 76 +-------------- src/relax/op/tensor/unary.cc | 60 +++++++++++- src/relax/op/tensor/unary.h | 34 +------ tests/python/mlc/test_high_level_op.py | 49 ++++++++++ 18 files changed, 730 insertions(+), 107 deletions(-) create mode 100644 python/tvm/relax/op/nn/__init__.py create mode 100644 python/tvm/relax/op/nn/_make.py create mode 100644 python/tvm/relax/op/nn/nn.py create mode 100644 src/relax/op/nn/convolution.cc create mode 100644 src/relax/op/nn/convolution.h create mode 100644 src/relax/op/nn/nn.cc create mode 100644 src/relax/op/nn/nn.h create mode 100644 src/relax/op/nn/pooling.cc create mode 100644 src/relax/op/nn/pooling.h create mode 100644 tests/python/mlc/test_high_level_op.py diff --git a/include/tvm/relax/op_attr_types.h b/include/tvm/relax/op_attr_types.h index bd05e1814df2..30a892daf5d0 100644 --- a/include/tvm/relax/op_attr_types.h +++ b/include/tvm/relax/op_attr_types.h @@ -109,6 +109,34 @@ struct AssertOpAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes used in MaxPool2d operator */ +struct MaxPool2dAttrs : public tvm::AttrsNode { + Array kernel_size; + Array stride; + Array padding; + Array dilation; + TVM_DECLARE_ATTRS(MaxPool2dAttrs, "relax.attrs.MaxPool2dAttrs") { + TVM_ATTR_FIELD(kernel_size).describe("The size of the window to take a max over."); + TVM_ATTR_FIELD(stride).describe("The stride of the window."); + TVM_ATTR_FIELD(padding).describe("The padding on the input."); + TVM_ATTR_FIELD(dilation).describe("The stride of elements in the window."); + } +}; // struct MaxPool2dAttrs + +/*! \brief Attributes used in Conv2d operator */ +struct Conv2dAttrs : public tvm::AttrsNode { + Array kernel_size; + Array stride; + Array padding; + Array dilation; + TVM_DECLARE_ATTRS(Conv2dAttrs, "relax.attrs.Conv2dAttrs") { + TVM_ATTR_FIELD(kernel_size).describe("The size of the convolving kernel."); + TVM_ATTR_FIELD(stride).describe("The stride of the convolution."); + TVM_ATTR_FIELD(padding).describe("The padding on the input."); + TVM_ATTR_FIELD(dilation).describe("The spacing between kernel elements."); + } +}; // struct Conv2dAttrs + } // namespace relax } // namespace tvm #endif // TVM_RELAX_OP_ATTR_TYPES_H_ diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 59d35d4ab876..45734b1461fa 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -19,7 +19,8 @@ # Operators from .base import * -from .tensor import * +from .nn import * from .op_attrs import * +from .tensor import * from . import builtin from . import memory diff --git a/python/tvm/relax/op/nn/__init__.py b/python/tvm/relax/op/nn/__init__.py new file mode 100644 index 000000000000..af2aa106bca7 --- /dev/null +++ b/python/tvm/relax/op/nn/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import +"""Neural network related operators.""" +from .nn import * diff --git a/python/tvm/relax/op/nn/_make.py b/python/tvm/relax/op/nn/_make.py new file mode 100644 index 000000000000..1785345ac1b1 --- /dev/null +++ b/python/tvm/relax/op/nn/_make.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Constructor APIs""" +import tvm._ffi + +tvm._ffi._init_api("relax.op.nn", __name__) diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py new file mode 100644 index 000000000000..4fec1ac26cb1 --- /dev/null +++ b/python/tvm/relax/op/nn/nn.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from . import _make +from ...expr import Expr + + +def dense(lhs: Expr, rhs: Expr) -> Expr: + return _make.dense(lhs, rhs) + + +def conv2d( + lhs: Expr, rhs: Expr, kernel_size, stride=(1, 1), padding=[0, 0], dilation=[1, 1] +) -> Expr: + return _make.conv2d(lhs, rhs, kernel_size, stride, padding, dilation) + + +def relu(data: Expr) -> Expr: + return _make.relu(data) + + +def softmax(data: Expr) -> Expr: + return _make.softmax(data) + + +def flatten(data: Expr) -> Expr: + return _make.flatten(data) + + +def max_pool2d(data: Expr, kernel_size, stride=None, padding=(0, 0), dilation=(1, 1)) -> Expr: + if stride is None: + stride = kernel_size + return _make.max_pool2d(data, kernel_size, stride, padding, dilation) + diff --git a/python/tvm/relax/op/tensor.py b/python/tvm/relax/op/tensor.py index 9ebc4cc7d5d4..6ac6772daf48 100644 --- a/python/tvm/relax/op/tensor.py +++ b/python/tvm/relax/op/tensor.py @@ -13,6 +13,7 @@ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations +# under the License. # pylint: disable=redefined-builtin, invalid-name """Basic tensor operations.""" import numpy as np # type: ignore @@ -83,13 +84,13 @@ def numpy_unique( Uses numpy.unique to compute unique elements. """ - # TODO(prakalp): add support for returning a tuple when return_inverse or return_counts is True + # TODO(prakalp) : add support for returning a tuple when return_inverse or return_counts is True if bool(return_inverse) or bool(return_counts): raise NotImplementedError("missing support return_inverse or return_counts set to true") if dim < 0: dim = None a_numpy = a.numpy() - # TODO(prakalp): use torch.unique instead of numpy when torch is installed in ci. + # TODO(prakalp) : use torch.unique instead of numpy when torch is installed in ci. output_sorted_numpy, indices = np.unique(a_numpy, return_index=True) if sort: return tvm.nd.array(output_sorted_numpy) diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc new file mode 100644 index 000000000000..b3bbd38a8497 --- /dev/null +++ b/src/relax/op/nn/convolution.cc @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "convolution.h" + +#include "../tensor/binary.h" +namespace tvm { +namespace relax { + +TVM_REGISTER_NODE_TYPE(Conv2dAttrs); + +RELAY_REGISTER_OP("relax.nn.conv2d") + .set_num_inputs(2) + .add_argument("e1", "Expr", "The input expression") + .add_argument("e2", "Expr", "The input expression") + .set_attrs_type() + .set_attr("FInferShape", InferShapeConv2d) + .set_attr("FInferType", InferTypeBinaryBroadcast); + +Expr MakeConv2d(Expr expr1, Expr expr2, Array kernel_size, Array stride, + Array padding, Array dilation) { + static const Op& op = Op::Get("relax.nn.conv2d"); + auto attrs = make_object(); + attrs->kernel_size = kernel_size; + attrs->stride = stride; + attrs->padding = padding; + attrs->dilation = dilation; + return Call(op, {expr1, expr2}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.conv2d").set_body_typed(MakeConv2d); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/nn/convolution.h b/src/relax/op/nn/convolution.h new file mode 100644 index 000000000000..8d340e919598 --- /dev/null +++ b/src/relax/op/nn/convolution.h @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RELAX_OP_NN_CONVOLUTION_H_ +#define TVM_RELAX_OP_NN_CONVOLUTION_H_ + +#include +#include + +#include "../op_common.h" +namespace tvm { +namespace relax { + +Optional InferShapeConv2d(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 2) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "Conv2d op should have 2 arguments"); + } + Expr shape0 = call->args[0]->shape(); + Expr shape1 = call->args[1]->shape(); + auto* s0 = shape0.as(); + auto* s1 = shape1.as(); + auto* attrs = call->attrs.as(); + if (s0 && s1) { + std::vector output_shape; + size_t ndim0 = s0->values.size(); + size_t ndim1 = s1->values.size(); + if (ndim0 != 4 || ndim1 != 4) { + LOG(INFO) << ndim0; + LOG(INFO) << ndim1; + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "The 2 arguments of Conv2d must be 4D Tensors"); + } + // N + output_shape.push_back(s0->values[0]); + // C + output_shape.push_back(s1->values[0]); + // H + output_shape.push_back((s0->values[2] + 2 * attrs->padding[0] - + attrs->dilation[0] * (attrs->kernel_size[0] - 1) - 1) / + attrs->stride[0] + + 1); + // W + output_shape.push_back((s0->values[3] + 2 * attrs->padding[1] - + attrs->dilation[1] * (attrs->kernel_size[1] - 1) - 1) / + attrs->stride[1] + + 1); + return ShapeExpr(Array{output_shape.begin(), output_shape.end()}); + } else { + return NullOpt; + } +} + +} // namespace relax +} // namespace tvm +#endif \ No newline at end of file diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc new file mode 100644 index 000000000000..d25ac86b6475 --- /dev/null +++ b/src/relax/op/nn/nn.cc @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "nn.h" + +namespace tvm { +namespace relax { + +RELAY_REGISTER_OP("relax.nn.dense") + .set_num_inputs(2) + .add_argument("e1", "Expr", "The input expression") + .add_argument("e2", "Expr", "The input expression") + .set_attr("FInferShape", InferShapeDense) + .set_attr("FInferType", InferTypeDense); + +Expr MakeDense(Expr expr1, Expr expr2) { + static const Op& op = Op::Get("relax.nn.dense"); + return Call(op, {expr1, expr2}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.dense").set_body_typed(MakeDense); + +RELAX_REGISTER_UNARY_OP("nn.softmax"); + +RELAX_REGISTER_UNARY_OP("nn.relu"); + +RELAY_REGISTER_OP("relax.nn.flatten") + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor") + .set_attr("FInferShape", InferShapeFlatten) + .set_attr("FInferType", InferTypeFlatten); + +Expr MakeFlatten(Expr data) { + static const Op& op = Op::Get("relax.nn.flatten"); + return Call(op, {data}, {}, {}); +} +TVM_REGISTER_GLOBAL("relax.op.nn.flatten").set_body_typed(MakeFlatten); + +} // namespace relax +} // namespace tvm \ No newline at end of file diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h new file mode 100644 index 000000000000..8ec1932a55c2 --- /dev/null +++ b/src/relax/op/nn/nn.h @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RELAX_OP_NN_NN_H_ +#define TVM_RELAX_OP_NN_NN_H_ + +#include +#include + +#include "../op_common.h" +#include "../tensor/unary.h" +namespace tvm { +namespace relax { + +Optional InferShapeFlatten(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 1) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "Flatten op should have 1 argument"); + } + Expr shape = call->args[0]->shape(); + auto* s = shape.as(); + if (s) { + PrimExpr output_dim = 1; + for (int i = 1; i < static_cast(s->values.size()); i++) { + output_dim *= s->values[i]; + } + return ShapeExpr({s->values[0], output_dim}); + } else { + return NullOpt; + } +} + +Type InferTypeFlatten(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 1) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "Flatten op should have 1 argument"); + } + auto* input_ty = call->args[0]->checked_type().as(); + if (!input_ty) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "Input should be DynTensor, but got " + << call->args[0]->checked_type()->GetTypeKey()); + } + return DynTensorType(/*ndim=*/2, input_ty->dtype); +} + +Optional InferShapeDense(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 2) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "Dense op should have 2 arguments"); + } + Expr shape0 = call->args[0]->shape(); + Expr shape1 = call->args[1]->shape(); + auto* s0 = shape0.as(); + auto* s1 = shape1.as(); + if (s0 && s1) { + std::vector output_shape; + size_t ndim0 = s0->values.size(); + size_t ndim1 = s1->values.size(); + if (ndim0 != 2 || ndim1 != 2) { + LOG(INFO) << ndim0; + LOG(INFO) << ndim1; + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "The 2 arguments of Dense must be 2D Tensors"); + } + if (!EqualCheck(s0->values[1], s1->values[1])) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "The 2 arguments of Dense must have the same number of columns"); + } + return ShapeExpr(Array{s0->values[0], s1->values[0]}); + } else { + return NullOpt; + } +} + +Type InferTypeDense(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 2) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "Dense op should have 2 arguments"); + } + Type type0 = call->args[0]->checked_type(); + Type type1 = call->args[1]->checked_type(); + auto* t0 = type0.as(); + auto* t1 = type1.as(); + if (!t0 || !t1) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "The 2 arguments of Dense should be DynTensor"); + } + + DataType output_dtype; + if (t0->IsUnknownDtype() || t1->IsUnknownDtype()) { + output_dtype = DataType::Void(); + } else if (t0->dtype != t1->dtype) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "Data types " << t0->dtype << ", and" + << t1->dtype << " must be equal for Dense"); + } else { + output_dtype = t0->dtype; + } + + int output_ndim; + if (t0->IsUnknownNdim() || t1->IsUnknownNdim()) { + output_ndim = -1; + } else { + output_ndim = t0->ndim; + } + return DynTensorType(output_ndim, output_dtype); +} + +} // namespace relax +} // namespace tvm +#endif \ No newline at end of file diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc new file mode 100644 index 000000000000..54340ebf6804 --- /dev/null +++ b/src/relax/op/nn/pooling.cc @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "pooling.h" + +#include "../tensor/unary.h" +namespace tvm { +namespace relax { + +RELAY_REGISTER_OP("relax.nn.max_pool2d") + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor") + .set_attrs_type() + .set_attr("FInferShape", InferShapeMaxPool2d) + .set_attr("FInferType", InferTypeUnaryBroadcast); + +Expr MakeMaxPool2d(Expr data, Array kernel_size, Array stride, + Array padding, Array dilation) { + auto attrs = make_object(); + attrs->kernel_size = kernel_size; + attrs->stride = stride; + attrs->padding = padding; + attrs->dilation = dilation; + static const Op& op = Op::Get("relax.nn.max_pool2d"); + return Call(op, {data}, Attrs(attrs)); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.max_pool2d").set_body_typed(MakeMaxPool2d); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/nn/pooling.h b/src/relax/op/nn/pooling.h new file mode 100644 index 000000000000..d6638c978c3f --- /dev/null +++ b/src/relax/op/nn/pooling.h @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RELAX_OP_NN_POOLING_H_ +#define TVM_RELAX_OP_NN_POOLING_H_ + +#include +#include + +#include "../op_common.h" +namespace tvm { +namespace relax { + +Optional InferShapeMaxPool2d(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 1) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "MaxPool2d op should have 1 argument"); + } + auto attrs = call->attrs.as(); + Expr shape = call->args[0]->shape(); + auto* s = shape.as(); + if (s) { + Array output_shape; + for (int i = 0; i < static_cast(s->values.size()); i++) { + if (i == static_cast(s->values.size()) - 2) { + output_shape.push_back((s->values[i] + 2 * attrs->padding[0] - + attrs->dilation[0] * (attrs->kernel_size[0] - 1) - 1) / + attrs->stride[0] + + 1); + } else if (i == static_cast(s->values.size()) - 1) { + output_shape.push_back((s->values[i] + 2 * attrs->padding[1] - + attrs->dilation[1] * (attrs->kernel_size[1] - 1) - 1) / + attrs->stride[1] + + 1); + } else { + output_shape.push_back(s->values[i]); + } + } + return ShapeExpr(Array{output_shape.begin(), output_shape.end()}); + } else { + return NullOpt; + } +} + +} // namespace relax +} // namespace tvm +#endif \ No newline at end of file diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index 7f97716ec024..bc8731ddee52 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -59,6 +59,16 @@ bool EqualCheck(const PrimExpr& lhs, const PrimExpr& rhs); .set_attr("FInferShape", InferShapeBinaryBroadcast) \ .set_attr("FInferType", InferTypeBinaryBroadcast) +#define RELAX_REGISTER_UNARY_OP(OpName) \ + TVM_REGISTER_GLOBAL("relax.op." OpName).set_body_typed([](Expr e) { \ + static const Op& op = Op::Get("relax." OpName); \ + return Call(op, {e}, Attrs(), {}); \ + }); \ + RELAY_REGISTER_OP("relax." OpName) \ + .set_num_inputs(1) \ + .add_argument("e", "Tensor", "The input tensor.") \ + .set_attr("FInferShape", InferShapeUnaryBroadcast) \ + .set_attr("FInferType", InferTypeUnaryBroadcast) } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc index 7d601678789a..e758b6b8a198 100644 --- a/src/relax/op/tensor/binary.cc +++ b/src/relax/op/tensor/binary.cc @@ -27,6 +27,82 @@ namespace tvm { namespace relax { +Optional InferShapeBinaryBroadcast(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 2) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "Binary broadcast op should have 2 arguments"); + } + Expr lhs_shape = call->args[0]->shape(); + Expr rhs_shape = call->args[1]->shape(); + auto* s0 = lhs_shape.as(); + auto* s1 = rhs_shape.as(); + if (s0 && s1) { + std::vector output_shape; + size_t ndim0 = s0->values.size(); + size_t ndim1 = s1->values.size(); + size_t i = 1; + for (; i <= std::min(ndim0, ndim1); ++i) { + PrimExpr dim0 = s0->values[ndim0 - i]; + PrimExpr dim1 = s1->values[ndim1 - i]; + if (EqualConstInt(dim0, 1)) { + output_shape.push_back(dim1); + } else if (EqualConstInt(dim1, 1)) { + output_shape.push_back(dim0); + } else if (EqualCheck(dim0, dim1)) { + output_shape.push_back(dim0); + } else { + // defer the computation of output shapes to runtime + // e.g., broadcast Tensor([m, n]), Tensor([k]) -> defer to runtime + return Call(ExternFunc(String("vm.binary_broadcast_shape_infer")), + {call->args[0], call->args[1]}, {}, {}); + } + } + size_t max_ndim = std::max(ndim0, ndim1); + auto& longer_shape = (ndim0 > ndim1) ? s0 : s1; + for (; i <= max_ndim; ++i) { + output_shape.push_back(longer_shape->values[max_ndim - i]); + } + return ShapeExpr(Array(output_shape.rbegin(), output_shape.rend())); + } else { + return NullOpt; + } +} + +Type InferTypeBinaryBroadcast(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 2) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "Binary broadcast op should have 2 arguments"); + } + Type lhs_type = call->args[0]->checked_type(); + Type rhs_type = call->args[1]->checked_type(); + auto* t0 = lhs_type.as(); + auto* t1 = rhs_type.as(); + if (!t0 || !t1) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "Both lhs and rhs should be DynTensor for broadcasting, but got " + << lhs_type->GetTypeKey() << " and " << rhs_type->GetTypeKey()); + } + + DataType output_dtype; + if (t0->IsUnknownDtype() || t1->IsUnknownDtype()) { + output_dtype = DataType::Void(); + } else if (t0->dtype != t1->dtype) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "Data types " << t0->dtype << " and " << t1->dtype + << " must be equal for broadcasting operators"); + } else { + output_dtype = t0->dtype; + } + + int output_ndim; + if (t0->IsUnknownNdim() || t1->IsUnknownNdim()) { + output_ndim = -1; + } else { + output_ndim = std::max(t0->ndim, t1->ndim); + } + return DynTensorType(output_ndim, output_dtype); +} + RELAX_REGISTER_BINARY_BROADCAST_OP("add") .describe("Elementwise add with broadcasting") .set_support_level(1); diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h index 241bd856e853..99b3aee1ff19 100644 --- a/src/relax/op/tensor/binary.h +++ b/src/relax/op/tensor/binary.h @@ -36,81 +36,9 @@ namespace tvm { namespace relax { -Optional InferShapeBinaryBroadcast(const Call& call, DiagnosticContext diag_ctx) { - if (call->args.size() != 2) { - diag_ctx.EmitFatal(Diagnostic::Error(call->span) - << "Binary broadcast op should have 2 arguments"); - } - Expr lhs_shape = call->args[0]->shape(); - Expr rhs_shape = call->args[1]->shape(); - auto* s0 = lhs_shape.as(); - auto* s1 = rhs_shape.as(); - if (s0 && s1) { - std::vector output_shape; - size_t ndim0 = s0->values.size(); - size_t ndim1 = s1->values.size(); - size_t i = 1; - for (; i <= std::min(ndim0, ndim1); ++i) { - PrimExpr dim0 = s0->values[ndim0 - i]; - PrimExpr dim1 = s1->values[ndim1 - i]; - if (EqualConstInt(dim0, 1)) { - output_shape.push_back(dim1); - } else if (EqualConstInt(dim1, 1)) { - output_shape.push_back(dim0); - } else if (EqualCheck(dim0, dim1)) { - output_shape.push_back(dim0); - } else { - // defer the computation of output shapes to runtime - // e.g., broadcast Tensor([m, n]), Tensor([k]) -> defer to runtime - return Call(ExternFunc(String("vm.binary_broadcast_shape_infer")), - {call->args[0], call->args[1]}, {}, {}); - } - } - size_t max_ndim = std::max(ndim0, ndim1); - auto& longer_shape = (ndim0 > ndim1) ? s0 : s1; - for (; i <= max_ndim; ++i) { - output_shape.push_back(longer_shape->values[max_ndim - i]); - } - return ShapeExpr(Array(output_shape.rbegin(), output_shape.rend())); - } else { - return NullOpt; - } -} +Optional InferShapeBinaryBroadcast(const Call& call, DiagnosticContext diag_ctx); -Type InferTypeBinaryBroadcast(const Call& call, DiagnosticContext diag_ctx) { - if (call->args.size() != 2) { - diag_ctx.EmitFatal(Diagnostic::Error(call->span) - << "Binary broadcast op should have 2 arguments"); - } - Type lhs_type = call->args[0]->checked_type(); - Type rhs_type = call->args[1]->checked_type(); - auto* t0 = lhs_type.as(); - auto* t1 = rhs_type.as(); - if (!t0 || !t1) { - diag_ctx.EmitFatal(Diagnostic::Error(call->span) - << "Both lhs and rhs should be DynTensor for broadcasting, but got " - << lhs_type->GetTypeKey() << " and " << rhs_type->GetTypeKey()); - } - - DataType output_dtype; - if (t0->IsUnknownDtype() || t1->IsUnknownDtype()) { - output_dtype = DataType::Void(); - } else if (t0->dtype != t1->dtype) { - diag_ctx.EmitFatal(Diagnostic::Error(call->span) - << "Data types " << t0->dtype << " and " << t1->dtype - << " must be equal for broadcasting operators"); - } else { - output_dtype = t0->dtype; - } - - int output_ndim; - if (t0->IsUnknownNdim() || t1->IsUnknownNdim()) { - output_ndim = -1; - } else { - output_ndim = std::max(t0->ndim, t1->ndim); - } - return DynTensorType(output_ndim, output_dtype); -} +Type InferTypeBinaryBroadcast(const Call& call, DiagnosticContext diag_ctx); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/unary.cc b/src/relax/op/tensor/unary.cc index 228de3ae8c75..3bedddde05a7 100644 --- a/src/relax/op/tensor/unary.cc +++ b/src/relax/op/tensor/unary.cc @@ -27,8 +27,66 @@ namespace tvm { namespace relax { -TVM_REGISTER_NODE_TYPE(UniqueAttrs); +Optional InferShapeUnique(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 1) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "Unique op should have 1 argument"); + } + auto unique_attrs = call->attrs.as(); + // Only default values of these attributes are supported right now. + if (unique_attrs->return_counts || unique_attrs->return_inverse || unique_attrs->dim != -1) + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "support for return_inverse, return_counts, and dim is not implemented"); + return relax::RuntimeDepShape(call->span); +} + +Type InferTypeUnique(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 1) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "Unique op should have 1 argument"); + } + auto* input_ty = call->args[0]->checked_type().as(); + if (!input_ty) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "Input should be DynTensor, but got " + << call->args[0]->checked_type()->GetTypeKey()); + } + + // TODO(prakalp): Add support for return_inverse, return_counts and dim attributes. Only defaults + // are supported right now. + auto unique_attrs = call->attrs.as(); + if (unique_attrs->return_counts || unique_attrs->return_inverse || unique_attrs->dim != -1) + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "support for return_inverse, return_counts, and dim is not implemented"); + return DynTensorType(/*ndim=*/1, input_ty->dtype); +} +Optional InferShapeUnaryBroadcast(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 1) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "Unary op should have 1 argument"); + } + Expr shape = call->args[0]->shape(); + auto* s = shape.as(); + if (s) { + return ShapeExpr(s->values); + } else { + return NullOpt; + } +} + +Type InferTypeUnaryBroadcast(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 1) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "Unary op should have 1 argument"); + } + auto* input_ty = call->args[0]->checked_type().as(); + if (!input_ty) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "Input should be DynTensor, but got " + << call->args[0]->checked_type()->GetTypeKey()); + } + return GetRef(input_ty); +} + +TVM_REGISTER_NODE_TYPE(UniqueAttrs); +TVM_REGISTER_NODE_TYPE(MaxPool2dAttrs); RELAY_REGISTER_OP("relax.unique") .describe( "This operation returns the unique elements and the new index of each item in a given " diff --git a/src/relax/op/tensor/unary.h b/src/relax/op/tensor/unary.h index d033e838e1ab..971af57d65bc 100644 --- a/src/relax/op/tensor/unary.h +++ b/src/relax/op/tensor/unary.h @@ -35,37 +35,13 @@ namespace tvm { namespace relax { -Optional InferShapeUnique(const Call& call, DiagnosticContext diag_ctx) { - if (call->args.size() != 1) { - diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "Unique op should have 1 argument"); - } - auto unique_attrs = call->attrs.as(); - // Only default values of these attributes are supported right now. - if (unique_attrs->return_counts || unique_attrs->return_inverse || unique_attrs->dim != -1) - diag_ctx.EmitFatal(Diagnostic::Error(call->span) - << "support for return_inverse, return_counts, and dim is not implemented"); - return relax::RuntimeDepShape(call->span); -} +Optional InferShapeUnique(const Call& call, DiagnosticContext diag_ctx); -Type InferTypeUnique(const Call& call, DiagnosticContext diag_ctx) { - if (call->args.size() != 1) { - diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "Unique op should have 1 argument"); - } - auto* input_ty = call->args[0]->checked_type().as(); - if (!input_ty) { - diag_ctx.EmitFatal(Diagnostic::Error(call->span) - << "Input should be DynTensor, but got " - << call->args[0]->checked_type()->GetTypeKey()); - } +Type InferTypeUnique(const Call& call, DiagnosticContext diag_ctx); - // TODO(prakalp): Add support for return_inverse, return_counts and dim attributes. Only defaults - // are supported right now. - auto unique_attrs = call->attrs.as(); - if (unique_attrs->return_counts || unique_attrs->return_inverse || unique_attrs->dim != -1) - diag_ctx.EmitFatal(Diagnostic::Error(call->span) - << "support for return_inverse, return_counts, and dim is not implemented"); - return DynTensorType(/*ndim=*/1, input_ty->dtype); -} +Optional InferShapeUnaryBroadcast(const Call& call, DiagnosticContext diag_ctx); + +Type InferTypeUnaryBroadcast(const Call& call, DiagnosticContext diag_ctx); } // namespace relax } // namespace tvm diff --git a/tests/python/mlc/test_high_level_op.py b/tests/python/mlc/test_high_level_op.py new file mode 100644 index 000000000000..a5dc72f34242 --- /dev/null +++ b/tests/python/mlc/test_high_level_op.py @@ -0,0 +1,49 @@ +import numpy as np +import pickle as pkl +import torch +import torch.nn.functional as F +import torchvision +import tvm +import tvm.testing + +from matplotlib import pyplot as plt +from torch import nn +from torchvision import transforms +from tvm import topi, relax, te +from tvm.script import tir as T + + +batch_size = 4 +input_shape = (batch_size, 1, 28, 28) +weight_map = pkl.load(open("fasionmnist_mlp_assignment_params.pkl", "rb")) + + +def create_model(): + bb = relax.BlockBuilder() + conv2d_weight = relax.const(weight_map["conv2d_weight"], "float32") + conv2d_bias = relax.const(weight_map["conv2d_bias"].reshape(1, 32, 1, 1), "float32") + linear0_weight = relax.const(weight_map["linear0_weight"], "float32") + linear0_bias = relax.const(weight_map["linear0_bias"].reshape(1, 100), "float32") + linear1_weight = relax.const(weight_map["linear1_weight"], "float32") + linear1_bias = relax.const(weight_map["linear1_bias"].reshape(1, 10), "float32") + + x = relax.Var("x", input_shape, relax.DynTensorType(batch_size, "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit(relax.op.nn.conv2d(x, conv2d_weight, (3, 3))) + lv1 = bb.emit(relax.op.add(lv0, conv2d_bias)) + lv2 = bb.emit(relax.op.nn.relu(lv1)) + lv3 = bb.emit(relax.op.nn.max_pool2d(lv2, (2, 2))) + lv4 = bb.emit(relax.op.nn.flatten(lv3)) + lv5 = bb.emit(relax.op.nn.dense(lv4, linear0_weight)) + lv6 = bb.emit(relax.op.add(lv5, linear0_bias)) + lv7 = bb.emit(relax.op.nn.relu(lv6)) + lv8 = bb.emit(relax.op.nn.dense(lv7, linear1_weight)) + lv9 = bb.emit(relax.op.add(lv8, linear1_bias)) + lv10 = bb.emit(relax.op.nn.softmax(lv9)) + gv = bb.emit_output(lv4) + bb.emit_func_output(gv) + return bb.get() + + +print(create_model().script())