Skip to content

Commit

Permalink
Feat (ptq/evaluate): support for bfloat16
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 22, 2023
1 parent af95707 commit ade1036
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@
metavar='ARCH',
choices=model_names,
help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet18)')
parser.add_argument(
'--dtype', default='float', choices=['float', 'bfloat16'], help='Data type to use')
parser.add_argument(
'--target-backend',
default='fx',
Expand Down Expand Up @@ -215,6 +217,7 @@
default=None,
type=int,
help='Accumulator Bit Width for GPFA2Q (default: None)')
parser.add_argument('--onnx-opset-version', default=None, type=int, help='ONNX opset version')
add_bool_arg(parser, 'gptq', default=False, help='GPTQ (default: disabled)')
add_bool_arg(parser, 'gpfq', default=False, help='GPFQ (default: disabled)')
add_bool_arg(parser, 'gpfa2q', default=False, help='GPFA2Q (default: disabled)')
Expand All @@ -226,11 +229,11 @@

def main():
args = parser.parse_args()
dtype = getattr(torch, args.dtype)

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

if args.act_quant_calibration_type == 'stats':
act_quant_calib_config = str(args.act_quant_percentile) + 'stats'
else:
Expand Down Expand Up @@ -312,14 +315,15 @@ def main():

# Get the model from torchvision
model = get_torchvision_model(args.model_name)
model = model.to(dtype)

# Preprocess the model for quantization
if args.target_backend == 'flexml':
# flexml requires static shapes, pass a representative input in
img_shape = model_config['center_crop_shape']
model = preprocess_for_flexml_quantize(
model,
torch.ones(1, 3, img_shape, img_shape),
torch.ones(1, 3, img_shape, img_shape, dtype=dtype),
equalize_iters=args.graph_eq_iterations,
equalize_merge_bias=args.graph_eq_merge_bias,
merge_bn=not args.calibrate_bn)
Expand All @@ -339,6 +343,7 @@ def main():
# Define the quantized model
quant_model = quantize_model(
model,
dtype=dtype,
backend=args.target_backend,
scale_factor_type=args.scale_factor_type,
bias_bit_width=args.bias_bit_width,
Expand Down Expand Up @@ -405,7 +410,7 @@ def main():

# Validate the quant_model on the validation dataloader
print("Starting validation:")
validate(val_loader, quant_model)
validate(val_loader, quant_model, stable=dtype != torch.bfloat16)

if args.export_onnx_qcdq or args.export_torch_qcdq:
# Generate reference input tensor to drive the export process
Expand All @@ -418,7 +423,7 @@ def main():
export_name = os.path.join(args.export_dir, config)
if args.export_onnx_qcdq:
export_name = export_name + '.onnx'
export_onnx_qcdq(model, ref_input, export_name)
export_onnx_qcdq(model, ref_input, export_name, opset_version=args.onnx_opset_version)
if args.export_torch_qcdq:
export_name = export_name + '.pt'
export_torch_qcdq(model, ref_input, export_name)
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas_examples/imagenet_classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def accuracy(output, target, topk=(1,), stable=False):
return res


def validate(val_loader, model):
def validate(val_loader, model, stable=True):
"""
Run validation on the desired dataset
"""
Expand All @@ -82,7 +82,7 @@ def print_accuracy(top1, prefix=''):

output = model(images)
# measure accuracy
acc1, = accuracy(output, target, stable=True)
acc1, = accuracy(output, target, stable=stable)
top1.update(acc1[0], images.size(0))

print_accuracy(top1, 'Total:')
Expand Down

0 comments on commit ade1036

Please sign in to comment.