1
1
import functools
2
2
from dataclasses import dataclass
3
+ import math
3
4
from typing import Dict , Tuple
4
5
5
6
import torch
6
7
import torch .nn .functional as F
7
- from torch import Tensor
8
8
9
9
10
10
aten = torch .ops .aten
15
15
16
16
NF4_OPS_TABLE : Dict [Any , Any ] = {}
17
17
18
+ # Note: Quantize in Chunks
19
+ # During quantization to NF4, one of the steps to convert from the original float number
20
+ # to the index of the nearest value in the NF4 format. This can cause a large memory spike
21
+ # Due to intermediates of the quantization process. Instead we process the original
22
+ # tensor in chunks. This is a tradeoff between memory and speed. This number seems to
23
+ # strike a good balance between memory and speed
24
+ CHUNK_SIZE = 1024 ** 2
25
+
18
26
19
27
def same_metadata (a : "NF4Tensor" , b : "NF4Tensor" ):
20
28
both_nf4 = isinstance (a , NF4Tensor ) and isinstance (b , NF4Tensor )
@@ -375,7 +383,7 @@ def dequantize_scalers(
375
383
376
384
@staticmethod
377
385
def convert_to_norm_float_weight (
378
- inpt_tensor : torch .Tensor , n_blocks : int , block_size : int , nf4 : torch .tensor
386
+ inpt_tensor : torch .Tensor , n_blocks : int , block_size : int , nf4 : torch .Tensor
379
387
) -> torch .Tensor :
380
388
"""Convert a tensor to the normalized float weight format"""
381
389
flattened_tensor = inpt_tensor .flatten ()
@@ -393,9 +401,13 @@ def convert_to_norm_float_weight(
393
401
scaled_blocks = blocks / scales
394
402
395
403
# Returns a flattened tensor with each element quantized to nf4 index
396
- quantized_blocks = NF4Tensor .quantize_tensor_nearest (
397
- scaled_blocks .flatten (), nf4
398
- )
404
+ # See Note: Quantize in Chunks
405
+ quantized_blocks = torch .empty (numel , dtype = torch .uint8 , device = inpt_tensor .device )
406
+ flattened = scaled_blocks .flatten ()
407
+ for chunk_num in range (math .ceil (numel / CHUNK_SIZE )):
408
+ start = chunk_num * CHUNK_SIZE
409
+ end = min (start + CHUNK_SIZE , numel )
410
+ quantized_blocks [start :end ] = NF4Tensor .quantize_tensor_nearest (flattened [start :end ], nf4 ).to (torch .uint8 )
399
411
400
412
# Combine the quantized elements into uint8 values
401
413
# This lays out two consecutive elements in the same byte
@@ -435,7 +447,7 @@ def get_original_weight(self) -> torch.Tensor:
435
447
436
448
@staticmethod
437
449
def quantize_tensor_nearest (
438
- value : torch .float16 , nf4 : torch .Tensor
450
+ value : torch .Tensor , nf4 : torch .Tensor
439
451
) -> torch .Tensor :
440
452
"""Quantize a float16 tensor to nf4 format to nearest and not rounded up"""
441
453
value = value .unsqueeze (- 1 ) # (numel, 1)
@@ -445,36 +457,15 @@ def quantize_tensor_nearest(
445
457
return closest_nf4
446
458
447
459
@staticmethod
448
-
449
- # inconsistently.
450
-
451
- # defined in `torch._C.TensorBase`.
452
460
def dequantize (value : torch .Tensor , nf4 : torch .Tensor ) -> torch .Tensor :
453
461
"""Dequantize a nf4 value to bfloat16 format"""
454
462
# return nf4.index_select(0, value)
455
463
return nf4 [value ]
456
464
457
- def unpack (
458
- self ,
459
- ) -> Tuple [
460
- int , int , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Size
461
- ]:
462
-
463
- # Size]` but got `Tuple[int, int, int, Tensor, Tensor, Tensor, Tensor]`.
464
- return (
465
- self .block_size ,
466
- self .n_blocks ,
467
- self .scaler_block_size ,
468
- self .quantized_scalers ,
469
- self .quantization_factor ,
470
- self .scaler_mean ,
471
- self .quantized_data ,
472
- )
473
-
474
- def __repr__ (self ):
465
+ def __repr__ (self ) -> str :
475
466
return f"Quantized Data: { self .quantized_data } \n Scalers: { self .quantized_scalers } \n "
476
467
477
- def __str__ (self ):
468
+ def __str__ (self ) -> str :
478
469
return f"NF4Tensor({ self .shape } , { self .block_size } )"
479
470
480
471
def __tensor_flatten__ (self ):
@@ -501,9 +492,6 @@ def __tensor_flatten__(self):
501
492
], ctx
502
493
503
494
@staticmethod
504
-
505
- # `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
506
-
507
495
def __tensor_unflatten__ (inner_tensors : Dict , metadata , outer_size , outer_stride ):
508
496
assert len (inner_tensors ) == 5 , "Expected 5 inner tensors"
509
497
return NF4Tensor (
@@ -567,18 +555,12 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
567
555
568
556
class LinearNF4 (torch .autograd .Function ):
569
557
@staticmethod
570
-
571
- # inconsistently.
572
-
573
558
def forward (ctx , input : torch .Tensor , weight : NF4Tensor ):
574
559
"""Save the quantized nf4 weight for backward pass"""
575
560
ctx .nf4_weight = weight
576
561
return F .linear (input , weight .to (input .dtype ))
577
562
578
563
@staticmethod
579
-
580
- # inconsistently.
581
-
582
564
def backward (ctx , grad_output ):
583
565
"""The nf4 weight will never require grad so we can just return the grad_output @ weight.to(grad_output.dtype)"""
584
566
weight : NF4Tensor = ctx .nf4_weight
0 commit comments