From bf56c82c889f144a4e9079db6a06db54ff02b8da Mon Sep 17 00:00:00 2001 From: Archermmt Date: Sat, 7 Sep 2024 08:49:06 +0800 Subject: [PATCH 1/2] reconstruct tensorrt --- .../contrib/msc/core/frontend/translate.py | 2 +- .../framework/tensorrt/frontend/translate.py | 5 +- .../framework/tensorrt/transform/pattern.py | 31 +- .../framework/tensorrt/transform/transform.py | 13 +- .../msc/core/transform/rewrite_utils.cc | 58 ++ .../msc/core/transform/rewrite_utils.h | 72 ++ .../msc/framework/tensorrt/tensorrt_opcode.cc | 6 +- .../framework/tensorrt/transform_tensorrt.cc | 668 +++++++++++------- .../test_msc/test_translate_tensorrt.py | 47 +- 9 files changed, 623 insertions(+), 279 deletions(-) create mode 100644 src/contrib/msc/core/transform/rewrite_utils.cc create mode 100644 src/contrib/msc/core/transform/rewrite_utils.h diff --git a/python/tvm/contrib/msc/core/frontend/translate.py b/python/tvm/contrib/msc/core/frontend/translate.py index 63b4424524eb..cea021ade331 100644 --- a/python/tvm/contrib/msc/core/frontend/translate.py +++ b/python/tvm/contrib/msc/core/frontend/translate.py @@ -330,7 +330,7 @@ def _is_target_func(func): msc_mod = _partition_mod(mod) func_names = [var.name_hint for var, func in msc_mod.functions.items() if _is_target_func(func)] - if not trans_config.get("allow_incomplete", False): + if trans_config.get("as_complete", True): assert len(func_names) == 1, "More than 1 target func is found: " + str(msc_mod) BYOCChecker().check(func_names, msc_mod[entry]) diff --git a/python/tvm/contrib/msc/framework/tensorrt/frontend/translate.py b/python/tvm/contrib/msc/framework/tensorrt/frontend/translate.py index 8758fdb63079..4a02b02728de 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/frontend/translate.py +++ b/python/tvm/contrib/msc/framework/tensorrt/frontend/translate.py @@ -49,7 +49,10 @@ def transform_for_tensorrt( return tvm.transform.Sequential( [ msc_transform.SetExprName(), - trt_transform.TransformTensorRT(trans_config.get("version")), + trt_transform.TransformTensorRT( + version=trans_config.get("version"), + linear_to_conv=trans_config.get("linear_to_conv", False), + ), relax.transform.FoldConstant(), ] )(mod) diff --git a/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py b/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py index 8eea3f7081a7..17aee690e370 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py +++ b/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py @@ -136,12 +136,22 @@ def _check_expr(expr: relax.Expr, dtypes: Tuple[str] = None) -> bool: return True if isinstance(expr, relax.Tuple): return all(_check_expr(field) for field in expr.fields) - if any(i < 0 for i in expr.struct_info.shape.values): - return False - dtypes = dtypes or ("float32", "float16") - if expr.struct_info.dtype not in dtypes: - return False - return True + dtypes = dtypes or ("float32", "float16", "int64", "int32", "bool") + + def _check(sinfo): + if not sinfo.shape or sinfo.dtype not in dtypes: + return False + unknown_dim = 0 + for s in sinfo.shape.values: + if isinstance(s, (tvm.tir.Var, tvm.tir.Any)): + unknown_dim += 1 + elif isinstance(s, tvm.tir.IntImm) and s < 0: + unknown_dim += 1 + return unknown_dim <= 1 + + if isinstance(expr.struct_info, relax.TupleStructInfo): + return all(_check(s) for s in expr.struct_info.fields) + return _check(expr.struct_info) def _basic_check(context: PatternCheckContext) -> bool: @@ -216,8 +226,7 @@ def _reshape_check(context: PatternCheckContext) -> bool: Whether the pattern is correct. """ - dtypes = ("float32", "float16", "int32") - if any(not _check_expr(context.annotated_expr[key], dtypes) for key in ["input_0", "out"]): + if any(not _check_expr(context.annotated_expr[key]) for key in ["input_0", "out"]): return False return True @@ -323,16 +332,18 @@ def get_patterns(target) -> List[Pattern]: "nn.avg_pool2d": ["input"], "nn.conv2d": ["input", "constant"], "nn.max_pool2d": ["input"], + "astype": ["input"], "concat": ["input"], "clip": ["input", "input", "input"], "image.resize2d": ["input", "input"], "matmul": ["input", "input"], "permute_dims": ["input"], - "strided_slice": ["input"], + "strided_slice": ["input", "input", "input", "input", "input"], + "topk": ["input"], } activation_ops = ["nn.relu", "nn.softmax", "sigmoid", "tanh"] reduce_ops = ["max", "min", "mean", "sum"] - unary_ops = ["cos", "exp", "negative", "round", "sin", "square", "sqrt", "tan"] + unary_ops = ["cos", "erf", "exp", "negative", "round", "sin", "square", "sqrt", "tan"] elemwise_ops = [ "add", "divide", diff --git a/python/tvm/contrib/msc/framework/tensorrt/transform/transform.py b/python/tvm/contrib/msc/framework/tensorrt/transform/transform.py index d6f15c43dacd..cf4d4b9f33ec 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/transform/transform.py +++ b/python/tvm/contrib/msc/framework/tensorrt/transform/transform.py @@ -25,18 +25,25 @@ from tvm.contrib.msc.core import utils as msc_utils -def TransformTensorRT(version: List[int] = None) -> tvm.ir.transform.Pass: +def TransformTensorRT( + version: List[int] = None, linear_to_conv: bool = False +) -> tvm.ir.transform.Pass: """Transform the Function to fit TensorRT. Parameters ---------- version: list The tensorrt version. + linear_to_conv: bool + Whether to cast linear to conv2d Returns ------- ret: tvm.ir.transform.Pass """ - version = version or msc_utils.get_version(MSCFramework.TENSORRT) - return relax_api.TransformTensorRT(version) # type: ignore + config = { + "version": version or msc_utils.get_version(MSCFramework.TENSORRT), + "linear_to_conv": linear_to_conv, + } + return relax_api.TransformTensorRT(msc_utils.dump_dict(config)) # type: ignore diff --git a/src/contrib/msc/core/transform/rewrite_utils.cc b/src/contrib/msc/core/transform/rewrite_utils.cc new file mode 100644 index 000000000000..20e4821e6fa7 --- /dev/null +++ b/src/contrib/msc/core/transform/rewrite_utils.cc @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/msc/core/transform/rewrite_utils.cc + */ +#include "rewrite_utils.h" + +#include +#include + +namespace tvm { +namespace contrib { +namespace msc { + +Var RewriteUtils::ReEmit(BlockBuilder builder, const String& name, const Expr& expr) { + expr->span = SpanUtils::SetAttr(expr->span, msc_attr::kName, name); + return builder->Emit(expr, name); +} + +Var RewriteUtils::MakeCall(BlockBuilder builder, const String& name, Expr op, Array args, + Attrs attrs) { + const auto& call = Call(op, args, attrs); + return ReEmit(builder, name, call); +} + +Expr RewriteUtils::MakeConstant(BlockBuilder builder, const String& name, double value, + const DataType& dtype, size_t ndim) { + const auto& data = support::FloatImmToNDArray(FloatImm(dtype, value)); + Span span = SpanUtils::CreateWithAttr(msc_attr::kName, name); + const auto& constant = Constant(data, NullOpt, span); + if (ndim == 0) { + return constant; + } + static const Op& reshape_op = Op::Get("relax.reshape"); + Array exp_shape(ndim, Integer(1)); + return MakeCall(builder, name + "_exp", reshape_op, {constant, ShapeExpr(exp_shape)}); +} + +} // namespace msc +} // namespace contrib +} // namespace tvm diff --git a/src/contrib/msc/core/transform/rewrite_utils.h b/src/contrib/msc/core/transform/rewrite_utils.h new file mode 100644 index 000000000000..2693a6ccd2eb --- /dev/null +++ b/src/contrib/msc/core/transform/rewrite_utils.h @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/msc/core/transform/rewrite_utils.h + * \brief Common utilities for rewrite. + */ +#ifndef TVM_CONTRIB_MSC_CORE_TRANSFORM_REWRITE_UTILS_H_ +#define TVM_CONTRIB_MSC_CORE_TRANSFORM_REWRITE_UTILS_H_ + +#include +#include + +#include + +#include "../../../../relax/transform/utils.h" +#include "../../../../support/scalars.h" +#include "../utils.h" + +namespace tvm { +namespace contrib { +namespace msc { + +using Expr = tvm::RelayExpr; +using namespace tvm::relax; + +/*! + * \brief Utils for Layout. + */ +class RewriteUtils { + public: + /*! + * \brief Emit call with span name. + * \return The emitted var. + */ + TVM_DLL static Var ReEmit(BlockBuilder builder, const String& name, const Expr& expr); + + /*! + * \brief Make and emit a call binding with span. + * \return The emitted var. + */ + TVM_DLL static Var MakeCall(BlockBuilder builder, const String& name, Expr op, Array args, + Attrs attrs = Attrs()); + + /*! + * \brief Make and emit a (shaped)constant with span. + * \return The constant/reshape. + */ + TVM_DLL static Expr MakeConstant(BlockBuilder builder, const String& name, double value, + const DataType& dtype, size_t ndim = 0); +}; + +} // namespace msc +} // namespace contrib +} // namespace tvm +#endif // TVM_CONTRIB_MSC_CORE_TRANSFORM_REWRITE_UTILS_H_ diff --git a/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc b/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc index a080fdd77862..d90cdc35d17d 100644 --- a/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc +++ b/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc @@ -92,6 +92,8 @@ const String TensorRTOpCode::DType(const DataType& dtype) { dtype_enum = "DataType::kINT8"; } else if (dtype_name == "int32") { dtype_enum = "DataType::kINT32"; + } else if (dtype_name == "int64") { + dtype_enum = "DataType::kINT32"; } else if (dtype_name == "float16") { dtype_enum = "DataType::kHALF"; } else if (dtype_name == "float32") { @@ -267,7 +269,7 @@ class TensorRTAstypeCodeGen : public TensorRTOpCode { void CodeGenBuild() final { stack_.op_call() .op_input_arg() - .func_call("setOutput", NullOpt, DocUtils::ToPtr(IdxNode())) + .func_call("setOutputType", NullOpt, DocUtils::ToPtr(IdxNode())) .call_arg(0) .op_dtype_arg(node()->OutputAt(0)->dtype); } @@ -661,7 +663,7 @@ class TensorRTTopkCodeGen : public TensorRTOpCode { protected: void CodeGenBuild() final { - const String& symbol = node()->GetTypeAttr("is_asend") ? "MIN" : "MAX"; + const String& symbol = node()->GetTypeAttr("largest") ? "MAX" : "MIN"; stack_.op_call() .op_input_arg() .call_arg("TopKOperation::k" + symbol) diff --git a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc index 3f85309cd847..0f95f2d20622 100644 --- a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc +++ b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc @@ -22,83 +22,101 @@ * \brief Pass for transform the function to tensorrt. */ +#include #include #include #include #include "../../../../relax/transform/utils.h" #include "../../../../support/scalars.h" +#include "../../core/transform/rewrite_utils.h" #include "../../core/utils.h" namespace tvm { namespace relax { using namespace tvm::contrib::msc; -const Array GetShape(const Expr& var) { - const auto& shape_opt = Downcast(GetStructInfo(var))->GetShape(); - ICHECK(shape_opt.defined()) << "Shape is not defined for " << var; - return shape_opt.value(); -} - -Var EmitCall(BlockBuilder builder, const Expr& expr, const Span& src_span, const String& suffix) { - const auto& name = SpanUtils::GetAttr(src_span, msc_attr::kName) + "_" + suffix; - expr->span = SpanUtils::SetAttr(expr->span, msc_attr::kName, name); - return builder->Emit(expr, name); -} - -Var MakeCall(BlockBuilder builder, const Span& src_span, const String& suffix, Expr op, - Array args, Attrs attrs = Attrs()) { - const auto& call = Call(op, args, attrs); - return EmitCall(builder, call, src_span, suffix); -} +struct TensorRTTransConfig { + // Whether to cast linear to conv + bool linear_to_conv{true}; + std::vector version{0, 0, 0}; + + void Load(dmlc::JSONReader* reader) { + std::string key; + reader->BeginObject(); + while (reader->NextObjectItem(&key)) { + if (key == "linear_to_conv") { + reader->Read(&linear_to_conv); + } else if (key == "version") { + reader->Read(&version); + } else { + LOG(FATAL) << "Do not support key " << key; + } + } + } +}; -Expr MakeConstant(double value, const DataType& dtype, const String& name) { - const auto& data = support::FloatImmToNDArray(FloatImm(dtype, value)); - const auto& span = SpanUtils::SetAttr(Span(), msc_attr::kName, name); - return Constant(data, NullOpt, span); +const TensorRTTransConfig ParseConfig(const String& config_str) { + TensorRTTransConfig config; + if (config_str.size() > 0) { + std::istringstream is(config_str); + dmlc::JSONReader reader(&is); + reader.Read(&config); + } + return config; } using FRewriteTensorRT = runtime::TypedPackedFunc& new_calls, const Array& version)>; + const Map& new_calls, const String& config)>; + +const Array BroadcastShape(const Array& src_shape, + const Array& out_shape) { + size_t diff = out_shape.size() - src_shape.size(); + Array leading_shape, tailing_shape; + for (size_t i = 0; i < diff; i++) { + leading_shape.push_back(Integer(1)); + } + for (const auto& s : src_shape) { + tailing_shape.push_back(s); + leading_shape.push_back(s); + } + for (size_t i = 0; i < diff; i++) { + tailing_shape.push_back(Integer(1)); + } + if (ArrayUtils::Broadcastable(tailing_shape, out_shape)) { + return tailing_shape; + } + ICHECK(ArrayUtils::Broadcastable(leading_shape, out_shape)) + << "Only support elemwise ops with leading or tailing expand"; + return leading_shape; +}; Expr RewriteElemwise(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const Array& version) { + const Map& new_calls, const String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - const auto& shape_a = GetShape(call->args[0]); - const auto& shape_b = GetShape(call->args[1]); + const auto& shape_a = ExprUtils::GetShape(call->args[0]); + const auto& shape_b = ExprUtils::GetShape(call->args[1]); + const auto& shape_out = ExprUtils::GetShape(var); static const Op& reshape_op = Op::Get("relax.reshape"); if (shape_a.size() > shape_b.size()) { - Array exp_shape(shape_a.size(), Integer(1)); - if (shape_b.size() == 1) { - exp_shape.Set(shape_a.size() - 1, shape_b[0]); - } else if (shape_b.size() == 0) { - LOG_DEBUG << "Expand scalar argument to " << exp_shape; - } else { - LOG_FATAL << "broadcast only support 1 dim and scalar, get " << shape_b; - } - const auto& expand_b = MakeCall(builder, call->span, "expand_b", reshape_op, - {call->args[1], ShapeExpr(exp_shape)}); + const auto& exp_shape = BroadcastShape(shape_b, shape_out); + const auto& expand_b = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "expand_b"), reshape_op, + {call->args[1], ShapeExpr(exp_shape)}); return Call(call->op, {call->args[0], expand_b}, call->attrs, call->sinfo_args, call->span); - } - if (shape_a.size() < shape_b.size()) { - Array exp_shape(shape_b.size(), Integer(1)); - if (shape_a.size() == 1) { - exp_shape.Set(shape_b.size() - 1, shape_a[0]); - } else if (shape_a.size() == 0) { - LOG_DEBUG << "Expand scalar argument to " << exp_shape; - } else { - LOG_FATAL << "broadcast only support 1 dim and scalar, get " << shape_a; - } - const auto& expand_a = MakeCall(builder, call->span, "expand_a", reshape_op, - {call->args[0], ShapeExpr(exp_shape)}); + } else if (shape_a.size() < shape_b.size()) { + const auto& exp_shape = BroadcastShape(shape_a, shape_out); + const auto& expand_a = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "expand_a"), reshape_op, + {call->args[0], ShapeExpr(exp_shape)}); return Call(call->op, {expand_a, call->args[1]}, call->attrs, call->sinfo_args, call->span); } return call; } Expr RewriteAdd(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const Array& version) { + const Map& new_calls, const String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; if (new_calls.count(call->args[0]) && new_calls[call->args[0]]->op == Op::Get("relax.nn.conv1d")) { @@ -110,19 +128,20 @@ Expr RewriteAdd(BlockBuilder builder, const Var& var, const Call& src_call, if (conv2d->op != Op::Get("relax.nn.conv2d")) { return call; } - const auto& input_shape = GetShape(call->args[0]); - const auto& bias_shape = GetShape(call->args[1]); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); + const auto& bias_shape = ExprUtils::GetShape(call->args[1]); const auto* conv_attrs = conv2d->attrs.as(); if (conv_attrs->data_layout == "NCHW") { // expand bias reshape Array exp_bias_shape{bias_shape[0], bias_shape[1], Integer(1), bias_shape[2]}; static const Op& reshape_op = Op::Get("relax.reshape"); - const auto& exp_bias = MakeCall(builder, call->span, "exp_bias", reshape_op, - {call->args[1], ShapeExpr(exp_bias_shape)}); + const auto& exp_bias = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_bias"), reshape_op, + {call->args[1], ShapeExpr(exp_bias_shape)}); // redirect to conv2d static const Op& add_op = Op::Get("relax.add"); - const auto& exp_add = - MakeCall(builder, call->span, "exp_add", add_op, {reshape->args[0], exp_bias}); + const auto& exp_add = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_add"), + add_op, {reshape->args[0], exp_bias}); // reduce output return Call(reshape_op, {exp_add, ShapeExpr(input_shape)}, Attrs(), call->sinfo_args, call->span); @@ -130,48 +149,50 @@ Expr RewriteAdd(BlockBuilder builder, const Var& var, const Call& src_call, LOG_FATAL << "Unexpected data layout " << conv_attrs->data_layout; } } - return RewriteElemwise(builder, var, call, new_calls, version); + return RewriteElemwise(builder, var, call, new_calls, config); } Expr RewriteArgmaxmin(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const Array& version) { + const Map& new_calls, const String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - const auto& out_dtype = Downcast(GetStructInfo(var))->dtype; + const auto& out_dtype = ExprUtils::GetDataType(var); const auto* src_attrs = src_call->attrs.as(); - Expr raw_var; - if (src_attrs->keepdims) { - raw_var = EmitCall(builder, call, call->span, "raw"); - } else { - auto new_attrs = make_object(); - new_attrs->axis = src_attrs->axis; - new_attrs->keepdims = true; - raw_var = - MakeCall(builder, call->span, "keepdims", call->op, {call->args[0]}, Attrs(new_attrs)); + ICHECK(out_dtype == DataType::Int(32) || out_dtype == DataType::Int(64)) + << "Unexpected out dtype " << out_dtype; + static const Op& topk_op = Op::Get("relax.topk"); + auto topk_attrs = make_object(); + topk_attrs->k = 1; + if (src_attrs->axis.defined()) { + topk_attrs->axis = src_attrs->axis.value()->value; } - static const Op& astype_op = Op::Get("relax.astype"); - auto cast_to_attrs = make_object(); - cast_to_attrs->dtype = DataType::Int(32); - Expr res = MakeCall(builder, call->span, "cast_to", astype_op, {raw_var}, Attrs(cast_to_attrs)); - // reshape back - if (!src_attrs->keepdims) { - const auto& output_shape = GetShape(var); - static const Op& reshape_op = Op::Get("relax.reshape"); - res = MakeCall(builder, call->span, "reshape", reshape_op, {res, ShapeExpr(output_shape)}); + topk_attrs->largest = call->op == Op::Get("relax.argmax"); + topk_attrs->ret_type = "both"; + topk_attrs->dtype = out_dtype; + // change to topk + const auto& topk = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "topk"), topk_op, + {call->args[0]}, Attrs(topk_attrs)); + const auto& get_name = ExprUtils::GetSpanName(call, ".1"); + const auto& get_item = + TupleGetItem(topk, 1, SpanUtils::CreateWithAttr(msc_attr::kName, get_name)); + if (src_attrs->keepdims) { + return get_item; } - auto cast_from_attrs = make_object(); - cast_from_attrs->dtype = out_dtype; - return Call(astype_op, {res}, Attrs(cast_from_attrs), call->sinfo_args, call->span); + const auto& get_item_var = builder->Emit(get_item, get_name); + static const Op& reshape_op = Op::Get("relax.reshape"); + const auto& output_shape = ExprUtils::GetShape(var); + return Call(reshape_op, {get_item_var, ShapeExpr(output_shape)}, Attrs(), call->sinfo_args, + call->span); } Expr RewriteAttention(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const Array& version) { + const Map& new_calls, const String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - const auto& in_dtype = Downcast(GetStructInfo(call->args[0]))->dtype; + const auto& in_dtype = ExprUtils::GetDataType(call->args[0]); const auto* src_attrs = src_call->attrs.as(); // define dims - const auto& in_q_shape = GetShape(call->args[0]); - const auto& in_v_shape = GetShape(call->args[2]); + const auto& in_q_shape = ExprUtils::GetShape(call->args[0]); + const auto& in_v_shape = ExprUtils::GetShape(call->args[2]); const auto& batch_size = in_q_shape[0]; const auto& seq_len = in_q_shape[1]; const auto& num_head = in_q_shape[2]; @@ -198,50 +219,53 @@ Expr RewriteAttention(BlockBuilder builder, const Var& var, const Call& src_call auto permute_attrs = make_object(); Array axes{Integer(0), Integer(2), Integer(1), Integer(3)}; permute_attrs->axes = axes; - const auto& q_trans = MakeCall(builder, call->span, "q_trans", permute_dims_op, {call->args[0]}, - Attrs(permute_attrs)); - const auto& k_trans = MakeCall(builder, call->span, "k_trans", permute_dims_op, {call->args[1]}, - Attrs(permute_attrs)); - const auto& v_trans = MakeCall(builder, call->span, "v_trans", permute_dims_op, {call->args[2]}, - Attrs(permute_attrs)); + const auto& q_trans = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "q_trans"), permute_dims_op, + {call->args[0]}, Attrs(permute_attrs)); + const auto& k_trans = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "k_trans"), permute_dims_op, + {call->args[1]}, Attrs(permute_attrs)); + const auto& v_trans = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "v_trans"), permute_dims_op, + {call->args[2]}, Attrs(permute_attrs)); Array q_shape({batch_size * num_head, seq_len, head_dim}); - const auto& q_reshape = - MakeCall(builder, call->span, "q_reshape", reshape_op, {q_trans, ShapeExpr(q_shape)}); + const auto& q_reshape = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "q_reshape"), + reshape_op, {q_trans, ShapeExpr(q_shape)}); Array k_shape({batch_size * num_head, seq_len_kv, head_dim}); - const auto& k_reshape = - MakeCall(builder, call->span, "k_reshape", reshape_op, {k_trans, ShapeExpr(k_shape)}); + const auto& k_reshape = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "k_reshape"), + reshape_op, {k_trans, ShapeExpr(k_shape)}); Array v_shape({batch_size * num_head, seq_len_kv, head_dim_v}); - const auto& v_reshape = - MakeCall(builder, call->span, "v_reshape", reshape_op, {v_trans, ShapeExpr(v_shape)}); + const auto& v_reshape = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "v_reshape"), + reshape_op, {v_trans, ShapeExpr(v_shape)}); auto reduce_permute_attrs = make_object(); Array v_axes{Integer(0), Integer(2), Integer(1)}; reduce_permute_attrs->axes = v_axes; // transpose for batch_matmul - const auto& k_reshape_trans = MakeCall(builder, call->span, "k_reshape_trans", permute_dims_op, - {k_reshape}, Attrs(reduce_permute_attrs)); + const auto& k_reshape_trans = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "k_reshape_trans"), + permute_dims_op, {k_reshape}, Attrs(reduce_permute_attrs)); // calculate product auto matmul_attrs = make_object(); matmul_attrs->out_dtype = in_dtype; - const auto& qk_prod = MakeCall(builder, call->span, "qk_prod", matmul_op, - {q_reshape, k_reshape_trans}, Attrs(matmul_attrs)); + const auto& qk_prod = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "qk_prod"), matmul_op, + {q_reshape, k_reshape_trans}, Attrs(matmul_attrs)); Expr p_scale; if (src_attrs->scale.defined()) { - const auto& scale = MakeConstant(static_cast(src_attrs->scale.value()->value), in_dtype, - SpanUtils::GetAttr(call->span, msc_attr::kName) + "_scale"); - Array exp_shape(3, Integer(1)); - const auto& exp_scale = - MakeCall(builder, call->span, "exp_scale", reshape_op, {scale, ShapeExpr(exp_shape)}); - p_scale = MakeCall(builder, call->span, "p_scale", multiply_op, {qk_prod, exp_scale}); + double value = static_cast(src_attrs->scale.value()->value); + const auto& scale = RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "scale"), + value, in_dtype, 3); + p_scale = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_scale"), multiply_op, + {qk_prod, scale}); } else { - const auto& scale = - MakeConstant(static_cast(Downcast(head_dim)->value), in_dtype, - SpanUtils::GetAttr(call->span, msc_attr::kName) + "_scale"); - Array exp_shape(3, Integer(1)); - const auto& exp_scale = - MakeCall(builder, call->span, "exp_scale", reshape_op, {scale, ShapeExpr(exp_shape)}); - const auto& sqrt_scale = MakeCall(builder, call->span, "sqrt_scale", sqrt_op, {exp_scale}); - p_scale = MakeCall(builder, call->span, "p_scale", divide_op, {qk_prod, sqrt_scale}); + double value = static_cast(Downcast(head_dim)->value); + const auto& scale = RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "scale"), + value, in_dtype, 3); + const auto& sqrt_scale = RewriteUtils::MakeCall( + builder, ExprUtils::GetSpanName(call, "sqrt_scale"), sqrt_op, {scale}); + p_scale = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_scale"), divide_op, + {qk_prod, sqrt_scale}); } // bias @@ -249,12 +273,12 @@ Expr RewriteAttention(BlockBuilder builder, const Var& var, const Call& src_call if (call->args.size() == 4) { Array exp_shape{batch_size, num_head, seq_len, seq_len_kv}; Array reduce_shape{batch_size * num_head, seq_len, seq_len_kv}; - const auto& prod_exp = - MakeCall(builder, call->span, "prod_exp", reshape_op, {prod, ShapeExpr(exp_shape)}); - const auto& prod_add = - MakeCall(builder, call->span, "prod_add", add_op, {prod_exp, call->args[3]}); - prod = MakeCall(builder, call->span, "prod_reduce", reshape_op, - {prod_add, ShapeExpr(reduce_shape)}); + const auto& prod_exp = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "prod_exp"), + reshape_op, {prod, ShapeExpr(exp_shape)}); + const auto& prod_add = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "prod_add"), + add_op, {prod_exp, call->args[3]}); + prod = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "prod_reduce"), reshape_op, + {prod_add, ShapeExpr(reduce_shape)}); } // causal_mask @@ -262,7 +286,8 @@ Expr RewriteAttention(BlockBuilder builder, const Var& var, const Call& src_call if (!src_attrs->causal_mask.defined()) { auto softmax_attrs = make_object(); softmax_attrs->axis = 2; - s_value = MakeCall(builder, call->span, "act", softmax_op, {prod}, Attrs(softmax_attrs)); + s_value = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "act"), softmax_op, + {prod}, Attrs(softmax_attrs)); } else { const auto& causal_mask = src_attrs->causal_mask.value(); PrimValue tril_k; @@ -273,41 +298,47 @@ Expr RewriteAttention(BlockBuilder builder, const Var& var, const Call& src_call } else { LOG_FATAL << "Unexpected causal_mask " << causal_mask; } - const auto& p_masked = MakeCall(builder, call->span, "p_masked", tril_op, {prod, tril_k}); + const auto& p_masked = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_masked"), + tril_op, {prod, tril_k}); auto reduce_attrs = make_object(); Array axis{Integer(2)}; reduce_attrs->axis = axis; reduce_attrs->keepdims = true; - const auto& p_max = MakeCall(builder, call->span, "p_max", max_op, {prod}, Attrs(reduce_attrs)); - const auto& p_diff = MakeCall(builder, call->span, "p_diff", subtract_op, {p_masked, p_max}); - const auto& p_exp = MakeCall(builder, call->span, "p_exp", exp_op, {p_diff}); - const auto& p_masked_exp = - MakeCall(builder, call->span, "p_masked_exp", tril_op, {p_exp, tril_k}); + const auto& p_max = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_max"), + max_op, {prod}, Attrs(reduce_attrs)); + const auto& p_diff = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_diff"), + subtract_op, {p_masked, p_max}); + const auto& p_exp = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_exp"), exp_op, {p_diff}); + const auto& p_masked_exp = RewriteUtils::MakeCall( + builder, ExprUtils::GetSpanName(call, "p_masked_exp"), tril_op, {p_exp, tril_k}); const auto& p_masked_sum = - MakeCall(builder, call->span, "p_masked_sum", sum_op, {p_masked_exp}, Attrs(reduce_attrs)); - s_value = MakeCall(builder, call->span, "act", divide_op, {p_masked_exp, p_masked_sum}); + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_masked_sum"), sum_op, + {p_masked_exp}, Attrs(reduce_attrs)); + s_value = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "act"), divide_op, + {p_masked_exp, p_masked_sum}); } // final calculation - const auto& o_prod = - MakeCall(builder, call->span, "o_prod", matmul_op, {s_value, v_reshape}, Attrs(matmul_attrs)); + const auto& o_prod = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "o_prod"), + matmul_op, {s_value, v_reshape}, Attrs(matmul_attrs)); Array o_shape{batch_size, num_head, seq_len, head_dim_v}; return Call(reshape_op, {o_prod, ShapeExpr(o_shape)}, Attrs(), call->sinfo_args, call->span); } Expr RewriteBatchNorm(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const Array& version) { + const Map& new_calls, const String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - const auto& input_shape = GetShape(call->args[0]); - const auto& in_dtype = Downcast(GetStructInfo(call->args[0]))->dtype; + const auto& input_shape = ExprUtils::GetShape(call->args[0]); + const auto& in_dtype = ExprUtils::GetDataType(call->args[0]); const auto* src_attrs = src_call->attrs.as(); // define expand shape Array exp_shape(input_shape.size(), Integer(1)); exp_shape.Set(src_attrs->axis, input_shape[src_attrs->axis]); // create eps constant - const auto& eps = MakeConstant(src_attrs->epsilon, in_dtype, - SpanUtils::GetAttr(call->span, msc_attr::kName) + "_eps"); + const auto& eps = RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "eps"), + src_attrs->epsilon, in_dtype); // create ops static const Op& add_op = Op::Get("relax.add"); @@ -318,36 +349,43 @@ Expr RewriteBatchNorm(BlockBuilder builder, const Var& var, const Call& src_call static const Op& subtract_op = Op::Get("relax.subtract"); // scale factor: gamma/sqrt(var + epsilon) - const auto& eps_add = MakeCall(builder, call->span, "eps_add", add_op, {call->args[4], eps}); - const auto& sqrt = MakeCall(builder, call->span, "sqrt", sqrt_op, {eps_add}); - const auto& scale_factor = - MakeCall(builder, call->span, "scale_factor", divide_op, {call->args[1], sqrt}); + const auto& eps_add = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "eps_add"), + add_op, {call->args[4], eps}); + const auto& sqrt = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "sqrt"), sqrt_op, {eps_add}); + const auto& scale_factor = RewriteUtils::MakeCall( + builder, ExprUtils::GetSpanName(call, "scale_factor"), divide_op, {call->args[1], sqrt}); Expr res = call->args[0]; // scale if (src_attrs->scale) { - const auto& exp_scale = MakeCall(builder, call->span, "exp_scale", reshape_op, - {scale_factor, ShapeExpr(exp_shape)}); - res = MakeCall(builder, call->span, "scale", multiply_op, {res, exp_scale}); + const auto& exp_scale = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_scale"), reshape_op, + {scale_factor, ShapeExpr(exp_shape)}); + res = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "scale"), multiply_op, + {res, exp_scale}); } // offset if (src_attrs->center) { // offset factor: beta-mean*scale_factor - const auto& average = - MakeCall(builder, call->span, "average", multiply_op, {call->args[3], scale_factor}); + const auto& average = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "average"), + multiply_op, {call->args[3], scale_factor}); const auto& offset_factor = - MakeCall(builder, call->span, "offset_factor", subtract_op, {call->args[2], average}); - const auto& exp_offset = MakeCall(builder, call->span, "exp_offset", reshape_op, - {offset_factor, ShapeExpr(exp_shape)}); - res = MakeCall(builder, call->span, "offset", add_op, {res, exp_offset}); + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "offset_factor"), subtract_op, + {call->args[2], average}); + const auto& exp_offset = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_offset"), reshape_op, + {offset_factor, ShapeExpr(exp_shape)}); + res = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "offset"), add_op, + {res, exp_offset}); } return Tuple(Array{res}, call->span); } Expr RewriteBroadcastTo(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const Array& version) { + const Map& new_calls, const String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - const auto& input_shape = GetShape(call->args[0]); - const auto& output_shape = GetShape(var); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); + const auto& output_shape = ExprUtils::GetShape(var); Expr concat_input = call->args[0]; static const Op& concat_op = Op::Get("relax.concat"); for (size_t i = 0; i < input_shape.size(); i++) { @@ -357,30 +395,33 @@ Expr RewriteBroadcastTo(BlockBuilder builder, const Var& var, const Call& src_ca Array concat_inputs(out_dim / in_dim, concat_input); auto concat_attrs = make_object(); concat_attrs->axis = Integer(i); - concat_input = MakeCall(builder, call->span, "concat_" + std::to_string(i), concat_op, - {Tuple(concat_inputs)}, Attrs(concat_attrs)); + concat_input = RewriteUtils::MakeCall( + builder, ExprUtils::GetSpanName(call, "concat_" + std::to_string(i)), concat_op, + {Tuple(concat_inputs)}, Attrs(concat_attrs)); } } return concat_input; } Expr RewriteConv1d(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const Array& version) { + const Map& new_calls, const String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; const auto* src_attrs = src_call->attrs.as(); - const auto& input_shape = GetShape(call->args[0]); - const auto& weight_shape = GetShape(call->args[1]); - const auto& output_shape = GetShape(var); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); + const auto& weight_shape = ExprUtils::GetShape(call->args[1]); + const auto& output_shape = ExprUtils::GetShape(var); if (src_attrs->data_layout == "NCW") { Array new_args; // expand inputs Array exp_input_shape{input_shape[0], input_shape[1], Integer(1), input_shape[2]}; Array exp_weight_shape{weight_shape[0], weight_shape[1], Integer(1), weight_shape[2]}; static const Op& reshape_op = Op::Get("relax.reshape"); - new_args.push_back(MakeCall(builder, call->span, "exp_input", reshape_op, - {call->args[0], ShapeExpr(exp_input_shape)})); - new_args.push_back(MakeCall(builder, call->span, "exp_weight", reshape_op, - {call->args[1], ShapeExpr(exp_weight_shape)})); + new_args.push_back(RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_input"), + reshape_op, + {call->args[0], ShapeExpr(exp_input_shape)})); + new_args.push_back(RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_weight"), + reshape_op, + {call->args[1], ShapeExpr(exp_weight_shape)})); // change to conv2d static const Op& conv2d_op = Op::Get("relax.nn.conv2d"); auto conv_attrs = make_object(); @@ -393,8 +434,8 @@ Expr RewriteConv1d(BlockBuilder builder, const Var& var, const Call& src_call, conv_attrs->kernel_layout = "OIHW"; conv_attrs->out_layout = "NCHW"; conv_attrs->out_dtype = src_attrs->out_dtype; - const auto& conv2d = - MakeCall(builder, call->span, "exp", conv2d_op, new_args, Attrs(conv_attrs)); + const auto& conv2d = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp"), + conv2d_op, new_args, Attrs(conv_attrs)); // reduce output return Call(reshape_op, {conv2d, ShapeExpr(output_shape)}, Attrs(), call->sinfo_args, call->span); @@ -404,11 +445,80 @@ Expr RewriteConv1d(BlockBuilder builder, const Var& var, const Call& src_call, return call; } +Expr RewriteGelu(BlockBuilder builder, const Var& var, const Call& src_call, + const Map& new_calls, const String& config) { + // 0.5 * x * (1 + erf(sqrt(0.5) * x)) + const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; + size_t in_dim = ExprUtils::GetShape(call->args[0]).size(); + const auto& in_dtype = ExprUtils::GetDataType(call->args[0]); + // create ops + static const Op& add_op = Op::Get("relax.add"); + static const Op& multiply_op = Op::Get("relax.multiply"); + static const Op& erf_op = Op::Get("relax.erf"); + + const auto& factor = RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "factor"), + std::sqrt(0.5), in_dtype, in_dim); + const auto& mul = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "mul"), + multiply_op, {factor, call->args[0]}); + const auto& erf = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "erf"), erf_op, {mul}); + const auto& one = + RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "one"), 1, in_dtype, in_dim); + const auto& add = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "add"), add_op, {one, erf}); + const auto& mul2 = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "mul2"), + multiply_op, {call->args[0], add}); + const auto& half = RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "one"), 0.5, + in_dtype, in_dim); + return Call(multiply_op, {half, mul2}, Attrs(), call->sinfo_args, call->span); +} + +Expr RewriteGeluTanh(BlockBuilder builder, const Var& var, const Call& src_call, + const Map& new_calls, const String& config) { + // 0.5 * x * (1 + tanh(sqrt(2/pi) * (0.044715F * pow(x, 3) + x))) + const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; + size_t in_dim = ExprUtils::GetShape(call->args[0]).size(); + const auto& in_dtype = ExprUtils::GetDataType(call->args[0]); + + // create ops + static const Op& add_op = Op::Get("relax.add"); + static const Op& multiply_op = Op::Get("relax.multiply"); + static const Op& pow_op = Op::Get("relax.power"); + static const Op& tanh_op = Op::Get("relax.tanh"); + + const auto& pow_factor = RewriteUtils::MakeConstant( + builder, ExprUtils::GetSpanName(call, "pow_factor"), 3, in_dtype, in_dim); + const auto& mul_factor = RewriteUtils::MakeConstant( + builder, ExprUtils::GetSpanName(call, "mul_factor"), 0.044715, in_dtype, in_dim); + const auto& pi_factor = RewriteUtils::MakeConstant( + builder, ExprUtils::GetSpanName(call, "pi_factor"), std::sqrt(2 / M_PI), in_dtype, in_dim); + + const auto& pow = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "pow"), pow_op, + {call->args[0], pow_factor}); + const auto& mul = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "mul"), + multiply_op, {mul_factor, pow}); + const auto& add = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "add"), add_op, + {mul, call->args[0]}); + const auto& mul2 = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "mul2"), + multiply_op, {pi_factor, add}); + const auto& tanh = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "tanh"), tanh_op, {mul2}); + const auto& one = + RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "one"), 1, in_dtype, in_dim); + const auto& add2 = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "add"), add_op, {one, tanh}); + const auto& mul3 = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "mul3"), + multiply_op, {call->args[0], add2}); + const auto& half = RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "one"), 0.5, + in_dtype, in_dim); + return Call(multiply_op, {half, mul3}, Attrs(), call->sinfo_args, call->span); +} + Expr RewriteGroupNorm(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const Array& version) { + const Map& new_calls, const String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - const auto& input_shape = GetShape(call->args[0]); - const auto& in_dtype = Downcast(GetStructInfo(call->args[0]))->dtype; + const auto& input_shape = ExprUtils::GetShape(call->args[0]); + const auto& in_dtype = ExprUtils::GetDataType(call->args[0]); const auto* src_attrs = src_call->attrs.as(); Array group_shape = input_shape; Array exp_shape(input_shape.size(), Integer(1)); @@ -420,8 +530,8 @@ Expr RewriteGroupNorm(BlockBuilder builder, const Var& var, const Call& src_call exp_shape.Set(axis, Integer(src_attrs->num_groups)); // create eps constant - const auto& eps = MakeConstant(src_attrs->epsilon, in_dtype, - SpanUtils::GetAttr(call->span, msc_attr::kName) + "_eps"); + const auto& eps = RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "eps"), + src_attrs->epsilon, in_dtype); // create ops static const Op& add_op = Op::Get("relax.add"); @@ -434,53 +544,63 @@ Expr RewriteGroupNorm(BlockBuilder builder, const Var& var, const Call& src_call static const Op& subtract_op = Op::Get("relax.subtract"); // reshape input - const auto& reshape_in = MakeCall(builder, call->span, "reshape_in", reshape_op, - {call->args[0], ShapeExpr(group_shape)}); + const auto& reshape_in = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "reshape_in"), reshape_op, + {call->args[0], ShapeExpr(group_shape)}); // mean(input) auto mean_attrs = make_object(); mean_attrs->axis = src_attrs->axes; mean_attrs->keepdims = true; - const auto& mean = - MakeCall(builder, call->span, "mean", mean_op, {reshape_in}, Attrs(mean_attrs)); + const auto& mean = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "mean"), mean_op, + {reshape_in}, Attrs(mean_attrs)); // variance: mean((input-mean)*(input-mean)) - const auto& diff = MakeCall(builder, call->span, "diff", subtract_op, {reshape_in, mean}); - const auto& square = MakeCall(builder, call->span, "square", square_op, {diff}); - const auto& variance = - MakeCall(builder, call->span, "variance", mean_op, {square}, Attrs(mean_attrs)); + const auto& diff = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "diff"), + subtract_op, {reshape_in, mean}); + const auto& square = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "square"), square_op, {diff}); + const auto& variance = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "variance"), + mean_op, {square}, Attrs(mean_attrs)); // sqrt(var + epsilon) Array exp_eps_shape(input_shape.size(), Integer(1)); - const auto& exp_eps = - MakeCall(builder, call->span, "exp_eps", reshape_op, {eps, ShapeExpr(exp_eps_shape)}); - const auto& eps_add = MakeCall(builder, call->span, "eps_add", add_op, {variance, exp_eps}); - const auto& sqrt = MakeCall(builder, call->span, "sqrt", sqrt_op, {eps_add}); + const auto& exp_eps = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_eps"), + reshape_op, {eps, ShapeExpr(exp_eps_shape)}); + const auto& eps_add = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "eps_add"), + add_op, {variance, exp_eps}); + const auto& sqrt = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "sqrt"), sqrt_op, {eps_add}); // diff/sqrt - Expr res = MakeCall(builder, call->span, "divide", divide_op, {diff, sqrt}); + Expr res = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "divide"), divide_op, + {diff, sqrt}); // scale if (src_attrs->scale) { - const auto& exp_gamma = MakeCall(builder, call->span, "exp_gamma", reshape_op, - {call->args[1], ShapeExpr(exp_shape)}); - res = MakeCall(builder, call->span, "scale", multiply_op, {res, exp_gamma}); + const auto& exp_gamma = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_gamma"), reshape_op, + {call->args[1], ShapeExpr(exp_shape)}); + res = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "scale"), multiply_op, + {res, exp_gamma}); } // offset if (src_attrs->center) { - const auto& exp_beta = MakeCall(builder, call->span, "exp_beta", reshape_op, - {call->args[2], ShapeExpr(exp_shape)}); - res = MakeCall(builder, call->span, "offset", add_op, {res, exp_beta}); + const auto& exp_beta = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_beta"), reshape_op, + {call->args[2], ShapeExpr(exp_shape)}); + res = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "offset"), add_op, + {res, exp_beta}); } // reshape output return Call(reshape_op, {res, ShapeExpr(input_shape)}, Attrs(), call->sinfo_args, call->span); } Expr RewriteLayerNorm(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const Array& version) { + const Map& new_calls, const String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - const auto& input_shape = GetShape(call->args[0]); - const auto& in_dtype = Downcast(GetStructInfo(call->args[0]))->dtype; + const auto& input_shape = ExprUtils::GetShape(call->args[0]); + const auto& in_dtype = ExprUtils::GetDataType(call->args[0]); const auto* src_attrs = src_call->attrs.as(); Array exp_shape(input_shape.size(), Integer(1)); for (const auto& a : src_attrs->axes) { @@ -488,8 +608,8 @@ Expr RewriteLayerNorm(BlockBuilder builder, const Var& var, const Call& src_call exp_shape.Set(index, input_shape[index]); } // create eps constant - const auto& eps = MakeConstant(src_attrs->epsilon, in_dtype, - SpanUtils::GetAttr(call->span, msc_attr::kName) + "_eps"); + const auto& eps = RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "eps"), + src_attrs->epsilon, in_dtype); // create ops static const Op& add_op = Op::Get("relax.add"); @@ -505,30 +625,36 @@ Expr RewriteLayerNorm(BlockBuilder builder, const Var& var, const Call& src_call auto mean_attrs = make_object(); mean_attrs->axis = src_attrs->axes; mean_attrs->keepdims = true; - const auto& mean = - MakeCall(builder, call->span, "mean", mean_op, {call->args[0]}, Attrs(mean_attrs)); + const auto& mean = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "mean"), mean_op, + {call->args[0]}, Attrs(mean_attrs)); // variance: mean((input-mean)*(input-mean)) - const auto& diff = MakeCall(builder, call->span, "diff", subtract_op, {call->args[0], mean}); - const auto& square = MakeCall(builder, call->span, "square", square_op, {diff}); - const auto& variance = - MakeCall(builder, call->span, "variance", mean_op, {square}, Attrs(mean_attrs)); + const auto& diff = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "diff"), + subtract_op, {call->args[0], mean}); + const auto& square = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "square"), square_op, {diff}); + const auto& variance = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "variance"), + mean_op, {square}, Attrs(mean_attrs)); // sqrt(var + epsilon) Array exp_eps_shape(input_shape.size(), Integer(1)); - const auto& exp_eps = - MakeCall(builder, call->span, "exp_eps", reshape_op, {eps, ShapeExpr(exp_eps_shape)}); - const auto& eps_add = MakeCall(builder, call->span, "eps_add", add_op, {variance, exp_eps}); - const auto& sqrt = MakeCall(builder, call->span, "sqrt", sqrt_op, {eps_add}); + const auto& exp_eps = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_eps"), + reshape_op, {eps, ShapeExpr(exp_eps_shape)}); + const auto& eps_add = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "eps_add"), + add_op, {variance, exp_eps}); + const auto& sqrt = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "sqrt"), sqrt_op, {eps_add}); // diff/sqrt Call res = Call(divide_op, {diff, sqrt}, Attrs(), call->sinfo_args, call->span); // scale if (src_attrs->scale) { - const auto& exp_gamma = MakeCall(builder, call->span, "exp_gamma", reshape_op, - {call->args[1], ShapeExpr(exp_shape)}); - const auto& res_var = EmitCall(builder, res, call->span, "pre_scale"); + const auto& exp_gamma = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_gamma"), reshape_op, + {call->args[1], ShapeExpr(exp_shape)}); + const auto& res_var = + RewriteUtils::ReEmit(builder, ExprUtils::GetSpanName(call, "pre_scale"), res); if (src_attrs->center) { res = Call(multiply_op, {res_var, exp_gamma}); } else { @@ -537,87 +663,126 @@ Expr RewriteLayerNorm(BlockBuilder builder, const Var& var, const Call& src_call } // offset if (src_attrs->center) { - const auto& exp_beta = MakeCall(builder, call->span, "exp_beta", reshape_op, - {call->args[2], ShapeExpr(exp_shape)}); - const auto& res_var = EmitCall(builder, res, call->span, "pre_offset"); + const auto& exp_beta = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_beta"), reshape_op, + {call->args[2], ShapeExpr(exp_shape)}); + const auto& res_var = + RewriteUtils::ReEmit(builder, ExprUtils::GetSpanName(call, "pre_offset"), res); res = Call(add_op, {res_var, exp_beta}, Attrs(), call->sinfo_args, call->span); } return res; } Expr RewriteMatmul(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const Array& version) { + const Map& new_calls, const String& config) { + const auto& trt_config = ParseConfig(config); const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - const auto& shape_a = GetShape(call->args[0]); - const auto& shape_b = GetShape(call->args[1]); + const auto& shape_a = ExprUtils::GetShape(call->args[0]); + const auto& shape_b = ExprUtils::GetShape(call->args[1]); static const Op& reshape_op = Op::Get("relax.reshape"); + if (call->args[1]->IsInstance() && shape_b.size() == 2 && + trt_config.linear_to_conv) { + const auto& out_shape = ExprUtils::GetShape(var); + PrimExpr accumulate = ArrayUtils::Accumulate(shape_a, shape_a.size() - 1); + Array exp_shape{accumulate, shape_a[shape_a.size() - 1], Integer(1), Integer(1)}; + const auto& exp_in = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_in"), + reshape_op, {call->args[0], ShapeExpr(exp_shape)}); + // transpose and expand weight to OIHW + static const Op& permute_dims_op = Op::Get("relax.permute_dims"); + auto permute_attrs = make_object(); + Array axes{Integer(1), Integer(0)}; + permute_attrs->axes = axes; + const auto& trans_weight = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "trans_weight"), + permute_dims_op, {call->args[1]}, Attrs(permute_attrs)); + Array weight_shape{shape_b[1], shape_b[0], Integer(1), Integer(1)}; + const auto& exp_weight = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "exp_weight"), reshape_op, + {trans_weight, ShapeExpr(weight_shape)}); + // to conv2d + static const Op& conv2d_op = Op::Get("relax.nn.conv2d"); + auto conv_attrs = make_object(); + conv_attrs->strides = Array{Integer(1), Integer(1)}; + conv_attrs->padding = Array{Integer(0), Integer(0), Integer(0), Integer(0)}; + conv_attrs->dilation = Array{Integer(1), Integer(1)}; + conv_attrs->groups = 1; + conv_attrs->data_layout = "NCHW"; + conv_attrs->kernel_layout = "OIHW"; + conv_attrs->out_layout = "NCHW"; + conv_attrs->out_dtype = ExprUtils::GetDataType(var); + const auto& conv2d = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "conv2d"), + conv2d_op, {exp_in, exp_weight}, Attrs(conv_attrs)); + return Call(reshape_op, {conv2d, ShapeExpr(out_shape)}, Attrs(), call->sinfo_args, call->span); + } if (shape_a.size() > shape_b.size()) { Array exp_shape(shape_a.size(), Integer(1)); - for (size_t i = shape_b.size(); i < shape_a.size(); i++) { - exp_shape.Set(i, shape_b[i - shape_b.size()]); + size_t diff = shape_a.size() - shape_b.size(); + for (size_t i = diff; i < shape_a.size(); i++) { + exp_shape.Set(i, shape_b[i - diff]); } - const auto& expand_b = MakeCall(builder, call->span, "expand_b", reshape_op, - {call->args[1], ShapeExpr(exp_shape)}); + const auto& expand_b = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "expand_b"), reshape_op, + {call->args[1], ShapeExpr(exp_shape)}); return Call(call->op, {call->args[0], expand_b}, call->attrs, call->sinfo_args, call->span); } if (shape_a.size() < shape_b.size()) { Array exp_shape(shape_b.size(), Integer(1)); - for (size_t i = shape_a.size(); i < shape_b.size(); i++) { - exp_shape.Set(i, shape_a[i - shape_a.size()]); + size_t diff = shape_b.size() - shape_a.size(); + for (size_t i = diff; i < shape_b.size(); i++) { + exp_shape.Set(i, shape_a[i - diff]); } - const auto& expand_a = MakeCall(builder, call->span, "expand_a", reshape_op, - {call->args[0], ShapeExpr(exp_shape)}); + const auto& expand_a = + RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "expand_a"), reshape_op, + {call->args[0], ShapeExpr(exp_shape)}); return Call(call->op, {expand_a, call->args[1]}, call->attrs, call->sinfo_args, call->span); } return call; } Expr RewriteRsqrt(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const Array& version) { + const Map& new_calls, const String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - const auto& input_shape = GetShape(call->args[0]); - const auto& in_dtype = Downcast(GetStructInfo(call->args[0]))->dtype; - Array exp_shape(input_shape.size(), Integer(1)); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); + const auto& in_dtype = ExprUtils::GetDataType(call->args[0]); // create 1 constant - const auto& one = - MakeConstant(1, in_dtype, SpanUtils::GetAttr(call->span, msc_attr::kName) + "_one"); + const auto& one = RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "eps"), 1, + in_dtype, input_shape.size()); // create ops - static const Op& reshape_op = Op::Get("relax.reshape"); static const Op& divide_op = Op::Get("relax.divide"); static const Op& sqrt_op = Op::Get("relax.sqrt"); // expand and divide - const auto& exp_one = - MakeCall(builder, call->span, "exp_one", reshape_op, {one, ShapeExpr(exp_shape)}); - const auto& sqrt = MakeCall(builder, call->span, "sqrt", sqrt_op, {call->args[0]}); - return Call(divide_op, {exp_one, sqrt}, Attrs(), call->sinfo_args, call->span); + const auto& sqrt = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "sqrt"), sqrt_op, + {call->args[0]}); + return Call(divide_op, {one, sqrt}, Attrs(), call->sinfo_args, call->span); } Expr RewriteSilu(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const Array& version) { + const Map& new_calls, const String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; // create ops static const Op& multiply_op = Op::Get("relax.multiply"); static const Op& sigmoid_op = Op::Get("relax.sigmoid"); // silu=input*sigmoid(input) - const auto& sigmoid = MakeCall(builder, call->span, "sigmoid", sigmoid_op, {call->args[0]}); + const auto& sigmoid = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "sigmoid"), + sigmoid_op, {call->args[0]}); return Call(multiply_op, {call->args[0], sigmoid}, Attrs(), call->sinfo_args, call->span); } Expr RewriteShapeLike(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const Array& version) { + const Map& new_calls, const String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - const auto& output_shape = GetShape(var); + const auto& output_shape = ExprUtils::GetShape(var); static const Op& reshape_op = Op::Get("relax.reshape"); return Call(reshape_op, {call->args[0], ShapeExpr(output_shape)}, Attrs(), call->sinfo_args, call->span); } Expr RewriteSplit(BlockBuilder builder, const Var& var, const Call& src_call, - const Map& new_calls, const Array& version) { + const Map& new_calls, const String& config) { const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; - const auto& input_shape = GetShape(call->args[0]); + const auto& input_shape = ExprUtils::GetShape(call->args[0]); const auto* src_attrs = src_call->attrs.as(); size_t axis = CommonUtils::GetIndex(src_attrs->axis, input_shape.size()); std::vector split_begins, split_ends; @@ -646,9 +811,16 @@ Expr RewriteSplit(BlockBuilder builder, const Var& var, const Call& src_call, // create strided_slices Array outputs; for (size_t i = 0; i < split_begins.size(); i++) { - auto slice = strided_slice(call->args[0], Tuple(Array{PrimValue(Integer(axis))}), - Tuple(Array{PrimValue(Integer(split_begins[i]))}), - Tuple(Array{PrimValue(Integer(split_ends[i]))})); + static const Op& strided_slice_op = Op::Get("relax.strided_slice"); + const auto& axes = Tuple(Array{PrimValue(IntImm(DataType::Int(64), axis))}); + const auto& begin = Tuple(Array{PrimValue(IntImm(DataType::Int(64), split_begins[i]))}); + const auto& end = Tuple(Array{PrimValue(IntImm(DataType::Int(64), split_ends[i]))}); + const auto& strides = Tuple(Array{PrimValue(IntImm(DataType::Int(64), 1))}); + auto attrs = make_object(); + attrs->assume_inbound = true; + const auto& slice = RewriteUtils::MakeCall( + builder, ExprUtils::GetSpanName(call, "slice_" + std::to_string(i)), strided_slice_op, + {call->args[0], axes, begin, end, strides}, Attrs(attrs)); outputs.push_back(slice); } return Tuple(outputs, call->span); @@ -664,6 +836,9 @@ TVM_REGISTER_OP("relax.nn.batch_norm") TVM_REGISTER_OP("relax.nn.conv1d").set_attr("FRewriteTensorRT", RewriteConv1d); TVM_REGISTER_OP("relax.nn.group_norm") .set_attr("FRewriteTensorRT", RewriteGroupNorm); +TVM_REGISTER_OP("relax.nn.gelu").set_attr("FRewriteTensorRT", RewriteGelu); +TVM_REGISTER_OP("relax.nn.gelu_tanh") + .set_attr("FRewriteTensorRT", RewriteGeluTanh); TVM_REGISTER_OP("relax.nn.layer_norm") .set_attr("FRewriteTensorRT", RewriteLayerNorm); TVM_REGISTER_OP("relax.nn.silu").set_attr("FRewriteTensorRT", RewriteSilu); @@ -695,9 +870,9 @@ TVM_REGISTER_OP("relax.split").set_attr("FRewriteTensorRT", Re class TensorRTTransformer : public ExprMutator { public: - explicit TensorRTTransformer(IRModule ctx_module, const Array& version) + explicit TensorRTTransformer(IRModule ctx_module, const String& config) : ExprMutator(ctx_module) { - version_ = version; + config_ = config; } void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) final { @@ -707,7 +882,7 @@ class TensorRTTransformer : public ExprMutator { if (rewrite_map.count(op)) { const auto& call = GetRef(call_node); FRewriteTensorRT f = rewrite_map[op]; - const auto& new_call = f(builder_, binding->var, call, new_calls_, version_); + const auto& new_call = f(builder_, binding->var, call, new_calls_, config_); if (new_call != call) { ReEmitBinding(binding, builder_->Normalize(new_call)); new_calls_.Set(binding->var, call); @@ -721,20 +896,19 @@ class TensorRTTransformer : public ExprMutator { private: Map new_calls_; - Array version_; + String config_; }; -Function TransformTensorRT(const Function& func, const IRModule& module, - const Array& version) { - return Downcast(TensorRTTransformer(module, version).VisitExpr(func)); +Function TransformTensorRT(const Function& func, const IRModule& module, const String& config) { + return Downcast(TensorRTTransformer(module, config).VisitExpr(func)); } namespace transform { -Pass TransformTensorRT(const Array& version) { +Pass TransformTensorRT(const String& config) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { - return relax::TransformTensorRT(f, m, version); + return relax::TransformTensorRT(f, m, config); }; return CreateFunctionPass(pass_func, 0, "TransformTensorRT", {}); } diff --git a/tests/python/contrib/test_msc/test_translate_tensorrt.py b/tests/python/contrib/test_msc/test_translate_tensorrt.py index 74c25ceacfe8..7c8c2830995c 100644 --- a/tests/python/contrib/test_msc/test_translate_tensorrt.py +++ b/tests/python/contrib/test_msc/test_translate_tensorrt.py @@ -87,7 +87,7 @@ def _is_target_func(func): NameChecker().check(func) -def verify_model(torch_model, input_info, allow_incomplete=False): +def verify_model(torch_model, input_info, **trans_config): """Build model and verify results""" graph_model = fx.symbolic_trace(torch_model) @@ -100,9 +100,7 @@ def verify_model(torch_model, input_info, allow_incomplete=False): golden = [golden] golden = [g.detach().cpu().numpy() for g in golden] # partition module for tensorrt - mod, graphs, weights = translate.partition_for_tensorrt( - mod, trans_config={"allow_incomplete": allow_incomplete} - ) + mod, graphs, weights = translate.partition_for_tensorrt(mod, trans_config=trans_config) check_names(mod) output_folder = msc_utils.msc_dir() # tranalte to tensorrt @@ -191,6 +189,8 @@ def forward(self, x, y): input_info = [([1, 3, 10, 10], "float32")] verify_model(Dense1(), input_info) verify_model(Dense2(), input_info) + verify_model(Dense1(), input_info, linear_to_conv=True) + verify_model(Dense2(), input_info, linear_to_conv=True) verify_model(MatMul1(), [([10, 10], "float32"), ([10, 10], "float32")]) @@ -368,10 +368,10 @@ def __init__(self): self.embedding = torch.nn.Embedding(10, 3) def forward(self, data): - return self.embedding(data) + return self.embedding(data.to(torch.int64)) - verify_model(Embedding(), [([4], "int64")], allow_incomplete=True) - verify_model(Embedding(), [([4, 5], "int64")], allow_incomplete=True) + verify_model(Embedding(), [([4], "int32")]) + verify_model(Embedding(), [([4, 5], "int32")]) @requires_tensorrt @@ -801,14 +801,14 @@ def test_argmax(): class Argmax1(Module): def forward(self, data): - return torch.argmax(data, dim=-1) + return torch.argmax(data, dim=-1).to(torch.int32) class Argmax2(Module): def forward(self, data): - return torch.argmax(data, dim=-1, keepdim=True) + return torch.argmax(data, dim=-1, keepdim=True).to(torch.int32) - verify_model(Argmax1(), [([256, 256], "float32")], allow_incomplete=True) - verify_model(Argmax2(), [([256, 256], "float32")], allow_incomplete=True) + verify_model(Argmax1(), [([256, 256], "float32")]) + verify_model(Argmax2(), [([256, 256], "float32")]) @requires_tensorrt @@ -817,14 +817,14 @@ def test_argmin(): class Argmin1(Module): def forward(self, data): - return torch.argmin(data, dim=-1) + return torch.argmin(data, dim=-1).to(torch.int32) class Argmin2(Module): def forward(self, data): - return torch.argmin(data, dim=-1, keepdim=True) + return torch.argmin(data, dim=-1, keepdim=True).to(torch.int32) - verify_model(Argmin1(), [([256, 256], "float32")], allow_incomplete=True) - verify_model(Argmin2(), [([256, 256], "float32")], allow_incomplete=True) + verify_model(Argmin1(), [([256, 256], "float32")]) + verify_model(Argmin2(), [([256, 256], "float32")]) @requires_tensorrt @@ -876,5 +876,22 @@ def forward(self, x, y): verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")]) +@requires_tensorrt +def test_gelu(): + """test tensorrt translator for gelu""" + + class Gelu1(Module): + def forward(self, data): + return torch.nn.functional.gelu(data) + + class Gelu2(Module): + def forward(self, data): + return torch.nn.functional.gelu(data, approximate="tanh") + + input_info = [([1, 3, 10, 10], "float32")] + verify_model(Gelu1(), input_info) + verify_model(Gelu2(), input_info) + + if __name__ == "__main__": tvm.testing.main() From 40989080b6df069c05c35c0b8d930051524a6133 Mon Sep 17 00:00:00 2001 From: Archermmt Date: Sun, 8 Sep 2024 07:28:53 +0800 Subject: [PATCH 2/2] format fix --- src/contrib/msc/core/utils.cc | 19 ++++++++++++++++--- src/contrib/msc/core/utils.h | 4 +++- .../framework/tensorrt/transform_tensorrt.cc | 2 +- 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc index c6e74d42843d..1e846b0b3a61 100644 --- a/src/contrib/msc/core/utils.cc +++ b/src/contrib/msc/core/utils.cc @@ -507,12 +507,25 @@ const String ExprUtils::GetSpanName(const Expr& expr, const String& suffix) { return name; } -const Array ExprUtils::GetShape(const Expr& expr) { - const auto& shape_opt = Downcast(relax::GetStructInfo(expr))->GetShape(); - ICHECK(shape_opt.defined()) << "Shape is not defined for " << expr; +const Array ExprUtils::GetShape(const relax::TensorStructInfo& sinfo, bool as_int) { + const auto& shape_opt = sinfo->GetShape(); + if (!shape_opt.defined()) { + return Array(); + } + if (as_int) { + Array shape; + for (const auto& s : shape_opt.value()) { + shape.push_back(s->IsInstance() ? s : Integer(-1)); + } + return shape; + } return shape_opt.value(); } +const Array ExprUtils::GetShape(const Expr& expr, bool as_int) { + return GetShape(Downcast(relax::GetStructInfo(expr)), as_int); +} + const DataType ExprUtils::GetDataType(const Expr& expr) { return Downcast(relax::GetStructInfo(expr))->dtype; } diff --git a/src/contrib/msc/core/utils.h b/src/contrib/msc/core/utils.h index d7758cc23d8b..7fb9c87a99f9 100644 --- a/src/contrib/msc/core/utils.h +++ b/src/contrib/msc/core/utils.h @@ -398,7 +398,9 @@ class ExprUtils { * \brief Get shape of expr. * \return The shape. */ - TVM_DLL static const Array GetShape(const Expr& expr); + TVM_DLL static const Array GetShape(const relax::TensorStructInfo& sinfo, + bool as_int = true); + TVM_DLL static const Array GetShape(const Expr& expr, bool as_int = true); /*! * \brief Get dtype of expr. diff --git a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc index 0f95f2d20622..542e15d06c3c 100644 --- a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc +++ b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc @@ -90,7 +90,7 @@ const Array BroadcastShape(const Array& src_shape, ICHECK(ArrayUtils::Broadcastable(leading_shape, out_shape)) << "Only support elemwise ops with leading or tailing expand"; return leading_shape; -}; +} Expr RewriteElemwise(BlockBuilder builder, const Var& var, const Call& src_call, const Map& new_calls, const String& config) {