diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py index a36b66c8f0dd..ab36b8bddd02 100644 --- a/python/tvm/relay/op/contrib/tensorrt.py +++ b/python/tvm/relay/op/contrib/tensorrt.py @@ -266,7 +266,7 @@ def _func_wrapper(expr): def reduce_annotate_fn(attrs, args, op_name): """Helper for reduce operations.""" - if not attrs.axis or len(attrs.axis) == 0: + if get_tensorrt_use_implicit_batch_mode() and (not attrs.axis or len(attrs.axis) == 0): logger.info("%s: cannot reduce to scalar.", op_name) return False if attrs.exclude: @@ -317,10 +317,9 @@ def add_annotate_fn(expr): # pylint: disable=unused-variable for arg in args ] - # RelayVM + TRT doesn't support scalar addition yet. - for shape in shapes: - if len(shape) < 1: - return False + # Scalars require explicit batch mode. + if get_tensorrt_use_implicit_batch_mode() and any([len(shape) < 1 for shape in shapes]): + return False if any([x.checked_type.dtype != "float32" for x in args]): logger.info("Only float32 inputs are supported for TensorRT.") @@ -328,6 +327,8 @@ def add_annotate_fn(expr): # pylint: disable=unused-variable if ( not get_tensorrt_use_implicit_batch_mode() and (isinstance(args[0], Constant) or isinstance(args[1], Constant)) + and len(shapes[0]) > 0 + and len(shapes[1]) > 0 and shapes[0][0] == shapes[1][0] and shapes[0][0] != 1 and (len(shapes[0]) > 3 or len(shapes[1]) > 3) @@ -552,6 +553,19 @@ def concatenate_annotate_fn(expr): # pylint: disable=unused-variable return True +@_register_external_dynamic_check_func("split") +def split_annotate_fn(expr): + """Check if split is supported by TensorRT.""" + + if any([x.checked_type.dtype != "float32" for x in expr.args]): + logger.info("Only float32 inputs are supported for TensorRT.") + return False + if get_tensorrt_use_implicit_batch_mode() and int(expr.attrs.axis) == 0: + logger.info("split: can't modify batch dimension.") + return False + return True + + @_register_external_dynamic_check_func("nn.conv2d_transpose") def conv2d_transpose_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.conv2d_transpose is supported by TensorRT.""" @@ -870,6 +884,11 @@ def visit_call(self, call): "nn.conv3d_transpose", "nn.dense", "nn.batch_matmul", + "sum", + "prod", + "max", + "min", + "mean", ] ) if isinstance(call.op, tvm.tir.op.Op): @@ -968,6 +987,7 @@ def visit_call(self, call): # Create new pruned module new_mod = tvm.IRModule(mod.functions, mod.type_definitions) new_mod["main"] = SubgraphRemover(subgraphs_to_remove, mod, new_mod).visit(mod["main"]) + new_mod = transform.RemoveUnusedFunctions()(new_mod) return new_mod diff --git a/src/relay/backend/contrib/tensorrt/codegen.cc b/src/relay/backend/contrib/tensorrt/codegen.cc index e121b6010ad8..d83a9003229c 100644 --- a/src/relay/backend/contrib/tensorrt/codegen.cc +++ b/src/relay/backend/contrib/tensorrt/codegen.cc @@ -99,6 +99,8 @@ class TensorRTJSONSerializer : public backend::contrib::JSONSerializer { SetPadNodeAttribute(node, cn); } else if (name == "strided_slice") { SetStridedSliceNodeAttribute(node, cn); + } else if (name == "split") { + SetSplitNodeAttribute(node, cn); } else { SetCallNodeAttribute(node, cn); } @@ -172,6 +174,35 @@ class TensorRTJSONSerializer : public backend::contrib::JSONSerializer { node->SetAttr("strides", strides_attr); } + void SetSplitNodeAttribute(std::shared_ptr node, const CallNode* cn) { + const auto* split_attr = cn->attrs.as(); + ICHECK(split_attr); + + std::vector indices_or_sections; + std::vector mode; + std::vector axis = {std::to_string(split_attr->axis)}; + if (const IntImmNode* sections = split_attr->indices_or_sections.as()) { + mode.emplace_back("sections"); + indices_or_sections.emplace_back(std::to_string(sections->value)); + } else { + mode.emplace_back("indices"); + auto indices = Downcast>(split_attr->indices_or_sections); + for (const auto& i : indices) { + indices_or_sections.emplace_back(std::to_string(i->value)); + } + } + + std::vector indices_or_sections_attr; + std::vector mode_attr; + std::vector axis_attr; + indices_or_sections_attr.emplace_back(indices_or_sections); + mode_attr.emplace_back(mode); + axis_attr.emplace_back(axis); + node->SetAttr("indices_or_sections", indices_or_sections_attr); + node->SetAttr("mode", mode_attr); + node->SetAttr("axis", axis_attr); + } + void SaveGlobalAttributes(std::shared_ptr node) { auto ctx = transform::PassContext::Current(); auto cfg = ctx->GetConfig("relay.ext.tensorrt.options"); diff --git a/src/runtime/contrib/tensorrt/tensorrt_ops.cc b/src/runtime/contrib/tensorrt/tensorrt_ops.cc index 04b1e838ee8e..9b108fac67c2 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_ops.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_ops.cc @@ -723,6 +723,53 @@ class ConcatOpConverter : public TensorRTOpConverter { } }; +class SplitOpConverter : public TensorRTOpConverter { + public: + SplitOpConverter() : TensorRTOpConverter({kTensor}) {} + + void Convert(TensorRTOpConverterParams* params) const { + auto input = params->inputs.at(0).tensor; + auto input_dims = TrtDimsToVector(input->getDimensions()); + const int original_axis = std::stoi(params->node.GetAttr>("axis")[0]); + const int axis = ConvertAxis(params, original_axis, input_dims.size()); + auto indices_or_sections = + params->node.GetAttr>("indices_or_sections"); + auto mode = params->node.GetAttr>("mode")[0]; + + std::vector split_starts; + std::vector split_sizes; + if (mode == "sections") { + int sections = std::stoi(indices_or_sections[0]); + int size = input_dims[axis] / sections; + for (int i = 0; i < sections; i++) { + split_starts.push_back(i * size); + split_sizes.push_back(size); + } + } else { + int last_index = 0; + for (size_t i = 0; i < indices_or_sections.size(); ++i) { + int index = std::stoi(indices_or_sections[i]); + split_starts.push_back(last_index); + split_sizes.push_back(index - last_index); + last_index = index; + } + split_starts.push_back(last_index); + split_sizes.push_back(input_dims[axis] - last_index); + } + + std::vector start(input_dims.size(), 0); + std::vector size(input_dims.begin(), input_dims.end()); + std::vector strides(input_dims.size(), 1); + for (int i = 0; i < split_sizes.size(); ++i) { + start[axis] = split_starts[i]; + size[axis] = split_sizes[i]; + auto slice_layer = params->network->addSlice(*input, VectorToTrtDims(start), + VectorToTrtDims(size), VectorToTrtDims(strides)); + params->outputs.push_back(slice_layer->getOutput(0)); + } + } +}; + class BiasAddOpConverter : public TensorRTOpConverter { public: BiasAddOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {} @@ -970,9 +1017,17 @@ class ReduceOpConverter : public TensorRTOpConverter { // TODO(trevmorr): Support reduce to scalar. ICHECK_GT(str_axis.size(), 0); uint32_t reduce_axes = 0; - for (size_t i = 0; i < str_axis.size(); ++i) { - const int axis = ConvertAxis(params, std::stoi(str_axis[i]), input->getDimensions().nbDims); - reduce_axes |= 1 << axis; + + if (str_axis.size() == 1 && str_axis[0].length() == 0) { + // Reduce to scalar + for (int i = 0; i < input->getDimensions().nbDims; ++i) { + reduce_axes |= 1 << i; + } + } else { + for (size_t i = 0; i < str_axis.size(); ++i) { + const int axis = ConvertAxis(params, std::stoi(str_axis[i]), input->getDimensions().nbDims); + reduce_axes |= 1 << axis; + } } auto reduce_layer = params->network->addReduce(*input, it->second, reduce_axes, keepdims); params->outputs.push_back(reduce_layer->getOutput(0)); @@ -1072,6 +1127,7 @@ GetOpConverters() { map->emplace("expand_dims", std::make_shared()); map->emplace("squeeze", std::make_shared()); map->emplace("concatenate", std::make_shared()); + map->emplace("split", std::make_shared()); map->emplace("nn.conv2d_transpose", std::make_shared()); map->emplace("transpose", std::make_shared()); map->emplace("layout_transform", std::make_shared()); diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc index 21031c67863f..7efa5bf73186 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc @@ -185,7 +185,8 @@ class TensorRTRuntime : public JSONRuntimeBase { * do nothing. */ void BuildEngine() { - batch_size_ = data_entry_[input_var_eid_[0]]->shape[0]; + batch_size_ = + data_entry_[input_var_eid_[0]]->ndim == 0 ? 1 : data_entry_[input_var_eid_[0]]->shape[0]; if (trt_engine_cache_.count(std::make_pair(symbol_name_, batch_size_))) return; DLOG(INFO) << "Building new TensorRT engine for subgraph " << symbol_name_ << " with batch size " << batch_size_; diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py index 2bef7be65938..c6b714ccbc8b 100644 --- a/tests/python/contrib/test_tensorrt.py +++ b/tests/python/contrib/test_tensorrt.py @@ -605,6 +605,19 @@ def get_graph(input_shapes, axis): run_and_verify_func(get_graph([(1, 2, 6, 6), (1, 3, 6, 6)], axis=1)) +def test_split(): + def get_graph(x_shape, indices_or_sections, axis): + x = relay.var("x", shape=(x_shape), dtype="float32") + out = relay.split(x, indices_or_sections=indices_or_sections, axis=axis) + f = relay.Function([x], out.astuple()) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph((1, 16), indices_or_sections=2, axis=1)) + run_and_verify_func(get_graph((1, 16), indices_or_sections=4, axis=1)) + run_and_verify_func(get_graph((1, 16), indices_or_sections=[8], axis=1)) + run_and_verify_func(get_graph((1, 16), indices_or_sections=[2, 3, 6, 10, 14], axis=1)) + + def test_conv2d_transpose(): def get_graph( x_shape=(1, 32, 8, 8),