44import torch
55from torch import Tensor
66from torch .utils ._python_dispatch import return_and_correct_aliasing
7- from torchao .prototype .custom_fp_utils import _f32_to_fpx_unpacked , _fpx_unpacked_to_f32 , _n_ones
7+ from torchao .prototype .custom_fp_utils import (
8+ _f32_to_fpx_unpacked ,
9+ _fpx_unpacked_to_f32 ,
10+ _n_ones ,
11+ )
812from torchao .dtypes .utils import (
913 LayoutType ,
1014)
1721
1822
1923def _pack (x : Tensor , n_bits : int ) -> Tensor :
20- return reduce (torch .bitwise_or , [x [..., i ::(8 // n_bits )] << (8 - (i + 1 ) * n_bits ) for i in range (8 // n_bits )])
24+ return reduce (
25+ torch .bitwise_or ,
26+ [
27+ x [..., i :: (8 // n_bits )] << (8 - (i + 1 ) * n_bits )
28+ for i in range (8 // n_bits )
29+ ],
30+ )
2131
2232
2333def _unpack (x : Tensor , n_bits : int ) -> Tensor :
24- return torch .stack ([(x >> (8 - (i + 1 ) * n_bits )) & ((1 << n_bits ) - 1 ) for i in range (8 // n_bits )], dim = - 1 ).flatten (- 2 )
34+ return torch .stack (
35+ [
36+ (x >> (8 - (i + 1 ) * n_bits )) & ((1 << n_bits ) - 1 )
37+ for i in range (8 // n_bits )
38+ ],
39+ dim = - 1 ,
40+ ).flatten (- 2 )
2541
2642
2743# https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h#L87-L116
@@ -35,8 +51,40 @@ def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor:
3551
3652 if not undo :
3753 bit_order = {
38- 1 : [1 , 5 , 9 , 13 , 17 , 21 , 25 , 29 , 3 , 7 , 11 , 15 , 19 , 23 , 27 , 31 ,
39- 0 , 4 , 8 , 12 , 16 , 20 , 24 , 28 , 2 , 6 , 10 , 14 , 18 , 22 , 26 , 30 ],
54+ 1 : [
55+ 1 ,
56+ 5 ,
57+ 9 ,
58+ 13 ,
59+ 17 ,
60+ 21 ,
61+ 25 ,
62+ 29 ,
63+ 3 ,
64+ 7 ,
65+ 11 ,
66+ 15 ,
67+ 19 ,
68+ 23 ,
69+ 27 ,
70+ 31 ,
71+ 0 ,
72+ 4 ,
73+ 8 ,
74+ 12 ,
75+ 16 ,
76+ 20 ,
77+ 24 ,
78+ 28 ,
79+ 2 ,
80+ 6 ,
81+ 10 ,
82+ 14 ,
83+ 18 ,
84+ 22 ,
85+ 26 ,
86+ 30 ,
87+ ],
4088 2 : [1 , 5 , 9 , 13 , 3 , 7 , 11 , 15 , 0 , 4 , 8 , 12 , 2 , 6 , 10 , 14 ],
4189 4 : [1 , 5 , 3 , 7 , 0 , 4 , 2 , 6 ],
4290 }[n_bits ]
@@ -45,8 +93,40 @@ def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor:
4593 # this is inverse of the above, obtained by running
4694 # [v.index(i) for i in range(len(v))]
4795 bit_order = {
48- 1 : [16 , 0 , 24 , 8 , 17 , 1 , 25 , 9 , 18 , 2 , 26 , 10 , 19 , 3 , 27 , 11 ,
49- 20 , 4 , 28 , 12 , 21 , 5 , 29 , 13 , 22 , 6 , 30 , 14 , 23 , 7 , 31 , 15 ],
96+ 1 : [
97+ 16 ,
98+ 0 ,
99+ 24 ,
100+ 8 ,
101+ 17 ,
102+ 1 ,
103+ 25 ,
104+ 9 ,
105+ 18 ,
106+ 2 ,
107+ 26 ,
108+ 10 ,
109+ 19 ,
110+ 3 ,
111+ 27 ,
112+ 11 ,
113+ 20 ,
114+ 4 ,
115+ 28 ,
116+ 12 ,
117+ 21 ,
118+ 5 ,
119+ 29 ,
120+ 13 ,
121+ 22 ,
122+ 6 ,
123+ 30 ,
124+ 14 ,
125+ 23 ,
126+ 7 ,
127+ 31 ,
128+ 15 ,
129+ ],
50130 2 : [8 , 0 , 12 , 4 , 9 , 1 , 13 , 5 , 10 , 2 , 14 , 6 , 11 , 3 , 15 , 7 ],
51131 4 : [4 , 0 , 6 , 2 , 5 , 1 , 7 , 3 ],
52132 }[n_bits ]
@@ -82,8 +162,12 @@ def _pack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor:
82162 tensor_ybit = (tensor >> (nbits - used_bits - y )) & mask
83163 tensor_ybit = _pack (tensor_ybit , y )
84164
85- tensor_ybit = tensor_ybit .view (32 , - 1 , 4 ).permute (1 , 0 , 2 ).flip (2 ) # Pass 2 from original code
86- tensor_ybit = _bit_interleave (tensor_ybit .flatten (), y ) # Pass 3 from original code
165+ tensor_ybit = (
166+ tensor_ybit .view (32 , - 1 , 4 ).permute (1 , 0 , 2 ).flip (2 )
167+ ) # Pass 2 from original code
168+ tensor_ybit = _bit_interleave (
169+ tensor_ybit .flatten (), y
170+ ) # Pass 3 from original code
87171 fragments .append (tensor_ybit )
88172 used_bits += y
89173
@@ -125,7 +209,9 @@ def to_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int) -> Tuple[Tensor, Te
125209
126210 # workaround: global lookup table
127211 exp_bias = _ONES_TABLE [ebits - 1 ]
128- max_normal = 2 ** (_ONES_TABLE [ebits ] - exp_bias ) * (_ONES_TABLE [mbits + 1 ] / (2 ** mbits ))
212+ max_normal = 2 ** (_ONES_TABLE [ebits ] - exp_bias ) * (
213+ _ONES_TABLE [mbits + 1 ] / (2 ** mbits )
214+ )
129215
130216 tensor = tensor .float ()
131217 scale = tensor .abs ().amax (1 ).clamp (min = 1e-12 ) / max_normal
@@ -151,8 +237,10 @@ def _unpack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor:
151237 tensor_ybit = tensor [offset : offset + size_ybit ]
152238 offset += size_ybit
153239
154- tensor_ybit = _bit_interleave (tensor_ybit , y , undo = True ) # undo Pass 3
155- tensor_ybit = tensor_ybit .view (- 1 , 32 , 4 ).flip (2 ).permute (1 , 0 , 2 ) # undo Pass 2
240+ tensor_ybit = _bit_interleave (tensor_ybit , y , undo = True ) # undo Pass 3
241+ tensor_ybit = (
242+ tensor_ybit .view (- 1 , 32 , 4 ).flip (2 ).permute (1 , 0 , 2 )
243+ ) # undo Pass 2
156244
157245 tensor_ybit = _unpack (tensor_ybit .flatten (), y )
158246 tensor_ybit = tensor_ybit << (nbits - used_bits - y )
@@ -223,7 +311,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
223311 10240 : 5 ,
224312 14336 : 7 ,
225313 28672 : 7 ,
226- 57344 : 7
314+ 57344 : 7 ,
227315 },
228316 { # tokens: [65:128]
229317 3072 : 9 ,
@@ -234,7 +322,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
234322 10240 : 5 ,
235323 14336 : 7 ,
236324 28672 : 7 ,
237- 57344 : 6
325+ 57344 : 6 ,
238326 },
239327 { # tokens: [129:192]
240328 3072 : 6 ,
@@ -245,7 +333,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
245333 10240 : 5 ,
246334 14336 : 5 ,
247335 28672 : 5 ,
248- 57344 : 4
336+ 57344 : 4 ,
249337 },
250338 { # tokens: [193:256]
251339 3072 : 9 ,
@@ -256,7 +344,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
256344 10240 : 4 ,
257345 14336 : 8 ,
258346 28672 : 6 ,
259- 57344 : 4
347+ 57344 : 4 ,
260348 },
261349 { # tokens: [257:320]
262350 3072 : 7 ,
@@ -267,7 +355,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
267355 10240 : 1 ,
268356 14336 : 3 ,
269357 28672 : 3 ,
270- 57344 : 4
358+ 57344 : 4 ,
271359 },
272360 { # tokens: [321:384]
273361 3072 : 3 ,
@@ -278,7 +366,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
278366 10240 : 8 ,
279367 14336 : 3 ,
280368 28672 : 4 ,
281- 57344 : 3
369+ 57344 : 3 ,
282370 },
283371 { # tokens: [385:448]
284372 3072 : 5 ,
@@ -289,7 +377,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
289377 10240 : 3 ,
290378 14336 : 1 ,
291379 28672 : 1 ,
292- 57344 : 3
380+ 57344 : 3 ,
293381 },
294382 { # tokens: [449:512]
295383 3072 : 2 ,
@@ -300,7 +388,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
300388 10240 : 2 ,
301389 14336 : 6 ,
302390 28672 : 4 ,
303- 57344 : 1
391+ 57344 : 1 ,
304392 },
305393 { # tokens: [513:576]
306394 3072 : 2 ,
@@ -311,7 +399,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
311399 10240 : 3 ,
312400 14336 : 3 ,
313401 28672 : 1 ,
314- 57344 : 1
402+ 57344 : 1 ,
315403 },
316404 { # tokens: [577:640]
317405 3072 : 5 ,
@@ -322,7 +410,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
322410 10240 : 1 ,
323411 14336 : 1 ,
324412 28672 : 1 ,
325- 57344 : 1
413+ 57344 : 1 ,
326414 },
327415 { # tokens: [641:704]
328416 3072 : 3 ,
@@ -333,7 +421,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
333421 10240 : 2 ,
334422 14336 : 1 ,
335423 28672 : 1 ,
336- 57344 : 1
424+ 57344 : 1 ,
337425 },
338426 { # tokens: [705:768]
339427 3072 : 3 ,
@@ -344,20 +432,22 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te
344432 10240 : 1 ,
345433 14336 : 1 ,
346434 28672 : 1 ,
347- 57344 : 1
348- }
435+ 57344 : 1 ,
436+ },
349437]
350438
351439
352440# quantization api integrations
353441
442+
354443@dataclass (frozen = True )
355444class FpxTensorCoreLayoutType (LayoutType ):
356- """Layout type for FpxTensorCoreAQTLayout
357- """
445+ """Layout type for FpxTensorCoreAQTLayout"""
446+
358447 ebits : int
359448 mbits : int
360449
450+
361451@register_layout_cls (FpxTensorCoreLayoutType )
362452class FpxTensorCoreAQTLayout (AQTLayout ):
363453 """FpxTensorCoreAQTLayout represents a Tensor with dtype fpx(ebits=a, mbits=b),
@@ -381,6 +471,7 @@ class FpxTensorCoreAQTLayout(AQTLayout):
381471 it will then pack the weight and instantiate the FpxTensorCoreAQTLayout tensor
382472 FpxTensorCoreAQTLayout.__init__() takes a packed fpx Tensor of shape (M, N // 8 * nbit)
383473 """
474+
384475 def __new__ (
385476 cls ,
386477 packed_fpx_data : torch .Tensor ,
@@ -389,11 +480,16 @@ def __new__(
389480 ):
390481 assert packed_fpx_data .ndim == 2
391482 assert packed_fpx_data .dtype == torch .uint8
392- shape = (packed_fpx_data .shape [0 ], packed_fpx_data .shape [1 ] // (1 + layout_type .ebits + layout_type .mbits ) * 8 )
483+ shape = (
484+ packed_fpx_data .shape [0 ],
485+ packed_fpx_data .shape [1 ] // (1 + layout_type .ebits + layout_type .mbits ) * 8 ,
486+ )
393487 kwargs = {}
394488 kwargs ["device" ] = packed_fpx_data .device
395489 kwargs ["layout" ] = (
396- kwargs .get ("layout" ) if kwargs .get ("layout" , False ) else packed_fpx_data .layout
490+ kwargs .get ("layout" )
491+ if kwargs .get ("layout" , False )
492+ else packed_fpx_data .layout
397493 )
398494 kwargs ["dtype" ] = packed_fpx_data .dtype
399495 kwargs ["requires_grad" ] = False
@@ -416,12 +512,17 @@ def __tensor_flatten__(self):
416512 def __tensor_unflatten__ (
417513 cls , tensor_data_dict , tensor_attributes , outer_size , outer_stride
418514 ):
419- packed_fpx_data , scale = tensor_data_dict ["packed_fpx_data" ], tensor_data_dict ["scale" ]
420- layout_type , = tensor_attributes
515+ packed_fpx_data , scale = (
516+ tensor_data_dict ["packed_fpx_data" ],
517+ tensor_data_dict ["scale" ],
518+ )
519+ (layout_type ,) = tensor_attributes
421520 return cls (packed_fpx_data , scale , layout_type )
422521
423522 def get_plain (self ) -> Tuple [torch .Tensor , torch .Tensor ]:
424- unpacked_fpx_data = unpack_tc_fpx (self .packed_fpx_data , 1 + self .layout_type .ebits + self .layout_type .mbits )
523+ unpacked_fpx_data = unpack_tc_fpx (
524+ self .packed_fpx_data , 1 + self .layout_type .ebits + self .layout_type .mbits
525+ )
425526 return unpacked_fpx_data , self .scale
426527
427528 @classmethod
@@ -440,7 +541,9 @@ def from_plain(
440541 bit, M is mantissa bit
441542 """
442543 assert isinstance (layout_type , FpxTensorCoreLayoutType )
443- packed_fpx_data = pack_tc_fpx (unpacked_fpx_data , 1 + layout_type .ebits + layout_type .mbits )
544+ packed_fpx_data = pack_tc_fpx (
545+ unpacked_fpx_data , 1 + layout_type .ebits + layout_type .mbits
546+ )
444547 return cls (packed_fpx_data , scale , layout_type )
445548
446549 def __repr__ (self ):
@@ -478,7 +581,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
478581 )
479582 elif func is aten ._to_copy .default :
480583 return return_and_correct_aliasing (
481- func , args , kwargs , args [0 ]._apply_fn_to_data (lambda x : x .to (device = kwargs .pop ("device" , None ))),
584+ func ,
585+ args ,
586+ kwargs ,
587+ args [0 ]._apply_fn_to_data (
588+ lambda x : x .to (device = kwargs .pop ("device" , None ))
589+ ),
482590 )
483591
484592 raise NotImplementedError (
0 commit comments