diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py index 44336073d842..acd4f4740b2d 100644 --- a/python/tvm/relay/op/contrib/tensorrt.py +++ b/python/tvm/relay/op/contrib/tensorrt.py @@ -23,7 +23,7 @@ from tvm.relay import transform from tvm.relay.build_module import bind_params_by_name from tvm.relay.expr import Call, Constant, Tuple, GlobalVar, Var, TupleGetItem -from tvm.relay.expr_functor import ExprMutator +from tvm.relay.expr_functor import ExprMutator, ExprVisitor logger = logging.getLogger("TensorRT") @@ -173,7 +173,7 @@ def check_dynamism(args, op_name): """ for arg in args: if isinstance(arg, (Call, Var, Constant, TupleGetItem)): - for dim_shape in arg.checked_type.shape: + for dim_shape in arg.checked_type.shape[1:]: if isinstance(dim_shape, tvm.tir.expr.Any): return True elif isinstance(arg, Tuple): @@ -198,6 +198,21 @@ def _func_wrapper(expr): if any([x.checked_type.dtype != "float32" for x in args]): logger.info("Only float32 inputs are supported for TensorRT.") return False + if op_name == "multiply": + shapes = [ + [ + int(x) if not isinstance(x, tvm.tir.expr.Any) else -1 + for x in arg.checked_type.shape + ] + for arg in args + ] + # Batched multiply operations don't work in implicit batch mode. The following shapes + # have been excluded because they occur in PT MaskRCNN model. The long term solution is + # to switch to explicit batch mode after performance regressions are solved. + if all( + [list(map(int, shape)) in [[300, 64, 7, 7], [300, 1, 1, 1]] for shape in shapes] + ): + return False return checker(attrs, args, op_name) return _func_wrapper @@ -292,19 +307,26 @@ def add_annotate_fn(expr): # pylint: disable=unused-variable """Check if add is supported by TensorRT.""" args = expr.args + + shapes = [ + [int(x) if not isinstance(x, tvm.tir.expr.Any) else -1 for x in arg.checked_type.shape] + for arg in args + ] + # RelayVM + TRT doesn't support scalar addition yet. - for arg in args: - if not arg.checked_type.shape: + for shape in shapes: + if len(shape) < 1: return False + if any([x.checked_type.dtype != "float32" for x in args]): logger.info("Only float32 inputs are supported for TensorRT.") return False if ( not get_tensorrt_use_implicit_batch_mode() and (isinstance(args[0], Constant) or isinstance(args[1], Constant)) - and args[0].checked_type.shape[0] == args[1].checked_type.shape[0] - and args[0].checked_type.shape[0] != 1 - and (len(args[0].checked_type.shape) > 3 or len(args[1].checked_type.shape) > 3) + and shapes[0][0] == shapes[1][0] + and shapes[0][0] != 1 + and (len(shapes[0]) > 3 or len(shapes[1]) > 3) ): logger.info("add: bug in TRT with adding batched constants.") return False @@ -592,11 +614,35 @@ def reshape_annotate_fn(expr): # pylint: disable=unused-variable logger.info("reshape: new shape dims must be explicit.") return False if get_tensorrt_use_implicit_batch_mode(): - shape = list(map(int, args[0].checked_type.shape)) - new_shape = list(map(int, attrs.newshape)) + shape = args[0].checked_type.shape + new_shape = attrs.newshape if len(new_shape) == 0 or len(shape) == 0: logger.info("reshape: Can't reshape to or from scalar.") return False + + dynamic_reshape = any([isinstance(x, tvm.tir.expr.Any) for x in shape]) + + if dynamic_reshape: + # Make sure that the batch dim is unmodified. + if int(new_shape[0]) < 0: + for shape_val, new_shape_val in enumerate(shape[1:], new_shape[1:]): + if not ( + isinstance(shape_val, int) + and isinstance(new_shape_val, int) + and int(shape_val) == int(new_shape_val) + ): + return False + elif int(new_shape[0]) > 0: + if not ( + isinstance(shape[0], int) + and isinstance(new_shape[0], int) + and int(shape[0]) == int(new_shape[0]) + ): + return False + return True + shape = list(map(int, shape)) + new_shape = list(map(int, new_shape)) + # TRT cannot modify batch dimension. original_volume = np.prod(shape) # First, resolve 0. @@ -607,6 +653,7 @@ def reshape_annotate_fn(expr): # pylint: disable=unused-variable for i, value in enumerate(new_shape): if value == -1: new_shape[i] = original_volume // np.prod([x for x in new_shape if x != -1]) + # Remove batch dimension and see if volumes match if shape[0] != new_shape[0]: logger.info("reshape: can't modify batch dimension.") return False @@ -795,6 +842,41 @@ def conv3d_transpose_annotate_fn(expr): # pylint: disable=unused-variable return True +class IsComputeIntensiveGraph(ExprVisitor): + """ + Visits the Graph recursively and checks if it contains compute heavy ops like convolutions and + its transpose, dense and batch mat-mul. + """ + + def __init__(self): + ExprVisitor.__init__(self) + self.is_compute_intensive = False + + def visit_call(self, call): + compute_intensive_ops = set( + [ + "nn.conv2d", + "nn.conv2d_transpose", + "nn.conv3d", + "nn.conv3d_transpose", + "nn.dense", + "nn.batch_matmul", + ] + ) + if isinstance(call.op, tvm.tir.op.Op): + if str(call.op) in compute_intensive_ops: + self.is_compute_intensive = True + + return super().visit_call(call) + + def is_graph_compute_intensive(self, subgraph) -> bool: + """ + This function recursively visits the graph and checks if it's compute intensive" + """ + self.visit(subgraph) + return self.is_compute_intensive + + def is_valid_subgraph(params, body): """Final check on whether the subgraph is valid and should be offloaded to TensorRT.""" # Remove invalid subgraphs for implicit batch mode. @@ -802,24 +884,31 @@ def is_valid_subgraph(params, body): input_batch_sizes = [] for var in params: # In implicit batch mode, all inputs must have same batch size + # TODO: (codeislife99) : Fix different dynamic batch size inputs + if isinstance(var.checked_type, relay.TupleType): for tupe_type in var.checked_type.fields: # Scalar inputs not allowed if len(tupe_type.shape) == 0: logger.info("tensorrt: scalar inputs not supported") return False - input_batch_sizes.append(int(tupe_type.shape[0])) + + if not isinstance(tupe_type.shape[0], tvm.tir.expr.Any): + input_batch_sizes.append(int(tupe_type.shape[0])) else: # Scalar inputs not allowed if len(var.checked_type.shape) == 0: logger.info("tensorrt: scalar inputs not supported") return False - input_batch_sizes.append(int(var.checked_type.shape[0])) + if not isinstance(var.checked_type.shape[0], tvm.tir.expr.Any): + input_batch_sizes.append(int(var.checked_type.shape[0])) if len(input_batch_sizes) > 1 and len(set(input_batch_sizes)) != 1: logger.info("tensorrt: inputs have different batch sizes") return False - # Remove subgraphs with no multiply-accumulates - if get_tensorrt_remove_no_mac_subgraphs() and relay.analysis.get_total_mac_number(body) == 0: + if ( + get_tensorrt_remove_no_mac_subgraphs() + and not IsComputeIntensiveGraph().is_graph_compute_intensive(body) + ): return False return True @@ -880,6 +969,8 @@ class RemoveDropout(ExprMutator): def visit_tuple_getitem(self, op): visit = super().visit_tuple_getitem(op) + if visit.index != 0: + return visit if ( isinstance(visit.tuple_value, Call) and visit.tuple_value.op.name == "nn.dropout" diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 4426642e8e18..ccb8611b7a3c 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -160,8 +160,7 @@ inline std::vector GetIntShape(const Array& shape) { std::vector ret; for (const auto& dim : shape) { const int64_t* pval = tir::as_const_int(dim); - ICHECK(pval) << "Expect integer, but received: " << dim->GetTypeKey(); - ret.push_back(*pval); + ret.push_back(pval ? *pval : -1); } return ret; } diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc index 445010321668..3f87f8d00ee6 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc @@ -41,6 +41,13 @@ namespace tvm { namespace runtime { namespace contrib { +struct PairHash { + template + std::size_t operator()(const std::pair& pair) const { + return std::hash()(pair.first) ^ std::hash()(pair.second); + } +}; + using namespace tvm::runtime::json; class TensorRTRuntime : public JSONRuntimeBase { @@ -105,12 +112,13 @@ class TensorRTRuntime : public JSONRuntimeBase { /*! \brief Run inference using built engine. */ void Run() override { BuildEngine(); - auto& engine_and_context = trt_engine_cache_.at(symbol_name_); + batch_size_ = data_entry_[input_var_eid_[0]]->shape[0]; + if (batch_size_ == 0) return; + auto& engine_and_context = trt_engine_cache_.at(std::make_pair(symbol_name_, batch_size_)); auto engine = engine_and_context.engine; auto context = engine_and_context.context; auto& device_buffers = engine_and_context.device_buffers; std::vector bindings(engine->getNbBindings(), nullptr); - for (size_t i = 0; i < input_nodes_.size(); ++i) { auto nid = input_nodes_[i]; if (nodes_[nid].GetOpType() == "input") { @@ -169,10 +177,11 @@ class TensorRTRuntime : public JSONRuntimeBase { * do nothing. */ void BuildEngine() { - if (trt_engine_cache_.count(symbol_name_)) return; - DLOG(INFO) << "Building new TensorRT engine for subgraph " << symbol_name_; + batch_size_ = data_entry_[input_var_eid_[0]]->shape[0]; + if (trt_engine_cache_.count(std::make_pair(symbol_name_, batch_size_))) return; + DLOG(INFO) << "Building new TensorRT engine for subgraph " << symbol_name_ + << " with batch size " << batch_size_; const bool use_fp16 = dmlc::GetEnv("TVM_TENSORRT_USE_FP16", false); - batch_size_ = GetBatchSize(); TensorRTBuilder builder(&logger_, data_entry_, max_workspace_size_, use_implicit_batch_, use_fp16, batch_size_); @@ -203,8 +212,9 @@ class TensorRTRuntime : public JSONRuntimeBase { } // Build engine. - trt_engine_cache_[symbol_name_] = builder.BuildEngine(); - DLOG(INFO) << "Finished building TensorRT engine for subgraph " << symbol_name_; + trt_engine_cache_[std::make_pair(symbol_name_, batch_size_)] = builder.BuildEngine(); + DLOG(INFO) << "Finished building TensorRT engine for subgraph " << symbol_name_ + << " with batch size " << batch_size_; CacheEngineToDisk(); } @@ -240,7 +250,8 @@ class TensorRTRuntime : public JSONRuntimeBase { helper.DeclareField("inputs", &engine_and_context.inputs); helper.DeclareField("outputs", &engine_and_context.outputs); helper.ReadAllFields(&reader); - trt_engine_cache_[symbol_name_] = engine_and_context; + const int batch_size = 1; + trt_engine_cache_[std::make_pair(symbol_name_, batch_size)] = engine_and_context; return true; } @@ -248,13 +259,15 @@ class TensorRTRuntime : public JSONRuntimeBase { * directory so it can be loaded later. */ void CacheEngineToDisk() { + batch_size_ = data_entry_[input_var_eid_[0]]->shape[0]; std::string cache_dir = dmlc::GetEnv("TVM_TENSORRT_CACHE_DIR", std::string("")); if (cache_dir.empty()) return; std::string key = GetSubgraphKey(); std::string path = cache_dir + "/" + key + ".plan"; DLOG(INFO) << "Caching TensorRT engine to " << path; // Serialize engine to disk - nvinfer1::IHostMemory* serialized_engine = trt_engine_cache_[symbol_name_].engine->serialize(); + nvinfer1::IHostMemory* serialized_engine = + trt_engine_cache_[std::make_pair(symbol_name_, batch_size_)].engine->serialize(); SaveBinaryToFile(path, std::string(static_cast(serialized_engine->data()), serialized_engine->size())); serialized_engine->destroy(); @@ -262,8 +275,10 @@ class TensorRTRuntime : public JSONRuntimeBase { std::ostringstream os; dmlc::JSONWriter writer(&os); writer.BeginObject(); - writer.WriteObjectKeyValue("inputs", trt_engine_cache_[symbol_name_].inputs); - writer.WriteObjectKeyValue("outputs", trt_engine_cache_[symbol_name_].outputs); + writer.WriteObjectKeyValue("inputs", + trt_engine_cache_[std::make_pair(symbol_name_, batch_size_)].inputs); + writer.WriteObjectKeyValue( + "outputs", trt_engine_cache_[std::make_pair(symbol_name_, batch_size_)].outputs); writer.EndObject(); std::string meta_path = cache_dir + "/" + key + ".meta"; SaveBinaryToFile(meta_path, os.str()); @@ -290,7 +305,8 @@ class TensorRTRuntime : public JSONRuntimeBase { } /*! \brief Map of function name to TRT engine if built already. */ - std::unordered_map trt_engine_cache_; + std::unordered_map, TensorRTEngineAndContext, PairHash> + trt_engine_cache_; /*! \brief TensorRT logger. */ TensorRTLogger logger_; diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py index 8b61323a71ad..10c311a6d363 100644 --- a/tests/python/contrib/test_tensorrt.py +++ b/tests/python/contrib/test_tensorrt.py @@ -21,11 +21,15 @@ import tvm import tvm.relay.testing + from tvm import relay from tvm.relay.op.contrib import tensorrt from tvm.contrib import graph_runtime, utils from tvm.runtime.vm import VirtualMachine from tvm.relay import Any, GlobalVar, transform +from typing import Dict, Tuple, Union +from tvm.contrib.download import download +from tvm.relay.op.contrib import tensorrt def skip_codegen_test(): @@ -1034,5 +1038,186 @@ def set_func_attr(func, compile_name, symbol_name): tvm.ir.assert_structural_equal(mod_trt, mod_exp, map_free_vars=True) +def test_tensorrt_dynamic_batch(): + if skip_codegen_test(): + return + + batches_to_test = [1, 1, 0, 2, 3, 0, 1, 3, 2] + x_shape = (relay.Any(), 1, 8, 8) + x_data = np.ones([max(batches_to_test)] + list(x_shape)[1:]).astype("float32") + result_dict = {} + for use_trt in [True, False]: + x = relay.var("x", shape=x_shape, dtype="float32") + out = relay.nn.relu(x) + f = relay.Function([x], out) + mod = tvm.IRModule() + mod["main"] = f + if use_trt: + mod = relay.tensorrt.EnableTrt(mod) + + if not skip_runtime_test(): + with relay.build_config(opt_level=3): + relay_exec = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0), target="llvm") + + for i, batch_size in enumerate(batches_to_test): + result_dict[(i, use_trt)] = relay_exec.evaluate()(x_data[:batch_size, ...]) + + if not skip_runtime_test(): + for i in range(len(batches_to_test)): + assert_result_matches(result_dict[(i, True)], result_dict[(i, False)]) + + +def test_tensorrt_dynamic_batch_conv(): + if skip_codegen_test(): + return + batches_to_test = [1, 1, 0, 2, 3, 0, 1, 3, 2] + x_shape = (relay.Any(), 32, 8, 8) + x_data = np.ones([max(batches_to_test)] + list(x_shape)[1:]).astype("float32") + k_shape = (16, 32, 3, 3) + params = {"kernel": np.random.uniform(-1, 1, k_shape).astype("float32")} + result_dict = {} + for use_trt in [True, False]: + x = relay.var("x", shape=x_shape, dtype="float32") + kernel = relay.var("kernel", shape=k_shape, dtype="float32") + out = relay.nn.conv2d(x, kernel, channels=16, kernel_size=(3, 3), groups=1) + f = relay.Function([x, kernel], out) + mod = tvm.IRModule() + mod["main"] = f + if use_trt: + mod = tensorrt.partition_for_tensorrt(mod, params) + + if not skip_runtime_test(): + with relay.build_config(opt_level=3): + relay_exec = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0), target="llvm") + + for i, batch_size in enumerate(batches_to_test): + result_dict[(i, use_trt)] = relay_exec.evaluate()( + x=x_data[:batch_size, ...], **params + ) + + if not skip_runtime_test(): + for i in range(len(batches_to_test)): + assert_result_matches(result_dict[(i, True)], result_dict[(i, False)]) + + +def test_maskrcnn_resnet50() -> None: + """ + This function tests the working of pytorch maskrcnn with resnet50 as backbone with + VM and VM + TRT. Since the order of compiled model outputs is a bit different from + original pytorch model, it uses a custom logic for comparison check. + """ + if skip_codegen_test(): + return + + import torch + import torchvision + + def convert_traced_model_to_vm_trt( + traced_module: torch.jit.TopLevelTracedModule, np_sample_input: np.ndarray, target: str + ) -> tvm.runtime.vm.Executable: + """ + This function converts a traced pytorch model to VM + TRT. + """ + input_shape = np_sample_input.shape + input_name = "input0" + shape_list = [(input_name, input_shape)] + mod, params = relay.frontend.from_pytorch(traced_module, shape_list) + mod, config = tensorrt.partition_for_tensorrt(mod, params, remove_no_mac_subgraphs=True) + with tvm.transform.PassContext(opt_level=3, disabled_pass=["FoldScaleAxis"]): + vm_trt_exec = relay.vm.compile(mod, target=target, params=params) + + return vm_trt_exec + + class TraceWrapper(torch.nn.Module): + """ + This class is a wrapper over the torch module to convert the outputs into traceable form + """ + + def __init__(self, model: torch.nn.Module) -> None: + super().__init__() + self.model = model + + def forward( + self, inp: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + out = self.model(inp) + return out[0]["boxes"], out[0]["scores"], out[0]["labels"], out[0]["masks"] + + def get_traced_maskrcnn_model(np_sample_input: np.ndarray) -> torch.jit.TopLevelTracedModule: + """ + This function takes a sample input and returns the traced maskrcnn model + """ + model_func = torchvision.models.detection.maskrcnn_resnet50_fpn + model = TraceWrapper(model_func(pretrained=True)) + model.eval() + inp = torch.Tensor(np.random.uniform(0.0, 250.0, size=np_sample_input.shape)) + + with torch.no_grad(): + out = model(inp) + traced_module = torch.jit.trace(model, inp) + traced_module.eval() + + return traced_module + + def get_maskrcnn_input(in_size: int) -> np.ndarray: + """ + This function gets a real image with multiple objects of interest and returns it. + """ + input_shape = (1, 3, in_size, in_size) + img_path = "test_street_small.jpg" + img_url = ( + "https://raw.githubusercontent.com/dmlc/web-data/" + "master/gluoncv/detection/street_small.jpg" + ) + download(img_url, img_path) + import cv2 + + img = cv2.imread(img_path).astype("float32") + img = cv2.resize(img, (in_size, in_size)) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = np.transpose(img / 255.0, [2, 0, 1]) + img = np.expand_dims(img, axis=0) + + return img + + in_size = 300 + np_sample_input = get_maskrcnn_input(in_size) + traced_module = get_traced_maskrcnn_model(np_sample_input) + vm_trt_exec = convert_traced_model_to_vm_trt(traced_module, np_sample_input, target="llvm") + + if skip_runtime_test(): + return + + ctx = tvm.cpu() + vm = tvm.runtime.vm.VirtualMachine(vm_trt_exec, ctx) + vm.set_input("main", **{"input0": np_sample_input}) + tvm_res = vm.run() + + # Descending sort by scores and get the high confidence indices. In this example 9 is chosen, + # because this image has 9 boxes over 0.9 confidence + num_high_confidence_boxes = 9 + tvm_indices = np.argsort(-1 * tvm_res[1].asnumpy())[:num_high_confidence_boxes] + + with torch.no_grad(): + out = traced_module(torch.Tensor(np_sample_input)) + # Descending sort by scores and get the high confidence indices + pt_indices = np.argsort(-1 * out[1].numpy())[:num_high_confidence_boxes] + + tol = [1e-1, 5e-3, 1e-5, 4e-1] # [Box Tol, Score Tol, Label Tol, Mask Tol] + # Because of certain ops, there are certain minor differences in TVM outputs and PT outputs, + # This means that the tolerance can't be 1e-4 or 1e-5 throughout. The ideal way to get around + # this is to test it on an entire dataset and compare mAP with the original model. + # However, since that is not practically possible on CI, the following compromise is made. + # These tolerances are chosen based on their impact or lack thereof to the mAP score, e.g: + # 0.1 pixel difference of a box in a 300X300 image wont make any change. + for i, tol_val in zip(range(4), tol): + np.testing.assert_allclose( + tvm_res[i].asnumpy()[tvm_indices], + out[i].numpy()[pt_indices], + rtol=tol_val, + atol=tol_val, + ) + + if __name__ == "__main__": pytest.main([__file__])