Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorRT] Add transpose_a/b for TensorRT batch_matmul #8607

Merged
merged 3 commits into from
Aug 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@

__all__ = ["from_onnx"]

# The default configurations of Relay ONNX frontend.
ONNX_DEFAULT_CONFIGS = {
# By default, TVM converts qualified onnx `matmul` to `transpose(weight) + nn.batch_matmul_NT`.
# Change this flag to False to directly convert to `nn.batch_matmul`.
# Note that `nn.batch_matmul` with format other than NT is in experimental, it may have some
# performance issues.
"use_nt_batch_matmul": True,
}


class onnx_input:
"""Dual purpose list or dictionary access object."""
Expand Down Expand Up @@ -770,10 +779,14 @@ def flatten_to_nd(x, x_shape, nd=3):
# Convert a and b into 3 dimensional tensors.
a = flatten_to_nd(inputs[0], a_shape, 3)
b = flatten_to_nd(inputs[1], b_shape, 3)
# Transpose matrix dimensions of b.
b = _op.transpose(b, [0, 2, 1])
# Perform a batch matmul.
output = _op.nn.batch_matmul(a, b)
if ONNX_DEFAULT_CONFIGS["use_nt_batch_matmul"]:
# Transpose matrix dimensions of b.
b = _op.transpose(b, [0, 2, 1])
# Perform a NT batch matmul.
output = _op.nn.batch_matmul(a, b)
else:
# Perform a NN batch matmul.
output = _op.nn.batch_matmul(a, b, transpose_b=False)
# Determine the output batch dimension.
if a_rank > b_rank:
out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2])
Expand Down Expand Up @@ -3916,7 +3929,9 @@ def _fix_outputs(self, op_name, outputs):
return outputs


def from_onnx(model, shape=None, dtype="float32", opset=None, freeze_params=False):
def from_onnx(
model, shape=None, dtype="float32", opset=None, freeze_params=False, convert_config=None
):
"""Convert a ONNX model into an equivalent Relay Function.

ONNX graphs are represented as Python Protobuf objects.
Expand Down Expand Up @@ -3955,6 +3970,12 @@ def from_onnx(model, shape=None, dtype="float32", opset=None, freeze_params=Fals
at compile time and helps in making models static if certain inputs represent
attributes relay would traditionally consider compile-time constants.

convert_config : Optional[Dict[str, Any]]
Default config:
use_nt_batch_matmul : bool = True
True to convert qualified onnx `matmul` to `nn.batch_matmul` strict to NT format
(transpose_a=False, transpose_b=True).

Returns
-------
mod : tvm.IRModule
Expand All @@ -3963,6 +3984,10 @@ def from_onnx(model, shape=None, dtype="float32", opset=None, freeze_params=Fals
params : dict of str to tvm.nd.NDArray
The parameter dict to be used by relay
"""
global ONNX_DEFAULT_CONFIGS
if convert_config is not None:
ONNX_DEFAULT_CONFIGS.update(convert_config)

try:
import onnx

Expand Down
2 changes: 1 addition & 1 deletion src/runtime/contrib/tensorrt/tensorrt_logger.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class TensorRTLogger : public nvinfer1::ILogger {
public:
TensorRTLogger() : TensorRTLogger(Severity::kWARNING) {}
explicit TensorRTLogger(Severity severity) : reportable_severity(severity) {}
void log(Severity severity, const char* msg) override {
void log(Severity severity, const char* msg) noexcept override {
trevor-m marked this conversation as resolved.
Show resolved Hide resolved
// suppress messages with severity enum value greater than the reportable
if (severity > reportable_severity) return;

Expand Down
11 changes: 8 additions & 3 deletions src/runtime/contrib/tensorrt/tensorrt_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,7 @@ class SplitOpConverter : public TensorRTOpConverter {
std::vector<int> start(input_dims.size(), 0);
std::vector<int> size(input_dims.begin(), input_dims.end());
std::vector<int> strides(input_dims.size(), 1);
for (int i = 0; i < split_sizes.size(); ++i) {
for (size_t i = 0; i < split_sizes.size(); ++i) {
start[axis] = split_starts[i];
size[axis] = split_sizes[i];
auto slice_layer = params->network->addSlice(*input, VectorToTrtDims(start),
Expand Down Expand Up @@ -1174,9 +1174,14 @@ class BatchMatmulOpConverter : public TensorRTOpConverter {
BatchMatmulOpConverter() : TensorRTOpConverter({kTensor, kTensor}) {}

void Convert(TensorRTOpConverterParams* params) const {
auto transa = std::stoi(params->node.GetAttr<std::vector<std::string>>("transpose_a")[0]);
auto transb = std::stoi(params->node.GetAttr<std::vector<std::string>>("transpose_b")[0]);
nvinfer1::MatrixOperation trt_transa =
transa ? nvinfer1::MatrixOperation::kTRANSPOSE : nvinfer1::MatrixOperation::kNONE;
nvinfer1::MatrixOperation trt_transb =
transb ? nvinfer1::MatrixOperation::kTRANSPOSE : nvinfer1::MatrixOperation::kNONE;
nvinfer1::IMatrixMultiplyLayer* matmul_layer = params->network->addMatrixMultiply(
*params->inputs.at(0).tensor, nvinfer1::MatrixOperation::kNONE,
*params->inputs.at(1).tensor, nvinfer1::MatrixOperation::kTRANSPOSE);
*params->inputs.at(0).tensor, trt_transa, *params->inputs.at(1).tensor, trt_transb);
ICHECK(matmul_layer != nullptr);
params->outputs.push_back(matmul_layer->getOutput(0));
}
Expand Down
17 changes: 14 additions & 3 deletions tests/python/contrib/test_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,14 +474,25 @@ def get_graph(x_shape=(1, 16), k_shape=(32, 16)):


def test_batch_matmul():
def get_graph(x_shape=(12, 128, 64), y_shape=(12, 128, 64)):
def get_graph(x_shape=(12, 128, 64), y_shape=(12, 128, 64), transa=False, transb=True):
x = relay.var("x", shape=(x_shape), dtype="float32")
y = relay.var("y", shape=(y_shape), dtype="float32")
out = relay.nn.batch_matmul(x, y)
out = relay.nn.batch_matmul(x, y, transpose_a=transa, transpose_b=transb)
f = relay.Function([x, y], out)
return f, {"x": x_shape, "y": y_shape}, []

run_and_verify_func(get_graph())
run_and_verify_func(
get_graph(x_shape=(12, 64, 128), y_shape=(12, 128, 64), transa=True, transb=True)
)
run_and_verify_func(
get_graph(x_shape=(12, 64, 128), y_shape=(12, 64, 128), transa=True, transb=False)
)
run_and_verify_func(
get_graph(x_shape=(12, 128, 64), y_shape=(12, 128, 64), transa=False, transb=True)
)
run_and_verify_func(
get_graph(x_shape=(12, 128, 64), y_shape=(12, 64, 128), transa=False, transb=False)
)


def test_bias_add():
Expand Down
50 changes: 44 additions & 6 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,26 @@ def get_input_data_shape_dict(graph_def, input_data):


def get_tvm_output_with_vm(
graph_def, input_data, target, dev, opset=None, freeze_params=False, convert_to_static=False
graph_def,
input_data,
target,
dev,
opset=None,
freeze_params=False,
convert_to_static=False,
convert_config=None,
):
"""Generic function to execute and get tvm output with vm executor"""
if not isinstance(input_data, list):
input_data = [input_data]
_, shape_dict = get_input_data_shape_dict(graph_def, input_data)

mod, params = relay.frontend.from_onnx(
graph_def, shape_dict, opset=opset, freeze_params=freeze_params
graph_def,
shape_dict,
opset=opset,
freeze_params=freeze_params,
convert_config=convert_config,
)

if convert_to_static:
Expand All @@ -78,12 +89,15 @@ def get_tvm_output(
output_dtype="float32",
opset=None,
opt_level=1,
convert_config=None,
):
"""Generic function to execute and get tvm output"""
# TODO: Resolve the issues and remove the following lines
input_names, shape_dict = get_input_data_shape_dict(graph_def, input_data)

mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset)
mod, params = relay.frontend.from_onnx(
graph_def, shape_dict, opset=opset, convert_config=convert_config
)

with tvm.transform.PassContext(opt_level=opt_level):
graph, lib, params = relay.build(mod, target, params=params)
Expand Down Expand Up @@ -146,6 +160,7 @@ def verify_with_ort_with_inputs(
atol=1e-5,
apply_softmax=False,
opt_level=1,
convert_config=None,
):
if opset is not None:
model.opset_import[0].version = opset
Expand All @@ -161,10 +176,19 @@ def verify_with_ort_with_inputs(
opset=opset,
freeze_params=freeze_params,
convert_to_static=convert_to_static,
convert_config=convert_config,
)
else:
tvm_out = get_tvm_output(
model, inputs, target, dev, out_shape, dtype, opset=opset, opt_level=opt_level
model,
inputs,
target,
dev,
out_shape,
dtype,
opset=opset,
opt_level=opt_level,
convert_config=convert_config,
)
if not isinstance(tvm_out, list):
tvm_out = [tvm_out]
Expand Down Expand Up @@ -1179,7 +1203,7 @@ def test_matmul(target, dev):

@tvm.testing.parametrize_targets
def test_batch_matmul(target, dev):
def verify_batch_matmul(a_shape, b_shape, out_shape):
def verify_batch_matmul(a_shape, b_shape, out_shape, convert_config=None):
a_array = np.random.uniform(size=a_shape).astype("float32")
b_array = np.random.uniform(size=b_shape).astype("float32")

Expand All @@ -1196,7 +1220,14 @@ def verify_batch_matmul(a_shape, b_shape, out_shape):
)

model = helper.make_model(graph, producer_name="matmul_test")
verify_with_ort_with_inputs(model, [a_array, b_array], use_vm=True, target=target, dev=dev)
verify_with_ort_with_inputs(
model,
[a_array, b_array],
use_vm=True,
target=target,
dev=dev,
convert_config=convert_config,
)

verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4), (2, 3, 4, 4))
verify_batch_matmul((2, 4, 3), (3, 4), (2, 4, 4))
Expand All @@ -1207,6 +1238,13 @@ def verify_batch_matmul(a_shape, b_shape, out_shape):
verify_batch_matmul((1, 4, 3), (2, 3, 4), (2, 4, 4))
verify_batch_matmul((4, 32, 16), (16, 32), (4, 32, 32))
verify_batch_matmul((4, 32, 16, 32), (32, 16), (4, 32, 16, 16))
# Test transb=False
verify_batch_matmul(
(2, 3, 4, 3),
(2, 3, 3, 4),
(2, 3, 4, 4),
convert_config={"use_nt_batch_matmul": False},
)


def verify_simple_dynamic_model(a_shape, b_shape, target, dev):
Expand Down