From 42ee7f5f5cd8cd4493f3e158d3cd5453fe4e35b3 Mon Sep 17 00:00:00 2001 From: Hanxian97 Date: Wed, 28 Aug 2024 12:01:33 -0700 Subject: [PATCH 1/3] add readme and code refactor --- .../prototype/mixed_precision/README.md | 165 ++++++++++++++++++ .../scripts/BO_acc_modelsize.py | 129 +++----------- .../scripts/BO_acc_throughput.py | 29 +-- .../scripts/Llama3-8B_initial_samples.json | 14 ++ .../scripts/Llama3-8B_parameters.json | 22 +++ .../scripts/Mistral-7B_initial_samples.json | 14 ++ .../scripts/Mistral-7B_parameters.json | 20 +++ .../prototype/mixed_precision/scripts/fit.py | 16 +- .../mixed_precision/scripts/hessian_grad.py | 14 +- .../mixed_precision/scripts/hessian_vhp.py | 11 +- .../mixed_precision/scripts/utils.py | 52 ++++++ 11 files changed, 350 insertions(+), 136 deletions(-) create mode 100644 torchao/quantization/prototype/mixed_precision/README.md create mode 100644 torchao/quantization/prototype/mixed_precision/scripts/Llama3-8B_initial_samples.json create mode 100644 torchao/quantization/prototype/mixed_precision/scripts/Llama3-8B_parameters.json create mode 100644 torchao/quantization/prototype/mixed_precision/scripts/Mistral-7B_initial_samples.json create mode 100644 torchao/quantization/prototype/mixed_precision/scripts/Mistral-7B_parameters.json diff --git a/torchao/quantization/prototype/mixed_precision/README.md b/torchao/quantization/prototype/mixed_precision/README.md new file mode 100644 index 000000000..bb8880632 --- /dev/null +++ b/torchao/quantization/prototype/mixed_precision/README.md @@ -0,0 +1,165 @@ +# Bayesian Optimization for Mixed-Precision Quantization +We provide a Bayesian Optimization (BO) tool to decide the post-training mixed-precision weight-only quantization configuration of a given pre-trained transformer model. It assigns different bitwidth and groupsize for each layer to shrink the model size or speedup the inference while preserving model accuracy. It also provides a sensitivity analysis tool and opens an option to assign initial configurations based on the sensitivity analysis, to further improve BO. + +## Usage + +### Dependencies: +The tool relies on lm_eval to measure model accuracy and ax-platform to conduct BO search. To install: +``` +pip install lm_eval +pip install ax-platform +``` +### Optional Step: Usage of sensitivity tool + +We provide a sensitivity tool to calculate the [average Hessian matrix trace](https://arxiv.org/pdf/1911.03852) and the [fisher information matrix trace (FIT)](https://arxiv.org/pdf/2210.08502). With the sensitivity scores, we are able to identify sensitivity-guided initial configurations to better initialize the BO search. This step is optinoal to use BO tool. + +#### Average Hessian trace: +Hessian is the second order partial derivation of the loss function and a higher average Hessian trace demonstrates a higher sensitivity of a layer to perturbations. Now the tool supports calculating one layer at a time to avoid out of memory issue for large models, e.g., Llama3-8B. It leverages the fast vhp (vector-hessian product) function from torch to achieve more efficient. To calculate average Hessian matrix trace of a layer on a calibration dataset (wikitext): +``` +python scripts/hessian_vhp.py --layer_id=LAYER_ID --checkpoint=/tmp/Meta-Llama-3-8B --max_seqlen=256 --max_iter=100 --nsamples=512 +``` +where, +--layer_id identifies which layer to calculate the average Hessian trace, LAYER_ID is an integer number used to identify the layer in the module name + +The tool will print out the average Hessian trace based on the calibration dataset for the certain layer. Calculating Hessian trace is both memory-intensive and computationally expensive, the current tool takes 4 days with 4 GPUs on a calibration dataset of 512 samples for Llama3-8B. + +#### FIT: +FIT quantifies the total amount of information in the data about the parameter. It has been theoretically and empirically proved to be very close to Hession but with higher efficiency ([FIT paper])(https://arxiv.org/pdf/2210.08502). The tool support calculate the FIT score for all the layers at once. To calculate the FIT of the whole model on a calibration dataset (wikitext): +``` +python scripts/fit.py --num_layers=32 --checkpoint=/tmp/Meta-Llama-3-8B --max_seqlen=2048 --max_iter=100 --nsamples=128 +``` +The tool will print out the average FIT scores based on the calibration dataset for all the layers. where the arguments checkpoint, max_seqlen, nsamples, max_iter are similar to the usage of running Hession. The only difference is that we replacing --layer_id with --num_layers to identify the total numbers of layers to calculate FIT scores for. + +Calculating FIT takes 3.3h with 1 GPU on a calibration dataset of 512 samples for Llama3-8B. + +### Usage of BO search + +#### Step 1: Define parameter space + +Given a model, to conduct a BO search, we first need to identify the parameter space for the model, ie., for each layer, set up the value or choices of bitwidth and groupsize. An example of parameter space configuration is shown below and in Llama3-8B_parameters.json. + +``` + { + "name": "bitwidth", + "name_format": "bitwidth.{i}.", + "layers": [ + {"range": [0, 3], "type": "fixed", "value": 5}, # the first 3 layers are assigned to the fixed bitwidth = 5 + {"range": [3, 32], "type": "choice", "values": [2, 3, 4, 5, 6, 8]}, # the bitwidths of the rest 29 layers will choose from the list + ] + }, +``` +A parameter for a layer (specified in the range) can be either "fixed" or "choice" type for a fixed value or a list of possible choices. A default parameter space setting will be search from [2, 3, 4, 5, 6, 8] bit and [32, 64, 128, 512] groupsize for each layer. + +#### Step 2: Define initial samples (optional) +Then an optional step is to obtain some better initial samples based on the sensitivity scores. A layer with a higher sensitivity score (Hessian or FIT) should be assigned with a higher bitwidth and a smaller groupsize, to preserve the model accuracy. E.g., the FIT scores for the first 3 layers are far higher then other layers, thus we can set <5-bit, groupsize=32> for them and <4-bit, groupsize=64> for all the other layers. An example of initial samples of BO search is shown below and in Llama3-8B_initial_samples.json. A default initial samples will be random sampling from valid parameter space. We recommend users to add at least 10 examples to better initialize the BO strategy. + +``` +{ + "initial_samples": [ + { + "bitwidth.0.": 8, + "groupsize.0.": 64, + "bitwidth.1.": 4, + "groupsize.1.": 32, + "bitwidth.2.": 5, + "groupsize.2.": 128, + }, + ] +} + +``` + +#### Step 3: Run BO experiment +To conduct BO search to optimize model accuracy under a certain model size constraint: + +``` +python --BO_acc_modelsize.py --checkpoint=/tmp/Meta-Llama-3-8B --num_trials=200 --model_size_constraint=6.0 --output_file=BO_acc_modelsize_output.csv --parameters_list=Llama3-8B_parameters.json --initial_samples=Llama3-8B_initial_samples.json --gpu_lists=0,1,2,3" +``` + +where +--num_trials identifies the number of search for BO +--model_size_constraint identifies the max model size for valid search results +--parameters_list identifies the path to load parameter space. +--initial_samples identifies the path to get initial samples of BO search +--gpu_lists enbles evaluating BO different BO trials on different GPUs, otherwise will use only one GPU + +For Llam3-8B, a search takes 1.5h on wikitext-document from lm_eval on 8 A100 GPUs with 80GB GPU memory. + +Example outputs: +The tool will print out the best configuration and results (accuracy, model size or throughput) among the search. + +``` +------Best config------ +{'cal_PPL': 7.4736, 'model_size': 5.9766} {'bitwidth.0.': 5, 'groupsize.0.': 32, 'bitwidth.1.': 5, 'groupsize.1.': 32,...,'bitwidth.31.': 5, 'groupsize.31.': 32} +``` + +The tool will also write the BO search trial history to history_output csv file with three columns: + +| cal_PPL |model size | quant_config| +| ---------------- | ------ | ------ | +| 7.5286 | 5.8418 | {'bitwidth.0.': 4, 'groupsize.0.': 64, 'bitwidth.1.': 6, 'groupsize.1.': 32,...,'bitwidth.31.': 5, 'groupsize.31.': 32} | +| 7.4736 | 5.9766 | {'bitwidth.0.': 5, 'groupsize.0.': 32, 'bitwidth.1.': 5, 'groupsize.1.': 32,...,'bitwidth.31.': 5, 'groupsize.31.': 32} | +... + +#### Run BO to optimize inference speed +We also provide another version of BO search to optimize inference throughput (with torch.compile()) under a certain model accuracy constraint: +``` +python --BO_acc_throughput.py --checkpoint=/tmp/Meta-Llama-3-8B --num_BO_initial_samples=10 --num_trials=200 --ppl_constraint=7.5 --output_file=BO_acc_modelsize_output.csv --parameters_list=Llama3-8B_parameters.json --initial_samples Llama3-8B_initial_samples.json +``` +All the arguments are similar to the optmizing accuracy under model size constraint, except replacing --model_size_constraint with --ppl_constraint=7.5 to set up the perplexity limit of the valid search results. + +Similarly, the tool will output the best configuration for both inference throughput and model accuracy. +``` +------Best config------ +{'cal_throughput': 147.72, 'cal_PPL': 7.3134} {'bitwidth.0.': 5, 'groupsize.0.': 32, 'bitwidth.1.': 5, 'groupsize.1.': 32,...,'bitwidth.31.': 5, 'groupsize.31.': 32} +``` +and write out the BO search history file: +| cal_throughput | cal_PPL | quant_config| +| ---------------- | ------ | ------ | +| 135.64 | 7.5322 | {'bitwidth.0.': 6, 'groupsize.0.': 64, 'bitwidth.1.': 4, 'groupsize.1.': 128,...,'bitwidth.31.': 5, 'groupsize.31.': 64} | +| 147.72 | 7.3134 | {'bitwidth.0.': 5, 'groupsize.0.': 32, 'bitwidth.1.': 5, 'groupsize.1.': 32,...,'bitwidth.31.': 5, 'groupsize.31.': 32} | +... + +#### Run BO for other models +We are supporting more models, such as more transformer models and ViT models. To run all the above experiments for a new model e.g., Mistrial-7B-v0.1, you will need to specified the correct path to load model with --checkpoint, the desired parameters space with --parameters_list and the optional your pre-defined initial samples with --initial_samples, with the following command, similarly for optimizing the inference speed: + +``` +python --BO_acc_modelsize.py --checkpoint=/tmp/Mistral-7B-v0.1/ --num_trials=200 --model_size_constraint=6.0 --output_file=BO_acc_modelsize_output.csv --parameters_list=Mistral-7B_parameters.json --initial_samples=Mistral-7B_initial_samples.json --gpu_lists=0,1,2,3" +``` + +Supports for ViT models is coming soon. + + +## Results +We evaluate BO search for Llama3-8B and Mistral-7B-v0.1 under two settings: (1) optimizing model accuracy under model size constraint; (2) optimizing model inference throughput under model accuracy constraint. + +### Results of BO for optimizing model accuracy under model size constraint + +For Llama3-8B, the BO search quantization saves 60.2% model size with 2.89% ppl degradation compared to bfloat-16 baseline. +| Llama3-8B |ppl | model size| +| ---------------- | ------ | ------ | +| bf16 baseline | 7.260 | 15.01 | +| int8wo uniform | 7.263 | 7.480 | +| int4wo uniform quantization | 7.900 | 5.411 | +| manual baseline | 7.679 | 5.545 | +| BO mixed-precision quantization | 7.470 | 5.976 | + + +For Mistral-7B-v0.1, BO search quantization saves 59.4% model size with only 1.8% ppl degradation compared to bfloat-16 baseline. +| Mistral-7B-v0.1 |ppl | model size| +| ---------------- | ------ | ------ | +| bf16 baseline | 8.021 | 13.49 | +| int8wo uniform quantization | 8.028 | 7.90 | +| int4wo uniform quantization | 8.387 | 4.65 | +| BO mixed-precision quantization | 8.168 | 5.48 | + + +### Results of BO for optimizing model inference throughput under model accuracy constraint +For Llama3-8B, the BO search quantization improving 69.5% throughput with only 3.25% ppl degradation compared to bfloat-16 baseline. + +| Llama3-8B |ppl | throughput| +| ---------------- | ------ | ------ | +| bf16 baseline | 7.260 | 94.97 | +| int8wo uniform quantization | 7.263 | 139.76 | +| int4wo uniform quantization | 7.900 | 179.44 | +| BO mixed-precision quantization | 7.470 | 160.96 | diff --git a/torchao/quantization/prototype/mixed_precision/scripts/BO_acc_modelsize.py b/torchao/quantization/prototype/mixed_precision/scripts/BO_acc_modelsize.py index f58fab5cf..89aff675b 100644 --- a/torchao/quantization/prototype/mixed_precision/scripts/BO_acc_modelsize.py +++ b/torchao/quantization/prototype/mixed_precision/scripts/BO_acc_modelsize.py @@ -16,7 +16,7 @@ from ax.service.ax_client import AxClient, ObjectiveProperties import torch.multiprocessing as mp from ax.modelbridge.cross_validation import cross_validate -from utils import write_history_to_csv, cal_wikitext_ppl, cal_model_size, load_model, quantize_by_fqn_to_config +from utils import write_history_to_csv, cal_wikitext_ppl, cal_model_size, load_model, quantize_by_fqn_to_config, load_parameters_from_json, load_initial_samples # return evaluation results to complete BO trials def eval(model, tokenizer, num_PPL_eval_samples, fqn_to_config): @@ -25,85 +25,9 @@ def eval(model, tokenizer, num_PPL_eval_samples, fqn_to_config): "model_size": (cal_model_size(model, fqn_to_config), 0.0), } -# TODO: make it into a yaml or json file to enable users specify their custom model formats -def define_parameter_list(): - - # define the search space for all layers - parameters_list = [] - - for i in range(0, 3): - parameters_list.append( - { - "name": f"bitwidth.{i}.", - "type": "fixed", - "value_type": "int", - "value": 5, - "is_ordered": True, - "sort_values": True, - } - ) - - parameters_list.append( - { - "name": f"groupsize.{i}.", - "type": "fixed", - "value_type": "int", - "value": 32, - "is_ordered": True, - "sort_values": True, - } - ) - - for i in range(3, 30): - parameters_list.append( - { - "name": f"bitwidth.{i}.", - "type": "choice", - "value_type": "int", - "values": [2,3,4,5,6,8], - "is_ordered": True, - "sort_values": True, - } - ) - - parameters_list.append( - { - "name": f"groupsize.{i}.", - "type": "choice", - "value_type": "int", - "values": [32, 64, 128, 256], - "is_ordered": True, - "sort_values": True, - } - ) - - for i in range(30, 32): - parameters_list.append( - { - "name": f"bitwidth.{i}.", - "type": "fixed", - "value_type": "int", - "value": 5, - "is_ordered": True, - "sort_values": True, - } - ) - parameters_list.append( - { - "name": f"groupsize.{i}.", - "type": "fixed", - "value_type": "int", - "value": 32, - "is_ordered": True, - "sort_values": True, - } - ) - - return parameters_list - # add initial search points based on the sensitivity score -# TODO: automate the initial samples by better leverage the sensitivity scores -def get_initial_samples(num_BO_initial_samples=50): +# TODO: add random initial samples if no sensitivity prior +def get_initial_samples(num_BO_initial_samples=10): initial_points_set = [] # auto sample the bit choices with random choice probability positive correlated to FIT score @@ -142,10 +66,12 @@ def get_initial_samples(num_BO_initial_samples=50): One trial, one BO update. TODO: refactor the sequential BO and parallel BO into a single function ''' -def run_sequential_BO(device, checkpoint, num_PPL_eval_samples, num_BO_initial_samples, num_trials, model_size_constraint, output_file): +def run_sequential_BO(device, checkpoint, num_PPL_eval_samples, num_trials, model_size_constraint, history_output, parameters_list, initial_samples): - parameters_list = define_parameter_list() - initial_points_set = get_initial_samples(num_BO_initial_samples) + # TODO: add default parameter list if not specified + parameters_list = load_parameters_from_json(parameters_list) + initial_points_set = load_initial_samples(initial_samples) + num_BO_initial_samples =len(initial_points_set) #initialize ax_client constraint="model_size <= "+str(model_size_constraint) @@ -208,15 +134,12 @@ def run_sequential_BO(device, checkpoint, num_PPL_eval_samples, num_BO_initial_s del m torch.cuda.empty_cache() - - print("------Finish BO------") - for h in history: - print(h) - write_history_to_csv(history, output_file, ["cal_PPL", "model_size", "quant_config"]) + #write BO search trial history to csv file + write_history_to_csv(history, history_output, ["cal_PPL", "model_size", "quant_config"]) print("------Best config------") best_parameters, values = ax_client.get_best_parameters() - print(best_parameters, values) + print(values, best_parameters) # Worker function to perform BO trials on a specific GPU def eval_in_parallel(gpu_id, checkpoint, num_PPL_eval_samples, config, return_dict, proc_id, trial_id): @@ -240,11 +163,13 @@ def eval_in_parallel(gpu_id, checkpoint, num_PPL_eval_samples, config, return_di Each time the BO gets multiple new trials, evaluates the trials on the GPUs and return the evaluation results to update the BO. Multiple trials, one BO update. ''' -def run_parallel_BO(device, checkpoint, num_PPL_eval_samples, num_BO_initial_samples, num_trials, model_size_constraint, gpu_list): +def run_parallel_BO(device, checkpoint, num_PPL_eval_samples, num_trials, model_size_constraint, gpu_list, history_output, parameters_list, initial_samples): + # TODO: add default parameter list if not specified parameters_list = define_parameter_list() - initial_points_set = get_initial_samples(num_BO_initial_samples) - + initial_points_set = load_initial_samples(initial_samples) + num_BO_initial_samples =len(initial_points_set) + #initialize ax_client constraint="model_size <= "+str(model_size_constraint) ax_client = AxClient() @@ -318,14 +243,12 @@ def run_parallel_BO(device, checkpoint, num_PPL_eval_samples, num_BO_initial_sam history.append((eval_results, config)) ax_client.complete_trial(trial_index = current_trial_id, raw_data = eval_results,) - print("------Finish BO------") - for h in history: - print(h) - write_history_to_csv(history, output_file, ["cal_PPL", "model_size", "quant_config"]) + #write BO search trial history to csv file + write_history_to_csv(history, history_output, ["cal_PPL", "model_size", "quant_config"]) print("------Best config------") best_parameters, values = ax_client.get_best_parameters() - print(best_parameters, values) + print(values, best_parameters) if __name__ == '__main__': @@ -335,14 +258,16 @@ def run_parallel_BO(device, checkpoint, num_PPL_eval_samples, num_BO_initial_sam parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation') parser.add_argument('--checkpoint', type=str, default="/tmp/Meta-Llama-3-8B", help='Path to load model') parser.add_argument('--num_PPL_eval_samples', type=int, default=None, help='Number of samples to evaluate ppl') - parser.add_argument('--num_BO_initial_samples', type=int, default=50, help='Number of initial points sampled by sensitivity scores') - parser.add_argument('--num_trials', type=int, default=150, help='Number of trials to run BO') + parser.add_argument('--num_trials', type=int, default=200, help='Number of trials to run BO') parser.add_argument('--model_size_constraint', type=float, default=6.0, help='The model size (GB) constraint for BO') parser.add_argument('--gpu_list', type=str, default="", help="A list of gpus to run evaluation, separated by comma, e.g., --gpu_lists=0,1,2,3") - parser.add_argument('--output_path', type=str, default="BO_acc_modelsize_output.csv", help="The file path to save the BO search trials") + parser.add_argument('--history_output', type=str, default="BO_acc_modelsize_output.csv", help="The csv file path to save the BO search trials") + parser.add_argument('--parameters_list', type=str, default="Llama3-8B_parameters.json", help="The json file path to save the parameters list for BO") + parser.add_argument('--initial_samples', type=str, default="Llama3-8B_initial_samples.json", help="The json file path to save the user-defined initial samples for BO") + args = parser.parse_args() - if args.gpu_list != "": - run_sequential_BO(device=args.device, checkpoint=args.checkpoint, num_PPL_eval_samples=args.num_PPL_eval_samples, num_BO_initial_samples=args.num_BO_initial_samples, num_trials=args.num_trials, model_size_constraint=args.model_size_constraint, output_path=args.output_path) + if args.gpu_list == "": + run_sequential_BO(device=args.device, checkpoint=args.checkpoint, num_PPL_eval_samples=args.num_PPL_eval_samples, num_trials=args.num_trials, model_size_constraint=args.model_size_constraint, history_output=args.history_output, parameters_list=args.parameters_list, initial_samples=args.initial_samples) else: - run_parallel_BO(device=args.device, checkpoint=args.checkpoint, num_PPL_eval_samples=args.num_PPL_eval_samples, num_BO_initial_samples=args.num_BO_initial_samples, num_trials=args.num_trials, model_size_constraint=args.model_size_constraint, gpu_list=args.gpu_list, output_path=args.output_path) + run_parallel_BO(device=args.device, checkpoint=args.checkpoint, num_PPL_eval_samples=args.num_PPL_eval_samples, num_trials=args.num_trials, model_size_constraint=args.model_size_constraint, gpu_list=args.gpu_list, history_output=args.history_output, parameters_list=args.parameters_list, initial_samples=args.initial_samples) diff --git a/torchao/quantization/prototype/mixed_precision/scripts/BO_acc_throughput.py b/torchao/quantization/prototype/mixed_precision/scripts/BO_acc_throughput.py index 85138403b..6e039b23c 100644 --- a/torchao/quantization/prototype/mixed_precision/scripts/BO_acc_throughput.py +++ b/torchao/quantization/prototype/mixed_precision/scripts/BO_acc_throughput.py @@ -45,7 +45,7 @@ _load_model, ) -from utils import write_history_to_csv, cal_wikitext_ppl, load_model, quantize_by_fqn_to_config +from utils import write_history_to_csv, cal_wikitext_ppl, load_model, quantize_by_fqn_to_config, load_parameters_from_json default_device = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -321,8 +321,8 @@ def define_parameter_list(): return parameters_list # add initial search points based on the sensitivity score -# TODO: automate the initial samples by better leverage the sensitivity scores -def get_initial_samples(num_BO_initial_samples=50): +# TODO: add default parameter list if not specified +def get_initial_samples(num_BO_initial_samples=10): initial_points_set = [] @@ -362,7 +362,7 @@ def get_initial_samples(num_BO_initial_samples=50): Each time the BO gets one new trial, evaluates the trial on the GPU and return the evaluation results to update the BO. One trial, one BO update. ''' -def run_sequential_BO(device, checkpoint_path, repo_id, num_PPL_eval_samples, num_BO_initial_samples, num_trials, ppl_constraint, args): +def run_sequential_BO(device, checkpoint_path, repo_id, num_PPL_eval_samples, num_trials, ppl_constraint, args): ''' currently use the loader and benchmark code from torchao/_models/llama/generate, and use lm_eval for ppl evaluation @@ -376,10 +376,12 @@ def run_sequential_BO(device, checkpoint_path, repo_id, num_PPL_eval_samples, nu tokenizer4ppl = AutoTokenizer.from_pretrained(repo_id) # initialize parameters - parameters_list = define_parameter_list() + # TODO: add default parameter list if not specified + parameters_list = load_parameters_from_json(args.parameters_list) # sample initial points - initial_points_set = get_initial_samples(num_BO_initial_samples) + initial_points_set = load_initial_samples(initial_samples) + num_BO_initial_samples = len(initial_points_set) # initialize BO experiment constraint="cal_PPL <= "+str(ppl_constraint) @@ -458,14 +460,12 @@ def run_sequential_BO(device, checkpoint_path, repo_id, num_PPL_eval_samples, nu raw_data=eval_results, ) - print("------Finish BO------") - for h in history: - print(h) - write_history_to_csv(history, args.output_file, ["cal_PPL", "cal_throughput", "quant_config"]) + #write BO search trial history to csv file + write_history_to_csv(history, args.history_output, ["cal_PPL", "cal_throughput", "quant_config"]) print("------Best config------") best_parameters, values = ax_client.get_best_parameters() - print(best_parameters, values) + print(values, best_parameters) if __name__ == '__main__': @@ -476,12 +476,13 @@ def run_sequential_BO(device, checkpoint_path, repo_id, num_PPL_eval_samples, nu parser.add_argument('--checkpoint_path', type=Path, default=Path("/tmp/Meta-Llama-3-8B/model.pth"), help='Model checkpoint path for model.pth.') parser.add_argument('--repo_id', type=str, default=Path("/tmp/Meta-Llama-3-8B"), help='Model repo id.') parser.add_argument('--num_PPL_eval_samples', type=int, default=None, help='Number of samples to evaluate ppl') - parser.add_argument('--num_BO_initial_samples', type=int, default=50, help='Number of initial points sampled by sensitivity scores') parser.add_argument('--num_trials', type=int, default=150, help='Number of trials to run BO') parser.add_argument('--ppl_constraint', type=float, default=7.5, help='The ppl constraint for BO') parser.add_argument('--multi_gpus', action='store_true', help="Use multi-processing to run evaluation on multi-gpus") parser.add_argument('--gpu_list', type=str, default="", help="A list of gpus to run evaluation, separated by comma, e.g., --gpu_lists=0,1,2,3") - parser.add_argument('--output_path', type=str, default="BO_acc_speed_output.csv", help="The csv file path to save the BO search trials") + parser.add_argument('--history_output', type=str, default="BO_acc_speed_output.csv", help="The csv file path to save the BO search trials") + parser.add_argument('--parameters_list', type=str, default="Llama3-8B_parameters.json", help="The json file path to save the parameters list for BO") + parser.add_argument('--initial_samples', type=str, default="Llama3-8B_initial_samples.json", help="The json file path to save the user-defined initial samples for BO") args = parser.parse_args() - run_sequential_BO(device=args.device, checkpoint_path=args.checkpoint_path, repo_id=args.repo_id, num_PPL_eval_samples=args.num_PPL_eval_samples, num_BO_initial_samples=args.num_BO_initial_samples, num_trials=args.num_trials, ppl_constraint=args.ppl_constraint, args=args) + run_sequential_BO(device=args.device, checkpoint_path=args.checkpoint_path, repo_id=args.repo_id, num_PPL_eval_samples=args.num_PPL_eval_samples, num_trials=args.num_trials, ppl_constraint=args.ppl_constraint, args=args) diff --git a/torchao/quantization/prototype/mixed_precision/scripts/Llama3-8B_initial_samples.json b/torchao/quantization/prototype/mixed_precision/scripts/Llama3-8B_initial_samples.json new file mode 100644 index 000000000..698ee2cee --- /dev/null +++ b/torchao/quantization/prototype/mixed_precision/scripts/Llama3-8B_initial_samples.json @@ -0,0 +1,14 @@ +{ + "initial_samples": [ + {"bitwidth.0.": 5, "groupsize.0.": 32, "bitwidth.1.": 5, "groupsize.1.": 32, "bitwidth.2.": 5, "groupsize.2.": 32, "bitwidth.3.": 4, "groupsize.3.": 32, "bitwidth.4.": 4, "groupsize.4.": 32, "bitwidth.5.": 4, "groupsize.5.": 32, "bitwidth.6.": 4, "groupsize.6.": 32, "bitwidth.7.": 5, "groupsize.7.": 32, "bitwidth.8.": 4, "groupsize.8.": 32, "bitwidth.9.": 5, "groupsize.9.": 32, "bitwidth.10.": 5, "groupsize.10.": 32, "bitwidth.11.": 5, "groupsize.11.": 32, "bitwidth.12.": 4, "groupsize.12.": 32, "bitwidth.13.": 4, "groupsize.13.": 32, "bitwidth.14.": 3, "groupsize.14.": 32, "bitwidth.15.": 4, "groupsize.15.": 32, "bitwidth.16.": 4, "groupsize.16.": 32, "bitwidth.17.": 4, "groupsize.17.": 32, "bitwidth.18.": 5, "groupsize.18.": 32, "bitwidth.19.": 5, "groupsize.19.": 32, "bitwidth.20.": 5, "groupsize.20.": 32, "bitwidth.21.": 4, "groupsize.21.": 32, "bitwidth.22.": 4, "groupsize.22.": 32, "bitwidth.23.": 4, "groupsize.23.": 32, "bitwidth.24.": 5, "groupsize.24.": 32, "bitwidth.25.": 4, "groupsize.25.": 32, "bitwidth.26.": 4, "groupsize.26.": 32, "bitwidth.27.": 5, "groupsize.27.": 32, "bitwidth.28.": 4, "groupsize.28.": 32, "bitwidth.29.": 5, "groupsize.29.": 32, "bitwidth.30.": 5, "groupsize.30.": 32, "bitwidth.31.": 5, "groupsize.31.": 32}, + {"bitwidth.0.": 5, "groupsize.0.": 32, "bitwidth.1.": 5, "groupsize.1.": 32, "bitwidth.2.": 5, "groupsize.2.": 32, "bitwidth.3.": 4, "groupsize.3.": 32, "bitwidth.4.": 5, "groupsize.4.": 32, "bitwidth.5.": 5, "groupsize.5.": 32, "bitwidth.6.": 5, "groupsize.6.": 32, "bitwidth.7.": 5, "groupsize.7.": 32, "bitwidth.8.": 3, "groupsize.8.": 32, "bitwidth.9.": 4, "groupsize.9.": 32, "bitwidth.10.": 5, "groupsize.10.": 32, "bitwidth.11.": 5, "groupsize.11.": 32, "bitwidth.12.": 5, "groupsize.12.": 32, "bitwidth.13.": 5, "groupsize.13.": 32, "bitwidth.14.": 5, "groupsize.14.": 32, "bitwidth.15.": 5, "groupsize.15.": 32, "bitwidth.16.": 4, "groupsize.16.": 32, "bitwidth.17.": 4, "groupsize.17.": 32, "bitwidth.18.": 5, "groupsize.18.": 32, "bitwidth.19.": 5, "groupsize.19.": 32, "bitwidth.20.": 5, "groupsize.20.": 32, "bitwidth.21.": 4, "groupsize.21.": 32, "bitwidth.22.": 5, "groupsize.22.": 32, "bitwidth.23.": 5, "groupsize.23.": 32, "bitwidth.24.": 5, "groupsize.24.": 32, "bitwidth.25.": 5, "groupsize.25.": 32, "bitwidth.26.": 5, "groupsize.26.": 32, "bitwidth.27.": 5, "groupsize.27.": 32, "bitwidth.28.": 3, "groupsize.28.": 32, "bitwidth.29.": 5, "groupsize.29.": 32, "bitwidth.30.": 5, "groupsize.30.": 32, "bitwidth.31.": 5, "groupsize.31.": 32}, + {"bitwidth.0.": 5, "groupsize.0.": 32, "bitwidth.1.": 5, "groupsize.1.": 32, "bitwidth.2.": 5, "groupsize.2.": 32, "bitwidth.3.": 4, "groupsize.3.": 32, "bitwidth.4.": 4, "groupsize.4.": 32, "bitwidth.5.": 5, "groupsize.5.": 32, "bitwidth.6.": 4, "groupsize.6.": 32, "bitwidth.7.": 4, "groupsize.7.": 32, "bitwidth.8.": 5, "groupsize.8.": 32, "bitwidth.9.": 4, "groupsize.9.": 32, "bitwidth.10.": 5, "groupsize.10.": 32, "bitwidth.11.": 4, "groupsize.11.": 32, "bitwidth.12.": 5, "groupsize.12.": 32, "bitwidth.13.": 5, "groupsize.13.": 32, "bitwidth.14.": 5, "groupsize.14.": 32, "bitwidth.15.": 4, "groupsize.15.": 32, "bitwidth.16.": 4, "groupsize.16.": 32, "bitwidth.17.": 4, "groupsize.17.": 32, "bitwidth.18.": 5, "groupsize.18.": 32, "bitwidth.19.": 4, "groupsize.19.": 32, "bitwidth.20.": 5, "groupsize.20.": 32, "bitwidth.21.": 5, "groupsize.21.": 32, "bitwidth.22.": 4, "groupsize.22.": 32, "bitwidth.23.": 5, "groupsize.23.": 32, "bitwidth.24.": 4, "groupsize.24.": 32, "bitwidth.25.": 4, "groupsize.25.": 32, "bitwidth.26.": 5, "groupsize.26.": 32, "bitwidth.27.": 5, "groupsize.27.": 32, "bitwidth.28.": 4, "groupsize.28.": 32, "bitwidth.29.": 5, "groupsize.29.": 32, "bitwidth.30.": 5, "groupsize.30.": 32, "bitwidth.31.": 5, "groupsize.31.": 32}, + {"bitwidth.0.": 5, "groupsize.0.": 32, "bitwidth.1.": 5, "groupsize.1.": 32, "bitwidth.2.": 5, "groupsize.2.": 32, "bitwidth.3.": 5, "groupsize.3.": 32, "bitwidth.4.": 4, "groupsize.4.": 32, "bitwidth.5.": 4, "groupsize.5.": 32, "bitwidth.6.": 5, "groupsize.6.": 32, "bitwidth.7.": 4, "groupsize.7.": 32, "bitwidth.8.": 4, "groupsize.8.": 32, "bitwidth.9.": 4, "groupsize.9.": 32, "bitwidth.10.": 4, "groupsize.10.": 32, "bitwidth.11.": 4, "groupsize.11.": 32, "bitwidth.12.": 5, "groupsize.12.": 32, "bitwidth.13.": 5, "groupsize.13.": 32, "bitwidth.14.": 4, "groupsize.14.": 32, "bitwidth.15.": 5, "groupsize.15.": 32, "bitwidth.16.": 4, "groupsize.16.": 32, "bitwidth.17.": 5, "groupsize.17.": 32, "bitwidth.18.": 4, "groupsize.18.": 32, "bitwidth.19.": 4, "groupsize.19.": 32, "bitwidth.20.": 5, "groupsize.20.": 32, "bitwidth.21.": 5, "groupsize.21.": 32, "bitwidth.22.": 5, "groupsize.22.": 32, "bitwidth.23.": 4, "groupsize.23.": 32, "bitwidth.24.": 4, "groupsize.24.": 32, "bitwidth.25.": 4, "groupsize.25.": 32, "bitwidth.26.": 4, "groupsize.26.": 32, "bitwidth.27.": 5, "groupsize.27.": 32, "bitwidth.28.": 5, "groupsize.28.": 32, "bitwidth.29.": 5, "groupsize.29.": 32, "bitwidth.30.": 5, "groupsize.30.": 32, "bitwidth.31.": 5, "groupsize.31.": 32}, + {"bitwidth.0.": 5, "groupsize.0.": 32, "bitwidth.1.": 5, "groupsize.1.": 32, "bitwidth.2.": 5, "groupsize.2.": 32, "bitwidth.3.": 4, "groupsize.3.": 32, "bitwidth.4.": 4, "groupsize.4.": 32, "bitwidth.5.": 4, "groupsize.5.": 32, "bitwidth.6.": 4, "groupsize.6.": 32, "bitwidth.7.": 5, "groupsize.7.": 32, "bitwidth.8.": 4, "groupsize.8.": 32, "bitwidth.9.": 4, "groupsize.9.": 32, "bitwidth.10.": 4, "groupsize.10.": 32, "bitwidth.11.": 4, "groupsize.11.": 32, "bitwidth.12.": 5, "groupsize.12.": 32, "bitwidth.13.": 4, "groupsize.13.": 32, "bitwidth.14.": 4, "groupsize.14.": 32, "bitwidth.15.": 4, "groupsize.15.": 32, "bitwidth.16.": 4, "groupsize.16.": 32, "bitwidth.17.": 4, "groupsize.17.": 32, "bitwidth.18.": 4, "groupsize.18.": 32, "bitwidth.19.": 4, "groupsize.19.": 32, "bitwidth.20.": 4, "groupsize.20.": 32, "bitwidth.21.": 4, "groupsize.21.": 32, "bitwidth.22.": 4, "groupsize.22.": 32, "bitwidth.23.": 4, "groupsize.23.": 32, "bitwidth.24.": 4, "groupsize.24.": 32, "bitwidth.25.": 4, "groupsize.25.": 32, "bitwidth.26.": 4, "groupsize.26.": 32, "bitwidth.27.": 5, "groupsize.27.": 32, "bitwidth.28.": 4, "groupsize.28.": 32, "bitwidth.29.": 5, "groupsize.29.": 32, "bitwidth.30.": 5, "groupsize.30.": 32, "bitwidth.31.": 5, "groupsize.31.": 32}, + {"bitwidth.0.": 5, "groupsize.0.": 32, "bitwidth.1.": 5, "groupsize.1.": 32, "bitwidth.2.": 5, "groupsize.2.": 32, "bitwidth.3.": 4, "groupsize.3.": 32, "bitwidth.4.": 5, "groupsize.4.": 32, "bitwidth.5.": 4, "groupsize.5.": 32, "bitwidth.6.": 4, "groupsize.6.": 32, "bitwidth.7.": 5, "groupsize.7.": 32, "bitwidth.8.": 4, "groupsize.8.": 32, "bitwidth.9.": 5, "groupsize.9.": 32, "bitwidth.10.": 4, "groupsize.10.": 32, "bitwidth.11.": 4, "groupsize.11.": 32, "bitwidth.12.": 4, "groupsize.12.": 32, "bitwidth.13.": 4, "groupsize.13.": 32, "bitwidth.14.": 5, "groupsize.14.": 32, "bitwidth.15.": 4, "groupsize.15.": 32, "bitwidth.16.": 5, "groupsize.16.": 32, "bitwidth.17.": 5, "groupsize.17.": 32, "bitwidth.18.": 5, "groupsize.18.": 32, "bitwidth.19.": 5, "groupsize.19.": 32, "bitwidth.20.": 5, "groupsize.20.": 32, "bitwidth.21.": 5, "groupsize.21.": 32, "bitwidth.22.": 5, "groupsize.22.": 32, "bitwidth.23.": 4, "groupsize.23.": 32, "bitwidth.24.": 4, "groupsize.24.": 32, "bitwidth.25.": 4, "groupsize.25.": 32, "bitwidth.26.": 4, "groupsize.26.": 32, "bitwidth.27.": 4, "groupsize.27.": 32, "bitwidth.28.": 5, "groupsize.28.": 32, "bitwidth.29.": 5, "groupsize.29.": 32, "bitwidth.30.": 5, "groupsize.30.": 32, "bitwidth.31.": 5, "groupsize.31.": 32}, + {"bitwidth.0.": 5, "groupsize.0.": 32, "bitwidth.1.": 5, "groupsize.1.": 32, "bitwidth.2.": 5, "groupsize.2.": 32, "bitwidth.3.": 4, "groupsize.3.": 32, "bitwidth.4.": 4, "groupsize.4.": 32, "bitwidth.5.": 4, "groupsize.5.": 32, "bitwidth.6.": 4, "groupsize.6.": 32, "bitwidth.7.": 5, "groupsize.7.": 32, "bitwidth.8.": 4, "groupsize.8.": 32, "bitwidth.9.": 4, "groupsize.9.": 32, "bitwidth.10.": 4, "groupsize.10.": 32, "bitwidth.11.": 4, "groupsize.11.": 32, "bitwidth.12.": 4, "groupsize.12.": 32, "bitwidth.13.": 4, "groupsize.13.": 32, "bitwidth.14.": 4, "groupsize.14.": 32, "bitwidth.15.": 5, "groupsize.15.": 32, "bitwidth.16.": 4, "groupsize.16.": 32, "bitwidth.17.": 4, "groupsize.17.": 32, "bitwidth.18.": 4, "groupsize.18.": 32, "bitwidth.19.": 4, "groupsize.19.": 32, "bitwidth.20.": 5, "groupsize.20.": 32, "bitwidth.21.": 4, "groupsize.21.": 32, "bitwidth.22.": 4, "groupsize.22.": 32, "bitwidth.23.": 4, "groupsize.23.": 32, "bitwidth.24.": 4, "groupsize.24.": 32, "bitwidth.25.": 4, "groupsize.25.": 32, "bitwidth.26.": 4, "groupsize.26.": 32, "bitwidth.27.": 4, "groupsize.27.": 32, "bitwidth.28.": 4, "groupsize.28.": 32, "bitwidth.29.": 5, "groupsize.29.": 32, "bitwidth.30.": 5, "groupsize.30.": 32, "bitwidth.31.": 5, "groupsize.31.": 32}, + {"bitwidth.0.": 5, "groupsize.0.": 32, "bitwidth.1.": 5, "groupsize.1.": 32, "bitwidth.2.": 5, "groupsize.2.": 32, "bitwidth.3.": 4, "groupsize.3.": 32, "bitwidth.4.": 4, "groupsize.4.": 32, "bitwidth.5.": 4, "groupsize.5.": 32, "bitwidth.6.": 3, "groupsize.6.": 32, "bitwidth.7.": 5, "groupsize.7.": 32, "bitwidth.8.": 4, "groupsize.8.": 32, "bitwidth.9.": 4, "groupsize.9.": 32, "bitwidth.10.": 5, "groupsize.10.": 32, "bitwidth.11.": 4, "groupsize.11.": 32, "bitwidth.12.": 4, "groupsize.12.": 32, "bitwidth.13.": 5, "groupsize.13.": 32, "bitwidth.14.": 4, "groupsize.14.": 32, "bitwidth.15.": 4, "groupsize.15.": 32, "bitwidth.16.": 4, "groupsize.16.": 32, "bitwidth.17.": 5, "groupsize.17.": 32, "bitwidth.18.": 4, "groupsize.18.": 32, "bitwidth.19.": 5, "groupsize.19.": 32, "bitwidth.20.": 4, "groupsize.20.": 32, "bitwidth.21.": 5, "groupsize.21.": 32, "bitwidth.22.": 5, "groupsize.22.": 32, "bitwidth.23.": 6, "groupsize.23.": 32, "bitwidth.24.": 4, "groupsize.24.": 32, "bitwidth.25.": 4, "groupsize.25.": 32, "bitwidth.26.": 4, "groupsize.26.": 32, "bitwidth.27.": 4, "groupsize.27.": 32, "bitwidth.28.": 5, "groupsize.28.": 32, "bitwidth.29.": 5, "groupsize.29.": 32, "bitwidth.30.": 5, "groupsize.30.": 32, "bitwidth.31.": 5, "groupsize.31.": 32}, + {"bitwidth.0.": 5, "groupsize.0.": 32, "bitwidth.1.": 5, "groupsize.1.": 32, "bitwidth.2.": 5, "groupsize.2.": 32, "bitwidth.3.": 4, "groupsize.3.": 32, "bitwidth.4.": 4, "groupsize.4.": 32, "bitwidth.5.": 4, "groupsize.5.": 32, "bitwidth.6.": 4, "groupsize.6.": 32, "bitwidth.7.": 4, "groupsize.7.": 32, "bitwidth.8.": 4, "groupsize.8.": 32, "bitwidth.9.": 4, "groupsize.9.": 32, "bitwidth.10.": 4, "groupsize.10.": 32, "bitwidth.11.": 4, "groupsize.11.": 32, "bitwidth.12.": 4, "groupsize.12.": 32, "bitwidth.13.": 4, "groupsize.13.": 32, "bitwidth.14.": 4, "groupsize.14.": 32, "bitwidth.15.": 4, "groupsize.15.": 32, "bitwidth.16.": 4, "groupsize.16.": 32, "bitwidth.17.": 4, "groupsize.17.": 32, "bitwidth.18.": 5, "groupsize.18.": 32, "bitwidth.19.": 4, "groupsize.19.": 32, "bitwidth.20.": 4, "groupsize.20.": 32, "bitwidth.21.": 4, "groupsize.21.": 32, "bitwidth.22.": 5, "groupsize.22.": 32, "bitwidth.23.": 4, "groupsize.23.": 32, "bitwidth.24.": 4, "groupsize.24.": 32, "bitwidth.25.": 4, "groupsize.25.": 32, "bitwidth.26.": 4, "groupsize.26.": 32, "bitwidth.27.": 4, "groupsize.27.": 32, "bitwidth.28.": 4, "groupsize.28.": 32, "bitwidth.29.": 5, "groupsize.29.": 32, "bitwidth.30.": 5, "groupsize.30.": 32, "bitwidth.31.": 5, "groupsize.31.": 32}, + {"bitwidth.0.": 5, "groupsize.0.": 32, "bitwidth.1.": 5, "groupsize.1.": 32, "bitwidth.2.": 5, "groupsize.2.": 32, "bitwidth.3.": 4, "groupsize.3.": 32, "bitwidth.4.": 4, "groupsize.4.": 32, "bitwidth.5.": 4, "groupsize.5.": 32, "bitwidth.6.": 4, "groupsize.6.": 32, "bitwidth.7.": 4, "groupsize.7.": 32, "bitwidth.8.": 4, "groupsize.8.": 32, "bitwidth.9.": 4, "groupsize.9.": 32, "bitwidth.10.": 4, "groupsize.10.": 32, "bitwidth.11.": 4, "groupsize.11.": 32, "bitwidth.12.": 4, "groupsize.12.": 32, "bitwidth.13.": 4, "groupsize.13.": 32, "bitwidth.14.": 4, "groupsize.14.": 32, "bitwidth.15.": 4, "groupsize.15.": 32, "bitwidth.16.": 4, "groupsize.16.": 32, "bitwidth.17.": 5, "groupsize.17.": 32, "bitwidth.18.": 4, "groupsize.18.": 32, "bitwidth.19.": 4, "groupsize.19.": 32, "bitwidth.20.": 4, "groupsize.20.": 32, "bitwidth.21.": 4, "groupsize.21.": 32, "bitwidth.22.": 4, "groupsize.22.": 32, "bitwidth.23.": 4, "groupsize.23.": 32, "bitwidth.24.": 4, "groupsize.24.": 32, "bitwidth.25.": 4, "groupsize.25.": 32, "bitwidth.26.": 4, "groupsize.26.": 32, "bitwidth.27.": 4, "groupsize.27.": 32, "bitwidth.28.": 4, "groupsize.28.": 32, "bitwidth.29.": 5, "groupsize.29.": 32, "bitwidth.30.": 5, "groupsize.30.": 32, "bitwidth.31.": 5, "groupsize.31.": 32} + ] +} diff --git a/torchao/quantization/prototype/mixed_precision/scripts/Llama3-8B_parameters.json b/torchao/quantization/prototype/mixed_precision/scripts/Llama3-8B_parameters.json new file mode 100644 index 000000000..770eedf07 --- /dev/null +++ b/torchao/quantization/prototype/mixed_precision/scripts/Llama3-8B_parameters.json @@ -0,0 +1,22 @@ +{ + "parameters": [ + { + "name": "bitwidth", + "name_format": "bitwidth.{i}.", + "layers": [ + {"range": [0, 3], "type": "fixed", "value": 5}, + {"range": [3, 30], "type": "choice", "values": [2, 3, 4, 5, 6, 8]}, + {"range": [30, 32], "type": "fixed", "value": 5} + ] + }, + { + "name": "groupsize", + "name_format": "groupsize.{i}.", + "layers": [ + {"range": [0, 3], "type": "fixed", "value": 32}, + {"range": [3, 30], "type": "choice", "values": [32, 64, 128, 256]}, + {"range": [30, 32], "type": "fixed", "value": 32} + ] + } + ] +} diff --git a/torchao/quantization/prototype/mixed_precision/scripts/Mistral-7B_initial_samples.json b/torchao/quantization/prototype/mixed_precision/scripts/Mistral-7B_initial_samples.json new file mode 100644 index 000000000..a07ee2bca --- /dev/null +++ b/torchao/quantization/prototype/mixed_precision/scripts/Mistral-7B_initial_samples.json @@ -0,0 +1,14 @@ +{ + "initial_samples": [ + {"bitwidth.0.": 5, "groupsize.0.": 32, "bitwidth.1.": 5, "groupsize.1.": 32, "bitwidth.2.": 5, "groupsize.2.": 32, "bitwidth.3.": 5, "groupsize.3.": 32, "bitwidth.4.": 5, "groupsize.4.": 32, "bitwidth.5.": 5, "groupsize.5.": 32, "bitwidth.6.": 4, "groupsize.6.": 64, "bitwidth.7.": 5, "groupsize.7.": 32, "bitwidth.8.": 4, "groupsize.8.": 64, "bitwidth.9.": 5, "groupsize.9.": 64, "bitwidth.10.": 5, "groupsize.10.": 32, "bitwidth.11.": 5, "groupsize.11.": 64, "bitwidth.12.": 4, "groupsize.12.": 64, "bitwidth.13.": 5, "groupsize.13.": 64, "bitwidth.14.": 5, "groupsize.14.": 32, "bitwidth.15.": 4, "groupsize.15.": 64, "bitwidth.16.": 4, "groupsize.16.": 32, "bitwidth.17.": 5, "groupsize.17.": 32, "bitwidth.18.": 4, "groupsize.18.": 64, "bitwidth.19.": 4, "groupsize.19.": 32, "bitwidth.20.": 5, "groupsize.20.": 128, "bitwidth.21.": 5, "groupsize.21.": 64, "bitwidth.22.": 4, "groupsize.22.": 32, "bitwidth.23.": 5, "groupsize.23.": 128, "bitwidth.24.": 5, "groupsize.24.": 64, "bitwidth.25.": 5, "groupsize.25.": 32, "bitwidth.26.": 5, "groupsize.26.": 128, "bitwidth.27.": 5, "groupsize.27.": 32, "bitwidth.28.": 5, "groupsize.28.": 32, "bitwidth.29.": 5, "groupsize.29.": 32, "bitwidth.30.": 5, "groupsize.30.": 32, "bitwidth.31.": 5, "groupsize.31.": 32}, + {"bitwidth.0.": 5, "groupsize.0.": 32, "bitwidth.1.": 5, "groupsize.1.": 32, "bitwidth.2.": 5, "groupsize.2.": 32, "bitwidth.3.": 5, "groupsize.3.": 32, "bitwidth.4.": 5, "groupsize.4.": 64, "bitwidth.5.": 5, "groupsize.5.": 32, "bitwidth.6.": 5, "groupsize.6.": 64, "bitwidth.7.": 5, "groupsize.7.": 64, "bitwidth.8.": 4, "groupsize.8.": 64, "bitwidth.9.": 5, "groupsize.9.": 32, "bitwidth.10.": 5, "groupsize.10.": 64, "bitwidth.11.": 5, "groupsize.11.": 32, "bitwidth.12.": 5, "groupsize.12.": 32, "bitwidth.13.": 5, "groupsize.13.": 64, "bitwidth.14.": 5, "groupsize.14.": 64, "bitwidth.15.": 5, "groupsize.15.": 64, "bitwidth.16.": 5, "groupsize.16.": 64, "bitwidth.17.": 4, "groupsize.17.": 32, "bitwidth.18.": 5, "groupsize.18.": 128, "bitwidth.19.": 5, "groupsize.19.": 128, "bitwidth.20.": 4, "groupsize.20.": 64, "bitwidth.21.": 4, "groupsize.21.": 64, "bitwidth.22.": 5, "groupsize.22.": 128, "bitwidth.23.": 3, "groupsize.23.": 64, "bitwidth.24.": 3, "groupsize.24.": 32, "bitwidth.25.": 4, "groupsize.25.": 32, "bitwidth.26.": 5, "groupsize.26.": 128, "bitwidth.27.": 5, "groupsize.27.": 32, "bitwidth.28.": 3, "groupsize.28.": 64, "bitwidth.29.": 5, "groupsize.29.": 32, "bitwidth.30.": 5, "groupsize.30.": 32, "bitwidth.31.": 5, "groupsize.31.": 32}, + {"bitwidth.0.": 5, "groupsize.0.": 32, "bitwidth.1.": 5, "groupsize.1.": 32, "bitwidth.2.": 5, "groupsize.2.": 32, "bitwidth.3.": 5, "groupsize.3.": 32, "bitwidth.4.": 5, "groupsize.4.": 64, "bitwidth.5.": 5, "groupsize.5.": 64, "bitwidth.6.": 5, "groupsize.6.": 32, "bitwidth.7.": 5, "groupsize.7.": 32, "bitwidth.8.": 4, "groupsize.8.": 64, "bitwidth.9.": 5, "groupsize.9.": 32, "bitwidth.10.": 4, "groupsize.10.": 32, "bitwidth.11.": 5, "groupsize.11.": 32, "bitwidth.12.": 4, "groupsize.12.": 64, "bitwidth.13.": 5, "groupsize.13.": 32, "bitwidth.14.": 5, "groupsize.14.": 64, "bitwidth.15.": 5, "groupsize.15.": 64, "bitwidth.16.": 5, "groupsize.16.": 64, "bitwidth.17.": 4, "groupsize.17.": 32, "bitwidth.18.": 5, "groupsize.18.": 32, "bitwidth.19.": 5, "groupsize.19.": 64, "bitwidth.20.": 5, "groupsize.20.": 128, "bitwidth.21.": 5, "groupsize.21.": 64, "bitwidth.22.": 5, "groupsize.22.": 128, "bitwidth.23.": 5, "groupsize.23.": 32, "bitwidth.24.": 3, "groupsize.24.": 32, "bitwidth.25.": 4, "groupsize.25.": 32, "bitwidth.26.": 5, "groupsize.26.": 32, "bitwidth.27.": 5, "groupsize.27.": 32, "bitwidth.28.": 5, "groupsize.28.": 128, "bitwidth.29.": 5, "groupsize.29.": 32, "bitwidth.30.": 5, "groupsize.30.": 32, "bitwidth.31.": 5, "groupsize.31.": 32}, + {"bitwidth.0.": 5, "groupsize.0.": 32, "bitwidth.1.": 5, "groupsize.1.": 32, "bitwidth.2.": 5, "groupsize.2.": 32, "bitwidth.3.": 5, "groupsize.3.": 32, "bitwidth.4.": 5, "groupsize.4.": 32, "bitwidth.5.": 5, "groupsize.5.": 32, "bitwidth.6.": 4, "groupsize.6.": 32, "bitwidth.7.": 5, "groupsize.7.": 32, "bitwidth.8.": 4, "groupsize.8.": 32, "bitwidth.9.": 5, "groupsize.9.": 32, "bitwidth.10.": 5, "groupsize.10.": 64, "bitwidth.11.": 5, "groupsize.11.": 64, "bitwidth.12.": 4, "groupsize.12.": 64, "bitwidth.13.": 4, "groupsize.13.": 32, "bitwidth.14.": 5, "groupsize.14.": 32, "bitwidth.15.": 4, "groupsize.15.": 32, "bitwidth.16.": 5, "groupsize.16.": 32, "bitwidth.17.": 4, "groupsize.17.": 64, "bitwidth.18.": 5, "groupsize.18.": 128, "bitwidth.19.": 4, "groupsize.19.": 32, "bitwidth.20.": 5, "groupsize.20.": 32, "bitwidth.21.": 5, "groupsize.21.": 64, "bitwidth.22.": 5, "groupsize.22.": 32, "bitwidth.23.": 5, "groupsize.23.": 64, "bitwidth.24.": 2, "groupsize.24.": 32, "bitwidth.25.": 5, "groupsize.25.": 128, "bitwidth.26.": 4, "groupsize.26.": 32, "bitwidth.27.": 4, "groupsize.27.": 64, "bitwidth.28.": 4, "groupsize.28.": 64, "bitwidth.29.": 5, "groupsize.29.": 32, "bitwidth.30.": 5, "groupsize.30.": 32, "bitwidth.31.": 5, "groupsize.31.": 32}, + {"bitwidth.0.": 5, "groupsize.0.": 32, "bitwidth.1.": 5, "groupsize.1.": 32, "bitwidth.2.": 5, "groupsize.2.": 32, "bitwidth.3.": 5, "groupsize.3.": 32, "bitwidth.4.": 5, "groupsize.4.": 32, "bitwidth.5.": 4, "groupsize.5.": 64, "bitwidth.6.": 4, "groupsize.6.": 32, "bitwidth.7.": 5, "groupsize.7.": 64, "bitwidth.8.": 5, "groupsize.8.": 64, "bitwidth.9.": 5, "groupsize.9.": 32, "bitwidth.10.": 5, "groupsize.10.": 64, "bitwidth.11.": 5, "groupsize.11.": 32, "bitwidth.12.": 5, "groupsize.12.": 64, "bitwidth.13.": 5, "groupsize.13.": 64, "bitwidth.14.": 5, "groupsize.14.": 64, "bitwidth.15.": 5, "groupsize.15.": 32, "bitwidth.16.": 4, "groupsize.16.": 32, "bitwidth.17.": 5, "groupsize.17.": 32, "bitwidth.18.": 5, "groupsize.18.": 128, "bitwidth.19.": 4, "groupsize.19.": 64, "bitwidth.20.": 4, "groupsize.20.": 32, "bitwidth.21.": 5, "groupsize.21.": 64, "bitwidth.22.": 5, "groupsize.22.": 64, "bitwidth.23.": 5, "groupsize.23.": 32, "bitwidth.24.": 3, "groupsize.24.": 32, "bitwidth.25.": 4, "groupsize.25.": 64, "bitwidth.26.": 3, "groupsize.26.": 64, "bitwidth.27.": 4, "groupsize.27.": 64, "bitwidth.28.": 5, "groupsize.28.": 32, "bitwidth.29.": 5, "groupsize.29.": 32, "bitwidth.30.": 5, "groupsize.30.": 32, "bitwidth.31.": 5, "groupsize.31.": 32}, + {"bitwidth.0.": 5, "groupsize.0.": 32, "bitwidth.1.": 5, "groupsize.1.": 32, "bitwidth.2.": 5, "groupsize.2.": 32, "bitwidth.3.": 5, "groupsize.3.": 32, "bitwidth.4.": 4, "groupsize.4.": 32, "bitwidth.5.": 4, "groupsize.5.": 32, "bitwidth.6.": 5, "groupsize.6.": 64, "bitwidth.7.": 4, "groupsize.7.": 64, "bitwidth.8.": 5, "groupsize.8.": 64, "bitwidth.9.": 5, "groupsize.9.": 32, "bitwidth.10.": 5, "groupsize.10.": 32, "bitwidth.11.": 4, "groupsize.11.": 32, "bitwidth.12.": 5, "groupsize.12.": 64, "bitwidth.13.": 5, "groupsize.13.": 64, "bitwidth.14.": 5, "groupsize.14.": 32, "bitwidth.15.": 5, "groupsize.15.": 64, "bitwidth.16.": 4, "groupsize.16.": 32, "bitwidth.17.": 5, "groupsize.17.": 64, "bitwidth.18.": 4, "groupsize.18.": 64, "bitwidth.19.": 5, "groupsize.19.": 128, "bitwidth.20.": 4, "groupsize.20.": 32, "bitwidth.21.": 4, "groupsize.21.": 32, "bitwidth.22.": 5, "groupsize.22.": 32, "bitwidth.23.": 5, "groupsize.23.": 32, "bitwidth.24.": 4, "groupsize.24.": 32, "bitwidth.25.": 3, "groupsize.25.": 64, "bitwidth.26.": 5, "groupsize.26.": 32, "bitwidth.27.": 5, "groupsize.27.": 32, "bitwidth.28.": 3, "groupsize.28.": 32, "bitwidth.29.": 5, "groupsize.29.": 32, "bitwidth.30.": 5, "groupsize.30.": 32, "bitwidth.31.": 5, "groupsize.31.": 32}, + {"bitwidth.0.": 5, "groupsize.0.": 32, "bitwidth.1.": 5, "groupsize.1.": 32, "bitwidth.2.": 5, "groupsize.2.": 32, "bitwidth.3.": 5, "groupsize.3.": 32, "bitwidth.4.": 5, "groupsize.4.": 64, "bitwidth.5.": 5, "groupsize.5.": 64, "bitwidth.6.": 4, "groupsize.6.": 64, "bitwidth.7.": 4, "groupsize.7.": 32, "bitwidth.8.": 5, "groupsize.8.": 64, "bitwidth.9.": 5, "groupsize.9.": 64, "bitwidth.10.": 5, "groupsize.10.": 32, "bitwidth.11.": 4, "groupsize.11.": 64, "bitwidth.12.": 5, "groupsize.12.": 64, "bitwidth.13.": 4, "groupsize.13.": 32, "bitwidth.14.": 5, "groupsize.14.": 32, "bitwidth.15.": 5, "groupsize.15.": 64, "bitwidth.16.": 5, "groupsize.16.": 64, "bitwidth.17.": 5, "groupsize.17.": 32, "bitwidth.18.": 5, "groupsize.18.": 128, "bitwidth.19.": 5, "groupsize.19.": 64, "bitwidth.20.": 5, "groupsize.20.": 128, "bitwidth.21.": 5, "groupsize.21.": 32, "bitwidth.22.": 3, "groupsize.22.": 64, "bitwidth.23.": 5, "groupsize.23.": 32, "bitwidth.24.": 4, "groupsize.24.": 32, "bitwidth.25.": 5, "groupsize.25.": 32, "bitwidth.26.": 4, "groupsize.26.": 64, "bitwidth.27.": 5, "groupsize.27.": 64, "bitwidth.28.": 5, "groupsize.28.": 128, "bitwidth.29.": 5, "groupsize.29.": 32, "bitwidth.30.": 5, "groupsize.30.": 32, "bitwidth.31.": 5, "groupsize.31.": 32}, + {"bitwidth.0.": 5, "groupsize.0.": 32, "bitwidth.1.": 5, "groupsize.1.": 32, "bitwidth.2.": 5, "groupsize.2.": 32, "bitwidth.3.": 5, "groupsize.3.": 32, "bitwidth.4.": 5, "groupsize.4.": 64, "bitwidth.5.": 5, "groupsize.5.": 32, "bitwidth.6.": 5, "groupsize.6.": 32, "bitwidth.7.": 5, "groupsize.7.": 64, "bitwidth.8.": 5, "groupsize.8.": 32, "bitwidth.9.": 5, "groupsize.9.": 32, "bitwidth.10.": 5, "groupsize.10.": 64, "bitwidth.11.": 5, "groupsize.11.": 64, "bitwidth.12.": 5, "groupsize.12.": 64, "bitwidth.13.": 5, "groupsize.13.": 64, "bitwidth.14.": 5, "groupsize.14.": 64, "bitwidth.15.": 5, "groupsize.15.": 32, "bitwidth.16.": 5, "groupsize.16.": 64, "bitwidth.17.": 5, "groupsize.17.": 64, "bitwidth.18.": 5, "groupsize.18.": 128, "bitwidth.19.": 5, "groupsize.19.": 128, "bitwidth.20.": 5, "groupsize.20.": 128, "bitwidth.21.": 4, "groupsize.21.": 64, "bitwidth.22.": 5, "groupsize.22.": 128, "bitwidth.23.": 3, "groupsize.23.": 32, "bitwidth.24.": 5, "groupsize.24.": 64, "bitwidth.25.": 4, "groupsize.25.": 32, "bitwidth.26.": 5, "groupsize.26.": 32, "bitwidth.27.": 5, "groupsize.27.": 32, "bitwidth.28.": 4, "groupsize.28.": 32, "bitwidth.29.": 5, "groupsize.29.": 32, "bitwidth.30.": 5, "groupsize.30.": 32, "bitwidth.31.": 5, "groupsize.31.": 32}, + {"bitwidth.0.": 5, "groupsize.0.": 32, "bitwidth.1.": 5, "groupsize.1.": 32, "bitwidth.2.": 5, "groupsize.2.": 32, "bitwidth.3.": 5, "groupsize.3.": 32, "bitwidth.4.": 4, "groupsize.4.": 64, "bitwidth.5.": 5, "groupsize.5.": 64, "bitwidth.6.": 5, "groupsize.6.": 64, "bitwidth.7.": 5, "groupsize.7.": 32, "bitwidth.8.": 4, "groupsize.8.": 32, "bitwidth.9.": 5, "groupsize.9.": 64, "bitwidth.10.": 5, "groupsize.10.": 64, "bitwidth.11.": 5, "groupsize.11.": 32, "bitwidth.12.": 5, "groupsize.12.": 32, "bitwidth.13.": 5, "groupsize.13.": 32, "bitwidth.14.": 5, "groupsize.14.": 64, "bitwidth.15.": 5, "groupsize.15.": 64, "bitwidth.16.": 5, "groupsize.16.": 64, "bitwidth.17.": 5, "groupsize.17.": 64, "bitwidth.18.": 3, "groupsize.18.": 64, "bitwidth.19.": 4, "groupsize.19.": 64, "bitwidth.20.": 5, "groupsize.20.": 32, "bitwidth.21.": 5, "groupsize.21.": 128, "bitwidth.22.": 5, "groupsize.22.": 64, "bitwidth.23.": 3, "groupsize.23.": 32, "bitwidth.24.": 4, "groupsize.24.": 32, "bitwidth.25.": 4, "groupsize.25.": 64, "bitwidth.26.": 3, "groupsize.26.": 32, "bitwidth.27.": 5, "groupsize.27.": 32, "bitwidth.28.": 5, "groupsize.28.": 64, "bitwidth.29.": 5, "groupsize.29.": 32, "bitwidth.30.": 5, "groupsize.30.": 32, "bitwidth.31.": 5, "groupsize.31.": 32}, + {"bitwidth.0.": 5, "groupsize.0.": 32, "bitwidth.1.": 5, "groupsize.1.": 32, "bitwidth.2.": 5, "groupsize.2.": 32, "bitwidth.3.": 5, "groupsize.3.": 32, "bitwidth.4.": 5, "groupsize.4.": 64, "bitwidth.5.": 4, "groupsize.5.": 32, "bitwidth.6.": 5, "groupsize.6.": 32, "bitwidth.7.": 4, "groupsize.7.": 32, "bitwidth.8.": 5, "groupsize.8.": 64, "bitwidth.9.": 5, "groupsize.9.": 64, "bitwidth.10.": 5, "groupsize.10.": 32, "bitwidth.11.": 5, "groupsize.11.": 32, "bitwidth.12.": 5, "groupsize.12.": 32, "bitwidth.13.": 5, "groupsize.13.": 64, "bitwidth.14.": 5, "groupsize.14.": 32, "bitwidth.15.": 5, "groupsize.15.": 32, "bitwidth.16.": 5, "groupsize.16.": 32, "bitwidth.17.": 4, "groupsize.17.": 64, "bitwidth.18.": 4, "groupsize.18.": 32, "bitwidth.19.": 4, "groupsize.19.": 64, "bitwidth.20.": 5, "groupsize.20.": 64, "bitwidth.21.": 4, "groupsize.21.": 32, "bitwidth.22.": 4, "groupsize.22.": 32, "bitwidth.23.": 3, "groupsize.23.": 32, "bitwidth.24.": 4, "groupsize.24.": 64, "bitwidth.25.": 5, "groupsize.25.": 128, "bitwidth.26.": 3, "groupsize.26.": 64, "bitwidth.27.": 5, "groupsize.27.": 64, "bitwidth.28.": 5, "groupsize.28.": 64, "bitwidth.29.": 5, "groupsize.29.": 32, "bitwidth.30.": 5, "groupsize.30.": 32, "bitwidth.31.": 5, "groupsize.31.": 32} + ] +} diff --git a/torchao/quantization/prototype/mixed_precision/scripts/Mistral-7B_parameters.json b/torchao/quantization/prototype/mixed_precision/scripts/Mistral-7B_parameters.json new file mode 100644 index 000000000..f3625db53 --- /dev/null +++ b/torchao/quantization/prototype/mixed_precision/scripts/Mistral-7B_parameters.json @@ -0,0 +1,20 @@ +{ + "parameters": [ + { + "name": "bitwidth", + "name_format": "bitwidth.{i}.", + "layers": [ + {"range": [0, 4], "type": "fixed", "value": 5}, + {"range": [4, 32], "type": "choice", "values": [2, 3, 4, 5, 6, 8]} + ] + }, + { + "name": "groupsize", + "name_format": "groupsize.{i}.", + "layers": [ + {"range": [0, 4], "type": "fixed", "value": 32}, + {"range": [4, 32], "type": "choice", "values": [32, 64, 128, 256]} + ] + } + ] +} diff --git a/torchao/quantization/prototype/mixed_precision/scripts/fit.py b/torchao/quantization/prototype/mixed_precision/scripts/fit.py index 78ec878d3..c19087e62 100644 --- a/torchao/quantization/prototype/mixed_precision/scripts/fit.py +++ b/torchao/quantization/prototype/mixed_precision/scripts/fit.py @@ -25,7 +25,7 @@ def get_wikitext2(nsamples, seed, seqlen, tokenizer): trainloader.append((inp, tar)) return trainloader, testenc -def cal_FIT(device, data, nsamples, model, maxIter, max_seqlen, criterion, num_layers): +def cal_FIT(device, data, nsamples, model, max_iter, max_seqlen, criterion, num_layers): # store the history of trace for each layer estimated_history=[] @@ -35,7 +35,7 @@ def cal_FIT(device, data, nsamples, model, maxIter, max_seqlen, criterion, num_l trace = [0.] * num_layers - for iteration in range(maxIter): + for iteration in range(max_iter): print("iteration: ",iteration) trace_tmp = [0.] * num_layers @@ -72,7 +72,7 @@ def cal_FIT(device, data, nsamples, model, maxIter, max_seqlen, criterion, num_l F_average = np.array([np.mean(i) for i in estimated_mean]) return F_average, estimated_mean, estimated_history -def main(max_seqlen, checkpoint, nsamples, maxIter, num_layers): +def main(max_seqlen, checkpoint, nsamples, max_iter, num_layers): device = 'cuda' if torch.cuda.is_available() else 'cpu' # have been tested models Llama-3-8B, Llama-2-7B, Mistral-7B, and stories110M @@ -87,7 +87,7 @@ def main(max_seqlen, checkpoint, nsamples, maxIter, num_layers): seed = 0 trainloader, testloader = get_wikitext2(nsamples, seed, max_seqlen, tokenizer) - F_average, estimated_mean, estimated_history = cal_FIT(device=device, data=trainloader, nsamples=nsamples, model=model, maxIter=maxIter, max_seqlen=max_seqlen, criterion=criterion, num_layers=num_layers) + F_average, estimated_mean, estimated_history = cal_FIT(device=device, data=trainloader, nsamples=nsamples, model=model, max_iter=max_iter, max_seqlen=max_seqlen, criterion=criterion, num_layers=num_layers) print("Iteration Done") print("avg_trace:", F_average) print("estimated_mean:", estimated_mean) @@ -95,11 +95,11 @@ def main(max_seqlen, checkpoint, nsamples, maxIter, num_layers): if __name__ == '__main__': import argparse - parser = argparse.ArgumentParser(description='Calculate layer-wised fish information matric trace.') - parser.add_argument('--checkpoint', type=str, default="/home/hanxianhuang/ao/torchao/quantization/prototype/mixed_precision/checkpoints/meta-llama/Meta-Llama-3-8B", help='Path to load model') + parser = argparse.ArgumentParser(description='Calculate layer-wised fish information matrix trace.') + parser.add_argument('--checkpoint', type=str, default="/tmp/Meta-Llama-3-8B", help='Path to load model') parser.add_argument('--max_seqlen', type=int, default=2048, help='Max sequence length') - parser.add_argument('--maxIter', type=int, default=100, help='The number of iterations to calculate FIT') + parser.add_argument('--max_iter', type=int, default=100, help='The number of iterations to calculate FIT') parser.add_argument('--num_layers', type=int, default=32, help='The number of layers to calculate FIT.') parser.add_argument('--nsamples', type=int, default=128, help='The number of samples in calibration dataset') args = parser.parse_args() - main(args.max_seqlen, args.checkpoint, args.nsamples, args.maxIter, args.num_layers) + main(args.max_seqlen, args.checkpoint, args.nsamples, args.max_iter, args.num_layers) diff --git a/torchao/quantization/prototype/mixed_precision/scripts/hessian_grad.py b/torchao/quantization/prototype/mixed_precision/scripts/hessian_grad.py index 8aea88892..1a4998d78 100644 --- a/torchao/quantization/prototype/mixed_precision/scripts/hessian_grad.py +++ b/torchao/quantization/prototype/mixed_precision/scripts/hessian_grad.py @@ -74,12 +74,12 @@ def dataloader_hv_product(layerid, params, device, v, data, nsamples, model, max THv = [THv1 / float(nsamples) for THv1 in THv] return THv -def cal_trace(layerid, params, device, data, nsamples, model, maxIter, max_seqlen, criterion): +def cal_trace(layerid, params, device, data, nsamples, model, max_iter, max_seqlen, criterion): vhv_c_history = [] trace_history = [] trace = 0. - for i in range(maxIter): + for i in range(max_iter): print("iteration: ",i) # generate Rademacher random variables @@ -110,7 +110,7 @@ def cal_trace(layerid, params, device, data, nsamples, model, maxIter, max_seqle return np.mean(trace_history) -def main(layer_id, checkpoint, max_seqlen, maxIter, nsamples): +def main(layer_id, checkpoint, max_seqlen, max_iter, nsamples): device = 'cuda' if torch.cuda.is_available() else 'cpu' # to avoid aten::_scaled_dot_product_flash_attention_backward not implemented error @@ -136,16 +136,16 @@ def main(layer_id, checkpoint, max_seqlen, maxIter, nsamples): for param in layer_.mlp.parameters(): params.append(param) - trace = cal_trace(layerid=layer_id, params=params, device=device, data=trainloader, nsamples=nsamples, model=model, maxIter=maxIter, max_seqlen=max_seqlen, criterion=criterion) + trace = cal_trace(layerid=layer_id, params=params, device=device, data=trainloader, nsamples=nsamples, model=model, max_iter=max_iter, max_seqlen=max_seqlen, criterion=criterion) print("The trace of layer " + str(layer_id) + " is", trace) if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='Calculate layer-wised Hessian trace leveraging autograd.') parser.add_argument('--layer_id', type=int, default=0, help='Which layer to compute the trace and hessian') - parser.add_argument('--checkpoint', type=str, default="/home/hanxianhuang/ao/torchao/quantization/prototype/mixed_precision/checkpoints/meta-llama/Meta-Llama-3-8B", help='Path to load model') + parser.add_argument('--checkpoint', type=str, default="/tmp/Meta-Llama-3-8B", help='Path to load model') parser.add_argument('--max_seqlen', type=int, default=2048, help='Max sequence length') - parser.add_argument('--maxIter', type=int, default=100, help='The number of iterations to calculate Hessian trace') + parser.add_argument('--max_iter', type=int, default=100, help='The number of iterations to calculate Hessian trace') parser.add_argument('--nsamples', type=int, default=128, help='The number of samples in calibration dataset') args = parser.parse_args() - main(args.layer_id, args.checkpoint, args.max_seqlen, args.maxIter, args.nsamples) + main(args.layer_id, args.checkpoint, args.max_seqlen, args.max_iter, args.nsamples) diff --git a/torchao/quantization/prototype/mixed_precision/scripts/hessian_vhp.py b/torchao/quantization/prototype/mixed_precision/scripts/hessian_vhp.py index 3470031cb..58b56f3b7 100644 --- a/torchao/quantization/prototype/mixed_precision/scripts/hessian_vhp.py +++ b/torchao/quantization/prototype/mixed_precision/scripts/hessian_vhp.py @@ -61,7 +61,7 @@ def make_functional(mod, layer_id): -def main(layer_id, checkpoint, max_seqlen, maxIter, nsamples): +def main(layer_id, checkpoint, max_seqlen, max_iter, nsamples): # use the functional model to load the weights back def load_weights(mod, names, params, selected_params, selected_params_names): @@ -113,7 +113,7 @@ def f(*new_params): trace_history = [] vhv_c_history=[] - for iteration in range(maxIter): + for iteration in range(max_iter): print("iteration: ",iteration) @@ -155,10 +155,11 @@ def f(*new_params): if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description="Calculate layer-wised Hessian trace leveraging torch's vhp function.") + # TODO: make it a for loop for all the layer_ids to automatically calculate the Hessian trace for all the layers of a model parser.add_argument('--layer_id', type=int, default=0, help='Which layer to compute the Hessian trace') - parser.add_argument('--checkpoint', type=str, default="/home/hanxianhuang/ao/torchao/quantization/prototype/mixed_precision/checkpoints/meta-llama/Meta-Llama-3-8B", help='Path to load model') + parser.add_argument('--checkpoint', type=str, default="/tmp/Meta-Llama-3-8B", help='Path to load model') parser.add_argument('--max_seqlen', type=int, default=2048, help='Max sequence length') - parser.add_argument('--maxIter', type=int, default=100, help='The number of iterations to calculate Hessian trace') + parser.add_argument('--max_iter', type=int, default=100, help='The number of iterations to calculate Hessian trace') parser.add_argument('--nsamples', type=int, default=128, help='The number of samples in calibration dataset') args = parser.parse_args() - main(args.layer_id, args.checkpoint, args.max_seqlen, args.maxIter, args.nsamples) + main(args.layer_id, args.checkpoint, args.max_seqlen, args.max_iter, args.nsamples) diff --git a/torchao/quantization/prototype/mixed_precision/scripts/utils.py b/torchao/quantization/prototype/mixed_precision/scripts/utils.py index 4d075b469..108c2eaff 100644 --- a/torchao/quantization/prototype/mixed_precision/scripts/utils.py +++ b/torchao/quantization/prototype/mixed_precision/scripts/utils.py @@ -14,6 +14,7 @@ from lm_eval.tasks import get_task_dict from transformers import AutoModelForCausalLM, AutoTokenizer +import json def write_history_to_csv(history, output_file, keyword): #keyword example: ['cal_PPL', 'cal_throughput', 'config'] @@ -106,3 +107,54 @@ def load_model(repo_id, device): device=device ) return model, tokenizer + + + +def load_parameters_from_json(json_path): + with open(json_path, "r") as f: + config = json.load(f) + + bitwidth_config = next(param for param in config["parameters"] if param["name"] == "bitwidth") + groupsize_config = next(param for param in config["parameters"] if param["name"] == "groupsize") + + parameters_list = [] + + # Ensure that we are interleaving bitwidth and groupsize for each layer + for bw_layer, gs_layer in zip(bitwidth_config["layers"], groupsize_config["layers"]): + start, end = bw_layer["range"] + for i in range(start, end): + # Add bitwidth parameter + bitwidth_param = { + "name": bitwidth_config["name_format"].format(i=i), + "type": bw_layer["type"], + "value_type": "int", + "is_ordered": True, + "sort_values": True, + } + if bw_layer["type"] == "fixed": + bitwidth_param["value"] = bw_layer["value"] + elif bw_layer["type"] == "choice": + bitwidth_param["values"] = bw_layer["values"] + parameters_list.append(bitwidth_param) + + # Add groupsize parameter + groupsize_param = { + "name": groupsize_config["name_format"].format(i=i), + "type": gs_layer["type"], + "value_type": "int", + "is_ordered": True, + "sort_values": True, + } + if gs_layer["type"] == "fixed": + groupsize_param["value"] = gs_layer["value"] + elif gs_layer["type"] == "choice": + groupsize_param["values"] = gs_layer["values"] + parameters_list.append(groupsize_param) + + return parameters_list + + +def load_initial_samples(json_path): + with open(json_path, "r") as f: + config = json.load(f) + return config["initial_samples"] From 5d1d4a226d5a0fa8e2ea4a8e0e31fc7a839f61be Mon Sep 17 00:00:00 2001 From: Hanxian97 Date: Wed, 28 Aug 2024 12:46:54 -0700 Subject: [PATCH 2/3] edit readme --- .../prototype/mixed_precision/README.md | 54 ++++++++++++------- .../prototype/mixed_precision/scripts/fit.py | 2 +- .../mixed_precision/scripts/hessian_vhp.py | 2 +- 3 files changed, 37 insertions(+), 21 deletions(-) diff --git a/torchao/quantization/prototype/mixed_precision/README.md b/torchao/quantization/prototype/mixed_precision/README.md index bb8880632..114249301 100644 --- a/torchao/quantization/prototype/mixed_precision/README.md +++ b/torchao/quantization/prototype/mixed_precision/README.md @@ -14,29 +14,44 @@ pip install ax-platform We provide a sensitivity tool to calculate the [average Hessian matrix trace](https://arxiv.org/pdf/1911.03852) and the [fisher information matrix trace (FIT)](https://arxiv.org/pdf/2210.08502). With the sensitivity scores, we are able to identify sensitivity-guided initial configurations to better initialize the BO search. This step is optinoal to use BO tool. #### Average Hessian trace: -Hessian is the second order partial derivation of the loss function and a higher average Hessian trace demonstrates a higher sensitivity of a layer to perturbations. Now the tool supports calculating one layer at a time to avoid out of memory issue for large models, e.g., Llama3-8B. It leverages the fast vhp (vector-hessian product) function from torch to achieve more efficient. To calculate average Hessian matrix trace of a layer on a calibration dataset (wikitext): +Hessian is the second order partial derivation of the loss function and a higher average Hessian trace indicates a higher sensitivity of a layer to perturbations. Now the tool supports calculating one layer at a time to avoid out of memory issue for large models, e.g., Llama3-8B. It leverages the fast vhp (vector-hessian product) function from torch to achieve higher efficiency. To calculate average Hessian matrix trace of a layer on a calibration dataset (wikitext-v2-document): ``` python scripts/hessian_vhp.py --layer_id=LAYER_ID --checkpoint=/tmp/Meta-Llama-3-8B --max_seqlen=256 --max_iter=100 --nsamples=512 ``` -where, ---layer_id identifies which layer to calculate the average Hessian trace, LAYER_ID is an integer number used to identify the layer in the module name - -The tool will print out the average Hessian trace based on the calibration dataset for the certain layer. Calculating Hessian trace is both memory-intensive and computationally expensive, the current tool takes 4 days with 4 GPUs on a calibration dataset of 512 samples for Llama3-8B. +where --layer_id specifies which layer to calculate the average Hessian trace, LAYER_ID is an integer number used to identify the layer in the module name. The tool will print out the average Hessian trace using the calibration dataset for the certain layer. An output example: +``` +Iterations Done +Avg Hessian trace for layer 0 is: 20135.83 +``` +Calculating Hessian trace is both memory-intensive and computationally expensive, the current tool takes 4 days with 4 A100 GPUs with 80GB GPU memory on a calibration dataset of 512 samples for Llama3-8B. #### FIT: FIT quantifies the total amount of information in the data about the parameter. It has been theoretically and empirically proved to be very close to Hession but with higher efficiency ([FIT paper])(https://arxiv.org/pdf/2210.08502). The tool support calculate the FIT score for all the layers at once. To calculate the FIT of the whole model on a calibration dataset (wikitext): ``` python scripts/fit.py --num_layers=32 --checkpoint=/tmp/Meta-Llama-3-8B --max_seqlen=2048 --max_iter=100 --nsamples=128 ``` -The tool will print out the average FIT scores based on the calibration dataset for all the layers. where the arguments checkpoint, max_seqlen, nsamples, max_iter are similar to the usage of running Hession. The only difference is that we replacing --layer_id with --num_layers to identify the total numbers of layers to calculate FIT scores for. +The tool will print out the average FIT scores based on the calibration dataset for all the layers. An output example: +``` +Iterations Done +FIT scores for 32 layers: +[237201.35, 547750.91, 87226.19, 50000.96, + 52017.47, 28319.72, 21997.11, 20681.59, + 21076.09, 21016.67, 18572.73, 19594.67, + 17585.58, 20135.83, 22986.77, 21849.15, + 21690.99, 21204.48, 19281.44, 17967.87, + 16843.32, 19385.39, 18394.11, 15991.45, + 15684.25, 15192.07, 15993.08, 16999.28, + 17418.69, 21241.36, 23579.92, 52762.86] +``` +where the arguments checkpoint, max_seqlen, nsamples, max_iter are similar to the usage of running Hession. The only difference is that we replacing --layer_id with --num_layers to identify the total numbers of layers to calculate FIT scores for. -Calculating FIT takes 3.3h with 1 GPU on a calibration dataset of 512 samples for Llama3-8B. +Calculating FIT takes 3.3h with 1 A100 GPU with 80GB GPU memory. on a calibration dataset of 512 samples for Llama3-8B. ### Usage of BO search #### Step 1: Define parameter space -Given a model, to conduct a BO search, we first need to identify the parameter space for the model, ie., for each layer, set up the value or choices of bitwidth and groupsize. An example of parameter space configuration is shown below and in Llama3-8B_parameters.json. +Given a model, to conduct a BO search, we first need to identify the parameter space for the model, ie., for each layer, set up the value or choices of bitwidth and groupsize. A simple example of parameter space configuration is shown below and an example for Llama3-8B is in Llama3-8B_parameters.json. ``` { @@ -51,7 +66,7 @@ Given a model, to conduct a BO search, we first need to identify the parameter s A parameter for a layer (specified in the range) can be either "fixed" or "choice" type for a fixed value or a list of possible choices. A default parameter space setting will be search from [2, 3, 4, 5, 6, 8] bit and [32, 64, 128, 512] groupsize for each layer. #### Step 2: Define initial samples (optional) -Then an optional step is to obtain some better initial samples based on the sensitivity scores. A layer with a higher sensitivity score (Hessian or FIT) should be assigned with a higher bitwidth and a smaller groupsize, to preserve the model accuracy. E.g., the FIT scores for the first 3 layers are far higher then other layers, thus we can set <5-bit, groupsize=32> for them and <4-bit, groupsize=64> for all the other layers. An example of initial samples of BO search is shown below and in Llama3-8B_initial_samples.json. A default initial samples will be random sampling from valid parameter space. We recommend users to add at least 10 examples to better initialize the BO strategy. +Then an optional step is to obtain some better initial samples based on the sensitivity scores. A layer with a higher sensitivity score (Hessian or FIT) should be assigned with a higher bitwidth and a smaller groupsize, to preserve the model accuracy. E.g., the FIT scores for the first 3 layers are far higher then other layers, thus we can set <5-bit, groupsize=32> for them and <4-bit, groupsize=64> for all the other layers. A simple example of initial samples of BO search is shown below and an example for Llama3-8B is shown in Llama3-8B_initial_samples.json. A default initial samples will be randomly sampled from the valid parameter space. We recommend users to add at least 10 examples to better initialize the BO strategy. ``` { @@ -78,15 +93,14 @@ python --BO_acc_modelsize.py --checkpoint=/tmp/Meta-Llama-3-8B --num_trials=200 where --num_trials identifies the number of search for BO ---model_size_constraint identifies the max model size for valid search results +--model_size_constraint identifies the max model size for valid search results (unit: GB) --parameters_list identifies the path to load parameter space. --initial_samples identifies the path to get initial samples of BO search --gpu_lists enbles evaluating BO different BO trials on different GPUs, otherwise will use only one GPU For Llam3-8B, a search takes 1.5h on wikitext-document from lm_eval on 8 A100 GPUs with 80GB GPU memory. -Example outputs: -The tool will print out the best configuration and results (accuracy, model size or throughput) among the search. +The tool will print out the best configuration and results (accuracy ("cal_PPL"), model size ("model_size") or throughput ("cal_throughput")) among the search. Example output: ``` ------Best config------ @@ -104,7 +118,7 @@ The tool will also write the BO search trial history to history_output csv file #### Run BO to optimize inference speed We also provide another version of BO search to optimize inference throughput (with torch.compile()) under a certain model accuracy constraint: ``` -python --BO_acc_throughput.py --checkpoint=/tmp/Meta-Llama-3-8B --num_BO_initial_samples=10 --num_trials=200 --ppl_constraint=7.5 --output_file=BO_acc_modelsize_output.csv --parameters_list=Llama3-8B_parameters.json --initial_samples Llama3-8B_initial_samples.json +python --BO_acc_throughput.py --checkpoint=/tmp/Meta-Llama-3-8B --num_trials=200 --ppl_constraint=7.5 --output_file=BO_acc_modelsize_output.csv --parameters_list=Llama3-8B_parameters.json --initial_samples Llama3-8B_initial_samples.json ``` All the arguments are similar to the optmizing accuracy under model size constraint, except replacing --model_size_constraint with --ppl_constraint=7.5 to set up the perplexity limit of the valid search results. @@ -127,25 +141,27 @@ We are supporting more models, such as more transformer models and ViT models. T python --BO_acc_modelsize.py --checkpoint=/tmp/Mistral-7B-v0.1/ --num_trials=200 --model_size_constraint=6.0 --output_file=BO_acc_modelsize_output.csv --parameters_list=Mistral-7B_parameters.json --initial_samples=Mistral-7B_initial_samples.json --gpu_lists=0,1,2,3" ``` -Supports for ViT models is coming soon. +Support for ViT models is coming soon. ## Results -We evaluate BO search for Llama3-8B and Mistral-7B-v0.1 under two settings: (1) optimizing model accuracy under model size constraint; (2) optimizing model inference throughput under model accuracy constraint. +We evaluated BO search for Llama3-8B and Mistral-7B-v0.1 under two settings: (1) optimizing model accuracy under model size constraint; (2) optimizing model inference throughput under model accuracy constraint. ### Results of BO for optimizing model accuracy under model size constraint -For Llama3-8B, the BO search quantization saves 60.2% model size with 2.89% ppl degradation compared to bfloat-16 baseline. +For Llama3-8B, the BO search quantization saves 20.1% model size with 2.85% ppl degradation compared to int8wo uniform quantization baseline. +The manual baseline here means using <5-bit, groupsize=32> for the first-3 and last-2 layers which have higher sensitivity scores, and <4-bit, groupsize=64> for all the other layers. + | Llama3-8B |ppl | model size| | ---------------- | ------ | ------ | | bf16 baseline | 7.260 | 15.01 | -| int8wo uniform | 7.263 | 7.480 | +| int8wo uniform quantization | 7.263 | 7.480 | | int4wo uniform quantization | 7.900 | 5.411 | | manual baseline | 7.679 | 5.545 | | BO mixed-precision quantization | 7.470 | 5.976 | -For Mistral-7B-v0.1, BO search quantization saves 59.4% model size with only 1.8% ppl degradation compared to bfloat-16 baseline. +For Mistral-7B-v0.1, BO search quantization saves 30.6% model size with only 1.74% ppl degradation compared to int8wo uniform quantization baseline. | Mistral-7B-v0.1 |ppl | model size| | ---------------- | ------ | ------ | | bf16 baseline | 8.021 | 13.49 | @@ -155,7 +171,7 @@ For Mistral-7B-v0.1, BO search quantization saves 59.4% model size with only 1.8 ### Results of BO for optimizing model inference throughput under model accuracy constraint -For Llama3-8B, the BO search quantization improving 69.5% throughput with only 3.25% ppl degradation compared to bfloat-16 baseline. +For Llama3-8B, the BO search quantization improves 15.2% throughput with only 2.85% ppl degradation compared to int8wo uniform quantization baseline. | Llama3-8B |ppl | throughput| | ---------------- | ------ | ------ | diff --git a/torchao/quantization/prototype/mixed_precision/scripts/fit.py b/torchao/quantization/prototype/mixed_precision/scripts/fit.py index c19087e62..db77878a9 100644 --- a/torchao/quantization/prototype/mixed_precision/scripts/fit.py +++ b/torchao/quantization/prototype/mixed_precision/scripts/fit.py @@ -89,7 +89,7 @@ def main(max_seqlen, checkpoint, nsamples, max_iter, num_layers): F_average, estimated_mean, estimated_history = cal_FIT(device=device, data=trainloader, nsamples=nsamples, model=model, max_iter=max_iter, max_seqlen=max_seqlen, criterion=criterion, num_layers=num_layers) print("Iteration Done") - print("avg_trace:", F_average) + print("FIT scores for",num_layers,"layers:\n", F_average) print("estimated_mean:", estimated_mean) print("estimated_history:", estimated_history) diff --git a/torchao/quantization/prototype/mixed_precision/scripts/hessian_vhp.py b/torchao/quantization/prototype/mixed_precision/scripts/hessian_vhp.py index 58b56f3b7..ba294317c 100644 --- a/torchao/quantization/prototype/mixed_precision/scripts/hessian_vhp.py +++ b/torchao/quantization/prototype/mixed_precision/scripts/hessian_vhp.py @@ -148,7 +148,7 @@ def f(*new_params): trace_history.append(trace) print("Iteration Done") - print("avg: trace,", np.mean(trace_history)) + print("Avg Hessian trace for layer", layer_id, "is:" np.mean(trace_history)) print("trace_history,", trace_history) From 2e3fe41c379298c3abf41f3b6348825375f08597 Mon Sep 17 00:00:00 2001 From: Hanxian97 Date: Wed, 28 Aug 2024 12:52:36 -0700 Subject: [PATCH 3/3] update README --- torchao/quantization/prototype/mixed_precision/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/quantization/prototype/mixed_precision/README.md b/torchao/quantization/prototype/mixed_precision/README.md index 114249301..3d6e0479c 100644 --- a/torchao/quantization/prototype/mixed_precision/README.md +++ b/torchao/quantization/prototype/mixed_precision/README.md @@ -145,7 +145,7 @@ Support for ViT models is coming soon. ## Results -We evaluated BO search for Llama3-8B and Mistral-7B-v0.1 under two settings: (1) optimizing model accuracy under model size constraint; (2) optimizing model inference throughput under model accuracy constraint. +We evaluated BO search for Llama3-8B and Mistral-7B-v0.1 under two settings: (1) optimizing model accuracy under model size constraint; (2) optimizing model inference throughput under model accuracy constraint, and compared the BO results with bfloat-16, [int8 weight only](https://github.com/pytorch/ao/blob/983f5653f5516e91c9fb9df73d6f407fbd4b381f/torchao/quantization/quant_api.py#L432) uniform quantization and [int4 weight only](https://github.com/pytorch/ao/blob/983f5653f5516e91c9fb9df73d6f407fbd4b381f/torchao/quantization/quant_api.py#L396) uniform quantization. ### Results of BO for optimizing model accuracy under model size constraint