From 66bbc361af7b427e82b16a1cb6305ccdc971bb08 Mon Sep 17 00:00:00 2001 From: "trotsky1997@qq.com" Date: Thu, 26 Oct 2023 20:40:17 +0800 Subject: [PATCH 1/2] bayesian' --- awq/optimizers/SPSA_Adam.py | 0 awq/optimizers/bayesian_warp.py | 87 +++++++++++++++++++++++++++++++++ awq/quantize/auto_scale.py | 56 +++++++++++++++++---- 3 files changed, 134 insertions(+), 9 deletions(-) create mode 100644 awq/optimizers/SPSA_Adam.py create mode 100644 awq/optimizers/bayesian_warp.py diff --git a/awq/optimizers/SPSA_Adam.py b/awq/optimizers/SPSA_Adam.py new file mode 100644 index 0000000..e69de29 diff --git a/awq/optimizers/bayesian_warp.py b/awq/optimizers/bayesian_warp.py new file mode 100644 index 0000000..e40efbf --- /dev/null +++ b/awq/optimizers/bayesian_warp.py @@ -0,0 +1,87 @@ +verbose = False +# Number of pairwise comparisons performed before checking posterior mean +every_n_comps = 3 +# Total number of checking the maximum posterior mean +n_check_post_mean = 5 +n_outcome_model_initialization_points = 8 +n_reps = 1 +within_session_results = [] +exp_candidate_results = [] + +for i in range(n_reps): + print(f"Run {i}") + # Experimentation stage: initial exploration batch + torch.manual_seed(i) + np.random.seed(i) + X, Y = generate_random_exp_data(problem, n_outcome_model_initialization_points) + outcome_model = fit_outcome_model(X, Y, problem.bounds) + + # Preference exploration stage: initialize the preference model with comparsions + # between pairs of outcomes estimated using random design points + init_train_Y, init_train_comps = generate_random_pref_data(outcome_model, n=1) + + # Perform preference exploration using either Random-f or EUBO-zeta + for pe_strategy in ["EUBO-zeta", "Random-f"]: + train_Y, train_comps = init_train_Y.clone(), init_train_comps.clone() + within_result = find_max_posterior_mean(outcome_model, train_Y, train_comps) + within_result.update({"run_id": i, "pe_strategy": pe_strategy}) + within_session_results.append(within_result) + + for j in range(n_check_post_mean): + train_Y, train_comps = run_pref_learn( + outcome_model, + train_Y, + train_comps, + n_comps=every_n_comps, + pe_strategy=pe_strategy, + verbose=verbose, + ) + if verbose: + print( + f"Checking posterior mean after {(j+1) * every_n_comps} comps using PE strategy {pe_strategy}" + ) + within_result = find_max_posterior_mean( + outcome_model, train_Y, train_comps, verbose=verbose + ) + within_result.update({"run_id": i, "pe_strategy": pe_strategy}) + within_session_results.append(within_result) + + # Going back to the experimentation stage: generate an additional batch of experimental evaluations + # with the learned preference model and qNEIUU + pref_model = fit_pref_model(train_Y, train_comps) + sampler = SobolQMCNormalSampler(sample_shape=torch.Size([NUM_PREF_SAMPLES])) + pref_obj = LearnedObjective(pref_model=pref_model, sampler=sampler) + exp_cand_X = gen_exp_cand(outcome_model, pref_obj, q=1, acqf_name="qNEI") + qneiuu_util = util_func(problem(exp_cand_X)).item() + print(f"{pe_strategy} qNEIUU candidate utility: {qneiuu_util:.3f}") + exp_result = { + "util": qneiuu_util, + "strategy": pe_strategy, + "run_id": i, + } + exp_candidate_results.append(exp_result) + + # Generate a batch of experimental evaluations using oracle and random baselines + # True utility + true_obj = GenericMCObjective(util_func) + true_obj_cand_X = gen_exp_cand(outcome_model, true_obj, q=1, acqf_name="qNEI") + true_obj_util = util_func(problem(true_obj_cand_X)).item() + print(f"True objective utility: {true_obj_util:.3f}") + exp_result = { + "util": true_obj_util, + "strategy": "True Utility", + "run_id": i, + } + exp_candidate_results.append(exp_result) + + # Random experiment + _, random_Y = generate_random_exp_data(problem, 1) + random_util = util_func(random_Y).item() + print(f"Random experiment utility: {random_util:.3f}") + exp_result = { + "util": random_util, + "strategy": "Random Experiment", + "run_id": i, + } + exp_candidate_results.append(exp_result) + diff --git a/awq/quantize/auto_scale.py b/awq/quantize/auto_scale.py index 5f3c787..3bd88b0 100644 --- a/awq/quantize/auto_scale.py +++ b/awq/quantize/auto_scale.py @@ -1,6 +1,8 @@ import gc import torch import torch.nn as nn +from mango import scheduler, Tuner +from scipy.stats import uniform from transformers.models.bloom.modeling_bloom import BloomBlock, BloomGelu from transformers.models.opt.modeling_opt import OPTDecoderLayer @@ -105,7 +107,7 @@ def w_quantize_func(p): return p module_kwargs.pop("use_cache") # find the best scale ratio - def _search_module_scale(block, linears2scale: list, x, kwargs={}): + def _search_module_scale(block, linears2scale: list, x, kwargs={},optimization_method='BO'): # w: co, ci # x: n, ci weight = torch.cat([_m.weight for _m in linears2scale], dim=0) @@ -116,6 +118,8 @@ def _search_module_scale(block, linears2scale: list, x, kwargs={}): gc.collect() torch.cuda.empty_cache() + optimize_func = {'grid_search':grid_search,'BO':Bayesian_optimization}[optimization_method] + x = x.to(next(block.parameters()).device) with torch.no_grad(): org_out = block(x, **kwargs) @@ -130,8 +134,18 @@ def _search_module_scale(block, linears2scale: list, x, kwargs={}): n_grid = 20 history = [] - org_sd = {k: v.cpu() for k, v in block.state_dict().items()} + best_ratio, best_scales = optimize_func(block, linears2scale, x, kwargs, w_max, org_out, x_max, best_error, n_grid, history, org_sd) + if best_ratio == -1: + print(history) + raise Exception + # print(best_ratio) + best_scales = best_scales.view(-1) + + assert torch.isnan(best_scales).sum() == 0, best_scales + return best_scales.detach() + + def grid_search(block, linears2scale, x, kwargs, w_max, org_out, x_max, best_error, n_grid, history, org_sd): for ratio in range(n_grid): ratio = ratio * 1 / n_grid scales = (x_max.pow(ratio) / w_max.pow(1-ratio) @@ -153,14 +167,38 @@ def _search_module_scale(block, linears2scale: list, x, kwargs={}): best_ratio = ratio best_scales = scales block.load_state_dict(org_sd) - if best_ratio == -1: - print(history) - raise Exception - # print(best_ratio) - best_scales = best_scales.view(-1) + return best_ratio,best_scales + + def Bayesian_optimization(block, linears2scale, x, kwargs, w_max, org_out, x_max, best_error, n_grid, history, org_sd): + @scheduler.serial + def get_loss(ratio): + ratio = ratio * 1 / n_grid + scales = (x_max.pow(ratio) / w_max.pow(1-ratio) + ).clamp(min=1e-4).view(-1) + scales = scales / (scales.max() * scales.min()).sqrt() + for fc in linears2scale: + fc.weight.mul_(scales.view(1, -1).to(fc.weight.device)) + fc.weight.data = w_quantize_func( + fc.weight.data) / (scales.view(1, -1)) + out = block(x, **kwargs) + if isinstance(out, tuple): + out = out[0] + + loss = (org_out - out).float().pow(2).mean().item() # float prevents overflow + history.append(loss) + is_best = loss < best_error + if is_best: + best_error = loss + best_ratio = ratio + best_scales = scales + block.load_state_dict(org_sd) + return loss + + param_space = dict(ratio=uniform(0, 1)) + tuner = Tuner(param_space, objective) + result = tuner.minimize() + return best_ratio,best_scales - assert torch.isnan(best_scales).sum() == 0, best_scales - return best_scales.detach() def _auto_get_scale(prev_op, layers, inp, module2inspect=None, kwargs={}): # module2inspect: if given, we will check the output diff of this module instead of layers From 838085e3a5d76f0037672cb4f5c1a4f86c89ccef Mon Sep 17 00:00:00 2001 From: "trotsky1997@qq.com" Date: Thu, 26 Oct 2023 23:52:58 +0800 Subject: [PATCH 2/2] fix --- README.md | 2 +- awq/quantize/auto_scale.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7a742fe..b949623 100644 --- a/README.md +++ b/README.md @@ -100,7 +100,7 @@ We provide several sample script to run AWQ (please refer to `./scripts`). We us 1. Perform AWQ search and save search results (we already did it for you): ```bash -python -m awq.entry --model_path /PATH/TO/OPT/opt-6.7b \ +python -m awq.entry --model_path /mnt/workspace/zhangdi/vicuna-7b-v1.3 \ --w_bit 4 --q_group_size 128 \ --run_awq --dump_awq awq_cache/opt-6.7b-w4-g128.pt ``` diff --git a/awq/quantize/auto_scale.py b/awq/quantize/auto_scale.py index 3bd88b0..c11f1d8 100644 --- a/awq/quantize/auto_scale.py +++ b/awq/quantize/auto_scale.py @@ -167,11 +167,15 @@ def grid_search(block, linears2scale, x, kwargs, w_max, org_out, x_max, best_err best_ratio = ratio best_scales = scales block.load_state_dict(org_sd) + print(best_error) return best_ratio,best_scales def Bayesian_optimization(block, linears2scale, x, kwargs, w_max, org_out, x_max, best_error, n_grid, history, org_sd): + best_ratio = -1 + best_scales = None @scheduler.serial def get_loss(ratio): + nonlocal best_error,best_ratio,best_scales ratio = ratio * 1 / n_grid scales = (x_max.pow(ratio) / w_max.pow(1-ratio) ).clamp(min=1e-4).view(-1) @@ -195,7 +199,7 @@ def get_loss(ratio): return loss param_space = dict(ratio=uniform(0, 1)) - tuner = Tuner(param_space, objective) + tuner = Tuner(param_space, get_loss,{"num_iteration":15,'exploration':0.5,'exploration_decay':1}) result = tuner.minimize() return best_ratio,best_scales