Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA] Fix dense tensorcore legalize type error when units is specified #9030

Merged
merged 2 commits into from
Sep 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/tvm/topi/cuda/tensorcore_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,12 @@ def _dense_legalize(attrs, inputs, arg_types):

x_ = relay.nn.pad(x, pad_width=((0, dm), (0, dk))) if dm or dk else x
y_ = relay.nn.pad(y, pad_width=((0, dn), (0, dk))) if dn or dk else y

# If units is explicitly specified, it is used to compute the output shape.
# We need to update units after padding to prevent a type error.
if attrs["units"] is not None:
new_attrs["units"] = N + dn
Comment on lines +180 to +183
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking about the semantic when units is specified but overwritten here. It looks like we are changing the expected output shape in this case, but it is fine if units is used basically internally. Is units mainly specified by users or other passes?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

units is specified by frontends. For example, PT and ONNX use it, while TF frontend doesn't. Actually I wonder why no one has hit this error before.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apparently, the dense tensorcore schedule is not used when batch size is 1. That's probably the reason this error is not frequent.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see make sense. Thanks for the explanation.


out_ = relay.nn.dense(x_, y_, **new_attrs)
out = (
relay.strided_slice(out_, begin=[0, 0], end=[x.value for x in output_tensor.shape])
Expand Down
12 changes: 6 additions & 6 deletions tests/python/relay/test_pass_legalize_tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def expected():

@tvm.testing.uses_gpu
def test_legalize_dense():
def _test_legalize_dense(data_shape, kernel_shape, pad_shape, dtype, do_pad=True):
def _test_legalize_dense(data_shape, kernel_shape, pad_shape, dtype, do_pad=True, units=None):
"""test legalize dense to enable tensorcore"""
M, K = data_shape
N, _ = kernel_shape
Expand All @@ -216,7 +216,7 @@ def _test_legalize_dense(data_shape, kernel_shape, pad_shape, dtype, do_pad=True
def before():
x = relay.var("x", shape=data_shape, dtype=dtype)
weight = relay.var("weight", shape=kernel_shape, dtype=dtype)
y = relay.nn.dense(x, weight)
y = relay.nn.dense(x, weight, units)
y = relay.Function([x, weight], y)
return y

Expand All @@ -237,10 +237,7 @@ def expected():
weight_pad = relay.nn.pad(weight, pad_width=((0, dn), (0, dk)))
else:
weight_pad = weight
y_pad = relay.nn.dense(
x_pad,
weight_pad,
)
y_pad = relay.nn.dense(x_pad, weight_pad, units=N + dn if units else None)
if dm or dn:
y = relay.strided_slice(y_pad, begin=[0, 0], end=out_shape)
else:
Expand All @@ -264,6 +261,9 @@ def expected():
_test_legalize_dense((3, 16), (32, 16), (5, 0, 0), dtype)
_test_legalize_dense((2, 16), (32, 16), (0, 0, 0), dtype, False)

# Test if units parameter is correctly updated
_test_legalize_dense((8, 16), (30, 16), (0, 0, 2), "float16", units=30)

_test_legalize_dense((8, 32), (32, 32), (0, 0, 0), "int4", False)
_test_legalize_dense((7, 32), (32, 32), (1, 0, 0), "int4")
_test_legalize_dense((8, 31), (32, 31), (0, 1, 0), "int4")
Expand Down