Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better Bfloat16 support #777

Merged
merged 3 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions src/brevitas/core/stats/stats_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from brevitas import config
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops import max_int
# Use custom implementation of kthvalue as work around to (b)float16 kernel limitations
from brevitas.utils.torch_utils import kthvalue

from .stats_wrapper import SCALAR_SHAPE

Expand Down Expand Up @@ -64,15 +66,15 @@ def forward(self, x: Tensor):
if self.stats_reduce_dim is None:
# k is 1-indexed, so round away from zero
k = int(math.floor(.01 * self.q * x.numel() + 0.5))
result = x.abs().view(-1).kthvalue(k).values
result = kthvalue(x.abs().view(-1), k)[0]
else:
# assuming x is two dimensional, get the other dimension
assert len(x.size()) == 2, "Only 2-dim input is supported."
other_dim = abs(self.stats_reduce_dim - 1)
dim_slice = torch.narrow(x, dim=other_dim, start=0, length=1)
# k is 1-indexed, so round away from zero
k = int(math.floor(.01 * self.q * dim_slice.numel() + 0.5))
result = x.abs().kthvalue(k, dim=self.stats_reduce_dim, keepdim=self.keepdim).values
result = kthvalue(x.abs(), k, dim=self.stats_reduce_dim, keepdim=self.keepdim)[0]
return result


Expand All @@ -97,15 +99,15 @@ def forward(self, x: Tensor) -> Tensor:
if self.stats_reduce_dim is None:
# k is 1-indexed, so round away from zero
k = int(math.ceil(.01 * self.q * x.numel()))
result = x.view(-1).kthvalue(k).values
result = kthvalue(x.view(-1), k)[0]
else:
# assuming x is two dimensional, get the other dimension
assert len(x.size()) == 2, "Only 2-dim input is supported."
other_dim = abs(self.stats_reduce_dim - 1)
dim_slice = torch.narrow(x, dim=other_dim, start=0, length=1)
# k is 1-indexed, so round away from zero
k = int(math.ceil(.01 * self.q * dim_slice.numel()))
result = x.kthvalue(k, dim=self.stats_reduce_dim, keepdim=self.keepdim).values
result = kthvalue(x, k, dim=self.stats_reduce_dim, keepdim=self.keepdim)[0]
result = torch.clamp(result, max=self.zero())
return result

Expand Down Expand Up @@ -134,8 +136,8 @@ def forward(self, x: Tensor) -> Tensor:
low_k = int(math.ceil(.01 * self.low_q * x.numel()))
# k is 1-indexed, so round away from zero
high_k = int(math.floor(.01 * self.high_q * x.numel() + 0.5))
low_result = x.view(-1).kthvalue(low_k).values
high_result = x.view(-1).kthvalue(high_k).values
low_result = kthvalue(x.view(-1), low_k)[0]
high_result = kthvalue(x.view(-1), high_k)[0]
else:
# assuming x is two dimensional, get the other dimension
assert len(x.size()) == 2, "Only 2-dim input is supported."
Expand All @@ -144,8 +146,8 @@ def forward(self, x: Tensor) -> Tensor:
low_k = int(math.ceil(.01 * self.low_q * dim_slice.numel()))
# k is 1-indexed, so round away from zero
high_k = int(math.floor(.01 * self.high_q * dim_slice.numel() + 0.5))
low_result = x.kthvalue(low_k, dim=self.stats_reduce_dim, keepdim=self.keepdim).values
high_result = x.kthvalue(high_k, dim=self.stats_reduce_dim, keepdim=self.keepdim).values
low_result = kthvalue(x, low_k, dim=self.stats_reduce_dim, keepdim=self.keepdim)[0]
high_result = kthvalue(x, high_k, dim=self.stats_reduce_dim, keepdim=self.keepdim)[0]
# We need to make sure the lower bound is not positive to align with zero-point statistics
low_result = torch.clamp(low_result, max=self.zero())
interval = high_result - low_result
Expand Down
28 changes: 23 additions & 5 deletions src/brevitas/quant_tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .torch_handler import QUANT_TENSOR_FN_HANDLER

IS_VALID_ATOL = 2e-1
BFLOAT16_IS_VALID_ATOL = 0.5


class QuantTensorBase(NamedTuple):
Expand Down Expand Up @@ -104,8 +105,15 @@ def is_not_none(self):

@property
def _pre_round_int_value(self):
int_value = self.value / self.scale
int_value = int_value + self.zero_point
value = self.value
scale = self.scale
zero_point = self.zero_point
if self.scale.dtype == torch.bfloat16:
value = self.value.type(torch.float32)
scale = self.scale.type(torch.float32)
zero_point = self.zero_point.type(torch.float32)
int_value = value / scale
int_value = int_value + zero_point
return int_value

@property
Expand All @@ -114,8 +122,9 @@ def is_valid(self):
with torch.no_grad():
pre_round_int_value = self._pre_round_int_value
rounded_int_value = torch.round(pre_round_int_value)
is_int = torch.isclose(
pre_round_int_value, rounded_int_value, atol=IS_VALID_ATOL).all()
max_abs_diff = torch.max(torch.abs(pre_round_int_value - rounded_int_value))
atol = BFLOAT16_IS_VALID_ATOL if self.value.dtype == torch.bfloat16 else IS_VALID_ATOL
is_int = max_abs_diff < atol
if self.bit_width >= 2:
if self.signed:
is_upper_b = (2.0 ** (self.bit_width - 1) - 1 >= rounded_int_value).all()
Expand Down Expand Up @@ -176,7 +185,12 @@ def int(self, float_datatype=False):
if self.is_valid:
int_value = round_ste(self._pre_round_int_value)
if float_datatype:
return int_value
# Values at 8bit and lower can be represented exactly with float16 and bfloat16
# otherwise (e.g. Int16 bias), we upscale to float32
if self.bit_width <= 8.:
return int_value.type(self.scale.dtype)
else:
return int_value.type(torch.float32)
else:
if self.bit_width <= 8. and self.signed_t.item():
return int_value.to(torch.int8)
Expand Down Expand Up @@ -301,6 +315,8 @@ def cat(tensors, dim, out=None):

def __neg__(self):
neg_value = (-self.int(float_datatype=True) - self.zero_point) * self.scale
# In case the dtype of self.int is different from the one of the scale
neg_value = neg_value.type(self.scale.dtype)
if self.signed:
return QuantTensor(
value=neg_value,
Expand Down Expand Up @@ -432,6 +448,8 @@ def __truediv__(self, other):
def __abs__(self):
if self.signed:
abs_value = (torch.abs(self.int(float_datatype=True)) - self.zero_point) * self.scale
# In case the dtype of self.int is different from the one of the scale
abs_value = abs_value.type(self.scale.dtype)
return QuantTensor(
value=abs_value,
scale=self.scale,
Expand Down
34 changes: 34 additions & 0 deletions src/brevitas/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

import copy
from typing import Optional, Tuple

import torch
from torch.nn import Sequential
Expand Down Expand Up @@ -46,3 +47,36 @@ def torch_partial_deepcopy(model):
memo[id(p)] = copy.copy(p) # Shallow copy of parameters
model_copy = copy.deepcopy(model, memo)
return model_copy


def kthvalue(
x: torch.Tensor,
k: int,
dim: Optional[int] = None,
keepdim: bool = False,
out: Optional[Tuple[torch.Tensor, torch.LongTensor]] = None
) -> Tuple[torch.Tensor, torch.LongTensor]:
# As of torch 2.1, there is no kthvalue implementation:
# - In CPU for float16
# - In GPU for bfloat16
# In these cases we cast to float32 and then go back to the original dtype
dtype = x.dtype
device = str(x.device)

# We do not support out as buffer for the output, since we cannot control its dtype
if out is not None:
raise RuntimeError("out argument for kthvalue not supported")

if (dtype == torch.float16 and 'cpu' in device) or \
(dtype == torch.bfloat16 and 'cuda' in device):
x = x.type(torch.float32)

# PyTorch specify None as default for `dim` but it breaks if we specifically pass None
if dim is not None:
x, indices = torch.kthvalue(x, k, dim=dim, keepdim=keepdim)
else:
x, indices = torch.kthvalue(x, k, keepdim=keepdim)

if x.dtype != dtype:
x = x.type(dtype)
return (x, indices)
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@
metavar='ARCH',
choices=model_names,
help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet18)')
parser.add_argument(
'--dtype', default='float', choices=['float', 'bfloat16'], help='Data type to use')
parser.add_argument(
'--target-backend',
default='fx',
Expand Down Expand Up @@ -215,6 +217,7 @@
default=None,
type=int,
help='Accumulator Bit Width for GPFA2Q (default: None)')
parser.add_argument('--onnx-opset-version', default=None, type=int, help='ONNX opset version')
add_bool_arg(parser, 'gptq', default=False, help='GPTQ (default: disabled)')
add_bool_arg(parser, 'gpfq', default=False, help='GPFQ (default: disabled)')
add_bool_arg(parser, 'gpfa2q', default=False, help='GPFA2Q (default: disabled)')
Expand All @@ -226,11 +229,11 @@

def main():
args = parser.parse_args()
dtype = getattr(torch, args.dtype)

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

if args.act_quant_calibration_type == 'stats':
act_quant_calib_config = str(args.act_quant_percentile) + 'stats'
else:
Expand Down Expand Up @@ -312,14 +315,15 @@ def main():

# Get the model from torchvision
model = get_torchvision_model(args.model_name)
model = model.to(dtype)

# Preprocess the model for quantization
if args.target_backend == 'flexml':
# flexml requires static shapes, pass a representative input in
img_shape = model_config['center_crop_shape']
model = preprocess_for_flexml_quantize(
model,
torch.ones(1, 3, img_shape, img_shape),
torch.ones(1, 3, img_shape, img_shape, dtype=dtype),
equalize_iters=args.graph_eq_iterations,
equalize_merge_bias=args.graph_eq_merge_bias,
merge_bn=not args.calibrate_bn)
Expand All @@ -339,6 +343,7 @@ def main():
# Define the quantized model
quant_model = quantize_model(
model,
dtype=dtype,
backend=args.target_backend,
scale_factor_type=args.scale_factor_type,
bias_bit_width=args.bias_bit_width,
Expand Down Expand Up @@ -405,7 +410,7 @@ def main():

# Validate the quant_model on the validation dataloader
print("Starting validation:")
validate(val_loader, quant_model)
validate(val_loader, quant_model, stable=dtype != torch.bfloat16)

if args.export_onnx_qcdq or args.export_torch_qcdq:
# Generate reference input tensor to drive the export process
Expand All @@ -418,7 +423,7 @@ def main():
export_name = os.path.join(args.export_dir, config)
if args.export_onnx_qcdq:
export_name = export_name + '.onnx'
export_onnx_qcdq(model, ref_input, export_name)
export_onnx_qcdq(model, ref_input, export_name, opset_version=args.onnx_opset_version)
if args.export_torch_qcdq:
export_name = export_name + '.pt'
export_torch_qcdq(model, ref_input, export_name)
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas_examples/imagenet_classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def accuracy(output, target, topk=(1,), stable=False):
return res


def validate(val_loader, model):
def validate(val_loader, model, stable=True):
"""
Run validation on the desired dataset
"""
Expand All @@ -82,7 +82,7 @@ def print_accuracy(top1, prefix=''):

output = model(images)
# measure accuracy
acc1, = accuracy(output, target, stable=True)
acc1, = accuracy(output, target, stable=stable)
top1.update(acc1[0], images.size(0))

print_accuracy(top1, 'Total:')
Expand Down
6 changes: 4 additions & 2 deletions tests/brevitas/core/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from brevitas.core.stats import AbsPercentile
from brevitas.core.stats import NegativePercentileOrZero
from brevitas.core.stats import PercentileInterval
# Use custom implementation of kthvalue as work around to (b)float16 kernel limitations
from brevitas.utils.torch_utils import kthvalue


def test_abs_percentile_per_tensor():
Expand Down Expand Up @@ -35,10 +37,10 @@ def compute_percentile(self, x, low_q=None, high_q=None):
low_p, high_p = None, None
if low_q is not None:
k = int(math.ceil(.01 * low_q * x.numel()))
low_p = x.view(-1).kthvalue(k).values
low_p = kthvalue(x.view(-1), k=k)[0]
if high_q is not None:
k = int(math.floor(.01 * high_q * x.numel() + 0.5))
high_p = x.view(-1).kthvalue(k).values
high_p = kthvalue(x.view(-1), k=k)[0]
return low_p, high_p

def test_negative_percentile(self):
Expand Down
4 changes: 3 additions & 1 deletion tests/brevitas/graph/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from brevitas.graph.calibrate import calibration_mode
import brevitas.nn as qnn
from brevitas.quant import Int8ActPerTensorFixedPoint
# Use custom implementation of kthvalue as work around to (b)float16 kernel limitations
from brevitas.utils.torch_utils import kthvalue
from tests.brevitas.hyp_helper import float_tensor_random_size_st

IN_CH = 8
Expand All @@ -21,7 +23,7 @@

def compute_quantile(x, q):
k = int(math.floor(.01 * q * x.numel() + 0.5))
result = x.abs().view(-1).kthvalue(k).values
result = kthvalue(x.abs().view(-1), k=k)[0]
return result


Expand Down
Loading