@@ -598,8 +598,22 @@ def _(func, types, args, kwargs):
598598 assert original_shape [- 1 ] == size [- 1 ], (
599599 f"Only support reshaping when last dimension matches, requested: reshaping from { original_shape } to { size } "
600600 )
601- qdata = self .qdata .reshape (* size )
602- scale = self .scale .reshape (* size )
601+ # TODO(andrew): This is technically not needed for unsloth fp8 RL
602+ # but fixes a bug nonetheless, can do this separately
603+ # Example input shapes:
604+ # self.shape = [6, 363, 4096]
605+ # self.scale.shape = [6, 363, 1]
606+ # self.block_size = [1, 1, 4096]
607+ # size = [-1, 4096]
608+ #
609+ # Example output shapes:
610+ # self.shape = [2178, 4096]
611+ # self.scale.shape = [2178, 1]
612+ # self.block_size = [1, 4096]
613+ new_dim0 = original_shape [0 ] * original_shape [1 ]
614+ assert size [0 ] == new_dim0 or size [0 ] == - 1
615+ qdata = self .qdata .reshape (new_dim0 , - 1 )
616+ scale = self .scale .reshape (new_dim0 , - 1 )
603617 block_size = self .block_size .copy ()
604618 block_size = [block_size [0 ] * block_size [1 ], block_size [2 ]]
605619 elif len (original_shape ) == 2 and len (size ) == 3 :
0 commit comments