Skip to content

Commit

Permalink
Updated with ruff check
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva committed Nov 7, 2024
1 parent 235052f commit 7bc89bc
Show file tree
Hide file tree
Showing 17 changed files with 937 additions and 613 deletions.
27 changes: 24 additions & 3 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .nf4tensor import NF4Tensor, to_nf4

# from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor
from .uintx import UInt4Tensor
from .affine_quantized_tensor import (
Expand All @@ -9,23 +10,32 @@
to_affine_quantized_fpx,
to_affine_quantized_floatx,
to_affine_quantized_floatx_static,
PlainAQTTensorImpl,
)
from .affine_quantized_tensor_ops import *

from . import affine_quantized_tensor_ops
from .utils import (
Layout,
MarlinSparseLayout,
PlainLayout,
)
from .floatx import (
Float8Layout,
Float8AQTTensorImpl,
)
from .uintx import (
UintxTensor,
UintxLayout,
UintxAQTTensorImpl,
to_uintx,
_DTYPE_TO_BIT_WIDTH,
_BIT_WIDTH_TO_DTYPE,
UInt4Tensor,
SemiSparseLayout,
TensorCoreTiledLayout,
MarlinSparseLayout,
PlainAQTTensorImpl,
BlockSparseLayout,
)

__all__ = [
"NF4Tensor",
"to_nf4",
Expand All @@ -43,4 +53,15 @@
"Float8Layout",
"Float8AQTTensorImpl",
"MarlinSparseLayout",
"PlainAQTTensorImpl",
"affine_quantized_tensor_ops",
"BlockSparseLayout",
"to_uintx",
"UintxTensor",
"UintxLayout",
"UintxAQTTensorImpl",
"_DTYPE_TO_BIT_WIDTH",
"_BIT_WIDTH_TO_DTYPE",
"Uint4Tensor",
"PlainAQTTensorImpl",
]
148 changes: 1 addition & 147 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from dataclasses import dataclass
import logging
import math
from typing import Optional, Tuple, Union

import torch
from torch.utils._python_dispatch import return_and_correct_aliasing
from torchao.dtypes.utils import Layout, PlainLayout
from torchao.quantization.quant_primitives import (
FP8_TYPES,
Expand All @@ -29,6 +27,7 @@
logger = logging.getLogger(__name__)
aten = torch.ops.aten


##############################
# Tensor Subclass Definition #
##############################
Expand Down Expand Up @@ -445,151 +444,6 @@ def _apply_fn_to_data(self, fn):
register_layout = AffineQuantizedTensor.register_layout
get_tensor_impl_constructor = AffineQuantizedTensor.get_tensor_impl_constructor


@register_layout(PlainLayout)
class PlainAQTTensorImpl(AQTTensorImpl):
"""
TensorImpl for plain layout for affine quantized tensor, it stores int_data, scale, zero_point
tensors directly as plain tensors.
fields:
int_data (torch.Tensor): the quantized integer data Tensor
scale (torch.Tensor): the scale Tensor used to map between floating point tensor to quantized tensor
zero_point (torch.Tensor): the zero_point Tensor used to map between floating point tensor to quantized tensor
"""

def __new__(
cls,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
_layout: Layout,
):
kwargs = {}
kwargs["device"] = int_data.device
kwargs["layout"] = (
kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout
)
kwargs["dtype"] = int_data.dtype
kwargs["requires_grad"] = False
shape = int_data.shape
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(
self,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
_layout: Layout,
):
self.int_data = int_data
self.scale = scale
self.zero_point = zero_point
self._layout = _layout

def __tensor_flatten__(self):
return ["int_data", "scale", "zero_point"], [self._layout]

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
int_data, scale, zero_point = (
tensor_data_dict["int_data"],
tensor_data_dict["scale"],
tensor_data_dict["zero_point"],
)
(_layout,) = tensor_attributes
return cls(int_data, scale, zero_point, _layout)

def to(self, *args, **kwargs):
kwargs = self._get_to_kwargs(*args, **kwargs)
return self.__class__(
self.int_data.to(kwargs["device"]),
self.scale.to(kwargs["device"]),
self.zero_point.to(kwargs["device"]),
self._layout,
)

def _apply_fn_to_data(self, fn):
return self.__class__(
fn(self.int_data),
fn(self.scale),
fn(self.zero_point),
self._layout,
)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
kwargs = {} if kwargs is None else kwargs

if func is aten.detach.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)

if func is aten.clone.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)

elif func is aten.t.default:
tensor = args[0]
new = tensor.__class__(
tensor.int_data.t(), tensor.scale, tensor.zero_point, tensor._layout
)
return return_and_correct_aliasing(func, args, kwargs, new)

elif func is aten.slice.Tensor:
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
if dim == 0:
return return_and_correct_aliasing(
func,
args,
kwargs,
args[0]._apply_fn_to_data(
lambda x: aten.slice.Tensor(x, dim, start, end, step)
),
)
elif dim == 1:
assert (
len(self.scale.shape) == 1
), f"slice dim==1 only works when len(scale.shape) == 1 currently, got: {self.scale.shape}"
return PlainAQTTensorImpl(
aten.slice.Tensor(self.int_data, dim, start, end, step),
self.scale.view(-1),
self.zero_point.view(-1),
self._layout,
)
else:
raise NotImplementedError(
f"PlainAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported"
)

raise NotImplementedError(
f"PlainAQTTensorImpl dispatch: attempting to run {func}, this is not supported"
)

__torch_function__ = torch._C._disabled_torch_function_impl

def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return self.int_data, self.scale, self.zero_point

def get_layout(self) -> Layout:
return self._layout

@classmethod
def from_plain(
cls,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
_layout: Layout,
):
assert isinstance(_layout, PlainLayout)
return cls(int_data, scale, zero_point, _layout)


#####################################################
# torch functional and aten operator implementation #
#####################################################
Expand Down
Loading

0 comments on commit 7bc89bc

Please sign in to comment.