diff --git a/python/tvm/relay/op/strategy/rocm.py b/python/tvm/relay/op/strategy/rocm.py index 0486f71b526c..6cda346e5068 100644 --- a/python/tvm/relay/op/strategy/rocm.py +++ b/python/tvm/relay/op/strategy/rocm.py @@ -129,7 +129,7 @@ def dense_strategy_rocm(attrs, inputs, out_type, target): assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported." strategy.add_implementation( wrap_compute_dense(topi.rocm.dense_rocblas), - wrap_topi_schedule(topi.rocm.dense_rocblas), + wrap_topi_schedule(topi.rocm.schedule_dense_rocblas), name="dense_rocblas.rocm", plevel=15) return strategy diff --git a/topi/python/topi/rocm/dense.py b/topi/python/topi/rocm/dense.py index 097120da88d6..989cc2aed7c3 100644 --- a/topi/python/topi/rocm/dense.py +++ b/topi/python/topi/rocm/dense.py @@ -123,6 +123,8 @@ def dense_rocblas(cfg, data, weight, bias=None, out_dtype=None): output : tvm.te.Tensor 2-D with shape [batch, out_dim] """ + if out_dtype is None: + out_dtype = data.dtype assert out_dtype == data.dtype, "Mixed precision not supported." matmul = rocblas.matmul(data, weight, False, True) batch, in_dim = data.shape