Skip to content

Commit 7b390b2

Browse files
RandySherifffacebook-github-bot
authored andcommitted
Fix torchAO shape check on fp8 tensors
Differential Revision: D83184235
1 parent b47f1a3 commit 7b390b2

File tree

1 file changed

+45
-15
lines changed

1 file changed

+45
-15
lines changed

torchao/dtypes/floatx/cutlass_semi_sparse_layout.py

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -190,23 +190,53 @@ def _apply_fn_to_data(self, fn):
190190
def _linear_fp8_act_fp8_weight_sparse_cutlass_check(input_tensor, weight_tensor, bias):
191191
from torchao.dtypes.floatx import Float8Layout
192192

193-
return (
194-
isinstance(input_tensor, AffineQuantizedTensor)
195-
and isinstance(input_tensor._layout, Float8Layout)
196-
and input_tensor.dtype in (torch.float16, torch.bfloat16)
197-
and len(input_tensor.shape) >= 2
198-
and input_tensor.tensor_impl.scale.dtype == torch.float32
199-
and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1
200-
and isinstance(weight_tensor, AffineQuantizedTensor)
201-
and isinstance(weight_tensor._layout, CutlassSemiSparseLayout)
202-
and weight_tensor.dtype == input_tensor.dtype
203-
and len(weight_tensor.shape) == 2
204-
and weight_tensor.tensor_impl.scale.dtype == torch.float32
205-
and len(weight_tensor.tensor_impl.scale.shape) == 1
206-
and (bias is None or bias.dtype == input_tensor.dtype)
207-
and (bias is None or len(bias.shape) == 1)
193+
base_check = (
194+
isinstance(input_tensor, AffineQuantizedTensor) and
195+
isinstance(input_tensor._layout, Float8Layout) and
196+
input_tensor.dtype in (torch.float16, torch.bfloat16) and
197+
len(input_tensor.shape) >= 2 and
198+
input_tensor.tensor_impl.scale.dtype == torch.float32 and
199+
isinstance(weight_tensor, AffineQuantizedTensor) and
200+
isinstance(weight_tensor._layout, CutlassSemiSparseLayout) and
201+
weight_tensor.dtype == input_tensor.dtype and
202+
len(weight_tensor.shape) == 2 and
203+
weight_tensor.tensor_impl.scale.dtype == torch.float32 and
204+
(bias is None or bias.dtype == input_tensor.dtype) and
205+
(bias is None or len(bias.shape) == 1)
208206
)
209207

208+
if base_check:
209+
210+
# do extra check and reshape if needed
211+
input_tensor_squeezed = False
212+
if len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) and \
213+
len(input_tensor.tensor_impl.scale.shape) > 1 and \
214+
input_tensor.tensor_impl.scale.shape[-1] == 1:
215+
input_tensor.tensor_impl.scale = torch.squeeze(input_tensor.tensor_impl.scale, dim=-1)
216+
input_tensor_squeezed = True
217+
218+
weight_tensor_squeezed = False
219+
if len(weight_tensor.tensor_impl.scale.shape) == 2 and \
220+
weight_tensor.tensor_impl.scale.shape[-1] == 1:
221+
weight_tensor.tensor_impl.scale = torch.squeeze(weight_tensor.tensor_impl.scale, dim=-1)
222+
weight_tensor_squeezed = True
223+
224+
extra_check = (
225+
len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1
226+
and len(weight_tensor.tensor_impl.scale.shape) == 1
227+
)
228+
229+
if not extra_check: # revert if extra check failed
230+
if input_tensor_squeezed:
231+
input_tensor.tensor_impl.scale = torch.unsqueeze(input_tensor.tensor_impl.scale, dim=-1)
232+
if weight_tensor_squeezed:
233+
weight_tensor.tensor_impl.scale = torch.unsqueeze(weight_tensor.tensor_impl.scale, dim=-1)
234+
235+
return extra_check
236+
237+
else:
238+
return False
239+
210240

211241
def _linear_fp8_act_fp8_weight_sparse_cutlass_impl(input_tensor, weight_tensor, bias):
212242
from torchao.ops import rowwise_scaled_linear_sparse_cutlass_f8f8

0 commit comments

Comments
 (0)