Skip to content

Commit

Permalink
Pre-commit fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
i-colbert committed Mar 30, 2023
1 parent 1f8b949 commit f822ee2
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 98 deletions.
21 changes: 11 additions & 10 deletions src/brevitas/core/quant/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,15 +222,15 @@ def forward(self, x: Tensor, scale: Tensor, zero_point: Tensor,
class DecoupledRescalingIntQuantWithInput(DecoupledRescalingIntQuant):

def __init__(
self,
decoupled_int_quant: Module,
pre_scaling_impl: Module,
scaling_impl: Module,
int_scaling_impl: Module,
pre_zero_point_impl: Module,
zero_point_impl: Module,
bit_width_impl: Module,
):
self,
decoupled_int_quant: Module,
pre_scaling_impl: Module,
scaling_impl: Module,
int_scaling_impl: Module,
pre_zero_point_impl: Module,
zero_point_impl: Module,
bit_width_impl: Module,
):
super().__init__(
decoupled_int_quant,
pre_scaling_impl,
Expand All @@ -243,7 +243,8 @@ def __init__(
# TODO - check the make sure the pre-scaling module takes the input bit-width and sign

@brevitas.jit.script_method
def forward(self, x: Tensor, input_bit_width: Tensor, input_is_signed: bool) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
def forward(self, x: Tensor, input_bit_width: Tensor,
input_is_signed: bool) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
bit_width = self.msb_clamp_bit_width_impl()
int_threshold = self.int_scaling_impl(bit_width)
pre_threshold = self.pre_scaling_impl(x, input_bit_width, input_is_signed)
Expand Down
40 changes: 20 additions & 20 deletions src/brevitas/core/scaling/pre_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ def __init__(

if len(tracked_parameter_list) > 1:
raise NotImplementedError(
"Error: pre-clipping scales do not currently support multiple tracked quantizers."
)
"Error: pre-clipping scales do not currently support multiple tracked quantizers.")
assert len(tracked_parameter_list) == 1

# Initialize the weight norm parameter vector from the tracked parameter itself
Expand Down Expand Up @@ -152,17 +151,18 @@ class AccumulatorAwareParameterPreScaling(ParameterPreScalingWeightNorm):
Returns:
Tensor: scaling factor wrapped in a float torch.Tensor.
"""

def __init__(
self,
scaling_impl: Module,
normalize_stats_impl: Module,
accumulator_bit_width_impl: Module,
scaling_stats_input_view_shape_impl: Module,
tracked_parameter_list: List[torch.nn.Parameter],
pre_scaling_shape: Optional[Tuple[int, ...]] = None,
restrict_pre_scaling_impl: Optional[Module] = None,
pre_scaling_min_val: Optional[float] = None,
) -> None:
self,
scaling_impl: Module,
normalize_stats_impl: Module,
accumulator_bit_width_impl: Module,
scaling_stats_input_view_shape_impl: Module,
tracked_parameter_list: List[torch.nn.Parameter],
pre_scaling_shape: Optional[Tuple[int, ...]] = None,
restrict_pre_scaling_impl: Optional[Module] = None,
pre_scaling_min_val: Optional[float] = None,
) -> None:
super().__init__(
scaling_impl,
normalize_stats_impl,
Expand All @@ -181,12 +181,12 @@ def get_upper_bound_on_l1_norm(self, input_bit_width: Tensor, input_is_signed: b
by I.Colbert, A.Pappalardo, and J.Petri-Koenig."""
assert input_bit_width is not None, "A2Q relies on input bit-width."
assert input_is_signed is not None, "A2Q relies on input sign."
input_is_signed = float(input_is_signed) # 1. if signed else 0.
input_is_signed = float(input_is_signed) # 1. if signed else 0.
# This is the minimum of the two maximum magnitudes that P could take, which are -2^{P-1}
# and 2^{P-1}-1. Note that evaluating to -2^{P-1} would mean there is a possibility of overflow
# on the positive side of this range.
max_accumulator_bit_width = self.accumulator_bit_width() # P
max_accumulator_mag = pow(2., max_accumulator_bit_width - 1.) - 1. # 2^{P-1}-1
max_accumulator_bit_width = self.accumulator_bit_width() # P
max_accumulator_mag = pow(2., max_accumulator_bit_width - 1.) - 1. # 2^{P-1}-1
# This is the maximum possible magnitude that the input data could take. When the data is signed,
# this is 2^{N-1}. When the data is unsigned, this is 2^N - 1. We use a slightly looser bound here
# to simplify our derivations on the export validation.
Expand All @@ -197,10 +197,10 @@ def get_upper_bound_on_l1_norm(self, input_bit_width: Tensor, input_is_signed: b
def forward(self, weights: Tensor, input_bit_width: Tensor, input_is_signed: bool) -> Tensor:
"""Takes weights as input and returns the pre-clipping scaling factor"""
weights = self.stats_input_view_shape_impl(weights)
d_w = self.stats(weights) # denominator for weight normalization
s = self.scaling_impl(weights) # s
g = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value)) # g
T = self.get_upper_bound_on_l1_norm(input_bit_width, input_is_signed) # T / s
d_w = self.stats(weights) # denominator for weight normalization
s = self.scaling_impl(weights) # s
g = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value)) # g
T = self.get_upper_bound_on_l1_norm(input_bit_width, input_is_signed) # T / s
g = torch.clamp_max(g / s, T)
value = d_w / g # calculating final pre-clipping scaling factor
value = d_w / g # calculating final pre-clipping scaling factor
return value
3 changes: 1 addition & 2 deletions src/brevitas/nn/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause


from typing import Optional

import torch
Expand Down Expand Up @@ -113,4 +112,4 @@ def calculate_min_accumulator_bit_width(
phi = lambda x: torch.log2(1. + pow(2., -x))
min_bit_width = alpha + phi(alpha) + 1.
min_bit_width = ceil_ste(min_bit_width)
return min_bit_width # returns the minimum accumulator that can be used without risk of overflow
return min_bit_width # returns the minimum accumulator that can be used without risk of overflow
5 changes: 3 additions & 2 deletions src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ def forward(self, x: torch.Tensor) -> QuantTensor:
return QuantTensor(x, training=self.training)



class DecoupledWeightQuantWithInputProxyFromInjector(DecoupledWeightQuantProxyFromInjector):

@property
Expand All @@ -134,7 +133,9 @@ def pre_scale(self):
def pre_zero_point(self):
raise NotImplementedError

def forward(self, x: torch.Tensor, input_bit_width: torch.Tensor, input_is_signed: bool) -> QuantTensor:
def forward(
self, x: torch.Tensor, input_bit_width: torch.Tensor,
input_is_signed: bool) -> QuantTensor:
if self.is_quant_enabled:
impl = self.export_handler if self.export_mode else self.tensor_quant
out, scale, zero_point, bit_width, pre_scale, pre_zero_point = impl(x, input_bit_width, input_is_signed)
Expand Down
9 changes: 4 additions & 5 deletions src/brevitas/quant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@
'WeightPerTensorFloatDecoupledL2Param',
'WeightPerChannelFloatDecoupled',
'WeightNormPerChannelFloatDecoupled',
'AccumulatorAwareWeightQuant',
]
'AccumulatorAwareWeightQuant',]


class MaxStatsScaling(ExtendedInjector):
Expand Down Expand Up @@ -374,6 +373,6 @@ def accumulator_bit_width_impl(accumulator_bit_width):
tensor_quant = DecoupledRescalingIntQuantWithInput
pre_scaling_impl = AccumulatorAwareParameterPreScaling
pre_scaling_min_val = 1e-8
accumulator_bit_width = 32 # default maximum accumulator width is 32 bits
normalize_stats_impl = L1Norm # required to align with derivations in paper
float_to_int_impl = RoundToZeroSte # required to ensure no upwards rounding violates constraints
accumulator_bit_width = 32 # default maximum accumulator width is 32 bits
normalize_stats_impl = L1Norm # required to align with derivations in paper
float_to_int_impl = RoundToZeroSte # required to ensure no upwards rounding violates constraints
160 changes: 115 additions & 45 deletions tests/brevitas/nn/nn_quantizers_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@
'quant_sym': Int8WeightPerTensorFloat,
'quant_asym': ShiftedUint8WeightPerTensorFloat,
'quant_decoupled': Int8WeightNormL2PerChannelFixedPoint,
'quant_a2q': Int8AccumulatorAwareWeightQuant,
}
'quant_a2q': Int8AccumulatorAwareWeightQuant,}

IO_QUANTIZER = {
'None': None,
Expand All @@ -69,27 +68,30 @@
QuantConv1d,
QuantConv2d,
QuantConvTranspose1d,
QuantConvTranspose2d,
]
QuantConvTranspose2d,]

ACC_BIT_WIDTHS = [
8,
9,
10,
12,
16,
24,
32
]
ACC_BIT_WIDTHS = [8, 9, 10, 12, 16, 24, 32]

def build_case_model(weight_quantizer, bias_quantizer, io_quantizer, return_quant_tensor, module, case_id, input_quantized, is_training, accumulator_bit_width = 32):

def build_case_model(
weight_quantizer,
bias_quantizer,
io_quantizer,
return_quant_tensor,
module,
case_id,
input_quantized,
is_training,
accumulator_bit_width=32):

k, weight_quantizer = weight_quantizer
_, bias_quantizer = bias_quantizer
_, io_quantizer = io_quantizer

if io_quantizer is None and not input_quantized and k == 'quant_a2q':
pytest.skip("A2Q uses an input-aware decoupled weight proxy that requires a quantized input tensor.")
pytest.skip(
"A2Q uses an input-aware decoupled weight proxy that requires a quantized input tensor."
)

impl = module.__name__
if impl == 'QuantLinear':
Expand All @@ -109,8 +111,7 @@ def __init__(self):
output_quant=io_quantizer,
bias_quant=bias_quantizer,
return_quant_tensor=return_quant_tensor,
weight_accumulator_bit_width=accumulator_bit_width
)
weight_accumulator_bit_width=accumulator_bit_width)
self.conv.weight.data.uniform_(-0.01, 0.01)

def forward(self, x):
Expand All @@ -134,38 +135,107 @@ def forward(self, x):
quant_inp = torch.randn(in_size)
return module, quant_inp

@pytest_cases.parametrize('input_quantized', [True, False], ids=[f'input_quantized${c}' for c in [True, False]])
@pytest_cases.parametrize('bias_quantizer', BIAS_QUANTIZER.items(), ids=[f'bias_quant${c}' for c, _ in BIAS_QUANTIZER.items()])
@pytest_cases.parametrize('io_quantizer', IO_QUANTIZER.items(), ids=[f'io_quant${c}' for c, _ in IO_QUANTIZER.items()])
@pytest_cases.parametrize('weight_quantizer', WBIOL_WEIGHT_QUANTIZER.items(), ids=[f'weight_quant${c}' for c, _ in WBIOL_WEIGHT_QUANTIZER.items()])
@pytest_cases.parametrize('return_quant_tensor', [True, False], ids=[f'return_quant_tensor${f}' for f in [True, False]])
@pytest_cases.parametrize('module', QUANT_WBIOL_IMPL, ids=[f'model_type${c.__name__}' for c in QUANT_WBIOL_IMPL])
@pytest_cases.parametrize('is_training', [True, False], ids=[f'is_training${f}' for f in [True, False]])
def case_model(weight_quantizer, bias_quantizer, io_quantizer, return_quant_tensor, module, request, input_quantized, is_training):

@pytest_cases.parametrize(
'input_quantized', [True, False], ids=[f'input_quantized${c}' for c in [True, False]])
@pytest_cases.parametrize(
'bias_quantizer',
BIAS_QUANTIZER.items(),
ids=[f'bias_quant${c}' for c, _ in BIAS_QUANTIZER.items()])
@pytest_cases.parametrize(
'io_quantizer', IO_QUANTIZER.items(), ids=[f'io_quant${c}' for c, _ in IO_QUANTIZER.items()])
@pytest_cases.parametrize(
'weight_quantizer',
WBIOL_WEIGHT_QUANTIZER.items(),
ids=[f'weight_quant${c}' for c, _ in WBIOL_WEIGHT_QUANTIZER.items()])
@pytest_cases.parametrize(
'return_quant_tensor', [True, False], ids=[f'return_quant_tensor${f}' for f in [True, False]])
@pytest_cases.parametrize(
'module', QUANT_WBIOL_IMPL, ids=[f'model_type${c.__name__}' for c in QUANT_WBIOL_IMPL])
@pytest_cases.parametrize(
'is_training', [True, False], ids=[f'is_training${f}' for f in [True, False]])
def case_model(
weight_quantizer,
bias_quantizer,
io_quantizer,
return_quant_tensor,
module,
request,
input_quantized,
is_training):
set_case_id(request.node.callspec.id, case_model)
case_id = get_case_id(case_model)
return build_case_model(weight_quantizer, bias_quantizer, io_quantizer, return_quant_tensor, module, case_id, input_quantized, is_training)

@pytest_cases.parametrize('input_quantized', [True, False], ids=[f'input_quantized${c}' for c in [True, False]])
@pytest_cases.parametrize('bias_quantizer', BIAS_QUANTIZER.items(), ids=[f'bias_quant${c}' for c, _ in BIAS_QUANTIZER.items()])
@pytest_cases.parametrize('io_quantizer', IO_QUANTIZER.items(), ids=[f'io_quant${c}' for c, _ in IO_QUANTIZER.items()])
@pytest_cases.parametrize('return_quant_tensor', [True, False], ids=[f'return_quant_tensor${f}' for f in [True, False]])
@pytest_cases.parametrize('module', QUANT_WBIOL_IMPL, ids=[f'model_type${c.__name__}' for c in QUANT_WBIOL_IMPL])
@pytest_cases.parametrize('is_training', [True, False], ids=[f'is_training${f}' for f in [True, False]])
@pytest_cases.parametrize('accumulator_bit_width', ACC_BIT_WIDTHS, ids=[f'accumulator_bit_width${bw}' for bw in ACC_BIT_WIDTHS])
def case_model_a2q(bias_quantizer, io_quantizer, return_quant_tensor, module, request, input_quantized, is_training, accumulator_bit_width):
case_id = get_case_id(case_model)
return build_case_model(
weight_quantizer,
bias_quantizer,
io_quantizer,
return_quant_tensor,
module,
case_id,
input_quantized,
is_training)


@pytest_cases.parametrize(
'input_quantized', [True, False], ids=[f'input_quantized${c}' for c in [True, False]])
@pytest_cases.parametrize(
'bias_quantizer',
BIAS_QUANTIZER.items(),
ids=[f'bias_quant${c}' for c, _ in BIAS_QUANTIZER.items()])
@pytest_cases.parametrize(
'io_quantizer', IO_QUANTIZER.items(), ids=[f'io_quant${c}' for c, _ in IO_QUANTIZER.items()])
@pytest_cases.parametrize(
'return_quant_tensor', [True, False], ids=[f'return_quant_tensor${f}' for f in [True, False]])
@pytest_cases.parametrize(
'module', QUANT_WBIOL_IMPL, ids=[f'model_type${c.__name__}' for c in QUANT_WBIOL_IMPL])
@pytest_cases.parametrize(
'is_training', [True, False], ids=[f'is_training${f}' for f in [True, False]])
@pytest_cases.parametrize(
'accumulator_bit_width',
ACC_BIT_WIDTHS,
ids=[f'accumulator_bit_width${bw}' for bw in ACC_BIT_WIDTHS])
def case_model_a2q(
bias_quantizer,
io_quantizer,
return_quant_tensor,
module,
request,
input_quantized,
is_training,
accumulator_bit_width):
set_case_id(request.node.callspec.id, case_model_a2q)
case_id = get_case_id(case_model_a2q)
case_id = get_case_id(case_model_a2q)
# forcing test to only use accumulator-aware weight quantizer
weight_quantizer = ('quant_a2q', Int8AccumulatorAwareWeightQuant)
return build_case_model(weight_quantizer, bias_quantizer, io_quantizer, return_quant_tensor, module, case_id, input_quantized, is_training, accumulator_bit_width=accumulator_bit_width)

@pytest_cases.parametrize('io_quantizer', IO_QUANTIZER.items(), ids=[f'io_quant${c}' for c, _ in IO_QUANTIZER.items()])
@pytest_cases.parametrize('input_quantized', [True, False], ids=[f'input_quantized${c}' for c in [True, False]])
@pytest_cases.parametrize('bias_quantizer', BIAS_QUANTIZER.items(), ids=[f'bias_quant${c}' for c, _ in BIAS_QUANTIZER.items()])
@pytest_cases.parametrize('weight_quantizer', LSTM_WEIGHT_QUANTIZER.items(), ids=[f'weight_quant${c}' for c, _ in LSTM_WEIGHT_QUANTIZER.items()])
@pytest_cases.parametrize('return_quant_tensor', [True, False], ids=[f'return_quant_tensor${f}' for f in [True, False]])
@pytest_cases.parametrize('bidirectional', [True, False, 'shared_input_hidden'], ids=[f'bidirectional${f}' for f in [True, False, 'shared_input_hidden']])
return build_case_model(
weight_quantizer,
bias_quantizer,
io_quantizer,
return_quant_tensor,
module,
case_id,
input_quantized,
is_training,
accumulator_bit_width=accumulator_bit_width)


@pytest_cases.parametrize(
'io_quantizer', IO_QUANTIZER.items(), ids=[f'io_quant${c}' for c, _ in IO_QUANTIZER.items()])
@pytest_cases.parametrize(
'input_quantized', [True, False], ids=[f'input_quantized${c}' for c in [True, False]])
@pytest_cases.parametrize(
'bias_quantizer',
BIAS_QUANTIZER.items(),
ids=[f'bias_quant${c}' for c, _ in BIAS_QUANTIZER.items()])
@pytest_cases.parametrize(
'weight_quantizer',
LSTM_WEIGHT_QUANTIZER.items(),
ids=[f'weight_quant${c}' for c, _ in LSTM_WEIGHT_QUANTIZER.items()])
@pytest_cases.parametrize(
'return_quant_tensor', [True, False], ids=[f'return_quant_tensor${f}' for f in [True, False]])
@pytest_cases.parametrize(
'bidirectional', [True, False, 'shared_input_hidden'],
ids=[f'bidirectional${f}' for f in [True, False, 'shared_input_hidden']])
@pytest_cases.parametrize('cifg', [True, False])
@pytest_cases.parametrize('num_layers', [1, 2], ids=[f'num_layers${f}' for f in [1, 2]])
def case_quant_lstm(
Expand Down
Loading

0 comments on commit f822ee2

Please sign in to comment.