Skip to content

Commit

Permalink
Refactor int8 weight only quant to use quantize
Browse files Browse the repository at this point in the history
Summary:
Similar to pytorch#294 we replaced the implementation
of int8 weight only quant to used the newly added `quantize` function, as a part of
the unification effort for affine quantization

Test Plan:
1. unit perf test:
python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_int8_wo_quant_perf

elapsed time: 0.23909856796264647, ref elapsed time: 0.25150911331176756
elapsed time: 0.24894208908081056, ref elapsed time: 0.2570047950744629
elapsed time: 0.21607391357421876, ref elapsed time: 0.22809568405151368

2. integration test:

TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py

Reference: elapsed_time:  1.355208740234375  milliseconds
After refactor: elapsed_time:  1.32778857421875  milliseconds

code diff (gist): https://gist.github.com/jerryzh168/921a722cf20d476c8fc5888482e722dc
code diff (meta-only paste): https://www.internalfb.com/phabricator/paste/view/P1387333845

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed May 31, 2024
1 parent cd1ebc8 commit d058ace
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 24 deletions.
74 changes: 64 additions & 10 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,40 @@ def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs
model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs), filter_fn
)

def _ref_change_linear_weights_to_int8_woqtensors(model, filter_fn=None, **kwargs):
"""
The deprecated implementation for int8 weight only quant API, used as a reference for
numerics and performance
"""
from torchao.quantization.quant_api import _is_linear
from torchao.quantization.quant_api import _get_subclass_inserter
from torchao.quantization.subclass import Int8WeightOnlyQuantizedLinearWeight

filter_fn = kwargs.pop("filter_fn", _is_linear)

_replace_with_custom_fn_if_matches_filter(
model,
_get_subclass_inserter(Int8WeightOnlyQuantizedLinearWeight, enable_parametrization=True, **kwargs),
filter_fn,
)

def _ref_change_linear_weights_to_int4_woqtensors(model, **kwargs):
"""
The deprecated implementation for int4 weight only quant API, used as a reference for
numerics and performance
"""
from torchao.quantization.quant_api import _is_linear
from torchao.quantization.quant_api import _get_subclass_inserter
from torchao.quantization.subclass import Int4WeightOnlyQuantizedLinearWeight

filter_fn = kwargs.pop("filter_fn", _is_linear)

_replace_with_custom_fn_if_matches_filter(
model,
_get_subclass_inserter(Int4WeightOnlyQuantizedLinearWeight, enable_parametrization=False, **kwargs),
filter_fn,
)

class TestQuantFlow(unittest.TestCase):
def test_dynamic_quant_gpu_singleline(self):
m = ToyLinearModel().eval()
Expand Down Expand Up @@ -489,7 +523,7 @@ def test_quantized_tensor_subclass_int4(self):

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_int8(self):
def test_quantized_tensor_subclass_int8_wo(self):
m = ToyLinearModel().eval().to(torch.bfloat16)
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs()))
Expand All @@ -501,12 +535,12 @@ def test_quantized_tensor_subclass_int8(self):

# reference
from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors
change_linear_weights_to_int8_woqtensors(m_copy)
_ref_change_linear_weights_to_int8_woqtensors(m_copy)

res = m(*example_inputs)
ref = m_copy(*example_inputs)

torch.testing.assert_close(res, ref, rtol=0.00001, atol=1e-2)
self.assertTrue(torch.equal(res, ref))


@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
Expand Down Expand Up @@ -545,20 +579,20 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
# make sure it compiles
torch._export.aot_compile(m_unwrapped, example_inputs)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skip("This perf test is supposed to be run locally for sanity check performance when there is a change of int8 dynamic quant implementation")
def test_quantized_tensor_subclass_int8_dyn_quant_perf(self):

def _test_quantized_tensor_subclass_perf(self, api, ref_api, kwargs=None):
if kwargs is None:
kwargs = {}

m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
m_ref = copy.deepcopy(m)
# setting batch_size to 20 to be compatible with the kernel
example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda")

from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
change_linear_weights_to_int8_dqtensors(m)
api(m, **kwargs)

# reference
_ref_change_linear_weights_to_int8_dqtensors(m_ref)
ref_api(m_ref, **kwargs)

res = m(*example_inputs)
ref = m_ref(*example_inputs)
Expand All @@ -583,7 +617,27 @@ def test_quantized_tensor_subclass_int8_dyn_quant_perf(self):
print(f"elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}")
self.assertTrue(elapsed_time < 1.05 * ref_elapsed_time)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skip("This perf test is supposed to be run locally for sanity check performance when there is a change of int8 dynamic quant implementation")
def test_quantized_tensor_subclass_int8_dyn_quant_perf(self):
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
self._test_quantized_tensor_subclass_perf(change_linear_weights_to_int8_dqtensors, _ref_change_linear_weights_to_int8_dqtensors)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skip("This perf test is supposed to be run locally for sanity check performance when there is a change of int8 weight only quant implementation")
def test_quantized_tensor_subclass_int8_wo_quant_perf(self):
from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors
self._test_quantized_tensor_subclass_perf(change_linear_weights_to_int8_woqtensors, _ref_change_linear_weights_to_int8_woqtensors)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skip("This perf test is supposed to be run locally for sanity check performance when there is a change of int4 weight only quant implementation")
def test_quantized_tensor_subclass_int4_wo_quant_perf(self):
kwargs = {"groupsize": 32}
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors
self._test_quantized_tensor_subclass_perf(change_linear_weights_to_int4_woqtensors, _ref_change_linear_weights_to_int4_woqtensors, kwargs)

if __name__ == "__main__":
unittest.main()
9 changes: 5 additions & 4 deletions torchao/dtypes/aqt.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,29 +574,30 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
scale_and_zero = weight_qtensor.layout_tensor.scale_and_zero
return torch.ops.aten._weight_int4pack_mm(input_tensor.contiguous(), packed_weight, groupsize, scale_and_zero)
elif (
is_cpu and
weight_is_int8 and
len(weight_qtensor.shape) == 2 and
len(weight_qtensor.block_size) == 2 and
weight_qtensor.block_size[0] == 1 and
weight_qtensor.block_size[1] == weight_qtensor.shape[1] and
weight_qtensor.zero_point_domain == ZeroPointDomain.INT and
weight_qtensor.layout == "plain"
):
# TODO: enable cpu and mps efficient path
# per channel int8 weight only quantizated mm
w_vals_int8_t = weight_qtensor.layout_tensor.int_data.t().contiguous()
w_vals_int8_t = weight_qtensor.layout_tensor.int_data.t()
scale = weight_qtensor.layout_tensor.scale
orig_dtype = input_tensor.dtype
y = (
torch.mm(
input_tensor.reshape(-1, input_tensor.shape[-1]),
w_vals_int8_t.to(input_tensor.dtype),
)
* weight_qtensor.scale
* scale
)
y = y.reshape(*input_tensor.shape[:-1], y.shape[-1])
if bias is not None:
y += bias
return y.to(orig_dtype)
return y.to(orig_dtype)

# is_cpu and is_mps only, some issue with is_contiguous() currently
# return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), w_vals_int8_t, weight_qtensor.layout_tensor.scale)
Expand Down
17 changes: 11 additions & 6 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,13 +206,18 @@ def change_linear_weights_to_int8_woqtensors(model, filter_fn=None, **kwargs):
Converts all linear weight tensors to the
`Int8WeightOnlyQuantizedLinearWeight` tensor subclass,
effectively applying the same form of quantization
as apply_dynamic_quant while not modifying the linear modules.
as apply_weight_only_int8_quant while not modifying the linear modules.
"""
_replace_with_custom_fn_if_matches_filter(
model,
_get_subclass_inserter(Int8WeightOnlyQuantizedLinearWeight, enable_parametrization=TORCH_VERSION_AFTER_2_4, **kwargs),
_is_linear if filter_fn is None else filter_fn,
)

if TORCH_VERSION_AFTER_2_4:
quantize(model, get_apply_int8wo_quant(), filter_fn)
unwrap_tensor_subclass(model, filter_fn)
else:
_replace_with_custom_fn_if_matches_filter(
model,
_get_subclass_inserter(Int8WeightOnlyQuantizedLinearWeight, enable_parametrization=False, **kwargs),
_is_linear if filter_fn is None else filter_fn,
)


def change_linear_weights_to_int4_woqtensors(model, **kwargs):
Expand Down
17 changes: 13 additions & 4 deletions tutorials/quantize_vit/run_vit_b_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,27 @@
input_tensor = torch.randn(1, 3, 224, 224, dtype=torch.bfloat16, device='cuda')

## Quantization code - start
# int8 act, int8 weight dynamic quantization
torchao.apply_dynamic_quant(model)
from torch._inductor import config as inductorconfig
inductorconfig.force_fuse_int_mm_with_mul = True

# int8 weight only quantization
# torchao.quantization.change_linear_weights_to_int8_woqtensors(model)
## Quantization code - end


## compilation configs
torch._dynamo.config.automatic_dynamic_shapes = False
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True
## compilation configs end

model = torch.compile(model, mode='max-autotune', fullgraph=True)

# Must run with no_grad when optimizing for inference
with torch.no_grad():
# warmup
benchmark_model(model, 5, input_tensor)
benchmark_model(model, 20, input_tensor)
# benchmark
print("elapsed_time: ", benchmark_model(model, 100, input_tensor), " milliseconds")
print("elapsed_time: ", benchmark_model(model, 1000, input_tensor), " milliseconds")
# Create a trace
profiler_runner("quant.json.gz", benchmark_model, model, 5, input_tensor)

0 comments on commit d058ace

Please sign in to comment.