Skip to content

Commit

Permalink
Adding example for quantized tensor + tensor parallelism (#785)
Browse files Browse the repository at this point in the history
* [WIP] Adding example for quantized tensor + tensor parallelism

Summary:
This PR adds an example of how quantized tensor subclass can work with DTensor: https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/README.md

End goal is to rewrite https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama2.py with normal llama2 implementation and show case with DTensor + AffineQuantizedTensor + torch.compile we can get on par performance with the custom tensor parallel implementation

Test Plan:
torchrun --standalone --nnodes=1 --nproc-per-node=4 tutorials/developer_api_guide/tensor_parallel.py

Reviewers:

Subscribers:

Tasks:

Tags:

* tensor parallel file

* Use DTensor.from instead of distribute_tensor

* implementing aten.slice.Tensor (WIP)

* working

* some shape fix and use more quant primitive ops

* Add rowwise test

* make rowwise sharding work

* compile still not working yet

* fake tensor didn't pick up shape changes from transpose

* backend='eager'

* change transpose to non-inplace op

* add error message

* works now with torch nightly

* remove print

* ruff

* Clean up

* Fix device id

---------

Co-authored-by: Ke Wen <kw2501@meta.com>
  • Loading branch information
jerryzh168 and kwen2501 authored Sep 23, 2024
1 parent 1d6f8e2 commit 9680c48
Show file tree
Hide file tree
Showing 3 changed files with 294 additions and 35 deletions.
132 changes: 100 additions & 32 deletions tutorials/developer_api_guide/my_dtype_tensor_subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
import torch

from torch.utils._python_dispatch import return_and_correct_aliasing
from torchao.quantization.quant_primitives import choose_qparams_affine, MappingType
from torchao.quantization.quant_primitives import (
choose_qparams_affine,
MappingType,
quantize_affine,
dequantize_affine,
)
from torchao.dtypes.utils import (
LayoutType,
PlainLayoutType,
Expand All @@ -24,6 +29,32 @@

aten = torch.ops.aten

# TODO: move to torchao/utils.py
def fill_defaults(args, n, defaults_tail):
"""
__torch_dispatch__ doesn't guarantee the number of arguments you are
passed (e.g., defaulted arguments are not passed); but usually it is
convenient to pad out the arguments list with defaults. This function
helps you do that.
Args:
args: the list of positional arguments passed to __torch_dispatch__
n: the number of arguments you are expecting to get
defaults_tail: default values for the arguments, starting from the
end of the list
Example:
>>> fill_defaults([1, 2, 3], 5, [3, 4, 5])
[1, 2, 3, 4, 5]
>>> fill_defaults([1, 2, 3], 5, [None, None, None])
[1, 2, 3, None, None]]
"""
if n - len(defaults_tail) > len(args):
raise RuntimeError("not enough defaults to fill arguments")
r = list(args)
for i in range(len(args), n):
r.append(defaults_tail[i - n + len(defaults_tail)])
return r


###############################
# Base Layout Tensor Subclass #
###############################
Expand Down Expand Up @@ -140,10 +171,10 @@ def from_float(
layout_type: LayoutType = PlainLayoutType(),
):
mapping_type = MappingType.SYMMETRIC
block_size = input_float.shape
block_size = (1, input_float.shape[-1])
dtype = torch.int16
scale, _ = choose_qparams_affine(input_float, mapping_type, block_size, dtype)
int_data = (input_float / scale).to(torch.int8)
scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, dtype)
int_data = quantize_affine(input_float, block_size, scale, zero_point, dtype)
layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type))
layout_tensor = layout_tensor_ctr(int_data, scale, layout_type)
return cls(layout_tensor, input_float.shape)
Expand All @@ -160,7 +191,14 @@ def dequantize(self, output_dtype=None):
if output_dtype is None:
output_dtype = torch.get_default_dtype()
int_data, scale = self.layout_tensor.get_plain()
return int_data.to(output_dtype) * scale
transposed = False
block_size = (1, int_data.shape[-1])
if hasattr(self.layout_tensor, "transposed") and self.layout_tensor.transposed:
transposed = True
res = dequantize_affine(int_data, block_size, scale, None, int_data.dtype, output_dtype=output_dtype)
if transposed:
res = res.t()
return res

def __repr__(self):
return (
Expand Down Expand Up @@ -203,6 +241,7 @@ def __new__(
cls,
int_data: torch.Tensor,
scale: torch.Tensor,
transposed: bool,
layout_type: LayoutType,
):
kwargs = {}
Expand All @@ -219,22 +258,24 @@ def __init__(
self,
int_data: torch.Tensor,
scale: torch.Tensor,
transposed: bool,
layout_type: LayoutType,
):
self.int_data = int_data
self.scale = scale
self.transposed = transposed
self.layout_type = layout_type

def __tensor_flatten__(self):
return ["int_data", "scale"], [self.layout_type]
return ["int_data", "scale"], [self.transposed, self.layout_type]

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
int_data, scale = tensor_data_dict["int_data"], tensor_data_dict["scale"]
layout_type, = tensor_attributes
return cls(int_data, scale, layout_type)
transposed, layout_type, = tensor_attributes
return cls(int_data, scale, transposed, layout_type)

@classmethod
def from_plain(
Expand All @@ -247,12 +288,13 @@ def from_plain(
extra metadata for packing etc.
"""
assert isinstance(layout_type, PlainLayoutType)
return cls(int_data, scale, layout_type)
return cls(int_data, scale, False, layout_type)

def _apply_fn_to_data(self, fn):
return self.__class__(
fn(self.int_data),
fn(self.scale),
self.transposed,
self.layout_type,
)

Expand All @@ -265,8 +307,36 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)

# Tensor parallel support START
elif func in [aten._to_copy.default, aten.clone.default]:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)
elif func is aten.split.Tensor:
int_data_list = func(args[0].int_data, *args[1:], **kwargs)
scale_list = func(args[0].scale, *args[1:], **kwargs)
out = [PlainMyDTypeLayout(int_data, scale, args[0].transposed, args[0].layout_type) for int_data, scale in zip(int_data_list, scale_list)]
return out
elif func is aten.empty_like.default:
int_data_empty_like = func(args[0].int_data, *args[1:], **kwargs)
return PlainMyDTypeLayout(int_data_empty_like, args[0].scale, args[0].transposed, args[0].layout_type)
elif func is aten.slice.Tensor:
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
if dim == 0:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step))
)
elif dim == 1:
return PlainMyDTypeLayout(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1, 1), self.transposed, self.layout_type)
else:
raise NotImplementedError(f"PlainMyDTypeLayout dispatch: attempting to run {func}, with dim={dim}, that is not supported")
elif func is aten.t.default:
return return_and_correct_aliasing(func, args, kwargs, PlainMyDTypeLayout(args[0].int_data, args[0].scale, not args[0].transposed, args[0].layout_type))

# Tensor parallel support END

raise NotImplementedError(
f"MyDTypeLayout dispatch: attempting to run {func}, this is not supported"
f"PlainMyDTypeLayout dispatch: attempting to run {func}, this is not supported"
)

#####################################################
Expand Down Expand Up @@ -315,15 +385,6 @@ def _(func, types, args, kwargs):
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)


class M(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.linear = torch.nn.Linear(1024, 1024)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)

#####################
# Factory functions #
#####################
Expand All @@ -333,42 +394,49 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
########
# Test #
########

def test():
def main():
from torchao.utils import benchmark_model


class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(1024, 128)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)

m = M()
example_inputs = (100 * torch.randn(1024, 1024),)
example_inputs = (100 * torch.randn(512, 1024),)
NUM_WARMUPS = 10
NUM_RUNS = 100

for _ in range(NUM_WARMUPS):
m(*example_inputs)
print("before quantization:", benchmark_model(m, NUM_RUNS, example_inputs))

compiled = torch.compile(m, mode="max-autotune")
for _ in range(NUM_WARMUPS):
compiled(*example_inputs)
print("after compile:", benchmark_model(compiled, NUM_RUNS, example_inputs))

# convert weights to quantized weights
m.linear.weight = torch.nn.Parameter(
to_my_dtype(m.linear.weight), requires_grad=False
)

for _ in range(NUM_WARMUPS):
m(*example_inputs)

print("after quantization:", benchmark_model(m, NUM_RUNS, example_inputs))

m = torch.compile(m, mode="max-autotune")

for _ in range(NUM_WARMUPS):
m(*example_inputs)

# NOTE: currently there is no speedup because we just dequantize the weight in the _quantized_linear op
# we plan to add custom op example in the future and that will help us to get speedup
print("after quantization and compile:", benchmark_model(m, NUM_RUNS, example_inputs))

if __name__ == "__main__":
test()
main()
6 changes: 3 additions & 3 deletions tutorials/developer_api_guide/my_trainable_tensor_subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def from_float(
return _ToMyTrainableDTypeTensor.apply(input_float, layout_type)

class _ToMyTrainableDTypeTensor(torch.autograd.Function):
"""
"""
Differentiable constructor for `MyTrainableDTypeTensor`.
"""

Expand Down Expand Up @@ -163,8 +163,8 @@ def _(func, types, args, kwargs):
########

class M(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(512, 1024, bias=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down
Loading

0 comments on commit 9680c48

Please sign in to comment.