diff --git a/examples/3.x_api/pytorch/cv/static_quant/main.py b/examples/3.x_api/pytorch/cv/static_quant/main.py index 3ab2d6bd6ad..3d7af7827e3 100644 --- a/examples/3.x_api/pytorch/cv/static_quant/main.py +++ b/examples/3.x_api/pytorch/cv/static_quant/main.py @@ -4,30 +4,25 @@ import shutil import time import warnings -from enum import Enum +import sys import torch -import torch.backends.cudnn as cudnn -import torch.distributed as dist -import torch.multiprocessing as mp import torch.nn as nn import torch.nn.parallel +import torch.distributed as dist import torch.optim +import torch.multiprocessing as mp import torch.utils.data import torch.utils.data.distributed +import torchvision.transforms as transforms import torchvision.datasets as datasets import torchvision.models as models -import torchvision.transforms as transforms -from torch.optim.lr_scheduler import StepLR -from torch.utils.data import Subset -model_names = sorted(name for name in models.__dict__ - if name.islower() and not name.startswith("__") - and callable(models.__dict__[name])) +model_names = models.list_models(module=models) parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') -parser.add_argument('data', metavar='DIR', nargs='?', default='imagenet', - help='path to dataset (default: imagenet)') +parser.add_argument('data', metavar='DIR', + help='path to dataset') parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=model_names, help='model architecture: ' + @@ -57,6 +52,8 @@ help='path to latest checkpoint (default: none)') parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set') +parser.add_argument('-t', '--tune', dest='tune', action='store_true', + help='tune best int8 model on calibration dataset') parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model') parser.add_argument('--world-size', default=-1, type=int, @@ -71,79 +68,45 @@ help='seed for initializing training. ') parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.') +parser.add_argument('--ppn', default=1, type=int, + help='number of processes on each node of distributed training') parser.add_argument('--multiprocessing-distributed', action='store_true', help='Use multi-processing distributed training to launch ' 'N processes per node, which has N GPUs. This is the ' 'fastest way to use PyTorch for either single node or ' 'multi node data parallel training') -parser.add_argument('--dummy', action='store_true', help="use fake data to benchmark") -parser.add_argument('-q', '--quantize', dest='quantize', action='store_true', - help='quantize model') -parser.add_argument("--calib_iters", default=2, type=int, +parser.add_argument('-i', "--iter", default=0, type=int, + help='For accuracy measurement only.') +parser.add_argument('-w', "--warmup_iter", default=5, type=int, + help='For benchmark measurement only.') +parser.add_argument('--performance', dest='performance', action='store_true', + help='run benchmark') +parser.add_argument('-r', "--accuracy", dest='accuracy', action='store_true', + help='For accuracy measurement only.') +parser.add_argument("--tuned_checkpoint", default='./saved_results', type=str, metavar='PATH', + help='path to checkpoint tuned by Neural Compressor (default: ./)') +parser.add_argument('--int8', dest='int8', action='store_true', + help='Load int8 model.') +parser.add_argument("--calib_iters", default=128, type=int, help="For calibration only.") +parser.add_argument("--iters", default=100, type=int, + help="For benchmark only.") best_acc1 = 0 def main(): args = parser.parse_args() + + if 'mobilenet' in args.arch: + import torchvision.models.quantization as models + else: + import torchvision.models as models if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) - cudnn.deterministic = True - cudnn.benchmark = False - warnings.warn('You have chosen to seed training. ' - 'This will turn on the CUDNN deterministic setting, ' - 'which can slow down your training considerably! ' - 'You may see unexpected behavior when restarting ' - 'from checkpoints.') - - if args.gpu is not None: - warnings.warn('You have chosen a specific GPU. This will completely ' - 'disable data parallelism.') - - if args.dist_url == "env://" and args.world_size == -1: - args.world_size = int(os.environ["WORLD_SIZE"]) - - args.distributed = args.world_size > 1 or args.multiprocessing_distributed - - if torch.cuda.is_available(): - ngpus_per_node = torch.cuda.device_count() - if ngpus_per_node == 1 and args.dist_backend == "nccl": - warnings.warn("nccl backend >=2.5 requires GPU count>1, see https://github.com/NVIDIA/nccl/issues/103 perhaps use 'gloo'") - else: - ngpus_per_node = 1 - - if args.multiprocessing_distributed: - # Since we have ngpus_per_node processes per node, the total world_size - # needs to be adjusted accordingly - args.world_size = ngpus_per_node * args.world_size - # Use torch.multiprocessing.spawn to launch distributed processes: the - # main_worker process function - mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) - else: - # Simply call main_worker function - main_worker(args.gpu, ngpus_per_node, args) - - -def main_worker(gpu, ngpus_per_node, args): - global best_acc1 - args.gpu = gpu - - if args.gpu is not None: - print("Use GPU: {} for training".format(args.gpu)) - - if args.distributed: - if args.dist_url == "env://" and args.rank == -1: - args.rank = int(os.environ["RANK"]) - if args.multiprocessing_distributed: - # For multiprocessing distributed training, rank needs to be the - # global rank among all the processes - args.rank = args.rank * ngpus_per_node + gpu - dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, - world_size=args.world_size, rank=args.rank) - # create model + if args.pretrained: print("=> using pre-trained model '{}'".format(args.arch)) model = models.__dict__[args.arch](pretrained=True) @@ -151,70 +114,18 @@ def main_worker(gpu, ngpus_per_node, args): print("=> creating model '{}'".format(args.arch)) model = models.__dict__[args.arch]() - if not torch.cuda.is_available() and not torch.backends.mps.is_available(): - print('using CPU, this will be slow') - elif args.distributed: - # For multiprocessing distributed, DistributedDataParallel constructor - # should always set the single device scope, otherwise, - # DistributedDataParallel will use all available devices. - if torch.cuda.is_available(): - if args.gpu is not None: - torch.cuda.set_device(args.gpu) - model.cuda(args.gpu) - # When using a single GPU per process and per - # DistributedDataParallel, we need to divide the batch size - # ourselves based on the total number of GPUs of the current node. - args.batch_size = int(args.batch_size / ngpus_per_node) - args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) - else: - model.cuda() - # DistributedDataParallel will divide and allocate batch_size to all - # available GPUs if device_ids are not set - model = torch.nn.parallel.DistributedDataParallel(model) - elif args.gpu is not None and torch.cuda.is_available(): - torch.cuda.set_device(args.gpu) - model = model.cuda(args.gpu) - elif torch.backends.mps.is_available(): - device = torch.device("mps") - model = model.to(device) - else: - # DataParallel will divide and allocate batch_size to all available GPUs - if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): - model.features = torch.nn.DataParallel(model.features) - model.cuda() - else: - model = torch.nn.DataParallel(model).cuda() - - if torch.cuda.is_available(): - if args.gpu: - device = torch.device('cuda:{}'.format(args.gpu)) - else: - device = torch.device("cuda") - elif torch.backends.mps.is_available(): - device = torch.device("mps") - else: - device = torch.device("cpu") - # define loss function (criterion), optimizer, and learning rate scheduler - criterion = nn.CrossEntropyLoss().to(device) + # define loss function (criterion) and optimizer + criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) - - """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" - scheduler = StepLR(optimizer, step_size=30, gamma=0.1) - + # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) - if args.gpu is None: - checkpoint = torch.load(args.resume) - elif torch.cuda.is_available(): - # Map model to be loaded to specified single gpu. - loc = 'cuda:{}'.format(args.gpu) - checkpoint = torch.load(args.resume, map_location=loc) + checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_acc1 = checkpoint['best_acc1'] if args.gpu is not None: @@ -222,95 +133,77 @@ def main_worker(gpu, ngpus_per_node, args): best_acc1 = best_acc1.to(args.gpu) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) - scheduler.load_state_dict(checkpoint['scheduler']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) - # Data loading code - if args.dummy: - print("=> Dummy data is used!") - train_dataset = datasets.FakeData(1281167, (3, 224, 224), 1000, transforms.ToTensor()) - val_dataset = datasets.FakeData(50000, (3, 224, 224), 1000, transforms.ToTensor()) - else: - traindir = os.path.join(args.data, 'train') - valdir = os.path.join(args.data, 'val') - normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + traindir = os.path.join(args.data, 'train') + valdir = os.path.join(args.data, 'val') + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - train_dataset = datasets.ImageFolder( - traindir, - transforms.Compose([ - transforms.RandomResizedCrop(224), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - normalize, - ])) - - val_dataset = datasets.ImageFolder( - valdir, - transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - normalize, - ])) - - if args.distributed: - train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) - val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True) - else: - train_sampler = None - val_sampler = None + train_dataset = datasets.ImageFolder( + traindir, + transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])) train_loader = torch.utils.data.DataLoader( - train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), - num_workers=args.workers, pin_memory=True, sampler=train_sampler) + train_dataset, batch_size=args.batch_size, shuffle=True, + num_workers=args.workers, pin_memory=True, sampler=None) + + val_dataset = datasets.ImageFolder(valdir, transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ])) val_loader = torch.utils.data.DataLoader( - val_dataset, batch_size=args.batch_size, shuffle=False, - num_workers=args.workers, pin_memory=True, sampler=val_sampler) + val_dataset, + batch_size=args.batch_size, shuffle=False, + num_workers=args.workers, pin_memory=True) + + if args.evaluate: + validate(val_loader, model, criterion, args) + return + + def eval_func(model): + accu = validate(val_loader, model, criterion, args) + return float(accu) - if args.quantize: + if args.tune: from neural_compressor.torch.export import export from neural_compressor.torch.quantization import prepare, convert, get_default_static_config # Prepare the float model and example inputs for exporting model x = torch.randn(args.batch_size, 3, 224, 224).contiguous(memory_format=torch.channels_last) example_inputs = (x,) + + # Specify that the first dimension of each input is that batch size + from torch.export import Dim + print(args.batch_size) + batch = Dim("batch", min=16) + + # Specify that the first dimension of each input is that batch size + dynamic_shapes = {"x": {0: batch}} # Export eager model into FX graph model - exported_model = export(model=model, example_inputs=example_inputs) + exported_model = export(model=model, example_inputs=example_inputs, dynamic_shapes=dynamic_shapes) # Quantize the model quant_config = get_default_static_config() prepared_model = prepare(exported_model, quant_config=quant_config) # Calibrate - for i in range(args.calib_iters): - prepared_model(*example_inputs) - q_model = convert(prepared_model) - # Compile the quantized model and replace the Q/DQ pattern with Q-operator - from torch._inductor import config - - config.freezing = True - opt_model = torch.compile(q_model) - model = opt_model - - - if args.evaluate: - validate(val_loader, model, criterion, args) - return - - -def validate(val_loader, model, criterion, args): - - def run_validate(loader, base_progress=0): with torch.no_grad(): - end = time.time() - for i, (images, target) in enumerate(loader): - i = base_progress + i + for i, (images, target) in enumerate(val_loader): + if i == args.calib_iters: + break if args.gpu is not None and torch.cuda.is_available(): images = images.cuda(args.gpu, non_blocking=True) if torch.backends.mps.is_available(): @@ -318,52 +211,155 @@ def run_validate(loader, base_progress=0): target = target.to('mps') if torch.cuda.is_available(): target = target.cuda(args.gpu, non_blocking=True) - # compute output - output = model(images) - loss = criterion(output, target) - - # measure accuracy and record loss - acc1, acc5 = accuracy(output, target, topk=(1, 5)) - losses.update(loss.item(), images.size(0)) - top1.update(acc1[0], images.size(0)) - top5.update(acc5[0], images.size(0)) - - # measure elapsed time - batch_time.update(time.time() - end) - end = time.time() - - if i % args.print_freq == 0: - progress.display(i + 1) - - batch_time = AverageMeter('Time', ':6.3f', Summary.NONE) - losses = AverageMeter('Loss', ':.4e', Summary.NONE) - top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE) - top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE) - progress = ProgressMeter( - len(val_loader) + (args.distributed and (len(val_loader.sampler) * args.world_size < len(val_loader.dataset))), - [batch_time, losses, top1, top5], - prefix='Test: ') - - # switch to evaluate mode, pt2e no eval() or train() + prepared_model(images) + + q_model = convert(prepared_model) + + if args.tuned_checkpoint: + q_model.save(example_inputs=example_inputs, output_dir = args.tuned_checkpoint) + return + + if args.performance or args.accuracy: + if args.int8: + from neural_compressor.torch.quantization import load + q_model = load(args.tuned_checkpoint) + + # Compile the quantized model and replace the Q/DQ pattern with Q-operator + from torch._inductor import config + + config.freezing = True + opt_model = torch.compile(q_model) + new_model = opt_model + else: + new_model = model + new_model.eval() + if args.performance: + benchmark(val_loader, new_model, args) + return + if args.accuracy: + validate(val_loader, new_model, criterion, args) + return + + +def benchmark(val_loader, model, args): + + total_iters = args.iters + warmup_iters = args.warmup_iter + for i, (images, target) in enumerate(val_loader): + if args.gpu is not None and torch.cuda.is_available(): + images = images.cuda(args.gpu, non_blocking=True) + if torch.backends.mps.is_available(): + images = images.to('mps') + break + + with torch.no_grad(): + for i in range(total_iters): + if i == total_iters: + break + if i == warmup_iters: + start = time.time() + + # model inference + model(images) + + if i % args.print_freq == 0: + print(f"benchmarking... {i+1}/{total_iters}") + + end = time.time() + latency = (end - start) / ((total_iters - warmup_iters) * args.batch_size) + throughput = ((total_iters - warmup_iters) * args.batch_size) / (end - start) + print("Latency: {:.3f} ms".format(latency * 10**3)) + print("Throughput: {:.3f} samples/sec".format(throughput)) + +def train(train_loader, model, criterion, optimizer, epoch, args): + batch_time = AverageMeter('Time', ':6.3f') + data_time = AverageMeter('Data', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + progress = ProgressMeter(len(train_loader), batch_time, data_time, losses, top1, + top5, prefix="Epoch: [{}]".format(epoch)) + + # switch to train mode + model.train() + + end = time.time() + for i, (input, target) in enumerate(train_loader): + # measure data loading time + data_time.update(time.time() - end) + + if args.gpu is not None: + input = input.cuda(args.gpu, non_blocking=True) + target = target.cuda(args.gpu, non_blocking=True) + + # compute output + output = model(input) + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), input.size(0)) + top1.update(acc1[0], input.size(0)) + top5.update(acc5[0], input.size(0)) + + # compute gradient and do SGD step + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + progress.print(i) + + +def validate(val_loader, model, criterion, args): + batch_time = AverageMeter('Time', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + progress = ProgressMeter(len(val_loader), batch_time, losses, top1, top5, + prefix='Test: ') + + # switch to evaluate mode # model.eval() - run_validate(val_loader) - if args.distributed: - top1.all_reduce() - top5.all_reduce() + with torch.no_grad(): + for i, (input, target) in enumerate(val_loader): + if i >= args.warmup_iter: + start = time.time() + if args.gpu is not None: + input = input.cuda(args.gpu, non_blocking=True) + target = target.cuda(args.gpu, non_blocking=True) + + # compute output + output = model(input) + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), input.size(0)) + top1.update(acc1[0], input.size(0)) + top5.update(acc5[0], input.size(0)) + + # measure elapsed time + if i >= args.warmup_iter: + batch_time.update(time.time() - start) + + if i % args.print_freq == 0: + progress.print(i) - if args.distributed and (len(val_loader.sampler) * args.world_size < len(val_loader.dataset)): - aux_val_dataset = Subset(val_loader.dataset, - range(len(val_loader.sampler) * args.world_size, len(val_loader.dataset))) - aux_val_loader = torch.utils.data.DataLoader( - aux_val_dataset, batch_size=args.batch_size, shuffle=False, - num_workers=args.workers, pin_memory=True) - run_validate(aux_val_loader, len(val_loader)) + if args.iter > 0 and i >= (args.warmup_iter + args.iter - 1): + break - progress.display_summary() + print('Batch size = %d' % args.batch_size) + print('Accuracy: {top1:.5f} Accuracy@5 {top5:.5f}' + .format(top1=(top1.avg / 100), top5=(top5.avg / 100))) - return top1.avg + return top1.avg/100 def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): @@ -371,18 +367,11 @@ def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): if is_best: shutil.copyfile(filename, 'model_best.pth.tar') -class Summary(Enum): - NONE = 0 - AVERAGE = 1 - SUM = 2 - COUNT = 3 - class AverageMeter(object): """Computes and stores the average and current value""" - def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE): + def __init__(self, name, fmt=':f'): self.name = name self.fmt = fmt - self.summary_type = summary_type self.reset() def reset(self): @@ -397,59 +386,35 @@ def update(self, val, n=1): self.count += n self.avg = self.sum / self.count - def all_reduce(self): - if torch.cuda.is_available(): - device = torch.device("cuda") - elif torch.backends.mps.is_available(): - device = torch.device("mps") - else: - device = torch.device("cpu") - total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device) - dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False) - self.sum, self.count = total.tolist() - self.avg = self.sum / self.count - def __str__(self): fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' return fmtstr.format(**self.__dict__) - - def summary(self): - fmtstr = '' - if self.summary_type is Summary.NONE: - fmtstr = '' - elif self.summary_type is Summary.AVERAGE: - fmtstr = '{name} {avg:.3f}' - elif self.summary_type is Summary.SUM: - fmtstr = '{name} {sum:.3f}' - elif self.summary_type is Summary.COUNT: - fmtstr = '{name} {count:.3f}' - else: - raise ValueError('invalid summary type %r' % self.summary_type) - - return fmtstr.format(**self.__dict__) class ProgressMeter(object): - def __init__(self, num_batches, meters, prefix=""): + def __init__(self, num_batches, *meters, prefix=""): self.batch_fmtstr = self._get_batch_fmtstr(num_batches) self.meters = meters self.prefix = prefix - def display(self, batch): + def print(self, batch): entries = [self.prefix + self.batch_fmtstr.format(batch)] entries += [str(meter) for meter in self.meters] print('\t'.join(entries)) - - def display_summary(self): - entries = [" *"] - entries += [meter.summary() for meter in self.meters] - print(' '.join(entries)) def _get_batch_fmtstr(self, num_batches): num_digits = len(str(num_batches // 1)) fmt = '{:' + str(num_digits) + 'd}' return '[' + fmt + '/' + fmt.format(num_batches) + ']' + +def adjust_learning_rate(optimizer, epoch, args): + """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" + lr = args.lr * (0.1 ** (epoch // 30)) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + def accuracy(output, target, topk=(1,)): """Computes the accuracy over the k top predictions for the specified values of k""" with torch.no_grad(): @@ -468,4 +433,5 @@ def accuracy(output, target, topk=(1,)): if __name__ == '__main__': - main() \ No newline at end of file + main() + diff --git a/examples/3.x_api/pytorch/cv/static_quant/run_benchmark.sh b/examples/3.x_api/pytorch/cv/static_quant/run_benchmark.sh new file mode 100644 index 00000000000..6f6b69c35df --- /dev/null +++ b/examples/3.x_api/pytorch/cv/static_quant/run_benchmark.sh @@ -0,0 +1,103 @@ +#!/bin/bash +set -x + +function main { + + init_params "$@" + run_benchmark + +} + +# init params +function init_params { + iters=100 + batch_size=16 + tuned_checkpoint=saved_results + echo ${max_eval_samples} + for var in "$@" + do + case $var in + --topology=*) + topology=$(echo $var |cut -f2 -d=) + ;; + --dataset_location=*) + dataset_location=$(echo $var |cut -f2 -d=) + ;; + --input_model=*) + input_model=$(echo $var |cut -f2 -d=) + ;; + --mode=*) + mode=$(echo $var |cut -f2 -d=) + ;; + --batch_size=*) + batch_size=$(echo $var |cut -f2 -d=) + ;; + --iters=*) + iters=$(echo ${var} |cut -f2 -d=) + ;; + --int8=*) + int8=$(echo ${var} |cut -f2 -d=) + ;; + --config=*) + tuned_checkpoint=$(echo $var |cut -f2 -d=) + ;; + *) + echo "Error: No such parameter: ${var}" + exit 1 + ;; + esac + done + +} + + +# run_benchmark +function run_benchmark { + extra_cmd='' + + if [[ ${mode} == "accuracy" ]]; then + mode_cmd=" --accuracy " + elif [[ ${mode} == "performance" ]]; then + mode_cmd=" --performance --iters "${iters} + else + echo "Error: No such mode: ${mode}" + exit 1 + fi + if [[ ${int8} == "true" ]]; then + extra_cmd=$extra_cmd" --int8" + fi + echo $extra_cmd + + + echo $extra_cmd + + if [ "${topology}" = "resnet18_pt2e_static" ]; then + model_name_or_path="resnet18" + fi + + if [[ ${mode} == "accuracy" ]]; then + python main.py \ + --pretrained \ + -a resnet18 \ + -b 30 \ + --tuned_checkpoint ${tuned_checkpoint} \ + ${dataset_location} \ + ${extra_cmd} \ + ${mode_cmd} + elif [[ ${mode} == "performance" ]]; then + incbench --num_cores_per_instance 4 \ + main.py \ + --pretrained \ + -a resnet18 \ + -b 30 \ + --tuned_checkpoint ${tuned_checkpoint} \ + ${dataset_location} \ + ${extra_cmd} \ + ${mode_cmd} + else + echo "Error: No such mode: ${mode}" + exit 1 + fi +} + +main "$@" diff --git a/examples/3.x_api/pytorch/cv/static_quant/run_quant.sh b/examples/3.x_api/pytorch/cv/static_quant/run_quant.sh index ac4a5a2b668..1f4588e933c 100644 --- a/examples/3.x_api/pytorch/cv/static_quant/run_quant.sh +++ b/examples/3.x_api/pytorch/cv/static_quant/run_quant.sh @@ -10,6 +10,7 @@ function main { # init params function init_params { + tuned_checkpoint="saved_results" for var in "$@" do case $var in @@ -39,7 +40,13 @@ function run_tuning { if [ "${topology}" = "resnet18_pt2e_static" ]; then model_name_or_path="resnet18" fi - python main.py -a ${model_name_or_path} ${dataset_location} -q -e + python main.py \ + --pretrained \ + -t \ + -a resnet18 \ + -b 30 \ + --tuned_checkpoint ${tuned_checkpoint} \ + ${dataset_location} } main "$@" diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/static_quant/pt2e/run_benchmark.sh b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/static_quant/pt2e/run_benchmark.sh new file mode 100644 index 00000000000..169142cddb8 --- /dev/null +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/static_quant/pt2e/run_benchmark.sh @@ -0,0 +1,99 @@ +#!/bin/bash +set -x + +function main { + + init_params "$@" + run_benchmark + +} + +# init params +function init_params { + iters=100 + batch_size=16 + tuned_checkpoint=saved_results + task=lambada_openai + echo ${max_eval_samples} + for var in "$@" + do + case $var in + --topology=*) + topology=$(echo $var |cut -f2 -d=) + ;; + --dataset_location=*) + dataset_location=$(echo $var |cut -f2 -d=) + ;; + --input_model=*) + input_model=$(echo $var |cut -f2 -d=) + ;; + --mode=*) + mode=$(echo $var |cut -f2 -d=) + ;; + --batch_size=*) + batch_size=$(echo $var |cut -f2 -d=) + ;; + --iters=*) + iters=$(echo ${var} |cut -f2 -d=) + ;; + --int8=*) + int8=$(echo ${var} |cut -f2 -d=) + ;; + --config=*) + tuned_checkpoint=$(echo $var |cut -f2 -d=) + ;; + *) + echo "Error: No such parameter: ${var}" + exit 1 + ;; + esac + done + +} + + +# run_benchmark +function run_benchmark { + extra_cmd='' + + if [[ ${mode} == "accuracy" ]]; then + mode_cmd=" --accuracy " + extra_cmd=$extra_cmd + elif [[ ${mode} == "performance" ]]; then + mode_cmd=" --performance --iters "${iters} + extra_cmd=$extra_cmd + else + echo "Error: No such mode: ${mode}" + exit 1 + fi + + if [[ ${int8} == "true" ]]; then + extra_cmd=$extra_cmd" --int8" + fi + echo $extra_cmd + + echo $extra_cmd + + if [ "${topology}" = "opt_125m_pt2e_static" ]; then + model_name_or_path="facebook/opt-125m" + fi + if [[ ${mode} == "accuracy" ]]; then + python -u run_clm_no_trainer.py \ + --model ${model_name_or_path} \ + --output_dir ${tuned_checkpoint} \ + --task ${task} \ + --batch_size ${batch_size} \ + ${extra_cmd} ${mode_cmd} + elif [[ ${mode} == "performance" ]]; then + incbench --num_cores_per_instance 4 run_clm_no_trainer.py \ + --model ${model_name_or_path} \ + --batch_size ${batch_size} \ + --output_dir ${tuned_checkpoint} \ + ${extra_cmd} ${mode_cmd} + else + echo "Error: No such mode: ${mode}" + exit 1 + fi +} + +main "$@" diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/static_quant/pt2e/run_clm_no_trainer.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/static_quant/pt2e/run_clm_no_trainer.py index 98d3f11a1dd..395bc6f9b57 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/static_quant/pt2e/run_clm_no_trainer.py +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/static_quant/pt2e/run_clm_no_trainer.py @@ -14,7 +14,7 @@ "--revision", default=None, help="Transformers parameter: set the model hub commit number") parser.add_argument("--dataset", nargs="?", default="NeelNanda/pile-10k", const="NeelNanda/pile-10k") -parser.add_argument("--output_dir", nargs="?", default="./saved_results") +parser.add_argument("--output_dir", nargs="?", default="") parser.add_argument("--quantize", action="store_true") parser.add_argument("--approach", type=str, default='static', help="Select from ['dynamic', 'static', 'weight-only']") @@ -80,7 +80,7 @@ def get_example_inputs(tokenizer): dynamic_shapes = {"input_ids": (batch, seq_len)} example_inputs = get_example_inputs(tokenizer) exported_model = export(user_model, example_inputs=example_inputs, dynamic_shapes=dynamic_shapes) - + quant_config = get_default_static_config() # prepare prepare_model = prepare(exported_model, quant_config) @@ -90,17 +90,32 @@ def get_example_inputs(tokenizer): prepare_model(*example_inputs) # convert converted_model = convert(prepare_model) - # inference - from torch._inductor import config + + # save + if args.output_dir: + converted_model.save(example_inputs=example_inputs, output_dir = args.output_dir) + + + +if args.int8: + if args.output_dir: + print("Load int8 model.") + from neural_compressor.torch.quantization import load + model = load(args.output_dir) - config.freezing = True - opt_model = torch.compile(converted_model) + model.config = user_model.config # for lm eval + + # Compile the quantized model and replace the Q/DQ pattern with Q-operator + from torch._inductor import config - opt_model.config = user_model.config # for lm eval - user_model = opt_model + config.freezing = True + opt_model = torch.compile(model) + opt_model.config = user_model.config # for lm eval + user_model = opt_model if args.accuracy: + from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate, LMEvalParser eval_args = LMEvalParser( model="hf", @@ -120,29 +135,21 @@ def get_example_inputs(tokenizer): print('Batch size = %d' % args.batch_size) if args.performance: - # user_model.eval() - from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate, LMEvalParser + batch_size, input_leng = args.batch_size, 512 + example_inputs = torch.ones((batch_size, input_leng), dtype=torch.long) + print("Batch size = {:d}".format(batch_size)) + print("The length of input tokens = {:d}".format(input_leng)) import time - samples = args.iters * args.batch_size - eval_args = LMEvalParser( - model="hf", - user_model=user_model, - tokenizer=tokenizer, - batch_size=args.batch_size, - tasks=args.tasks, - limit=samples, - device="cpu", - ) - start = time.time() - results = evaluate(eval_args) - end = time.time() - for task_name in args.tasks.split(","): - if task_name == "wikitext": - acc = results["results"][task_name]["word_perplexity,none"] - else: - acc = results["results"][task_name]["acc,none"] - print("Accuracy: %.5f" % acc) - print('Throughput: %.3f samples/sec' % (samples / (end - start))) - print('Latency: %.3f ms' % ((end - start) * 1000 / samples)) - print('Batch size = %d' % args.batch_size) + total_iters = args.iters + warmup_iters = 5 + with torch.no_grad(): + for i in range(total_iters): + if i == warmup_iters: + start = time.time() + user_model(example_inputs) + end = time.time() + latency = (end - start) / ((total_iters - warmup_iters) * args.batch_size) + throughput = ((total_iters - warmup_iters) * args.batch_size) / (end - start) + print("Latency: {:.3f} ms".format(latency * 10**3)) + print("Throughput: {:.3f} samples/sec".format(throughput)) diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/static_quant/pt2e/run_quant.sh b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/static_quant/pt2e/run_quant.sh index 6bd599483ff..9e995ec8869 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/static_quant/pt2e/run_quant.sh +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/static_quant/pt2e/run_quant.sh @@ -39,8 +39,9 @@ function run_tuning { if [ "${topology}" = "opt_125m_pt2e_static" ]; then model_name_or_path="facebook/opt-125m" + output_dir="saved_results" fi - python run_clm_no_trainer.py --model ${model_name_or_path} --quantize --accuracy --tasks "lambada_openai" + python run_clm_no_trainer.py --model ${model_name_or_path} --quantize --output_dir ${output_dir} --tasks "lambada_openai" } main "$@" diff --git a/neural_compressor/common/__init__.py b/neural_compressor/common/__init__.py index e38627d5c7c..cbda53e57b3 100644 --- a/neural_compressor/common/__init__.py +++ b/neural_compressor/common/__init__.py @@ -15,6 +15,7 @@ from neural_compressor.common.utils import ( level, + level_name, logger, Logger, TuningLogger, @@ -31,6 +32,7 @@ __all__ = [ "options", "level", + "level_name", "logger", "Logger", "TuningLogger", diff --git a/neural_compressor/common/utils/logger.py b/neural_compressor/common/utils/logger.py index 4c933368fdd..a7f0b06009f 100644 --- a/neural_compressor/common/utils/logger.py +++ b/neural_compressor/common/utils/logger.py @@ -24,6 +24,7 @@ __all__ = [ "level", + "level_name", "Logger", # TODO: not expose it "logger", "TuningLogger", @@ -138,6 +139,7 @@ def warning(msg, *args, **kwargs): level = Logger().get_logger().level +level_name = logging.getLevelName(level) logger = Logger diff --git a/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py b/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py index bd1865e674c..759752b7c80 100644 --- a/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py +++ b/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py @@ -149,7 +149,8 @@ def transformation(gm: torch.fx.GraphModule, node_candidate_list: List[str], tar for pattern_pair in HALF_PRECISION_PATTERN_REGISTRY[target_dtype].values(): apply_single_pattern_pair(gm, pattern_pair, node_candidate_list) utils.logger.info("Half precision conversion is done:") - gm.print_readable(True) + if utils.level_name == "DEBUG": # pragma: no cover + gm.print_readable(True) # ============================================================================= diff --git a/neural_compressor/torch/algorithms/pt2e_quant/save_load.py b/neural_compressor/torch/algorithms/pt2e_quant/save_load.py index 606c31f41c2..fb3473d17a8 100644 --- a/neural_compressor/torch/algorithms/pt2e_quant/save_load.py +++ b/neural_compressor/torch/algorithms/pt2e_quant/save_load.py @@ -25,7 +25,8 @@ def save(model, example_inputs, output_dir="./saved_results"): os.makedirs(output_dir, exist_ok=True) qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME) qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), QCONFIG_NAME) - quantized_ep = torch.export.export(model, example_inputs) + dynamic_shapes = model.dynamic_shapes + quantized_ep = torch.export.export(model, example_inputs, dynamic_shapes=dynamic_shapes) torch.export.save(quantized_ep, qmodel_file_path) for key, op_config in model.qconfig.items(): model.qconfig[key] = op_config.to_dict() diff --git a/neural_compressor/torch/export/pt2e_export.py b/neural_compressor/torch/export/pt2e_export.py index 579e816894f..d187f9b5289 100644 --- a/neural_compressor/torch/export/pt2e_export.py +++ b/neural_compressor/torch/export/pt2e_export.py @@ -67,7 +67,9 @@ def export( dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, ) -> Optional[GraphModule]: if not is_ipex_imported(): - return export_model_for_pt2e_quant(model, example_inputs, dynamic_shapes) + model = export_model_for_pt2e_quant(model, example_inputs, dynamic_shapes) + model.dynamic_shapes = dynamic_shapes + return model else: # TODO, add `export` for ipex pass diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index 2a3eada9bf5..9281dd305e2 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -215,6 +215,7 @@ def pt2e_dynamic_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode run_fn = kwargs.get("run_fn", None) example_inputs = kwargs.get("example_inputs", None) inplace = kwargs.get("inplace", True) + dynamic_shapes = model.dynamic_shapes W8A8PT2EQuantizer.is_dynamic = True for _, quant_config in configs_mapping.items(): if quant_config.name == PT2E_DYNAMIC_QUANT: @@ -222,6 +223,7 @@ def pt2e_dynamic_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode model = w8a8_quantizer.execute( model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace ) + model.dynamic_shapes = dynamic_shapes model.qconfig = configs_mapping model.save = MethodType(save, model) return model @@ -238,12 +240,14 @@ def pt2e_static_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode, run_fn = kwargs.get("run_fn", None) example_inputs = kwargs.get("example_inputs", None) inplace = kwargs.get("inplace", True) + dynamic_shapes = model.dynamic_shapes for _, quant_config in configs_mapping.items(): if quant_config.name == STATIC_QUANT: w8a8_quantizer = W8A8PT2EQuantizer(quant_config=quant_config) model = w8a8_quantizer.execute( model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace ) + model.dynamic_shapes = dynamic_shapes model.qconfig = configs_mapping model.save = MethodType(save, model) return model