diff --git a/src/brevitas/core/stats/stats_op.py b/src/brevitas/core/stats/stats_op.py index 194631953..bfcfbb58f 100644 --- a/src/brevitas/core/stats/stats_op.py +++ b/src/brevitas/core/stats/stats_op.py @@ -12,6 +12,8 @@ from brevitas import config from brevitas.core.utils import StatelessBuffer from brevitas.function.ops import max_int +# Use custom implementation of kthvalue as work around to (b)float16 kernel limitations +from brevitas.utils.torch_utils import kthvalue from .stats_wrapper import SCALAR_SHAPE @@ -64,7 +66,7 @@ def forward(self, x: Tensor): if self.stats_reduce_dim is None: # k is 1-indexed, so round away from zero k = int(math.floor(.01 * self.q * x.numel() + 0.5)) - result = x.abs().view(-1).kthvalue(k).values + result = kthvalue(x.abs().view(-1), k)[0] else: # assuming x is two dimensional, get the other dimension assert len(x.size()) == 2, "Only 2-dim input is supported." @@ -72,7 +74,7 @@ def forward(self, x: Tensor): dim_slice = torch.narrow(x, dim=other_dim, start=0, length=1) # k is 1-indexed, so round away from zero k = int(math.floor(.01 * self.q * dim_slice.numel() + 0.5)) - result = x.abs().kthvalue(k, dim=self.stats_reduce_dim, keepdim=self.keepdim).values + result = kthvalue(x.abs(), k, dim=self.stats_reduce_dim, keepdim=self.keepdim)[0] return result @@ -97,7 +99,7 @@ def forward(self, x: Tensor) -> Tensor: if self.stats_reduce_dim is None: # k is 1-indexed, so round away from zero k = int(math.ceil(.01 * self.q * x.numel())) - result = x.view(-1).kthvalue(k).values + result = kthvalue(x.view(-1), k)[0] else: # assuming x is two dimensional, get the other dimension assert len(x.size()) == 2, "Only 2-dim input is supported." @@ -105,7 +107,7 @@ def forward(self, x: Tensor) -> Tensor: dim_slice = torch.narrow(x, dim=other_dim, start=0, length=1) # k is 1-indexed, so round away from zero k = int(math.ceil(.01 * self.q * dim_slice.numel())) - result = x.kthvalue(k, dim=self.stats_reduce_dim, keepdim=self.keepdim).values + result = kthvalue(x, k, dim=self.stats_reduce_dim, keepdim=self.keepdim)[0] result = torch.clamp(result, max=self.zero()) return result @@ -134,8 +136,8 @@ def forward(self, x: Tensor) -> Tensor: low_k = int(math.ceil(.01 * self.low_q * x.numel())) # k is 1-indexed, so round away from zero high_k = int(math.floor(.01 * self.high_q * x.numel() + 0.5)) - low_result = x.view(-1).kthvalue(low_k).values - high_result = x.view(-1).kthvalue(high_k).values + low_result = kthvalue(x.view(-1), low_k)[0] + high_result = kthvalue(x.view(-1), high_k)[0] else: # assuming x is two dimensional, get the other dimension assert len(x.size()) == 2, "Only 2-dim input is supported." @@ -144,8 +146,8 @@ def forward(self, x: Tensor) -> Tensor: low_k = int(math.ceil(.01 * self.low_q * dim_slice.numel())) # k is 1-indexed, so round away from zero high_k = int(math.floor(.01 * self.high_q * dim_slice.numel() + 0.5)) - low_result = x.kthvalue(low_k, dim=self.stats_reduce_dim, keepdim=self.keepdim).values - high_result = x.kthvalue(high_k, dim=self.stats_reduce_dim, keepdim=self.keepdim).values + low_result = kthvalue(x, low_k, dim=self.stats_reduce_dim, keepdim=self.keepdim)[0] + high_result = kthvalue(x, high_k, dim=self.stats_reduce_dim, keepdim=self.keepdim)[0] # We need to make sure the lower bound is not positive to align with zero-point statistics low_result = torch.clamp(low_result, max=self.zero()) interval = high_result - low_result diff --git a/src/brevitas/quant_tensor/__init__.py b/src/brevitas/quant_tensor/__init__.py index d8bb253bb..bd1da8edd 100644 --- a/src/brevitas/quant_tensor/__init__.py +++ b/src/brevitas/quant_tensor/__init__.py @@ -15,6 +15,7 @@ from .torch_handler import QUANT_TENSOR_FN_HANDLER IS_VALID_ATOL = 2e-1 +BFLOAT16_IS_VALID_ATOL = 0.5 class QuantTensorBase(NamedTuple): @@ -104,8 +105,15 @@ def is_not_none(self): @property def _pre_round_int_value(self): - int_value = self.value / self.scale - int_value = int_value + self.zero_point + value = self.value + scale = self.scale + zero_point = self.zero_point + if self.scale.dtype == torch.bfloat16: + value = self.value.type(torch.float32) + scale = self.scale.type(torch.float32) + zero_point = self.zero_point.type(torch.float32) + int_value = value / scale + int_value = int_value + zero_point return int_value @property @@ -114,8 +122,9 @@ def is_valid(self): with torch.no_grad(): pre_round_int_value = self._pre_round_int_value rounded_int_value = torch.round(pre_round_int_value) - is_int = torch.isclose( - pre_round_int_value, rounded_int_value, atol=IS_VALID_ATOL).all() + max_abs_diff = torch.max(torch.abs(pre_round_int_value - rounded_int_value)) + atol = BFLOAT16_IS_VALID_ATOL if self.value.dtype == torch.bfloat16 else IS_VALID_ATOL + is_int = max_abs_diff < atol if self.bit_width >= 2: if self.signed: is_upper_b = (2.0 ** (self.bit_width - 1) - 1 >= rounded_int_value).all() @@ -176,7 +185,12 @@ def int(self, float_datatype=False): if self.is_valid: int_value = round_ste(self._pre_round_int_value) if float_datatype: - return int_value + # Values at 8bit and lower can be represented exactly with float16 and bfloat16 + # otherwise (e.g. Int16 bias), we upscale to float32 + if self.bit_width <= 8.: + return int_value.type(self.scale.dtype) + else: + return int_value.type(torch.float32) else: if self.bit_width <= 8. and self.signed_t.item(): return int_value.to(torch.int8) @@ -301,6 +315,8 @@ def cat(tensors, dim, out=None): def __neg__(self): neg_value = (-self.int(float_datatype=True) - self.zero_point) * self.scale + # In case the dtype of self.int is different from the one of the scale + neg_value = neg_value.type(self.scale.dtype) if self.signed: return QuantTensor( value=neg_value, @@ -432,6 +448,8 @@ def __truediv__(self, other): def __abs__(self): if self.signed: abs_value = (torch.abs(self.int(float_datatype=True)) - self.zero_point) * self.scale + # In case the dtype of self.int is different from the one of the scale + abs_value = abs_value.type(self.scale.dtype) return QuantTensor( value=abs_value, scale=self.scale, diff --git a/src/brevitas/utils/torch_utils.py b/src/brevitas/utils/torch_utils.py index 7105ea874..ec7d6fac4 100644 --- a/src/brevitas/utils/torch_utils.py +++ b/src/brevitas/utils/torch_utils.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause import copy +from typing import Optional, Tuple import torch from torch.nn import Sequential @@ -46,3 +47,36 @@ def torch_partial_deepcopy(model): memo[id(p)] = copy.copy(p) # Shallow copy of parameters model_copy = copy.deepcopy(model, memo) return model_copy + + +def kthvalue( + x: torch.Tensor, + k: int, + dim: Optional[int] = None, + keepdim: bool = False, + out: Optional[Tuple[torch.Tensor, torch.LongTensor]] = None +) -> Tuple[torch.Tensor, torch.LongTensor]: + # As of torch 2.1, there is no kthvalue implementation: + # - In CPU for float16 + # - In GPU for bfloat16 + # In these cases we cast to float32 and then go back to the original dtype + dtype = x.dtype + device = str(x.device) + + # We do not support out as buffer for the output, since we cannot control its dtype + if out is not None: + raise RuntimeError("out argument for kthvalue not supported") + + if (dtype == torch.float16 and 'cpu' in device) or \ + (dtype == torch.bfloat16 and 'cuda' in device): + x = x.type(torch.float32) + + # PyTorch specify None as default for `dim` but it breaks if we specifically pass None + if dim is not None: + x, indices = torch.kthvalue(x, k, dim=dim, keepdim=keepdim) + else: + x, indices = torch.kthvalue(x, k, keepdim=keepdim) + + if x.dtype != dtype: + x = x.type(dtype) + return (x, indices) diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 454bf2488..c821dad33 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -72,6 +72,8 @@ metavar='ARCH', choices=model_names, help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet18)') +parser.add_argument( + '--dtype', default='float', choices=['float', 'bfloat16'], help='Data type to use') parser.add_argument( '--target-backend', default='fx', @@ -215,6 +217,7 @@ default=None, type=int, help='Accumulator Bit Width for GPFA2Q (default: None)') +parser.add_argument('--onnx-opset-version', default=None, type=int, help='ONNX opset version') add_bool_arg(parser, 'gptq', default=False, help='GPTQ (default: disabled)') add_bool_arg(parser, 'gpfq', default=False, help='GPFQ (default: disabled)') add_bool_arg(parser, 'gpfa2q', default=False, help='GPFA2Q (default: disabled)') @@ -226,11 +229,11 @@ def main(): args = parser.parse_args() + dtype = getattr(torch, args.dtype) random.seed(SEED) np.random.seed(SEED) torch.manual_seed(SEED) - if args.act_quant_calibration_type == 'stats': act_quant_calib_config = str(args.act_quant_percentile) + 'stats' else: @@ -312,6 +315,7 @@ def main(): # Get the model from torchvision model = get_torchvision_model(args.model_name) + model = model.to(dtype) # Preprocess the model for quantization if args.target_backend == 'flexml': @@ -319,7 +323,7 @@ def main(): img_shape = model_config['center_crop_shape'] model = preprocess_for_flexml_quantize( model, - torch.ones(1, 3, img_shape, img_shape), + torch.ones(1, 3, img_shape, img_shape, dtype=dtype), equalize_iters=args.graph_eq_iterations, equalize_merge_bias=args.graph_eq_merge_bias, merge_bn=not args.calibrate_bn) @@ -339,6 +343,7 @@ def main(): # Define the quantized model quant_model = quantize_model( model, + dtype=dtype, backend=args.target_backend, scale_factor_type=args.scale_factor_type, bias_bit_width=args.bias_bit_width, @@ -405,7 +410,7 @@ def main(): # Validate the quant_model on the validation dataloader print("Starting validation:") - validate(val_loader, quant_model) + validate(val_loader, quant_model, stable=dtype != torch.bfloat16) if args.export_onnx_qcdq or args.export_torch_qcdq: # Generate reference input tensor to drive the export process @@ -418,7 +423,7 @@ def main(): export_name = os.path.join(args.export_dir, config) if args.export_onnx_qcdq: export_name = export_name + '.onnx' - export_onnx_qcdq(model, ref_input, export_name) + export_onnx_qcdq(model, ref_input, export_name, opset_version=args.onnx_opset_version) if args.export_torch_qcdq: export_name = export_name + '.pt' export_torch_qcdq(model, ref_input, export_name) diff --git a/src/brevitas_examples/imagenet_classification/utils.py b/src/brevitas_examples/imagenet_classification/utils.py index 033058219..f614f287c 100644 --- a/src/brevitas_examples/imagenet_classification/utils.py +++ b/src/brevitas_examples/imagenet_classification/utils.py @@ -61,7 +61,7 @@ def accuracy(output, target, topk=(1,), stable=False): return res -def validate(val_loader, model): +def validate(val_loader, model, stable=True): """ Run validation on the desired dataset """ @@ -82,7 +82,7 @@ def print_accuracy(top1, prefix=''): output = model(images) # measure accuracy - acc1, = accuracy(output, target, stable=True) + acc1, = accuracy(output, target, stable=stable) top1.update(acc1[0], images.size(0)) print_accuracy(top1, 'Total:') diff --git a/tests/brevitas/core/test_stats.py b/tests/brevitas/core/test_stats.py index 224f323fc..4d397e457 100644 --- a/tests/brevitas/core/test_stats.py +++ b/tests/brevitas/core/test_stats.py @@ -8,6 +8,8 @@ from brevitas.core.stats import AbsPercentile from brevitas.core.stats import NegativePercentileOrZero from brevitas.core.stats import PercentileInterval +# Use custom implementation of kthvalue as work around to (b)float16 kernel limitations +from brevitas.utils.torch_utils import kthvalue def test_abs_percentile_per_tensor(): @@ -35,10 +37,10 @@ def compute_percentile(self, x, low_q=None, high_q=None): low_p, high_p = None, None if low_q is not None: k = int(math.ceil(.01 * low_q * x.numel())) - low_p = x.view(-1).kthvalue(k).values + low_p = kthvalue(x.view(-1), k=k)[0] if high_q is not None: k = int(math.floor(.01 * high_q * x.numel() + 0.5)) - high_p = x.view(-1).kthvalue(k).values + high_p = kthvalue(x.view(-1), k=k)[0] return low_p, high_p def test_negative_percentile(self): diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index 58561cbd1..6580d971b 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -12,6 +12,8 @@ from brevitas.graph.calibrate import calibration_mode import brevitas.nn as qnn from brevitas.quant import Int8ActPerTensorFixedPoint +# Use custom implementation of kthvalue as work around to (b)float16 kernel limitations +from brevitas.utils.torch_utils import kthvalue from tests.brevitas.hyp_helper import float_tensor_random_size_st IN_CH = 8 @@ -21,7 +23,7 @@ def compute_quantile(x, q): k = int(math.floor(.01 * q * x.numel() + 0.5)) - result = x.abs().view(-1).kthvalue(k).values + result = kthvalue(x.abs().view(-1), k=k)[0] return result