@@ -24,7 +24,9 @@ def _n_ones(n: int) -> int:
2424F32_EXP_BIAS = _n_ones (EBITS_F32 - 1 )
2525
2626
27- def _f32_to_floatx_unpacked (x : Tensor , ebits : int , mbits : int ) -> Tensor :
27+ def _f32_to_floatx_unpacked (
28+ x : Tensor , ebits : int , mbits : int , fake_quantize : bool = False
29+ ) -> Tensor :
2830 """Convert FP32 numbers to sub-byte floating point numbers with the given
2931 number of exponent and mantissa bits.
3032
@@ -105,7 +107,8 @@ def _f32_to_floatx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
105107 denormal_x = x + denorm_mask_float
106108 denormal_x = denormal_x .view (torch .int32 )
107109 denormal_x -= denorm_mask_int
108- denormal_x = denormal_x .to (torch .uint8 )
110+ if not fake_quantize :
111+ denormal_x = denormal_x .to (torch .uint8 )
109112
110113 #
111114 # branch 3: stay in normal range, adjust the exponent and round
@@ -120,31 +123,41 @@ def _f32_to_floatx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
120123 normal_x += mant_odd
121124 # take the bits!
122125 normal_x = normal_x >> (MBITS_F32 - mbits )
123- normal_x = normal_x .to (torch .uint8 )
126+ if not fake_quantize :
127+ normal_x = normal_x .to (torch .uint8 )
124128
125129 #
126130 # combine the branches
127131 #
128- x = torch .full_like (x , max_int , dtype = torch .uint8 )
132+ if fake_quantize :
133+ x = torch .full_like (x , max_int , dtype = torch .int32 )
134+ else :
135+ x = torch .full_like (x , max_int , dtype = torch .uint8 )
129136 x = torch .where (denormal_mask , denormal_x , x )
130137 x = torch .where (normal_mask , normal_x , x )
131138
132139 # add sign back
133140 sign_lp = sign >> (MBITS_F32 + EBITS_F32 - mbits - ebits )
134- sign_lp = sign_lp .to (torch .uint8 )
141+ if not fake_quantize :
142+ sign_lp = sign_lp .to (torch .uint8 )
135143 # Right shift of a negative signed integer can fill the least significant
136144 # bits with either 1s or 0s, depending on the implementation. Since PyTorch
137145 # doesn't have an uint32 dtype, we mask out these bits to get just the
138146 # f4 sign bit
139147 sign_lp = sign_lp & sign_mask
140148 x = x | sign_lp
141149
142- return x .to (torch .uint8 )
150+ if fake_quantize :
151+ return x
152+ else :
153+ return x .to (torch .uint8 )
143154
144155
145156# TODO(future): check if LUT for everything is faster than bit shifting,
146157# especially for fp4 (only 2^4=16 unique values).
147- def _floatx_unpacked_to_f32 (x : Tensor , ebits : int , mbits : int ) -> Tensor :
158+ def _floatx_unpacked_to_f32 (
159+ x : Tensor , ebits : int , mbits : int , fake_quantize : bool = False
160+ ) -> Tensor :
148161 """Convert sub-byte floating point numbers with the given number of exponent
149162 and mantissa bits to FP32.
150163
@@ -154,7 +167,8 @@ def _floatx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor:
154167 fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding
155168 Output: torch.Tensor of dtype fp32 with the dequantized value
156169 """
157- assert x .dtype == torch .uint8
170+ if not fake_quantize :
171+ assert x .dtype == torch .uint8
158172 assert 1 + ebits + mbits <= 8
159173
160174 sign_mask = 1 << (ebits + mbits )
0 commit comments