From b491b912c8018bcbe496b296d36e233ec87cfef5 Mon Sep 17 00:00:00 2001 From: Xingjian Shi Date: Thu, 22 Oct 2020 08:46:35 -0700 Subject: [PATCH] [FIX] Fix cublas batch matmul (#6715) * Update batch_matmul.py Update batch_matmul.py * fix --- python/tvm/topi/cuda/batch_matmul.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/cuda/batch_matmul.py b/python/tvm/topi/cuda/batch_matmul.py index bb060b3ad8a7b..ee94420066dd2 100644 --- a/python/tvm/topi/cuda/batch_matmul.py +++ b/python/tvm/topi/cuda/batch_matmul.py @@ -138,7 +138,7 @@ def _callback(op): return s -def batch_matmul_cublas(x, y): +def batch_matmul_cublas(x, y, out_shape=None): """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data in batch. @@ -150,6 +150,9 @@ def batch_matmul_cublas(x, y): y : tvm.te.Tensor 3-D with shape [batch, N, K] + out_shape : None + The output shape + Returns ------- output : tvm.te.Tensor