Skip to content

Commit

Permalink
[Fix] Tensor core type issue for dense (#7187)
Browse files Browse the repository at this point in the history
* fix tc type issue for dense

* fix lint

* rm float 32

Co-authored-by: Leyuan Wang <ziyu.guo@bytedance.com>
  • Loading branch information
Laurawly and Leyuan Wang authored Jan 2, 2021
1 parent c02c9c5 commit 6258fae
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,9 +678,26 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
if target.kind.name == "cuda":
if nvcc.have_tensorcore(target=target):
if (
(i % 16 == 0 and b % 16 == 0 and o % 16 == 0)
or (i % 16 == 0 and b % 8 == 0 and o % 32 == 0)
or (i % 16 == 0 and b % 32 == 0 and o % 8 == 0)
(
data.dtype in ["float16", "int8", "uint8"]
and (
(i % 16 == 0 and b % 16 == 0 and o % 16 == 0)
or (i % 16 == 0 and b % 8 == 0 and o % 32 == 0)
or (i % 16 == 0 and b % 32 == 0 and o % 8 == 0)
)
)
or (
data.dtype in ["int4", "uint4"]
and i % 32 == 0
and b % 8 == 0
and o % 8 == 0
)
or (
data.dtype in ["int1", "uint1"]
and i % 128 == 0
and b % 8 == 0
and o % 8 == 0
)
):
strategy.add_implementation(
wrap_compute_dense(topi.cuda.dense_tensorcore),
Expand Down

0 comments on commit 6258fae

Please sign in to comment.