Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Suggest: Add Bayesian optimization support for ratio search #104

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ We provide several sample script to run AWQ (please refer to `./scripts`). We us

1. Perform AWQ search and save search results (we already did it for you):
```bash
python -m awq.entry --model_path /PATH/TO/OPT/opt-6.7b \
python -m awq.entry --model_path /mnt/workspace/zhangdi/vicuna-7b-v1.3 \
--w_bit 4 --q_group_size 128 \
--run_awq --dump_awq awq_cache/opt-6.7b-w4-g128.pt
```
Expand Down
Empty file added awq/optimizers/SPSA_Adam.py
Empty file.
87 changes: 87 additions & 0 deletions awq/optimizers/bayesian_warp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
verbose = False
# Number of pairwise comparisons performed before checking posterior mean
every_n_comps = 3
# Total number of checking the maximum posterior mean
n_check_post_mean = 5
n_outcome_model_initialization_points = 8
n_reps = 1
within_session_results = []
exp_candidate_results = []

for i in range(n_reps):
print(f"Run {i}")
# Experimentation stage: initial exploration batch
torch.manual_seed(i)
np.random.seed(i)
X, Y = generate_random_exp_data(problem, n_outcome_model_initialization_points)
outcome_model = fit_outcome_model(X, Y, problem.bounds)

# Preference exploration stage: initialize the preference model with comparsions
# between pairs of outcomes estimated using random design points
init_train_Y, init_train_comps = generate_random_pref_data(outcome_model, n=1)

# Perform preference exploration using either Random-f or EUBO-zeta
for pe_strategy in ["EUBO-zeta", "Random-f"]:
train_Y, train_comps = init_train_Y.clone(), init_train_comps.clone()
within_result = find_max_posterior_mean(outcome_model, train_Y, train_comps)
within_result.update({"run_id": i, "pe_strategy": pe_strategy})
within_session_results.append(within_result)

for j in range(n_check_post_mean):
train_Y, train_comps = run_pref_learn(
outcome_model,
train_Y,
train_comps,
n_comps=every_n_comps,
pe_strategy=pe_strategy,
verbose=verbose,
)
if verbose:
print(
f"Checking posterior mean after {(j+1) * every_n_comps} comps using PE strategy {pe_strategy}"
)
within_result = find_max_posterior_mean(
outcome_model, train_Y, train_comps, verbose=verbose
)
within_result.update({"run_id": i, "pe_strategy": pe_strategy})
within_session_results.append(within_result)

# Going back to the experimentation stage: generate an additional batch of experimental evaluations
# with the learned preference model and qNEIUU
pref_model = fit_pref_model(train_Y, train_comps)
sampler = SobolQMCNormalSampler(sample_shape=torch.Size([NUM_PREF_SAMPLES]))
pref_obj = LearnedObjective(pref_model=pref_model, sampler=sampler)
exp_cand_X = gen_exp_cand(outcome_model, pref_obj, q=1, acqf_name="qNEI")
qneiuu_util = util_func(problem(exp_cand_X)).item()
print(f"{pe_strategy} qNEIUU candidate utility: {qneiuu_util:.3f}")
exp_result = {
"util": qneiuu_util,
"strategy": pe_strategy,
"run_id": i,
}
exp_candidate_results.append(exp_result)

# Generate a batch of experimental evaluations using oracle and random baselines
# True utility
true_obj = GenericMCObjective(util_func)
true_obj_cand_X = gen_exp_cand(outcome_model, true_obj, q=1, acqf_name="qNEI")
true_obj_util = util_func(problem(true_obj_cand_X)).item()
print(f"True objective utility: {true_obj_util:.3f}")
exp_result = {
"util": true_obj_util,
"strategy": "True Utility",
"run_id": i,
}
exp_candidate_results.append(exp_result)

# Random experiment
_, random_Y = generate_random_exp_data(problem, 1)
random_util = util_func(random_Y).item()
print(f"Random experiment utility: {random_util:.3f}")
exp_result = {
"util": random_util,
"strategy": "Random Experiment",
"run_id": i,
}
exp_candidate_results.append(exp_result)

60 changes: 51 additions & 9 deletions awq/quantize/auto_scale.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import gc
import torch
import torch.nn as nn
from mango import scheduler, Tuner
from scipy.stats import uniform

from transformers.models.bloom.modeling_bloom import BloomBlock, BloomGelu
from transformers.models.opt.modeling_opt import OPTDecoderLayer
Expand Down Expand Up @@ -105,7 +107,7 @@ def w_quantize_func(p): return p
module_kwargs.pop("use_cache")

# find the best scale ratio
def _search_module_scale(block, linears2scale: list, x, kwargs={}):
def _search_module_scale(block, linears2scale: list, x, kwargs={},optimization_method='BO'):
# w: co, ci
# x: n, ci
weight = torch.cat([_m.weight for _m in linears2scale], dim=0)
Expand All @@ -116,6 +118,8 @@ def _search_module_scale(block, linears2scale: list, x, kwargs={}):
gc.collect()
torch.cuda.empty_cache()

optimize_func = {'grid_search':grid_search,'BO':Bayesian_optimization}[optimization_method]

x = x.to(next(block.parameters()).device)
with torch.no_grad():
org_out = block(x, **kwargs)
Expand All @@ -130,8 +134,18 @@ def _search_module_scale(block, linears2scale: list, x, kwargs={}):

n_grid = 20
history = []

org_sd = {k: v.cpu() for k, v in block.state_dict().items()}
best_ratio, best_scales = optimize_func(block, linears2scale, x, kwargs, w_max, org_out, x_max, best_error, n_grid, history, org_sd)
if best_ratio == -1:
print(history)
raise Exception
# print(best_ratio)
best_scales = best_scales.view(-1)

assert torch.isnan(best_scales).sum() == 0, best_scales
return best_scales.detach()

def grid_search(block, linears2scale, x, kwargs, w_max, org_out, x_max, best_error, n_grid, history, org_sd):
for ratio in range(n_grid):
ratio = ratio * 1 / n_grid
scales = (x_max.pow(ratio) / w_max.pow(1-ratio)
Expand All @@ -153,14 +167,42 @@ def _search_module_scale(block, linears2scale: list, x, kwargs={}):
best_ratio = ratio
best_scales = scales
block.load_state_dict(org_sd)
if best_ratio == -1:
print(history)
raise Exception
# print(best_ratio)
best_scales = best_scales.view(-1)
print(best_error)
return best_ratio,best_scales

def Bayesian_optimization(block, linears2scale, x, kwargs, w_max, org_out, x_max, best_error, n_grid, history, org_sd):
best_ratio = -1
best_scales = None
@scheduler.serial
def get_loss(ratio):
nonlocal best_error,best_ratio,best_scales
ratio = ratio * 1 / n_grid
scales = (x_max.pow(ratio) / w_max.pow(1-ratio)
).clamp(min=1e-4).view(-1)
scales = scales / (scales.max() * scales.min()).sqrt()
for fc in linears2scale:
fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))
fc.weight.data = w_quantize_func(
fc.weight.data) / (scales.view(1, -1))
out = block(x, **kwargs)
if isinstance(out, tuple):
out = out[0]

loss = (org_out - out).float().pow(2).mean().item() # float prevents overflow
history.append(loss)
is_best = loss < best_error
if is_best:
best_error = loss
best_ratio = ratio
best_scales = scales
block.load_state_dict(org_sd)
return loss

param_space = dict(ratio=uniform(0, 1))
tuner = Tuner(param_space, get_loss,{"num_iteration":15,'exploration':0.5,'exploration_decay':1})
result = tuner.minimize()
return best_ratio,best_scales

assert torch.isnan(best_scales).sum() == 0, best_scales
return best_scales.detach()

def _auto_get_scale(prev_op, layers, inp, module2inspect=None, kwargs={}):
# module2inspect: if given, we will check the output diff of this module instead of layers
Expand Down