diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index e3b25e3c3..8ed9709d3 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -15,6 +15,7 @@ import io from collections import OrderedDict import torchao +from typing import Tuple, Union bnb_available = False @@ -222,7 +223,98 @@ def test_smoketest_linear_compile(self, dtype: torch.dtype): out3 = torch.compile(torch.nn.functional.linear, mode='max-autotune')(inp, a_nf4) +class TestFSDPOps(TestCase): + @parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)]) + def test_torch_chunk_valid(self, input_size: Union[Tuple[int], int]): + num_chunks = 2 + nf4_tensor = to_nf4(torch.randn(input_size)) + chunks = list(torch.chunk(nf4_tensor, num_chunks)) + self.assertEqual(len(chunks), num_chunks) + if isinstance(input_size, int): + expected_size0 = input_size // num_chunks + else: + expected_size0 = input_size[0] // num_chunks + for chunk in chunks: + self.assertEqual(chunk.size(0), expected_size0) + + @parametrize("input_size", [511 * 512, (511 * 512,), (511, 512), (512, 512, 512)]) + def test_torch_chunk_invalid(self, input_size: Union[Tuple[int], int]): + num_chunks = 2 + with self.assertRaises(AssertionError): + nf4_tensor = to_nf4(torch.randn(input_size)) + torch.chunk(nf4_tensor, num_chunks) + + @parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)]) + def test_tensor_new_zeros_valid(self, input_size: Union[Tuple[int], int]): + nf4_tensor = to_nf4(torch.randn(input_size)) + nf4_tensor_zeros = nf4_tensor.new_zeros(input_size) + for attr in ["quantized_scalers", "quantization_factor", "quantized_data"]: + inner_tensor = getattr(nf4_tensor_zeros, attr) + self.assertEqual(torch.count_nonzero(inner_tensor), 0) + expected_size = input_size if not isinstance(input_size, int) else (input_size, ) + self.assertEqual(nf4_tensor_zeros.size(), torch.Size(expected_size)) + + @parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)]) + def test_tensor_new_zeros_invalid(self, input_size: Union[Tuple[int], int]): + if isinstance(input_size, int): + new_size = input_size + 1 + elif len(input_size) == 1: + new_size = (input_size[0] + 1, ) + else: + new_size = (input_size[0] + 1, input_size[1]) + nf4_tensor = to_nf4(torch.randn(input_size)) + with self.assertRaisesRegex(NotImplementedError, "aten.new_zeros\(NF4Tensor\) with new size"): + nf4_tensor_zeros = nf4_tensor.new_zeros(new_size) + + @parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)]) + def test_tensor_slice_valid(self, input_size: Union[Tuple[int], int]): + nf4_tensor = to_nf4(torch.randn(input_size)) + end_idx = input_size if isinstance(input_size, int) else input_size[0] + sliced_tensor = nf4_tensor[:end_idx] + self.assertEqual(nf4_tensor.size(), sliced_tensor.size()) + attrs, _ = sliced_tensor.__tensor_flatten__() + for attr in attrs: + orig_storage = getattr(nf4_tensor, attr).untyped_storage().data_ptr() + self.assertEqual(getattr(sliced_tensor, attr).untyped_storage().data_ptr(), orig_storage) + + def test_tensor_slice_1d_invalid(self): + nf4_tensor = to_nf4(torch.randn(512 * 512)) + with self.assertRaisesRegex(NotImplementedError, "aten.slice\(NF4Tensor\) with step"): + nf4_tensor[..., ::2] + with self.assertRaisesRegex(NotImplementedError, "aten.slice\(NF4Tensor\) with start"): + nf4_tensor[1:] + with self.assertRaisesRegex(NotImplementedError, "aten.slice\(NF4Tensor\) with end "): + nf4_tensor[:2] + + def test_tensor_slice_2d_invalid(self): + nf4_tensor = to_nf4(torch.randn((512, 512))) + with self.assertRaisesRegex(NotImplementedError, "aten.slice\(NF4Tensor\) with dim"): + nf4_tensor[:, :511] + with self.assertRaisesRegex(NotImplementedError, "aten.slice\(NF4Tensor\) with start"): + nf4_tensor[1:] + with self.assertRaisesRegex(NotImplementedError, "aten.slice\(NF4Tensor\) with end"): + nf4_tensor[:2] + + @parametrize("input_size", [(512 * 512,), (512, 512)]) + def test_tensor_view_valid(self, input_size: Union[Tuple[int], int]): + nf4_tensor = to_nf4(torch.randn(input_size)) + viewed_tensor = nf4_tensor.view(-1) + self.asssertEqual(viewed_tensor.dim(), 1) + self.asssertEqual(viewed_tensor.numel(), math.prod(input_size)) + attrs, _ = sliced_tensor.__tensor_flatten__() + for attr in attrs: + orig_storage = getattr(nf4_tensor, attr).untyped_storage().data_ptr() + inner_tensor = getattr(sliced_tensor, attr) + self.asssertEqual(inner_tensor.dim(), 1) + self.assertEqual(inner_tensor.untyped_storage().data_ptr(), orig_storage) + + + # def test_tensor_as_strided(self): + # pass + + instantiate_parametrized_tests(TestNF4Linear) +instantiate_parametrized_tests(TestFSDPOps) if __name__ == "__main__": run_tests() diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index a74106802..500360659 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from typing import Dict, Tuple import math +import sys import torch import torch.nn.functional as F @@ -82,36 +83,46 @@ def nf4_detach(aten_op, args, kwargs=None): ] ) def nf4_split(aten_op, args, kwargs=None): - assert len(args) == 2 and (kwargs is None or len(kwargs) == 0), "only support aten.split.Tensor with 2 args" - # TODO: assert on dim-0 sharding. how to get dim from torch.chunk? - num_chunks = args[0].size(0) // args[1] - - assert args[0].quantized_scalers.numel() % num_chunks == 0, f"NF4Tensor.quantized_scalers.numel() not divisible by {num_chunks}" - quantized_scalers_chunks = aten_op(args[0].quantized_scalers, args[0].quantized_scalers.numel() // num_chunks, **kwargs) - assert args[0].quantization_factor.numel() % num_chunks == 0, f"NF4Tensor.quantization_factor.numel() not divisible by {num_chunks}" - quantization_factor_chunks = aten_op(args[0].quantization_factor, args[0].quantization_factor.numel() // num_chunks, **kwargs) - assert args[0].quantized_data.numel() % num_chunks == 0, f"NF4Tensor.quantized_data.numel() not divisible by {num_chunks}" - quantized_data_chunks = aten_op(args[0].quantized_data, args[0].quantized_data.numel() // num_chunks, **kwargs) - - assert len(args) == 2, "only support 2d because of tensor meta" + if len(args) == 3 and args[2] != 0: + raise NotImplementedError(f"aten.split(NF4Tensor, dim={args[2]})") + outer_wrapper = args[0] + num_chunks = outer_wrapper.size(0) // args[1] + + assert outer_wrapper.quantized_scalers.numel() % num_chunks == 0, f"NF4Tensor.quantized_scalers.numel() not divisible by {num_chunks}" + quantized_scalers_chunks = aten_op(outer_wrapper.quantized_scalers, outer_wrapper.quantized_scalers.numel() // num_chunks, **kwargs) + assert outer_wrapper.quantization_factor.numel() % num_chunks == 0, f"NF4Tensor.quantization_factor.numel() not divisible by {num_chunks}" + quantization_factor_chunks = aten_op(outer_wrapper.quantization_factor, outer_wrapper.quantization_factor.numel() // num_chunks, **kwargs) + assert outer_wrapper.quantized_data.numel() % num_chunks == 0, f"NF4Tensor.quantized_data.numel() not divisible by {num_chunks}" + quantized_data_chunks = aten_op(outer_wrapper.quantized_data, outer_wrapper.quantized_data.numel() // num_chunks, **kwargs) + + orig_dim = outer_wrapper.dim() + + if orig_dim == 1: + chunked_size = (outer_wrapper.size(0) // num_chunks, ) + elif orig_dim == 2: + chunked_size = (outer_wrapper.size(0) // num_chunks, outer_wrapper.size(1)) + else: + chunked_size = () + raise NotImplementedError(f"aten.split(NF4Tensor) wherer NF4Tensor.dim() = {orig_dim}") + return [ NF4Tensor( SubclassTensorArgs( - (args[0].size(0) // num_chunks, args[0].size(1)), - args[0].stride(), - args[0].storage_offset(), - args[0].dtype, - args[0].device, - args[0].requires_grad, + chunked_size, + outer_wrapper.stride(), + outer_wrapper.storage_offset(), + outer_wrapper.dtype, + outer_wrapper.device, + outer_wrapper.requires_grad, ), - args[0].block_size, - args[0].n_blocks, - args[0].scaler_block_size, + outer_wrapper.block_size, + outer_wrapper.n_blocks, + outer_wrapper.scaler_block_size, quantized_scalers, quantization_factor, - args[0].scaler_mean, + outer_wrapper.scaler_mean, quantized_data, - args[0].nf4, + outer_wrapper.nf4, ) for quantized_scalers, quantization_factor, quantized_data in zip( quantized_scalers_chunks, quantization_factor_chunks, quantized_data_chunks ) @@ -123,37 +134,44 @@ def nf4_split(aten_op, args, kwargs=None): ] ) def nf4_new_zeros(aten_op, args, kwargs=None): - assert len(args[0].shape) == 2 and len(args[1]) == 2, "only support new zeros on 2D" - assert args[0].numel() % math.prod(args[1]) == 0 - ratio = args[0].numel() // math.prod(args[1]) + outer_wrapper = args[0] + new_size = args[1] + new_size_dim = len(new_size) + if (not new_size_dim in [1, 2]) or outer_wrapper.numel() % math.prod(new_size) != 0: + raise NotImplementedError(f"aten.new_zeros(NF4Tensor) with new size {new_size}") + ratio = outer_wrapper.numel() // math.prod(new_size) - assert args[0].quantized_scalers.size(0) % ratio == 0, f"quantized_scalers.numel() must be divisible by {ratio}" - quantized_scalers_new_zeros = aten_op(args[0].quantized_scalers, [args[0].quantized_scalers.size(0) // ratio], **kwargs) + assert outer_wrapper.quantized_scalers.size(0) % ratio == 0, f"quantized_scalers.numel() must be divisible by {ratio}" + quantized_scalers_new_zeros = aten_op(outer_wrapper.quantized_scalers, [outer_wrapper.quantized_scalers.size(0) // ratio], **kwargs) - assert args[0].quantization_factor.size(0) % ratio == 0, f"quantization_factor.size(0) must be divisible by {ratio}" - quantization_factor_new_zeros = aten_op(args[0].quantization_factor, [args[0].quantization_factor.size(0) // ratio], **kwargs) + assert outer_wrapper.quantization_factor.size(0) % ratio == 0, f"quantization_factor.size(0) must be divisible by {ratio}" + quantization_factor_new_zeros = aten_op(outer_wrapper.quantization_factor, [outer_wrapper.quantization_factor.size(0) // ratio], **kwargs) - assert args[0].quantized_data.size(0) % ratio == 0, f"quantized_data.size(0) must be divisible by {ratio}" - quantized_data_new_zeros = aten_op(args[0].quantized_data, [args[0].quantized_data.size(0) // ratio], **kwargs) + assert outer_wrapper.quantized_data.size(0) % ratio == 0, f"quantized_data.size(0) must be divisible by {ratio}" + quantized_data_new_zeros = aten_op(outer_wrapper.quantized_data, [outer_wrapper.quantized_data.size(0) // ratio], **kwargs) + if new_size_dim == 1: + new_size = (new_size[0], ) + elif new_size_dim == 2: + new_size = (new_size[0], new_size[1]) return NF4Tensor( SubclassTensorArgs( - (args[1][0], args[1][1]), - args[0].stride(), - args[0].storage_offset(), - args[0].dtype, - args[0].device, - args[0].requires_grad, + new_size, + outer_wrapper.stride(), + outer_wrapper.storage_offset(), + outer_wrapper.dtype, + outer_wrapper.device, + outer_wrapper.requires_grad, ), - args[0].block_size, - args[0].n_blocks, - args[0].scaler_block_size, + outer_wrapper.block_size, + outer_wrapper.n_blocks, + outer_wrapper.scaler_block_size, quantized_scalers_new_zeros, quantization_factor_new_zeros, - args[0].scaler_mean, + outer_wrapper.scaler_mean, quantized_data_new_zeros, - args[0].nf4, + outer_wrapper.nf4, ) @implements( @@ -162,10 +180,16 @@ def nf4_new_zeros(aten_op, args, kwargs=None): ] ) def nf4_slice(aten_op, args, kwargs=None): - assert len(args) == 4 - assert args[1] == 0, f"only support dim=0 but got dim={args[1]}" - assert args[2] == 0, f"only support start=0 but got start={args[2]}" - assert args[3] == args[0].size(0), f"only support end == size(0) but got end={args[3]} and size(0)={args[0].size(0)}" + if len(args) == 5: + raise NotImplementedError(f"aten.slice(NF4Tensor) with step={args[4]}") + if not args[1] == 0: + raise NotImplementedError(f"aten.slice(NF4Tensor) with dim={args[1]}") + if not args[2] == 0: + raise NotImplementedError(f"aten.slice(NF4Tensor) with start={args[2]}") + # for tensor 512 x 512, tensor[:, :512] dispatch to + # aten.slice(dim = 0, end=sys.maxsize) + if not args[3] in [args[0].size(0), sys.maxsize]: + raise NotImplementedError(f"aten.slice(NF4Tensor) with end={args[3]}") return NF4Tensor( SubclassTensorArgs( args[0].size(), @@ -191,7 +215,7 @@ def nf4_slice(aten_op, args, kwargs=None): ] ) def nf4_view(aten_op, args, kwargs=None): - assert len(args) == 2, args[1] == -1 + assert len(args) == 2 and len(args[1]) == 1 and args[1][0] == -1 quantized_scalers = aten_op(args[0].quantized_scalers, *(args[1:]), **kwargs) quantization_factor = aten_op(args[0].quantization_factor, *(args[1:]), **kwargs) quantized_data = aten_op(args[0].quantized_data, *(args[1:]), **kwargs)