Skip to content

Commit

Permalink
Supporting tensor parallelism for int8 weight only quant (#939)
Browse files Browse the repository at this point in the history
* [WIP] Supporting tensor parallelism for int8 weight only quant

Summary:
following https://github.com/pytorch/ao/blob/main/tutorials/developer_api_guide/tensor_parallel.py
we can support tensor parallelism for int8 weight only quant, this is needed
for torchchat

Test Plan:
python test/dtypes/test_affine_quantized_tensor_parallel.py

Reviewers:

Subscribers:

Tasks:

Tags:

* implement tp for aqt

* fixes

* import fix

* remove cpu test

* fix

* fix

* fix test

* device

* change transpose impl

* Skip compiled TP test for torch version < 2.5

* version util

* fix

* fix version

---------

Co-authored-by: Ke Wen <kw2501@meta.com>
  • Loading branch information
jerryzh168 and kwen2501 committed Sep 27, 2024
1 parent 63cb7a9 commit 72d2518
Show file tree
Hide file tree
Showing 8 changed files with 210 additions and 32 deletions.
1 change: 1 addition & 0 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,5 +135,6 @@ def test_print_quantized_module(self, apply_quant):

common_utils.instantiate_parametrized_tests(TestAffineQuantized)


if __name__ == "__main__":
run_tests()
12 changes: 12 additions & 0 deletions test/dtypes/test_affine_quantized_tensor_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from torchao.testing.utils import copy_tests, TorchAOTensorParallelTestCase
from torch.testing._internal.common_utils import run_tests
from torchao.quantization import int8_weight_only

class TestAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase):
pass


copy_tests(TorchAOTensorParallelTestCase, TestAffineQuantizedTensorParallel, "aqt_tp")

if __name__ == "__main__":
run_tests()
2 changes: 2 additions & 0 deletions torchao/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@
quantize_,
)
from . import dtypes
from . import testing

__all__ = [
"dtypes",
"autoquant",
"quantize_",
"testing",
]

# test-pytorchbot
Expand Down
52 changes: 49 additions & 3 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
find_multiple,
TorchAOBaseTensor,
TORCH_VERSION_AT_LEAST_2_5,
_is_float8_type
_is_float8_type,
fill_defaults,
)
import logging

Expand Down Expand Up @@ -603,13 +604,25 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)

if func is aten.t.default:
elif func is aten.t.default:
tensor = args[0]
new = tensor.__class__(
tensor.int_data.view(tensor.shape[::-1]), tensor.scale, tensor.zero_point, tensor.layout_type
tensor.int_data.t(), tensor.scale, tensor.zero_point, tensor.layout_type
)
return return_and_correct_aliasing(func, args, kwargs, new)

elif func is aten.slice.Tensor:
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
if dim == 0:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step))
)
elif dim == 1:
assert len(self.scale.shape) == 1, f"slice dim==1 only works when len(scale.shape) == 1 currently, got: {self.scale.shape}"
return PlainAQTLayout(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1), self.zero_point.view(-1), self.layout_type)
else:
raise NotImplementedError(f"PlainAQTLayout dispatch: attempting to run {func}, with dim={dim}, that is not supported")

raise NotImplementedError(
f"PlainAQTLayout dispatch: attempting to run {func}, this is not supported"
)
Expand Down Expand Up @@ -1776,6 +1789,39 @@ def _(func, types, args, kwargs):
)
return return_and_correct_aliasing(func, args, kwargs, new)

@implements(aten.slice.Tensor)
def _(func, types, args, kwargs):
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
assert step == 1
assert dim == 0 or dim == 1, f"Only dim==0 or 1 are supported, got: {dim}"
if end >= self.shape[dim]:
end = self.shape[dim]
shape = list(self.shape)
shape[dim] = end - start
block_size = self.block_size
assert len(block_size) == 2, f"Slice only works for 2d block_size right now, got: {block_size}"
# with slice, some shape dimension might be smaller than block_size dimension, so
# we need to make sure there is no overflow
block_size = (min(shape[0], block_size[0]), min(shape[1], block_size[1]))
new = self.__class__(aten.slice.Tensor(self.layout_tensor, dim, start, end, step), block_size, shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride())
return return_and_correct_aliasing(func, args, kwargs, new)

# this is needed for DTensor.from_local() and for flattening tensor
@implements(aten.view.default)
def _(func, types, args, kwargs):
self, shape = args

if tuple(self.shape) == tuple(shape):
return self.__class__(self.layout_tensor, self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride())

if len(shape) == 1 and shape[0] == -1:
assert len(self.block_size) == 2 and self.block_size[0] == 1
block_size = (self.block_size[1],)
return self.__class__(self.layout_tensor, block_size, (self.numel(),), self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride())

raise ValueError(f"{self.__class__.__name__} only supports .view() with same shape or shape=[-1]")


to_affine_quantized_intx = AffineQuantizedTensor.from_hp_to_intx
to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static
to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx
Expand Down
115 changes: 115 additions & 0 deletions torchao/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
import copy
import torch
import torchao
import os

from torch.testing._internal import common_utils
from torchao.dtypes import AffineQuantizedTensor
from torchao.dtypes import to_affine_quantized_intx
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization import quantize_, int8_weight_only
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

"""
How to use:
Expand Down Expand Up @@ -213,10 +216,122 @@ def test_linear_compile(self, device, dtype):
lp_res = torch.compile(l)(hp_act_tensor)
self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR)

import torch.distributed as dist
from torch.distributed._tensor import DTensor, Replicate, Shard, DeviceMesh
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
NUM_DEVICES,
)

class TorchAOTensorParallelTestCase(DTensorTestBase):
"""Basic test case for tensor subclasses
"""
COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]

TENSOR_SUBCLASS = AffineQuantizedTensor
QUANT_METHOD_FN = staticmethod(int8_weight_only)
QUANT_METHOD_KWARGS = {}

@staticmethod
def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
"""
Shard linear layer of the model in column-wise fashion
"""
# Column-wise is wrt to A^T, so for A it is row-wise.
# Number of rows per rank
orig_weight = m.linear.weight
n_local_rows = orig_weight.size(0) // mesh.size()
rank = mesh.get_local_rank()
local_shard = orig_weight[rank * n_local_rows : (rank + 1) * n_local_rows, :]
# Construct DTensor from local shard
dtensor = DTensor.from_local(local_shard, mesh, [Shard(0)])
# Replace parameter in module
m.linear.weight = torch.nn.Parameter(
dtensor, requires_grad=False
)
return m

@staticmethod
def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
"""
Shard linear layer of the model in row-wise fashion
"""
# Row-wise is wrt to A^T, so for A it is column-wise.
# Number of rows per rank
orig_weight = m.linear.weight
n_local_cols = orig_weight.size(1) // mesh.size()
rank = mesh.get_local_rank()
local_shard = orig_weight[:, rank * n_local_cols : (rank + 1) * n_local_cols]
# Construct DTensor from local shard
dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)])
# Replace parameter in module
m.linear.weight = torch.nn.Parameter(
dtensor, requires_grad=False
)
return m

def quantize(self, m: torch.nn.Module) -> torch.nn.Module:
"""
Quantize the model
"""
quantize_(m, self.QUANT_METHOD_FN(**self.QUANT_METHOD_KWARGS))
return m

@common_utils.parametrize("dtype", COMMON_DTYPES)
@with_comms
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_tp(self, dtype):
device = "cuda"
# To make sure different ranks create the same module
torch.manual_seed(5)

class M(torch.nn.Module):
def __init__(self, in_features, out_features, **kwargs) -> None:
super().__init__(**kwargs)
self.linear = torch.nn.Linear(in_features, out_features, bias=False, device="cuda")

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)

# Get rank and device
device = torch.device(f"cuda:{self.rank % torch.cuda.device_count()}")

# Original model
proj_up = M(1024, 2048).to(device).to(dtype)
proj_dn = M(2048, 1024).to(device).to(dtype)
example_input = 100 * torch.randn(128, 1024, device=device, dtype=dtype)
y = proj_dn(proj_up(example_input))

# Quantize the model
up_quant = self.quantize(proj_up)
dn_quant = self.quantize(proj_dn)
y_q = dn_quant(up_quant(example_input))

mesh = self.build_device_mesh()
# Shard the models
up_dist = self.colwise_shard(up_quant, mesh)
dn_dist = self.rowwise_shard(dn_quant, mesh)

# We need to turn inputs into DTensor form as well -- just a format change
input_dtensor = DTensor.from_local(
example_input, mesh, [Replicate()]
)

y_d = dn_dist(up_dist(input_dtensor))

if not TORCH_VERSION_AT_LEAST_2_5:
# Need torch 2.5 to support compiled tensor parallelism
return

up_compiled = torch.compile(up_dist)
y_up = up_compiled(input_dtensor)
dn_compiled = torch.compile(dn_dist)
y_dn = dn_compiled(y_up)

common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase)
common_utils.instantiate_parametrized_tests(TorchAOCompileTestCase)
common_utils.instantiate_parametrized_tests(TorchAOTensorParallelTestCase)

if __name__ == "__main__":
unittest.main()
24 changes: 24 additions & 0 deletions torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,30 @@ def _get_to_kwargs(self, *args, **kwargs):
}
return kwargs

def fill_defaults(args, n, defaults_tail):
"""
__torch_dispatch__ doesn't guarantee the number of arguments you are
passed (e.g., defaulted arguments are not passed); but usually it is
convenient to pad out the arguments list with defaults. This function
helps you do that.
Args:
args: the list of positional arguments passed to __torch_dispatch__
n: the number of arguments you are expecting to get
defaults_tail: default values for the arguments, starting from the
end of the list
Example:
>>> fill_defaults([1, 2, 3], 5, [3, 4, 5])
[1, 2, 3, 4, 5]
>>> fill_defaults([1, 2, 3], 5, [None, None, None])
[1, 2, 3, None, None]]
"""
if n - len(defaults_tail) > len(args):
raise RuntimeError("not enough defaults to fill arguments")
r = list(args)
for i in range(len(args), n):
r.append(defaults_tail[i - n + len(defaults_tail)])
return r


## Deprecated, will be deleted in the future
def _torch_version_at_least(min_version):
Expand Down
33 changes: 5 additions & 28 deletions tutorials/developer_api_guide/my_dtype_tensor_subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,36 +25,13 @@
LayoutType,
PlainLayoutType,
)
from torchao.utils import TorchAOBaseTensor
from torchao.utils import (
TorchAOBaseTensor,
fill_defaults,
)

aten = torch.ops.aten

# TODO: move to torchao/utils.py
def fill_defaults(args, n, defaults_tail):
"""
__torch_dispatch__ doesn't guarantee the number of arguments you are
passed (e.g., defaulted arguments are not passed); but usually it is
convenient to pad out the arguments list with defaults. This function
helps you do that.
Args:
args: the list of positional arguments passed to __torch_dispatch__
n: the number of arguments you are expecting to get
defaults_tail: default values for the arguments, starting from the
end of the list
Example:
>>> fill_defaults([1, 2, 3], 5, [3, 4, 5])
[1, 2, 3, 4, 5]
>>> fill_defaults([1, 2, 3], 5, [None, None, None])
[1, 2, 3, None, None]]
"""
if n - len(defaults_tail) > len(args):
raise RuntimeError("not enough defaults to fill arguments")
r = list(args)
for i in range(len(args), n):
r.append(defaults_tail[i - n + len(defaults_tail)])
return r


###############################
# Base Layout Tensor Subclass #
###############################
Expand Down Expand Up @@ -327,7 +304,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step))
)
elif dim == 1:
return PlainMyDTypeLayout(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1, 1), self.transposed, self.layout_type)
return PlainMyDTypeLayout(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1), self.transposed, self.layout_type)
else:
raise NotImplementedError(f"PlainMyDTypeLayout dispatch: attempting to run {func}, with dim={dim}, that is not supported")
elif func is aten.t.default:
Expand Down
3 changes: 2 additions & 1 deletion tutorials/developer_api_guide/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from torch.distributed import DeviceMesh
from torch.distributed.tensor import DTensor, Replicate, Shard, Placement
from torch.utils._python_dispatch import return_and_correct_aliasing
from my_dtype_tensor_subclass import MyDTypeTensor, fill_defaults
from my_dtype_tensor_subclass import MyDTypeTensor
from torchao.utils import fill_defaults

# a tensor subclass that supports tensor parallelism with DTensor
class MyDTypeTensorTP(MyDTypeTensor):
Expand Down

0 comments on commit 72d2518

Please sign in to comment.