Skip to content

Commit

Permalink
draft tests before meeting
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Apr 15, 2024
1 parent 557d119 commit bee1fad
Show file tree
Hide file tree
Showing 5 changed files with 297 additions and 4 deletions.
54 changes: 54 additions & 0 deletions bin/quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch
from torch.nn import Linear

from sparsetensors.quantization.quant_args import QuantizationArgs
from sparsetensors.quantization.quant_scheme import QuantizationScheme
from sparsetensors.quantization.lifecycle.initialize import initialize_module_for_quantization
from sparsetensors.quantization.lifecycle.calibration import set_module_for_calibration
from sparsetensors.quantization.lifecycle.frozen import freeze_module_quantization
num_bits = 8

scheme = QuantizationScheme(
input_acivations=QuantizationArgs(num_bits=num_bits, symmetric=False),
weights=QuantizationArgs(num_bits=num_bits, symmetric=True),
output_activations=None,
targets = ["*"],
)

layer = Linear(4, 4)
print(layer)
print(dict(layer.named_parameters()))


initialize_module_for_quantization(layer, scheme)
print(layer) # should see observer under layer now
print(0)
print(dict(layer.named_parameters())) # should see empty tensors for scale and zero point now
print(1)


set_module_for_calibration(layer)
# do a calibration step
layer(torch.randn(4,4))
print(dict(layer.named_parameters())) # scale and zero point should have updated values
print(2)
print("calib layers ")
for i in range(10):
print("iter", i)
layer(torch.randn(4,4))
print(dict(layer.named_parameters())) # scale and zero point should have updated values again since we did another pass

print(3)
# breakpoint()


freeze_module_quantization(layer)
print("freeze layers ")
for i in range(10):
# do more forward passes but show args are frozen
print("iter", i)
layer(torch.randn(4,4))
print(dict(layer.named_parameters())) # scale and zero point should not be updated now


# # missing
5 changes: 1 addition & 4 deletions src/sparsetensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch.nn import Module


__all__ = ["wrap_module_forward_quantized"]
__all__ = ["wrap_module_forward_quantized","quantize","dequantize","fake_quantize"]


def quantize(
Expand Down Expand Up @@ -67,7 +67,6 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
@wraps(forward_func_orig) # ensures docstring, names, etc are propagated
def wrapped_forward(self, *args, **kwargs):
input_ = args[0]

if scheme.input_activations is not None:
# calibrate and (fake) quantize input activations when applicable
input_ = _maybe_calibrate_or_quantize(
Expand Down Expand Up @@ -113,8 +112,6 @@ def _maybe_calibrate_or_quantize(
# zero_point = getattr(module, f"{base_name}_zero_point").data
zero_point = getattr(module, f"{base_name}_zero_point")

print(scale, zero_point)

if module.quantization_status == QuantizationStatus.CALIBRATION:
# get observer and get new quant params from observation
observer = getattr(module, f"{base_name}_observer")
Expand Down
46 changes: 46 additions & 0 deletions tests/sparsetensors/quantization/lifecycle/test_calibration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional

import pytest
from sparsetensors.quantization.quant_args import QuantizationArgs
from sparsetensors.quantization.quant_scheme import QuantizationScheme
from torch.nn import Linear


@pytest.fixture(scope="module")
def create_quantization_scheme():
def quantization_scheme(
targets: List[str],
weights: Optional[QuantizationArgs] = None,
input_activations: Optional[QuantizationArgs] = None,
output_activations: Optional[QuantizationArgs] = None,
):
return QuantizationScheme(
targets=targets,
weights=weights,
input_activations=input_activations,
output_activations=output_activations,
)

return quantization_scheme


def test_set_module_for_calibration(create_quantization_scheme):
quantization_scheme = create_quantization_scheme(
targets=["*"],
)

layer = Linear(4, 4)
132 changes: 132 additions & 0 deletions tests/sparsetensors/quantization/lifecycle/test_end_to_end.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import torch
from torch.nn import Linear

from typing import Optional, List
import pytest
from sparsetensors.quantization.quant_args import QuantizationArgs
from sparsetensors.quantization.quant_scheme import QuantizationScheme
from sparsetensors.quantization.lifecycle.initialize import initialize_module_for_quantization
from sparsetensors.quantization.lifecycle.calibration import set_module_for_calibration
from sparsetensors.quantization.lifecycle.frozen import freeze_module_quantization
from sparsetensors.quantization.lifecycle.status import QuantizationStatus


@pytest.fixture(scope="module")
def create_quantization_scheme():
def quantization_scheme(
targets: List[str],
weights: Optional[QuantizationArgs] = None,
input_activations: Optional[QuantizationArgs] = None,
output_activations: Optional[QuantizationArgs] = None,
):
return QuantizationScheme(
targets=targets,
weights=weights,
input_activations=input_activations,
output_activations=output_activations,
)

return quantization_scheme


def test_lifecyle(create_quantization_scheme):
num_bits = 8

quantization_scheme = create_quantization_scheme(
targets=["*"],
weights=QuantizationArgs(num_bits=num_bits, symmetric=True),
input_activations=QuantizationArgs(num_bits=num_bits, symmetric=False),
)

layer = Linear(4, 4)
layer.weight.data *= 100

# updated layer keys check
expected_layer_keys = {"weight", "bias"}
for key in layer.state_dict().keys():
expected_layer_keys.remove(key)
assert len(expected_layer_keys) == 0


initialize_module_for_quantization(layer, quantization_scheme)
expected_layer_keys = {
"input_scale",
"input_zero_point",
"weight_scale",
"weight_zero_point",
"weight",
"bias",
}
for key in layer.state_dict().keys():
expected_layer_keys.remove(key)
assert len(expected_layer_keys) == 0

assert hasattr(layer, "quantization_scheme")
assert hasattr(layer, "quantization_status")
assert layer.quantization_status == QuantizationStatus.INITIALIZED

set_module_for_calibration(layer)
assert layer.quantization_status == QuantizationStatus.CALIBRATION

# do a calibration step
print(dict(layer.named_parameters())) # scale and zero point should have updated values
original_tensor = layer.weight.data
original_input_zero_point = layer.input_zero_point
original_input_scale = layer.input_scale
original_weight_scale = layer.weight_scale
original_weight_zero_point = layer.weight_zero_point

print()
print()
print()
print()
print()
print()

layer(torch.randn(4,4))

# zero-points and scale
updated_tensor = layer.weight.data
updated_input_zero_point = layer.input_zero_point
updated_input_scale = layer.input_scale
updated_weight_scale = layer.weight_scale
updated_weight_zero_point = layer.weight_zero_point

print(original_tensor, updated_tensor)
print(original_input_zero_point, updated_input_zero_point)
print(original_input_scale, updated_input_scale)
print(original_weight_scale, updated_weight_scale)
print(original_weight_zero_point, updated_weight_zero_point)


breakpoint()






print(dict(layer.named_parameters())) # scale and zero point should have updated values
breakpoint()

print(2)
print("calib layers ")
for i in range(10):
print("iter", i)
layer(torch.randn(4,4))
print(dict(layer.named_parameters())) # scale and zero point should have updated values again since we did another pass

print(3)
# breakpoint()


freeze_module_quantization(layer)
print("freeze layers ")
for i in range(10):
# do more forward passes but show args are frozen
print("iter", i)
layer(torch.randn(4,4))
print(dict(layer.named_parameters())) # scale and zero point should not be updated now


# # missing
64 changes: 64 additions & 0 deletions tests/sparsetensors/quantization/lifecycle/test_forward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import List, Optional

import pytest
from sparsetensors.quantization.lifecycle.initialize import (
initialize_module_for_quantization,
)
from sparsetensors.quantization.lifecycle.status import QuantizationStatus
from sparsetensors.quantization.quant_args import QuantizationArgs
from sparsetensors.quantization.quant_scheme import QuantizationScheme
from torch.nn import Linear

from sparsetensors.quantization.lifecycle.forward import wrap_module_forward_quantized


@pytest.fixture(scope="module")
def create_quantization_scheme():
def quantization_scheme(
targets: List[str],
weights: Optional[QuantizationArgs] = None,
input_activations: Optional[QuantizationArgs] = None,
output_activations: Optional[QuantizationArgs] = None,
):
return QuantizationScheme(
targets=targets,
weights=weights,
input_activations=input_activations,
output_activations=output_activations,
)

return quantization_scheme


def test_wrap_module_forward_quantized__forward_overwrite(create_quantization_scheme):
num_bits = 8
quantization_scheme = create_quantization_scheme(
targets=["*"],
weights=QuantizationArgs(num_bits=num_bits, symmetric=True),
input_activations=QuantizationArgs(num_bits=num_bits, symmetric=False),
)
layer = Linear(4, 4)

func_forward = layer.forward.__func__

# check that the forward call is overwritten
wrap_module_forward_quantized(layer, quantization_scheme)

assert not func_forward == layer.forward.__func__


def test_wrap_module_forward_quantized__forward_overwrite(create_quantization_scheme):
num_bits = 8
quantization_scheme = create_quantization_scheme(
targets=["*"],
weights=QuantizationArgs(num_bits=num_bits, symmetric=True),
input_activations=QuantizationArgs(num_bits=num_bits, symmetric=False),
)
layer = Linear(4, 4)
layer.weight.data *= 100

data =layer.weight.data

wrap_module_forward_quantized(layer, quantization_scheme)


0 comments on commit bee1fad

Please sign in to comment.