Skip to content

Commit

Permalink
add unit test
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
weifengpy committed Apr 20, 2024
1 parent d656b93 commit 5c4fe2b
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 48 deletions.
92 changes: 92 additions & 0 deletions test/dtypes/test_nf4.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import io
from collections import OrderedDict
import torchao
from typing import Tuple, Union


bnb_available = False
Expand Down Expand Up @@ -222,7 +223,98 @@ def test_smoketest_linear_compile(self, dtype: torch.dtype):
out3 = torch.compile(torch.nn.functional.linear, mode='max-autotune')(inp, a_nf4)


class TestFSDPOps(TestCase):
@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)])
def test_torch_chunk_valid(self, input_size: Union[Tuple[int], int]):
num_chunks = 2
nf4_tensor = to_nf4(torch.randn(input_size))
chunks = list(torch.chunk(nf4_tensor, num_chunks))
self.assertEqual(len(chunks), num_chunks)
if isinstance(input_size, int):
expected_size0 = input_size // num_chunks
else:
expected_size0 = input_size[0] // num_chunks
for chunk in chunks:
self.assertEqual(chunk.size(0), expected_size0)

@parametrize("input_size", [511 * 512, (511 * 512,), (511, 512), (512, 512, 512)])
def test_torch_chunk_invalid(self, input_size: Union[Tuple[int], int]):
num_chunks = 2
with self.assertRaises(AssertionError):
nf4_tensor = to_nf4(torch.randn(input_size))
torch.chunk(nf4_tensor, num_chunks)

@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)])
def test_tensor_new_zeros_valid(self, input_size: Union[Tuple[int], int]):
nf4_tensor = to_nf4(torch.randn(input_size))
nf4_tensor_zeros = nf4_tensor.new_zeros(input_size)
for attr in ["quantized_scalers", "quantization_factor", "quantized_data"]:
inner_tensor = getattr(nf4_tensor_zeros, attr)
self.assertEqual(torch.count_nonzero(inner_tensor), 0)
expected_size = input_size if not isinstance(input_size, int) else (input_size, )
self.assertEqual(nf4_tensor_zeros.size(), torch.Size(expected_size))

@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)])
def test_tensor_new_zeros_invalid(self, input_size: Union[Tuple[int], int]):
if isinstance(input_size, int):
new_size = input_size + 1
elif len(input_size) == 1:
new_size = (input_size[0] + 1, )
else:
new_size = (input_size[0] + 1, input_size[1])
nf4_tensor = to_nf4(torch.randn(input_size))
with self.assertRaisesRegex(NotImplementedError, "aten.new_zeros\(NF4Tensor\) with new size"):
nf4_tensor_zeros = nf4_tensor.new_zeros(new_size)

@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)])
def test_tensor_slice_valid(self, input_size: Union[Tuple[int], int]):
nf4_tensor = to_nf4(torch.randn(input_size))
end_idx = input_size if isinstance(input_size, int) else input_size[0]
sliced_tensor = nf4_tensor[:end_idx]
self.assertEqual(nf4_tensor.size(), sliced_tensor.size())
attrs, _ = sliced_tensor.__tensor_flatten__()
for attr in attrs:
orig_storage = getattr(nf4_tensor, attr).untyped_storage().data_ptr()
self.assertEqual(getattr(sliced_tensor, attr).untyped_storage().data_ptr(), orig_storage)

def test_tensor_slice_1d_invalid(self):
nf4_tensor = to_nf4(torch.randn(512 * 512))
with self.assertRaisesRegex(NotImplementedError, "aten.slice\(NF4Tensor\) with step"):
nf4_tensor[..., ::2]
with self.assertRaisesRegex(NotImplementedError, "aten.slice\(NF4Tensor\) with start"):
nf4_tensor[1:]
with self.assertRaisesRegex(NotImplementedError, "aten.slice\(NF4Tensor\) with end "):
nf4_tensor[:2]

def test_tensor_slice_2d_invalid(self):
nf4_tensor = to_nf4(torch.randn((512, 512)))
with self.assertRaisesRegex(NotImplementedError, "aten.slice\(NF4Tensor\) with dim"):
nf4_tensor[:, :511]
with self.assertRaisesRegex(NotImplementedError, "aten.slice\(NF4Tensor\) with start"):
nf4_tensor[1:]
with self.assertRaisesRegex(NotImplementedError, "aten.slice\(NF4Tensor\) with end"):
nf4_tensor[:2]

@parametrize("input_size", [(512 * 512,), (512, 512)])
def test_tensor_view_valid(self, input_size: Union[Tuple[int], int]):
nf4_tensor = to_nf4(torch.randn(input_size))
viewed_tensor = nf4_tensor.view(-1)
self.asssertEqual(viewed_tensor.dim(), 1)
self.asssertEqual(viewed_tensor.numel(), math.prod(input_size))
attrs, _ = sliced_tensor.__tensor_flatten__()
for attr in attrs:
orig_storage = getattr(nf4_tensor, attr).untyped_storage().data_ptr()
inner_tensor = getattr(sliced_tensor, attr)
self.asssertEqual(inner_tensor.dim(), 1)
self.assertEqual(inner_tensor.untyped_storage().data_ptr(), orig_storage)


# def test_tensor_as_strided(self):
# pass


instantiate_parametrized_tests(TestNF4Linear)
instantiate_parametrized_tests(TestFSDPOps)

if __name__ == "__main__":
run_tests()
120 changes: 72 additions & 48 deletions torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass
from typing import Dict, Tuple
import math
import sys

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -82,36 +83,46 @@ def nf4_detach(aten_op, args, kwargs=None):
]
)
def nf4_split(aten_op, args, kwargs=None):
assert len(args) == 2 and (kwargs is None or len(kwargs) == 0), "only support aten.split.Tensor with 2 args"
# TODO: assert on dim-0 sharding. how to get dim from torch.chunk?
num_chunks = args[0].size(0) // args[1]

assert args[0].quantized_scalers.numel() % num_chunks == 0, f"NF4Tensor.quantized_scalers.numel() not divisible by {num_chunks}"
quantized_scalers_chunks = aten_op(args[0].quantized_scalers, args[0].quantized_scalers.numel() // num_chunks, **kwargs)
assert args[0].quantization_factor.numel() % num_chunks == 0, f"NF4Tensor.quantization_factor.numel() not divisible by {num_chunks}"
quantization_factor_chunks = aten_op(args[0].quantization_factor, args[0].quantization_factor.numel() // num_chunks, **kwargs)
assert args[0].quantized_data.numel() % num_chunks == 0, f"NF4Tensor.quantized_data.numel() not divisible by {num_chunks}"
quantized_data_chunks = aten_op(args[0].quantized_data, args[0].quantized_data.numel() // num_chunks, **kwargs)

assert len(args) == 2, "only support 2d because of tensor meta"
if len(args) == 3 and args[2] != 0:
raise NotImplementedError(f"aten.split(NF4Tensor, dim={args[2]})")
outer_wrapper = args[0]
num_chunks = outer_wrapper.size(0) // args[1]

assert outer_wrapper.quantized_scalers.numel() % num_chunks == 0, f"NF4Tensor.quantized_scalers.numel() not divisible by {num_chunks}"
quantized_scalers_chunks = aten_op(outer_wrapper.quantized_scalers, outer_wrapper.quantized_scalers.numel() // num_chunks, **kwargs)
assert outer_wrapper.quantization_factor.numel() % num_chunks == 0, f"NF4Tensor.quantization_factor.numel() not divisible by {num_chunks}"
quantization_factor_chunks = aten_op(outer_wrapper.quantization_factor, outer_wrapper.quantization_factor.numel() // num_chunks, **kwargs)
assert outer_wrapper.quantized_data.numel() % num_chunks == 0, f"NF4Tensor.quantized_data.numel() not divisible by {num_chunks}"
quantized_data_chunks = aten_op(outer_wrapper.quantized_data, outer_wrapper.quantized_data.numel() // num_chunks, **kwargs)

orig_dim = outer_wrapper.dim()

if orig_dim == 1:
chunked_size = (outer_wrapper.size(0) // num_chunks, )
elif orig_dim == 2:
chunked_size = (outer_wrapper.size(0) // num_chunks, outer_wrapper.size(1))
else:
chunked_size = ()
raise NotImplementedError(f"aten.split(NF4Tensor) wherer NF4Tensor.dim() = {orig_dim}")

return [
NF4Tensor(
SubclassTensorArgs(
(args[0].size(0) // num_chunks, args[0].size(1)),
args[0].stride(),
args[0].storage_offset(),
args[0].dtype,
args[0].device,
args[0].requires_grad,
chunked_size,
outer_wrapper.stride(),
outer_wrapper.storage_offset(),
outer_wrapper.dtype,
outer_wrapper.device,
outer_wrapper.requires_grad,
),
args[0].block_size,
args[0].n_blocks,
args[0].scaler_block_size,
outer_wrapper.block_size,
outer_wrapper.n_blocks,
outer_wrapper.scaler_block_size,
quantized_scalers,
quantization_factor,
args[0].scaler_mean,
outer_wrapper.scaler_mean,
quantized_data,
args[0].nf4,
outer_wrapper.nf4,
) for quantized_scalers, quantization_factor, quantized_data in zip(
quantized_scalers_chunks, quantization_factor_chunks, quantized_data_chunks
)
Expand All @@ -123,37 +134,44 @@ def nf4_split(aten_op, args, kwargs=None):
]
)
def nf4_new_zeros(aten_op, args, kwargs=None):
assert len(args[0].shape) == 2 and len(args[1]) == 2, "only support new zeros on 2D"
assert args[0].numel() % math.prod(args[1]) == 0
ratio = args[0].numel() // math.prod(args[1])
outer_wrapper = args[0]
new_size = args[1]
new_size_dim = len(new_size)
if (not new_size_dim in [1, 2]) or outer_wrapper.numel() % math.prod(new_size) != 0:
raise NotImplementedError(f"aten.new_zeros(NF4Tensor) with new size {new_size}")
ratio = outer_wrapper.numel() // math.prod(new_size)

assert args[0].quantized_scalers.size(0) % ratio == 0, f"quantized_scalers.numel() must be divisible by {ratio}"
quantized_scalers_new_zeros = aten_op(args[0].quantized_scalers, [args[0].quantized_scalers.size(0) // ratio], **kwargs)
assert outer_wrapper.quantized_scalers.size(0) % ratio == 0, f"quantized_scalers.numel() must be divisible by {ratio}"
quantized_scalers_new_zeros = aten_op(outer_wrapper.quantized_scalers, [outer_wrapper.quantized_scalers.size(0) // ratio], **kwargs)

assert args[0].quantization_factor.size(0) % ratio == 0, f"quantization_factor.size(0) must be divisible by {ratio}"
quantization_factor_new_zeros = aten_op(args[0].quantization_factor, [args[0].quantization_factor.size(0) // ratio], **kwargs)
assert outer_wrapper.quantization_factor.size(0) % ratio == 0, f"quantization_factor.size(0) must be divisible by {ratio}"
quantization_factor_new_zeros = aten_op(outer_wrapper.quantization_factor, [outer_wrapper.quantization_factor.size(0) // ratio], **kwargs)

assert args[0].quantized_data.size(0) % ratio == 0, f"quantized_data.size(0) must be divisible by {ratio}"
quantized_data_new_zeros = aten_op(args[0].quantized_data, [args[0].quantized_data.size(0) // ratio], **kwargs)
assert outer_wrapper.quantized_data.size(0) % ratio == 0, f"quantized_data.size(0) must be divisible by {ratio}"
quantized_data_new_zeros = aten_op(outer_wrapper.quantized_data, [outer_wrapper.quantized_data.size(0) // ratio], **kwargs)

if new_size_dim == 1:
new_size = (new_size[0], )
elif new_size_dim == 2:
new_size = (new_size[0], new_size[1])

return NF4Tensor(
SubclassTensorArgs(
(args[1][0], args[1][1]),
args[0].stride(),
args[0].storage_offset(),
args[0].dtype,
args[0].device,
args[0].requires_grad,
new_size,
outer_wrapper.stride(),
outer_wrapper.storage_offset(),
outer_wrapper.dtype,
outer_wrapper.device,
outer_wrapper.requires_grad,
),
args[0].block_size,
args[0].n_blocks,
args[0].scaler_block_size,
outer_wrapper.block_size,
outer_wrapper.n_blocks,
outer_wrapper.scaler_block_size,
quantized_scalers_new_zeros,
quantization_factor_new_zeros,
args[0].scaler_mean,
outer_wrapper.scaler_mean,
quantized_data_new_zeros,
args[0].nf4,
outer_wrapper.nf4,
)

@implements(
Expand All @@ -162,10 +180,16 @@ def nf4_new_zeros(aten_op, args, kwargs=None):
]
)
def nf4_slice(aten_op, args, kwargs=None):
assert len(args) == 4
assert args[1] == 0, f"only support dim=0 but got dim={args[1]}"
assert args[2] == 0, f"only support start=0 but got start={args[2]}"
assert args[3] == args[0].size(0), f"only support end == size(0) but got end={args[3]} and size(0)={args[0].size(0)}"
if len(args) == 5:
raise NotImplementedError(f"aten.slice(NF4Tensor) with step={args[4]}")
if not args[1] == 0:
raise NotImplementedError(f"aten.slice(NF4Tensor) with dim={args[1]}")
if not args[2] == 0:
raise NotImplementedError(f"aten.slice(NF4Tensor) with start={args[2]}")
# for tensor 512 x 512, tensor[:, :512] dispatch to
# aten.slice(dim = 0, end=sys.maxsize)
if not args[3] in [args[0].size(0), sys.maxsize]:
raise NotImplementedError(f"aten.slice(NF4Tensor) with end={args[3]}")
return NF4Tensor(
SubclassTensorArgs(
args[0].size(),
Expand All @@ -191,7 +215,7 @@ def nf4_slice(aten_op, args, kwargs=None):
]
)
def nf4_view(aten_op, args, kwargs=None):
assert len(args) == 2, args[1] == -1
assert len(args) == 2 and len(args[1]) == 1 and args[1][0] == -1
quantized_scalers = aten_op(args[0].quantized_scalers, *(args[1:]), **kwargs)
quantization_factor = aten_op(args[0].quantization_factor, *(args[1:]), **kwargs)
quantized_data = aten_op(args[0].quantized_data, *(args[1:]), **kwargs)
Expand Down

0 comments on commit 5c4fe2b

Please sign in to comment.