Skip to content

Commit

Permalink
Add developer guide code to tutorials (#588)
Browse files Browse the repository at this point in the history
Summary:
Moved notebook: https://colab.research.google.com/drive/1jqC53MwiW9dSiPS-a6hng_yo0ywdc3nH#scrollTo=Aj9ii4darSRA from #391 to `tutorials` folder
so that code can be executed while we develop new APIs/utils and being kept up to date

Test Plan:
python

Reviewers:
python tutorials/developer_api_guide.py

regression tests:
python test/quantization/test_quant_api.py
python test/integration/test_integraton.py

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 authored Aug 2, 2024
1 parent 08024c6 commit 0844de3
Show file tree
Hide file tree
Showing 8 changed files with 483 additions and 92 deletions.
147 changes: 79 additions & 68 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,71 +22,17 @@
_register_layout_cls,
_get_layout_tensor_constructor,
LayoutType,
PlainLayoutType,
is_device,
)
from typing import ClassVar
from dataclasses import dataclass
from torchao.utils import TORCH_VERSION_AFTER_2_5

aten = torch.ops.aten

@dataclass(frozen=True)
class PlainLayoutType(LayoutType):
pass

@dataclass(frozen=True)
class SemiSparseLayoutType(LayoutType):

def pre_process(self, input: torch.Tensor) -> torch.Tensor:
# prune to 2:4 if not already
temp = input.detach()
pruning_inds = temp.abs().view(-1, 4).argsort(dim=1)[:, :2]
temp.view(-1, 4).scatter_(1, pruning_inds, value=0)
return temp


@dataclass(frozen=True)
class TensorCoreTiledLayoutType(LayoutType):
inner_k_tiles: int = 8

def pre_process(self, input: torch.Tensor) -> torch.Tensor:
orig_out_features, orig_in_features = input.shape
in_features = find_multiple(orig_in_features, 1024)
out_features = find_multiple(orig_out_features, 8)
input = torch.nn.functional.pad(
input,
(0, in_features - orig_in_features, 0, out_features - orig_out_features),
)
return input

def extra_repr(self):
return f"inner_k_tiles={self.inner_k_tiles}"


def _aqt_is_int8(aqt):
"""Check if an AffineQuantizedTensor is int8 quantized Tensor"""
return (
aqt.layout_tensor.dtype == torch.int8 and
aqt.quant_min is None or aqt.quant_min == -128 and
aqt.quant_max is None or aqt.quant_max == 127
)

def _aqt_is_int8_reduced_range(aqt):
return (
aqt.layout_tensor.dtype == torch.int8 and
aqt.quant_min == -127 and
aqt.quant_max is None or aqt.quant_max == 127
)

def _aqt_is_uint4(aqt):
"""Check if an AffineQuantizedTensor is uint4 quantized Tensor"""
# TODO: use torch.uint4
return (
aqt.layout_tensor.dtype == torch.int32 and
aqt.quant_min is None or aqt.quant_min == 0 and
aqt.quant_max is None or aqt.quant_max == 15
)

###############################
# Base Layout Tensor Subclass #
###############################
class AQTLayout(torch.Tensor):
"""
Base class for the layout tensor for `AffineQuantizedTensor`
Expand Down Expand Up @@ -126,6 +72,10 @@ def _get_to_kwargs(self, *args, **kwargs):
}
return kwargs

##############################
# Tensor Subclass Definition #
##############################

class AffineQuantizedTensor(torch.Tensor):
"""
Affine quantized tensor subclass. Affine quantization means we quantize the floating point tensor with an affine transformation:
Expand Down Expand Up @@ -337,7 +287,6 @@ def _apply_fn_to_data(self, fn):
strides=self.stride(),
)


implements = classmethod(_implements)
# Note: we only added cpu path here for 8da4w, this is for executorch, in the future
# 1. we'll add cpu/cuda version (int4mm etc.)
Expand All @@ -353,14 +302,46 @@ def _apply_fn_to_data(self, fn):
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
__torch_function__ = classmethod(_dispatch__torch_function__)

implements = AffineQuantizedTensor.implements

######################################################
# LayoutType and Layout Tensor Subclass Registration #
######################################################

def register_layout_cls(layout_type_class: type(LayoutType)):
return _register_layout_cls(AffineQuantizedTensor, layout_type_class)

def get_layout_tensor_constructor(layout_type_class: type(LayoutType)):
return _get_layout_tensor_constructor(AffineQuantizedTensor, layout_type_class)

@dataclass(frozen=True)
class SemiSparseLayoutType(LayoutType):

def pre_process(self, input: torch.Tensor) -> torch.Tensor:
# prune to 2:4 if not already
temp = input.detach()
pruning_inds = temp.abs().view(-1, 4).argsort(dim=1)[:, :2]
temp.view(-1, 4).scatter_(1, pruning_inds, value=0)
return temp


@dataclass(frozen=True)
class TensorCoreTiledLayoutType(LayoutType):
inner_k_tiles: int = 8

def pre_process(self, input: torch.Tensor) -> torch.Tensor:
orig_out_features, orig_in_features = input.shape
in_features = find_multiple(orig_in_features, 1024)
out_features = find_multiple(orig_out_features, 8)
input = torch.nn.functional.pad(
input,
(0, in_features - orig_in_features, 0, out_features - orig_out_features),
)
return input

def extra_repr(self):
return f"inner_k_tiles={self.inner_k_tiles}"


@register_layout_cls(PlainLayoutType)
class PlainAQTLayout(AQTLayout):
"""
Expand Down Expand Up @@ -487,7 +468,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
)

def get_plain(self):
# Currently we don't have cuSPARSELt expansion routines, so we matmul by
# Currently we don't have cuSPARSELt expansion routines, so we matmul by
# the identity matrix to get the original dense matrix. This is slow though.
cols = self.int_data.numel() * 16 // (10 * self.scale.shape[0])
int_data_expanded = torch._cslt_sparse_mm(self.int_data,
Expand All @@ -507,7 +488,7 @@ def from_plain(
assert isinstance(layout_type, SemiSparseLayoutType)
int_data_compressed = torch._cslt_compress(int_data)
return cls(int_data_compressed, scale, zero_point, layout_type)


@register_layout_cls(TensorCoreTiledLayoutType)
class TensorCoreTiledAQTLayout(AQTLayout):
Expand Down Expand Up @@ -654,6 +635,34 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def get_layout_type(self) -> LayoutType:
return self.layout_type

#####################################################
# torch functional and aten operator implementation #
#####################################################

def _aqt_is_int8(aqt):
"""Check if an AffineQuantizedTensor is int8 quantized Tensor"""
return (
aqt.layout_tensor.dtype == torch.int8 and
aqt.quant_min is None or aqt.quant_min == -128 and
aqt.quant_max is None or aqt.quant_max == 127
)

def _aqt_is_int8_reduced_range(aqt):
return (
aqt.layout_tensor.dtype == torch.int8 and
aqt.quant_min == -127 and
aqt.quant_max is None or aqt.quant_max == 127
)

def _aqt_is_uint4(aqt):
"""Check if an AffineQuantizedTensor is uint4 quantized Tensor"""
# TODO: use torch.uint4
return (
aqt.layout_tensor.dtype == torch.int32 and
aqt.quant_min is None or aqt.quant_min == 0 and
aqt.quant_max is None or aqt.quant_max == 15
)

def _quantized_linear_op(input_tensor, weight_qtensor, bias):
"""
Quantized version of F.linear operator
Expand Down Expand Up @@ -811,8 +820,10 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
raise NotImplementedError("No specialized dispatch found for quantized linear op")


implements = AffineQuantizedTensor.implements

@implements(torch.nn.functional.linear)
def _(func, types, *args, **kwargs):
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
args[1],
Expand All @@ -831,7 +842,7 @@ def _(func, types, *args, **kwargs):
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)

@implements([aten.mm.default, aten.addmm.default])
def _(func, types, *args, **kwargs):
def _(func, types, args, kwargs):
if not args[0].is_floating_point():
raise NotImplementedError(f"{func} is not implemented for non floating point input")

Expand Down Expand Up @@ -870,21 +881,21 @@ def _(func, types, *args, **kwargs):
return func(input_tensor, weight_tensor)

@implements([aten.detach.default])
def _(func, types, *args, **kwargs):
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)


@implements([aten.clone.default])
def _(func, types, *args, **kwargs):
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)


@implements([aten._to_copy.default])
def _(func, types, *args, **kwargs):
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func,
args,
Expand All @@ -893,7 +904,7 @@ def _(func, types, *args, **kwargs):
)

@implements([aten.t.default])
def _(func, types, *args, **kwargs):
def _(func, types, args, kwargs):
block_size = args[0].block_size
assert len(block_size) == 2
transposed_block_size = (block_size[1], block_size[0])
Expand Down
16 changes: 12 additions & 4 deletions torchao/dtypes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def _(func, types, args, kwargs):
def decorator(func):
for op in aten_ops_or_torch_fns:
@functools.wraps(op)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
def wrapper(f, types, args, kwargs):
return func(f, types, args, kwargs)

cls._ATEN_OP_OR_TORCH_FN_TABLE[op] = wrapper
return func
Expand All @@ -50,7 +50,7 @@ class MyTensor(torch.Tensor):
kwargs = {} if kwargs is None else kwargs
if hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE") and \
func in cls._ATEN_OP_OR_TORCH_FN_TABLE:
return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, *args, **kwargs)
return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs)

with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
Expand All @@ -65,7 +65,7 @@ class MyTensor(torch.Tensor):
"""
if hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE") and \
func in cls._ATEN_OP_OR_TORCH_FN_TABLE:
return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, *args, **kwargs)
return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs)

raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run unimplemented operator/function: {func}")

Expand All @@ -87,6 +87,14 @@ def __repr__(self):
def extra_repr(self) -> str:
return ""

"""
Plain LayoutType, the most basic LayoutType, also has no extra metadata, will typically be the default
"""
@dataclass(frozen=True)
class PlainLayoutType(LayoutType):
pass


"""
layout tensor constructor registration for different tensor subclassesa
Expand Down
8 changes: 4 additions & 4 deletions torchao/prototype/low_bit_optim/subclass_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __repr__(self):


@OptimState4bit.implements(aten.copy_.default)
def _(func, types, *args, **kwargs):
def _(func, types, args, kwargs):
dst = args[0]
src = args[1]

Expand All @@ -116,14 +116,14 @@ def _(func, types, *args, **kwargs):


@OptimState4bit.implements(aten.lerp.Scalar)
def _(func, types, *args, **kwargs):
def _(func, types, args, kwargs):
args = [x.dequantize() if isinstance(x, OptimState4bit) else x for x in args]
return func(*args, **kwargs)


# this is needed for DTensor.from_local() and for flattening tensor
@OptimState4bit.implements(aten.view.default)
def _(func, types, *args, **kwargs):
def _(func, types, args, kwargs):
x, shape = args

if tuple(x.shape) == tuple(shape):
Expand All @@ -142,7 +142,7 @@ def _(func, types, *args, **kwargs):
c10d_functional.wait_tensor.default,
_c10d_functional.wait_tensor.default,
])
def _(func, types, *args, **kwargs):
def _(func, types, args, kwargs):
x = args[0]
if not isinstance(x, OptimState4bit):
raise ValueError(f"expecting a OptimState4bit but found {type(x)}")
Expand Down
8 changes: 4 additions & 4 deletions torchao/prototype/low_bit_optim/subclass_8bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __repr__(self):


@OptimState8bit.implements(aten.copy_.default)
def _(func, types, *args, **kwargs):
def _(func, types, args, kwargs):
dst = args[0]
src = args[1]

Expand All @@ -98,14 +98,14 @@ def _(func, types, *args, **kwargs):


@OptimState8bit.implements(aten.lerp.Scalar)
def _(func, types, *args, **kwargs):
def _(func, types, args, kwargs):
args = [x.dequantize() if isinstance(x, OptimState8bit) else x for x in args]
return func(*args, **kwargs)


# this is needed for DTensor.from_local()
@OptimState8bit.implements(aten.view.default)
def _(func, types, *args, **kwargs):
def _(func, types, args, kwargs):
x, shape = args
return OptimState8bit(x.codes.view(shape), x.scale, x.qmap, x.signed)

Expand All @@ -117,7 +117,7 @@ def _(func, types, *args, **kwargs):
c10d_functional.wait_tensor.default,
_c10d_functional.wait_tensor.default,
])
def _(func, types, *args, **kwargs):
def _(func, types, args, kwargs):
x = args[0]
if not isinstance(x, OptimState8bit):
raise ValueError(f"expecting a OptimState8bit but found {type(x)}")
Expand Down
8 changes: 4 additions & 4 deletions torchao/prototype/low_bit_optim/subclass_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __repr__(self):


@OptimStateFp8.implements(aten.copy_.default)
def _(func, types, *args, **kwargs):
def _(func, types, args, kwargs):
dst = args[0]
src = args[1]

Expand All @@ -102,14 +102,14 @@ def _(func, types, *args, **kwargs):


@OptimStateFp8.implements(aten.lerp.Scalar)
def _(func, types, *args, **kwargs):
def _(func, types, args, kwargs):
args = [x.dequantize() if isinstance(x, OptimStateFp8) else x for x in args]
return func(*args, **kwargs)


# this is needed for DTensor.from_local()
@OptimStateFp8.implements(aten.view.default)
def _(func, types, *args, **kwargs):
def _(func, types, args, kwargs):
x, shape = args
return OptimStateFp8(x.codes.view(shape), x.scale)

Expand All @@ -121,7 +121,7 @@ def _(func, types, *args, **kwargs):
c10d_functional.wait_tensor.default,
_c10d_functional.wait_tensor.default,
])
def _(func, types, *args, **kwargs):
def _(func, types, args, kwargs):
x = args[0]
if not isinstance(x, OptimStateFp8):
raise ValueError(f"expecting a OptimStateFp8 but found {type(x)}")
Expand Down
Loading

0 comments on commit 0844de3

Please sign in to comment.