From 6258fae6d1e9ab77b8065d4ffb81a5033665e0cc Mon Sep 17 00:00:00 2001 From: Leyuan Wang Date: Fri, 1 Jan 2021 16:07:41 -0800 Subject: [PATCH] [Fix] Tensor core type issue for dense (#7187) * fix tc type issue for dense * fix lint * rm float 32 Co-authored-by: Leyuan Wang --- python/tvm/relay/op/strategy/cuda.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 9d8420c69610..37946c01cb46 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -678,9 +678,26 @@ def dense_strategy_cuda(attrs, inputs, out_type, target): if target.kind.name == "cuda": if nvcc.have_tensorcore(target=target): if ( - (i % 16 == 0 and b % 16 == 0 and o % 16 == 0) - or (i % 16 == 0 and b % 8 == 0 and o % 32 == 0) - or (i % 16 == 0 and b % 32 == 0 and o % 8 == 0) + ( + data.dtype in ["float16", "int8", "uint8"] + and ( + (i % 16 == 0 and b % 16 == 0 and o % 16 == 0) + or (i % 16 == 0 and b % 8 == 0 and o % 32 == 0) + or (i % 16 == 0 and b % 32 == 0 and o % 8 == 0) + ) + ) + or ( + data.dtype in ["int4", "uint4"] + and i % 32 == 0 + and b % 8 == 0 + and o % 8 == 0 + ) + or ( + data.dtype in ["int1", "uint1"] + and i % 128 == 0 + and b % 8 == 0 + and o % 8 == 0 + ) ): strategy.add_implementation( wrap_compute_dense(topi.cuda.dense_tensorcore),