From 971d1408fbf230014651c74524403568cb9d3b62 Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Fri, 12 Apr 2024 18:03:36 +0000 Subject: [PATCH] before tests, correctness verified --- bin/quant.py | 60 ------------------- .../quantization/lifecycle/forward.py | 33 ---------- .../quantization/lifecycle/initialize.py | 5 -- .../quantization/observers/base.py | 1 - .../quantization/observers/memoryless.py | 3 - src/sparsetensors/quantization/quant_args.py | 3 +- 6 files changed, 2 insertions(+), 103 deletions(-) delete mode 100644 bin/quant.py diff --git a/bin/quant.py b/bin/quant.py deleted file mode 100644 index 94bfc448..00000000 --- a/bin/quant.py +++ /dev/null @@ -1,60 +0,0 @@ -import torch -from torch.nn import Linear - -from sparsetensors.quantization.quant_args import QuantizationArgs -from sparsetensors.quantization.quant_scheme import QuantizationScheme -from sparsetensors.quantization.lifecycle.initialize import initialize_module_for_quantization -from sparsetensors.quantization.lifecycle.calibration import set_module_for_calibration -from sparsetensors.quantization.lifecycle.frozen import freeze_module_quantization -num_bits = 8 - -scheme = QuantizationScheme( - input_acivations=QuantizationArgs(num_bits=num_bits, symmetric=False), - weights=QuantizationArgs(num_bits=num_bits, symmetric=True), - output_activations=None, - targets = ["*"], -) - -layer = Linear(4, 4) -print(layer) -print(dict(layer.named_parameters())) - - -initialize_module_for_quantization(layer, scheme) -print(layer) # should see observer under layer now -print(0) -print(dict(layer.named_parameters())) # should see empty tensors for scale and zero point now -print(1) - - -set_module_for_calibration(layer) -# do a calibration step -layer(torch.randn(4,4)) -print(dict(layer.named_parameters())) # scale and zero point should have updated values -print(2) -print("calib layers ") -for i in range(10): - print("iter", i) - layer(torch.randn(4,4)) -print(dict(layer.named_parameters())) # scale and zero point should have updated values again since we did another pass - -print(3) -# breakpoint() - - -freeze_module_quantization(layer) -print("freeze layers ") -for i in range(10): - # do more forward passes but show args are frozen - print("iter", i) - layer(torch.randn(4,4)) -print(dict(layer.named_parameters())) # scale and zero point should not be updated now - - -# # missing - -# # correctness -# # quantizing an entire model - - - diff --git a/src/sparsetensors/quantization/lifecycle/forward.py b/src/sparsetensors/quantization/lifecycle/forward.py index cbb27dea..ab20e29b 100644 --- a/src/sparsetensors/quantization/lifecycle/forward.py +++ b/src/sparsetensors/quantization/lifecycle/forward.py @@ -16,11 +16,6 @@ import torch from sparsetensors.quantization.lifecycle.status import QuantizationStatus - -# from sparsetensors.quantization.utils.quantization_scheme import ( -# QuantizationArgs, -# QuantizationScheme, -# ) from sparsetensors.quantization.quant_args import QuantizationArgs from sparsetensors.quantization.quant_scheme import QuantizationScheme from torch.nn import Module @@ -59,35 +54,7 @@ def fake_quantize( args: QuantizationArgs, ) -> torch.Tensor: max_q = torch.tensor(2**args.num_bits - 1) - columns = x.shape[1] Q = torch.zeros_like(x) - # for i1 in range(0, columns, args.block_size): - # i2 = min(i1 + args.block_size, columns) - # count = i2 - i1 - - # W1 = x[:, i1:i2].clone() - # Q1 = torch.zeros_like(W1) - - # for i in range(count): - # w = W1[:, i] - # breakpoint() - # if args.group_size != -1: - # if (i1 + i) % args.group_size == 0: - # xmin, xmax = get_qparams( - # x[:, (i1 + i) : (i1 + i + args.group_size)], args.symmetric - # ) - # scale, zero = get_scale_zero_point( - # x[:, (i1 + i) : (i1 + i + args.group_size)], - # max_q, - # xmax, - # xmin, - # args.symmetric, - # args.group_size, - # ) - - # q = quantize(w.unsqueeze(1), scale, zero, max_q).flatten() - # Q1[:, i] = q - # Q[:, i1:i2] = Q1 Q = quantize(x, scale, zero_point, max_q) return dequantize(Q, scale, zero_point) diff --git a/src/sparsetensors/quantization/lifecycle/initialize.py b/src/sparsetensors/quantization/lifecycle/initialize.py index 6d23f4cc..5661fdbd 100644 --- a/src/sparsetensors/quantization/lifecycle/initialize.py +++ b/src/sparsetensors/quantization/lifecycle/initialize.py @@ -18,11 +18,6 @@ import torch from sparsetensors.quantization.lifecycle.forward import wrap_module_forward_quantized from sparsetensors.quantization.lifecycle.status import QuantizationStatus - -# from sparsetensors.quantization.utils.quantization_scheme import ( -# QuantizationArgs, -# QuantizationScheme, -# ) from sparsetensors.quantization.quant_args import QuantizationArgs from sparsetensors.quantization.quant_scheme import QuantizationScheme from torch.nn import Module, Parameter diff --git a/src/sparsetensors/quantization/observers/base.py b/src/sparsetensors/quantization/observers/base.py index 00cd7561..a3184096 100644 --- a/src/sparsetensors/quantization/observers/base.py +++ b/src/sparsetensors/quantization/observers/base.py @@ -14,7 +14,6 @@ from typing import Optional, Tuple -# from sparsetensors.quantization.utils.quantization_scheme import QuantizationArgs from sparsetensors.quantization.quant_args import QuantizationArgs from sparsezoo.utils.registry import RegistryMixin from torch import FloatTensor, IntTensor, Tensor diff --git a/src/sparsetensors/quantization/observers/memoryless.py b/src/sparsetensors/quantization/observers/memoryless.py index faabbb5a..b69c841d 100644 --- a/src/sparsetensors/quantization/observers/memoryless.py +++ b/src/sparsetensors/quantization/observers/memoryless.py @@ -19,9 +19,6 @@ from torch import FloatTensor, IntTensor, Tensor -# from sparsetensors.quantization.utils.quantization_scheme import QuantizationArgs - - __all__ = ["MemorylessObserver"] diff --git a/src/sparsetensors/quantization/quant_args.py b/src/sparsetensors/quantization/quant_args.py index fb9e9b01..d90fe9bc 100644 --- a/src/sparsetensors/quantization/quant_args.py +++ b/src/sparsetensors/quantization/quant_args.py @@ -64,7 +64,8 @@ class QuantizationArgs(BaseModel): observer: str = Field( default="minmax", description=( - "The class to use to compute the quantization params - scale and zero-point'" + "The class to use to compute the quantization param - " + "scale and zero-point'" ), ) observer_kwargs: Dict[str, Any] = Field(