Skip to content

Commit 86e14a1

Browse files
committed
Fix Float8Tensor view bug
1 parent d63cbe9 commit 86e14a1

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -598,8 +598,22 @@ def _(func, types, args, kwargs):
598598
assert original_shape[-1] == size[-1], (
599599
f"Only support reshaping when last dimension matches, requested: reshaping from {original_shape} to {size}"
600600
)
601-
qdata = self.qdata.reshape(*size)
602-
scale = self.scale.reshape(*size)
601+
# TODO(andrew): This is technically not needed for unsloth fp8 RL
602+
# but fixes a bug nonetheless, can do this separately
603+
# Example input shapes:
604+
# self.shape = [6, 363, 4096]
605+
# self.scale.shape = [6, 363, 1]
606+
# self.block_size = [1, 1, 4096]
607+
# size = [-1, 4096]
608+
#
609+
# Example output shapes:
610+
# self.shape = [2178, 4096]
611+
# self.scale.shape = [2178, 1]
612+
# self.block_size = [1, 4096]
613+
new_dim0 = original_shape[0] * original_shape[1]
614+
assert size[0] == new_dim0 or size[0] == -1
615+
qdata = self.qdata.reshape(new_dim0, -1)
616+
scale = self.scale.reshape(new_dim0, -1)
603617
block_size = self.block_size.copy()
604618
block_size = [block_size[0] * block_size[1], block_size[2]]
605619
elif len(original_shape) == 2 and len(size) == 3:

0 commit comments

Comments
 (0)