diff --git a/python/tvm/topi/cuda/tensorcore_alter_op.py b/python/tvm/topi/cuda/tensorcore_alter_op.py index 50bcafd9f9a7..080ddf28b7c2 100644 --- a/python/tvm/topi/cuda/tensorcore_alter_op.py +++ b/python/tvm/topi/cuda/tensorcore_alter_op.py @@ -176,6 +176,12 @@ def _dense_legalize(attrs, inputs, arg_types): x_ = relay.nn.pad(x, pad_width=((0, dm), (0, dk))) if dm or dk else x y_ = relay.nn.pad(y, pad_width=((0, dn), (0, dk))) if dn or dk else y + + # If units is explicitly specified, it is used to compute the output shape. + # We need to update units after padding to prevent a type error. + if attrs["units"] is not None: + new_attrs["units"] = N + dn + out_ = relay.nn.dense(x_, y_, **new_attrs) out = ( relay.strided_slice(out_, begin=[0, 0], end=[x.value for x in output_tensor.shape]) diff --git a/tests/python/relay/test_pass_legalize_tensorcore.py b/tests/python/relay/test_pass_legalize_tensorcore.py index bcd69f7253ef..97860630dea5 100644 --- a/tests/python/relay/test_pass_legalize_tensorcore.py +++ b/tests/python/relay/test_pass_legalize_tensorcore.py @@ -206,7 +206,7 @@ def expected(): @tvm.testing.uses_gpu def test_legalize_dense(): - def _test_legalize_dense(data_shape, kernel_shape, pad_shape, dtype, do_pad=True): + def _test_legalize_dense(data_shape, kernel_shape, pad_shape, dtype, do_pad=True, units=None): """test legalize dense to enable tensorcore""" M, K = data_shape N, _ = kernel_shape @@ -216,7 +216,7 @@ def _test_legalize_dense(data_shape, kernel_shape, pad_shape, dtype, do_pad=True def before(): x = relay.var("x", shape=data_shape, dtype=dtype) weight = relay.var("weight", shape=kernel_shape, dtype=dtype) - y = relay.nn.dense(x, weight) + y = relay.nn.dense(x, weight, units) y = relay.Function([x, weight], y) return y @@ -237,10 +237,7 @@ def expected(): weight_pad = relay.nn.pad(weight, pad_width=((0, dn), (0, dk))) else: weight_pad = weight - y_pad = relay.nn.dense( - x_pad, - weight_pad, - ) + y_pad = relay.nn.dense(x_pad, weight_pad, units=N + dn if units else None) if dm or dn: y = relay.strided_slice(y_pad, begin=[0, 0], end=out_shape) else: @@ -264,6 +261,9 @@ def expected(): _test_legalize_dense((3, 16), (32, 16), (5, 0, 0), dtype) _test_legalize_dense((2, 16), (32, 16), (0, 0, 0), dtype, False) + # Test if units parameter is correctly updated + _test_legalize_dense((8, 16), (30, 16), (0, 0, 2), "float16", units=30) + _test_legalize_dense((8, 32), (32, 32), (0, 0, 0), "int4", False) _test_legalize_dense((7, 32), (32, 32), (1, 0, 0), "int4") _test_legalize_dense((8, 31), (32, 31), (0, 1, 0), "int4")