Skip to content

Commit

Permalink
Add support for AQTStorage and PlainAQTStorage
Browse files Browse the repository at this point in the history
Summary:
Today `AffineQuantizedTensor` has hardcoded storage format of `int_data`, `scale`, `zero_point`. But this does not work if we want to support
packed weight. In this PR, we added support to hide the storage details for `AffineQuantizedTensor` in a family of tensor subclasses, all
should inherit from the base Storage type: `AQTStorage` (affine quantized tensor storage)

This PR just added support for a plain storage tensor (`PlainAQTStorage`) that stores `int_data`, `scale` and `zero_point` tensors directly,
in the next PR we'll also support storing packed weight (result of `torch.ops.aten._convert_weight_to_int4pack`) in a different
type of `AQTStorage`.

`AffineQuantizedTensor` will have the following:
- storage_tensor: AQTStorage (can store data of different storage formats)
- storage_layout: str (a string represents the type of storage_tensor we have, can be used in dispatch)

Test Plan:
python test/quantization/test_quant_api.py

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed May 29, 2024
1 parent 90b5e17 commit 8353c20
Show file tree
Hide file tree
Showing 2 changed files with 259 additions and 53 deletions.
10 changes: 5 additions & 5 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ def __init__(self, m=64, n=32, k=64):
self.linear1 = torch.nn.Linear(m, n, bias=False).to(torch.float)
self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float)

def example_inputs(self, batch_size=1):
return (torch.randn(batch_size, self.linear1.in_features).to(torch.float),)
def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"):
return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),)

def forward(self, x):
x = self.linear1(x)
Expand Down Expand Up @@ -482,10 +482,10 @@ def test_quantized_tensor_subclass_int4(self):
# use 1024 so that we don't need padding
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs()))
example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda")

def apply_weight_quant(weight):
return to_aq(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain)
return to_aq(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, extended_layout="tensor_core_tiled")

m = quantize(m, apply_weight_quant)
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
Expand Down Expand Up @@ -562,7 +562,7 @@ def get_per_token_block_size(x):
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
m_copy = copy.deepcopy(m)
# setting batch_size to 20 to be compatible with the kernel
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs(batch_size=20)))
example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda")

def apply_weight_quant(weight):
block_size = get_weight_block_size(weight)
Expand Down
Loading

0 comments on commit 8353c20

Please sign in to comment.