diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 9d8420c69610..37946c01cb46 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -678,9 +678,26 @@ def dense_strategy_cuda(attrs, inputs, out_type, target): if target.kind.name == "cuda": if nvcc.have_tensorcore(target=target): if ( - (i % 16 == 0 and b % 16 == 0 and o % 16 == 0) - or (i % 16 == 0 and b % 8 == 0 and o % 32 == 0) - or (i % 16 == 0 and b % 32 == 0 and o % 8 == 0) + ( + data.dtype in ["float16", "int8", "uint8"] + and ( + (i % 16 == 0 and b % 16 == 0 and o % 16 == 0) + or (i % 16 == 0 and b % 8 == 0 and o % 32 == 0) + or (i % 16 == 0 and b % 32 == 0 and o % 8 == 0) + ) + ) + or ( + data.dtype in ["int4", "uint4"] + and i % 32 == 0 + and b % 8 == 0 + and o % 8 == 0 + ) + or ( + data.dtype in ["int1", "uint1"] + and i % 128 == 0 + and b % 8 == 0 + and o % 8 == 0 + ) ): strategy.add_implementation( wrap_compute_dense(topi.cuda.dense_tensorcore),