Skip to content

Commit 4132cbe

Browse files
maleksan85maleksan85
andauthored
Making check for output match in original types. It saves some memory. (vllm-project#135)
Co-authored-by: maleksan85 <maleksan@amd.com>
1 parent d5bf9bc commit 4132cbe

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

gradlib/gradlib/GemmTuner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ def check_gemm_ref(self, libtype, solidx):
7676
self.outdtype)
7777
elif libtype == 'rocblas':
7878
c = rocsolidxgemm.rocb_mm(self.inp, self.weights.t(), solidx)
79-
if torch.allclose(c.to(torch.float32),
80-
ref.to(torch.float32),
79+
if torch.allclose(c.to(self.outdtype),
80+
ref.to(self.outdtype),
8181
atol=self.atol,
8282
rtol=self.rtol):
8383
return True

0 commit comments

Comments
 (0)