diff --git a/python/hidet/graph/ops/definitions/matmul/resolve.py b/python/hidet/graph/ops/definitions/matmul/resolve.py index dbd2f3a57..d16b84bbe 100644 --- a/python/hidet/graph/ops/definitions/matmul/resolve.py +++ b/python/hidet/graph/ops/definitions/matmul/resolve.py @@ -173,6 +173,9 @@ def resolve_f16(self, op: Operator) -> Optional[List[Tensor]]: if not (a.dtype == dtypes.float16 and b.dtype == dtypes.float16 and a.shape[-1] % 8 == b.shape[-1] % 8 == 0): return None + if hidet.cuda.compute_capability() < (8, 0): + return None + parallel_k = self.get_config('parallel_k', default='default') # 'default', 'search', 2, 4, ... if isinstance(parallel_k, str): if parallel_k == 'default':