Skip to content

Commit 20056b8

Browse files
committed
perform in chunks
1 parent 894312d commit 20056b8

File tree

1 file changed

+20
-38
lines changed

1 file changed

+20
-38
lines changed

torchao/dtypes/nf4tensor.py

Lines changed: 20 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import functools
22
from dataclasses import dataclass
3+
import math
34
from typing import Dict, Tuple
45

56
import torch
67
import torch.nn.functional as F
7-
from torch import Tensor
88

99

1010
aten = torch.ops.aten
@@ -15,6 +15,14 @@
1515

1616
NF4_OPS_TABLE: Dict[Any, Any] = {}
1717

18+
# Note: Quantize in Chunks
19+
# During quantization to NF4, one of the steps to convert from the original float number
20+
# to the index of the nearest value in the NF4 format. This can cause a large memory spike
21+
# Due to intermediates of the quantization process. Instead we process the original
22+
# tensor in chunks. This is a tradeoff between memory and speed. This number seems to
23+
# strike a good balance between memory and speed
24+
CHUNK_SIZE = 1024**2
25+
1826

1927
def same_metadata(a: "NF4Tensor", b: "NF4Tensor"):
2028
both_nf4 = isinstance(a, NF4Tensor) and isinstance(b, NF4Tensor)
@@ -375,7 +383,7 @@ def dequantize_scalers(
375383

376384
@staticmethod
377385
def convert_to_norm_float_weight(
378-
inpt_tensor: torch.Tensor, n_blocks: int, block_size: int, nf4: torch.tensor
386+
inpt_tensor: torch.Tensor, n_blocks: int, block_size: int, nf4: torch.Tensor
379387
) -> torch.Tensor:
380388
"""Convert a tensor to the normalized float weight format"""
381389
flattened_tensor = inpt_tensor.flatten()
@@ -393,9 +401,13 @@ def convert_to_norm_float_weight(
393401
scaled_blocks = blocks / scales
394402

395403
# Returns a flattened tensor with each element quantized to nf4 index
396-
quantized_blocks = NF4Tensor.quantize_tensor_nearest(
397-
scaled_blocks.flatten(), nf4
398-
)
404+
# See Note: Quantize in Chunks
405+
quantized_blocks = torch.empty(numel, dtype=torch.uint8, device=inpt_tensor.device)
406+
flattened = scaled_blocks.flatten()
407+
for chunk_num in range(math.ceil(numel / CHUNK_SIZE)):
408+
start = chunk_num * CHUNK_SIZE
409+
end = min(start + CHUNK_SIZE, numel)
410+
quantized_blocks[start:end] = NF4Tensor.quantize_tensor_nearest(flattened[start:end], nf4).to(torch.uint8)
399411

400412
# Combine the quantized elements into uint8 values
401413
# This lays out two consecutive elements in the same byte
@@ -435,7 +447,7 @@ def get_original_weight(self) -> torch.Tensor:
435447

436448
@staticmethod
437449
def quantize_tensor_nearest(
438-
value: torch.float16, nf4: torch.Tensor
450+
value: torch.Tensor, nf4: torch.Tensor
439451
) -> torch.Tensor:
440452
"""Quantize a float16 tensor to nf4 format to nearest and not rounded up"""
441453
value = value.unsqueeze(-1) # (numel, 1)
@@ -445,36 +457,15 @@ def quantize_tensor_nearest(
445457
return closest_nf4
446458

447459
@staticmethod
448-
449-
# inconsistently.
450-
451-
# defined in `torch._C.TensorBase`.
452460
def dequantize(value: torch.Tensor, nf4: torch.Tensor) -> torch.Tensor:
453461
"""Dequantize a nf4 value to bfloat16 format"""
454462
# return nf4.index_select(0, value)
455463
return nf4[value]
456464

457-
def unpack(
458-
self,
459-
) -> Tuple[
460-
int, int, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Size
461-
]:
462-
463-
# Size]` but got `Tuple[int, int, int, Tensor, Tensor, Tensor, Tensor]`.
464-
return (
465-
self.block_size,
466-
self.n_blocks,
467-
self.scaler_block_size,
468-
self.quantized_scalers,
469-
self.quantization_factor,
470-
self.scaler_mean,
471-
self.quantized_data,
472-
)
473-
474-
def __repr__(self):
465+
def __repr__(self) -> str:
475466
return f"Quantized Data: {self.quantized_data}\nScalers: {self.quantized_scalers}\n"
476467

477-
def __str__(self):
468+
def __str__(self) -> str:
478469
return f"NF4Tensor({self.shape}, {self.block_size})"
479470

480471
def __tensor_flatten__(self):
@@ -501,9 +492,6 @@ def __tensor_flatten__(self):
501492
], ctx
502493

503494
@staticmethod
504-
505-
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
506-
507495
def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride):
508496
assert len(inner_tensors) == 5, "Expected 5 inner tensors"
509497
return NF4Tensor(
@@ -567,18 +555,12 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
567555

568556
class LinearNF4(torch.autograd.Function):
569557
@staticmethod
570-
571-
# inconsistently.
572-
573558
def forward(ctx, input: torch.Tensor, weight: NF4Tensor):
574559
"""Save the quantized nf4 weight for backward pass"""
575560
ctx.nf4_weight = weight
576561
return F.linear(input, weight.to(input.dtype))
577562

578563
@staticmethod
579-
580-
# inconsistently.
581-
582564
def backward(ctx, grad_output):
583565
"""The nf4 weight will never require grad so we can just return the grad_output @ weight.to(grad_output.dtype)"""
584566
weight: NF4Tensor = ctx.nf4_weight

0 commit comments

Comments
 (0)