Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
Ziyu Guo committed Dec 31, 2020
1 parent cde4f4b commit 7e61564
Showing 1 changed file with 22 additions and 8 deletions.
30 changes: 22 additions & 8 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,14 +677,28 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
)
if target.kind.name == "cuda":
if nvcc.have_tensorcore(target=target):
if ((data.dtype in ["float32"] and i % 4 == 0 and b % 8 == 0 and o % 8 == 0) or
(data.dtype in ["float16", "int8", "uint8"] and
((i % 16 == 0 and b % 16 == 0 and o % 16 == 0)
or (i % 16 == 0 and b % 8 == 0 and o % 32 == 0)
or (i % 16 == 0 and b % 32 == 0 and o % 8 == 0))) or
(data.dtype in ["int4", "uint4"] and
i % 32 == 0 and b % 8 == 0 and o % 8 == 0) or
(data.dtype in ["int1", "uint1"] and i % 128 == 0 and b % 8 == 0 and o % 8 == 0)
if (
(data.dtype in ["float32"] and i % 4 == 0 and b % 8 == 0 and o % 8 == 0)
or (
data.dtype in ["float16", "int8", "uint8"]
and (
(i % 16 == 0 and b % 16 == 0 and o % 16 == 0)
or (i % 16 == 0 and b % 8 == 0 and o % 32 == 0)
or (i % 16 == 0 and b % 32 == 0 and o % 8 == 0)
)
)
or (
data.dtype in ["int4", "uint4"]
and i % 32 == 0
and b % 8 == 0
and o % 8 == 0
)
or (
data.dtype in ["int1", "uint1"]
and i % 128 == 0
and b % 8 == 0
and o % 8 == 0
)
):
strategy.add_implementation(
wrap_compute_dense(topi.cuda.dense_tensorcore),
Expand Down

0 comments on commit 7e61564

Please sign in to comment.