-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TensorRT] Add transpose_a/b for TensorRT batch_matmul #8607
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @ymwangg! Left some minor comments in review.
x = relay.var("x", shape=(x_shape), dtype="float32") | ||
y = relay.var("y", shape=(y_shape), dtype="float32") | ||
out = relay.nn.batch_matmul(x, y) | ||
out = relay.nn.batch_matmul( | ||
relay.transpose(x, [0, 2, 1]) if transa else x, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need these relay.transpose
on the inputs to test functionality of transa/transb args.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, I've changed to using x/y_shape instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Leave to @trevor-m to merge after addressing the comments.
python/tvm/relay/frontend/onnx.py
Outdated
# Transpose matrix dimensions of b. | ||
b = _op.transpose(b, [0, 2, 1]) | ||
# Perform a batch matmul. | ||
output = _op.nn.batch_matmul(a, b) | ||
output = _op.nn.batch_matmul(a, b, transpose_b=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! @ymwangg
Just a little concern about changing the default behavior of framework frontend, since currently the default topi schedule support for NN format is not as strong as the original NT one.
This may cause confusions to those who have used onnx frontend before or who is using onnx frontend now.
To give an example, I've added an extra config to TensorFlow frontend which uses the NT format by default but provides an option to use the normal format. I think that would be better before we have prepared a strong enough topi.
p.s.: You see, I've also kept the default layout for nn.batch_matmul
to be the original NT.
tvm/python/tvm/relay/frontend/tensorflow_ops.py
Lines 1191 to 1199 in 7653972
if TF_DEFAULT_CONFIGS["use_nt_batch_matmul"]: | |
# Strictly convert all batch_matmul to NT format | |
input_x = _op.transpose(input_x, axes=[0, 2, 1]) if adj_x else input_x | |
input_y = _op.transpose(input_y, axes=[0, 2, 1]) if not adj_y else input_y | |
ret = get_relay_op("batch_matmul")(input_x, input_y) | |
else: | |
ret = get_relay_op("batch_matmul")( | |
input_x, input_y, transpose_a=adj_x, transpose_b=adj_y | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jcf94 Thanks for the pointer. I will refactor to make NN optional.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks! @ymwangg
* Add transpose support for tensorrt batch_matmul * Address PR comment * Refactor to add ONNX_DEFAULT_CONFIGS
* Add transpose support for tensorrt batch_matmul * Address PR comment * Refactor to add ONNX_DEFAULT_CONFIGS
* Add transpose support for tensorrt batch_matmul * Address PR comment * Refactor to add ONNX_DEFAULT_CONFIGS
This PR added transpose_a/b for TensorRT batch_matmul, fixed a warning and compilation error with TensorRT-8. It also removed the redundant transpose op in onnx matmul. Tested with both TensorRT-7 and TensorRT-8.
cc @trevor-m @comaniac