diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index c12e096e9051..e9b2d01ae022 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -50,6 +50,15 @@ __all__ = ["from_onnx"] +# The default configurations of Relay ONNX frontend. +ONNX_DEFAULT_CONFIGS = { + # By default, TVM converts qualified onnx `matmul` to `transpose(weight) + nn.batch_matmul_NT`. + # Change this flag to False to directly convert to `nn.batch_matmul`. + # Note that `nn.batch_matmul` with format other than NT is in experimental, it may have some + # performance issues. + "use_nt_batch_matmul": True, +} + class onnx_input: """Dual purpose list or dictionary access object.""" @@ -770,10 +779,14 @@ def flatten_to_nd(x, x_shape, nd=3): # Convert a and b into 3 dimensional tensors. a = flatten_to_nd(inputs[0], a_shape, 3) b = flatten_to_nd(inputs[1], b_shape, 3) - # Transpose matrix dimensions of b. - b = _op.transpose(b, [0, 2, 1]) - # Perform a batch matmul. - output = _op.nn.batch_matmul(a, b) + if ONNX_DEFAULT_CONFIGS["use_nt_batch_matmul"]: + # Transpose matrix dimensions of b. + b = _op.transpose(b, [0, 2, 1]) + # Perform a NT batch matmul. + output = _op.nn.batch_matmul(a, b) + else: + # Perform a NN batch matmul. + output = _op.nn.batch_matmul(a, b, transpose_b=False) # Determine the output batch dimension. if a_rank > b_rank: out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2]) @@ -3916,7 +3929,9 @@ def _fix_outputs(self, op_name, outputs): return outputs -def from_onnx(model, shape=None, dtype="float32", opset=None, freeze_params=False): +def from_onnx( + model, shape=None, dtype="float32", opset=None, freeze_params=False, convert_config=None +): """Convert a ONNX model into an equivalent Relay Function. ONNX graphs are represented as Python Protobuf objects. @@ -3955,6 +3970,12 @@ def from_onnx(model, shape=None, dtype="float32", opset=None, freeze_params=Fals at compile time and helps in making models static if certain inputs represent attributes relay would traditionally consider compile-time constants. + convert_config : Optional[Dict[str, Any]] + Default config: + use_nt_batch_matmul : bool = True + True to convert qualified onnx `matmul` to `nn.batch_matmul` strict to NT format + (transpose_a=False, transpose_b=True). + Returns ------- mod : tvm.IRModule @@ -3963,6 +3984,10 @@ def from_onnx(model, shape=None, dtype="float32", opset=None, freeze_params=Fals params : dict of str to tvm.nd.NDArray The parameter dict to be used by relay """ + global ONNX_DEFAULT_CONFIGS + if convert_config is not None: + ONNX_DEFAULT_CONFIGS.update(convert_config) + try: import onnx diff --git a/src/runtime/contrib/tensorrt/tensorrt_logger.h b/src/runtime/contrib/tensorrt/tensorrt_logger.h index eb0164210dbb..5406f4c57d66 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_logger.h +++ b/src/runtime/contrib/tensorrt/tensorrt_logger.h @@ -39,7 +39,7 @@ class TensorRTLogger : public nvinfer1::ILogger { public: TensorRTLogger() : TensorRTLogger(Severity::kWARNING) {} explicit TensorRTLogger(Severity severity) : reportable_severity(severity) {} - void log(Severity severity, const char* msg) override { + void log(Severity severity, const char* msg) noexcept override { // suppress messages with severity enum value greater than the reportable if (severity > reportable_severity) return; diff --git a/src/runtime/contrib/tensorrt/tensorrt_ops.cc b/src/runtime/contrib/tensorrt/tensorrt_ops.cc index 7197172d73db..94bbae1559d9 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_ops.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_ops.cc @@ -835,7 +835,7 @@ class SplitOpConverter : public TensorRTOpConverter { std::vector start(input_dims.size(), 0); std::vector size(input_dims.begin(), input_dims.end()); std::vector strides(input_dims.size(), 1); - for (int i = 0; i < split_sizes.size(); ++i) { + for (size_t i = 0; i < split_sizes.size(); ++i) { start[axis] = split_starts[i]; size[axis] = split_sizes[i]; auto slice_layer = params->network->addSlice(*input, VectorToTrtDims(start), @@ -1174,9 +1174,14 @@ class BatchMatmulOpConverter : public TensorRTOpConverter { BatchMatmulOpConverter() : TensorRTOpConverter({kTensor, kTensor}) {} void Convert(TensorRTOpConverterParams* params) const { + auto transa = std::stoi(params->node.GetAttr>("transpose_a")[0]); + auto transb = std::stoi(params->node.GetAttr>("transpose_b")[0]); + nvinfer1::MatrixOperation trt_transa = + transa ? nvinfer1::MatrixOperation::kTRANSPOSE : nvinfer1::MatrixOperation::kNONE; + nvinfer1::MatrixOperation trt_transb = + transb ? nvinfer1::MatrixOperation::kTRANSPOSE : nvinfer1::MatrixOperation::kNONE; nvinfer1::IMatrixMultiplyLayer* matmul_layer = params->network->addMatrixMultiply( - *params->inputs.at(0).tensor, nvinfer1::MatrixOperation::kNONE, - *params->inputs.at(1).tensor, nvinfer1::MatrixOperation::kTRANSPOSE); + *params->inputs.at(0).tensor, trt_transa, *params->inputs.at(1).tensor, trt_transb); ICHECK(matmul_layer != nullptr); params->outputs.push_back(matmul_layer->getOutput(0)); } diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py index 3f57df5a5f4a..082ded704faa 100644 --- a/tests/python/contrib/test_tensorrt.py +++ b/tests/python/contrib/test_tensorrt.py @@ -474,14 +474,25 @@ def get_graph(x_shape=(1, 16), k_shape=(32, 16)): def test_batch_matmul(): - def get_graph(x_shape=(12, 128, 64), y_shape=(12, 128, 64)): + def get_graph(x_shape=(12, 128, 64), y_shape=(12, 128, 64), transa=False, transb=True): x = relay.var("x", shape=(x_shape), dtype="float32") y = relay.var("y", shape=(y_shape), dtype="float32") - out = relay.nn.batch_matmul(x, y) + out = relay.nn.batch_matmul(x, y, transpose_a=transa, transpose_b=transb) f = relay.Function([x, y], out) return f, {"x": x_shape, "y": y_shape}, [] - run_and_verify_func(get_graph()) + run_and_verify_func( + get_graph(x_shape=(12, 64, 128), y_shape=(12, 128, 64), transa=True, transb=True) + ) + run_and_verify_func( + get_graph(x_shape=(12, 64, 128), y_shape=(12, 64, 128), transa=True, transb=False) + ) + run_and_verify_func( + get_graph(x_shape=(12, 128, 64), y_shape=(12, 128, 64), transa=False, transb=True) + ) + run_and_verify_func( + get_graph(x_shape=(12, 128, 64), y_shape=(12, 64, 128), transa=False, transb=False) + ) def test_bias_add(): diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 0100b439736d..6a5ffd3821a1 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -48,7 +48,14 @@ def get_input_data_shape_dict(graph_def, input_data): def get_tvm_output_with_vm( - graph_def, input_data, target, dev, opset=None, freeze_params=False, convert_to_static=False + graph_def, + input_data, + target, + dev, + opset=None, + freeze_params=False, + convert_to_static=False, + convert_config=None, ): """Generic function to execute and get tvm output with vm executor""" if not isinstance(input_data, list): @@ -56,7 +63,11 @@ def get_tvm_output_with_vm( _, shape_dict = get_input_data_shape_dict(graph_def, input_data) mod, params = relay.frontend.from_onnx( - graph_def, shape_dict, opset=opset, freeze_params=freeze_params + graph_def, + shape_dict, + opset=opset, + freeze_params=freeze_params, + convert_config=convert_config, ) if convert_to_static: @@ -78,12 +89,15 @@ def get_tvm_output( output_dtype="float32", opset=None, opt_level=1, + convert_config=None, ): """Generic function to execute and get tvm output""" # TODO: Resolve the issues and remove the following lines input_names, shape_dict = get_input_data_shape_dict(graph_def, input_data) - mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset) + mod, params = relay.frontend.from_onnx( + graph_def, shape_dict, opset=opset, convert_config=convert_config + ) with tvm.transform.PassContext(opt_level=opt_level): graph, lib, params = relay.build(mod, target, params=params) @@ -146,6 +160,7 @@ def verify_with_ort_with_inputs( atol=1e-5, apply_softmax=False, opt_level=1, + convert_config=None, ): if opset is not None: model.opset_import[0].version = opset @@ -161,10 +176,19 @@ def verify_with_ort_with_inputs( opset=opset, freeze_params=freeze_params, convert_to_static=convert_to_static, + convert_config=convert_config, ) else: tvm_out = get_tvm_output( - model, inputs, target, dev, out_shape, dtype, opset=opset, opt_level=opt_level + model, + inputs, + target, + dev, + out_shape, + dtype, + opset=opset, + opt_level=opt_level, + convert_config=convert_config, ) if not isinstance(tvm_out, list): tvm_out = [tvm_out] @@ -1179,7 +1203,7 @@ def test_matmul(target, dev): @tvm.testing.parametrize_targets def test_batch_matmul(target, dev): - def verify_batch_matmul(a_shape, b_shape, out_shape): + def verify_batch_matmul(a_shape, b_shape, out_shape, convert_config=None): a_array = np.random.uniform(size=a_shape).astype("float32") b_array = np.random.uniform(size=b_shape).astype("float32") @@ -1196,7 +1220,14 @@ def verify_batch_matmul(a_shape, b_shape, out_shape): ) model = helper.make_model(graph, producer_name="matmul_test") - verify_with_ort_with_inputs(model, [a_array, b_array], use_vm=True, target=target, dev=dev) + verify_with_ort_with_inputs( + model, + [a_array, b_array], + use_vm=True, + target=target, + dev=dev, + convert_config=convert_config, + ) verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4), (2, 3, 4, 4)) verify_batch_matmul((2, 4, 3), (3, 4), (2, 4, 4)) @@ -1207,6 +1238,13 @@ def verify_batch_matmul(a_shape, b_shape, out_shape): verify_batch_matmul((1, 4, 3), (2, 3, 4), (2, 4, 4)) verify_batch_matmul((4, 32, 16), (16, 32), (4, 32, 32)) verify_batch_matmul((4, 32, 16, 32), (32, 16), (4, 32, 16, 16)) + # Test transb=False + verify_batch_matmul( + (2, 3, 4, 3), + (2, 3, 3, 4), + (2, 3, 4, 4), + convert_config={"use_nt_batch_matmul": False}, + ) def verify_simple_dynamic_model(a_shape, b_shape, target, dev):