-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 8144c3b
Showing
22 changed files
with
4,116 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
### Requirements | ||
|
||
Install the Python Optimal Transport Library | ||
|
||
``` | ||
pip install POT | ||
``` | ||
|
||
Other than that, we also need PyTorch v1 or higher and NumPy. (Also, Python 3.6 +) | ||
|
||
Before running, unzip the respective pretrained model zip file as well as the cifar zip file. | ||
|
||
### Sample commands of one-shot model fusion | ||
|
||
#### For MNIST + MLPNet | ||
|
||
``` | ||
python main.py --gpu-id 1 --model-name mlpnet --n-epochs 10 --save-result-file sample.csv \ | ||
--sweep-name exp_sample --exact --correction --ground-metric euclidean --weight-stats \ | ||
--activation-histograms --activation-mode raw --geom-ensemble-type acts --sweep-id 21 \ | ||
--act-num-samples 200 --ground-metric-normalize none --activation-seed 21 \ | ||
--prelu-acts --recheck-acc --load-models ./mnist_models --ckpt-type final \ | ||
--past-correction --not-squared --dist-normalize --print-distances --to-download | ||
``` | ||
|
||
#### For CIFAR10 + VGG11 | ||
``` | ||
python main.py --gpu-id 1 --model-name vgg11_nobias --n-epochs 300 --save-result-file sample.csv \ | ||
--sweep-name exp_sample --correction --ground-metric euclidean --weight-stats \ | ||
--geom-ensemble-type wts --ground-metric-normalize none --sweep-id 90 --load-models ./cifar_models/ \ | ||
--ckpt-type best --dataset Cifar10 --ground-metric-eff --recheck-cifar --activation-seed 21 \ | ||
--prelu-acts --past-correction --not-squared --normalize-wts --exact | ||
``` | ||
|
||
#### For CIFAR10 + ResNet18 | ||
|
||
``` | ||
python main.py --gpu-id 1 --model-name resnet18_nobias_nobn --n-epochs 300 --save-result-file sample.csv \ | ||
--sweep-name exp_sample --exact --correction --ground-metric euclidean --weight-stats \ | ||
--activation-histograms --activation-mode raw --geom-ensemble-type acts --sweep-id 21 \ | ||
--act-num-samples 200 --ground-metric-normalize none --activation-seed 21 --prelu-acts --recheck-acc \ | ||
--load-models ./resnet_models/ --ckpt-type best --past-correction --not-squared --dataset Cifar10 \ | ||
--handle-skips | ||
``` | ||
|
||
The code and pretrained models correspond to the ICML 2020 submission: | ||
`Model Fusion via Optimal Transport`. If you use any of the code or pretrained models for your research, please consider citing the paper. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
from model import get_model_from_name | ||
import routines | ||
|
||
def get_avg_parameters(networks, weights=None): | ||
avg_pars = [] | ||
for par_group in zip(*[net.parameters() for net in networks]): | ||
print([par.shape for par in par_group]) | ||
if weights is not None: | ||
weighted_par_group = [par * weights[i] for i, par in enumerate(par_group)] | ||
avg_par = torch.sum(torch.stack(weighted_par_group), dim=0) | ||
else: | ||
# print("shape of stacked params is ", torch.stack(par_group).shape) # (2, 400, 784) | ||
avg_par = torch.mean(torch.stack(par_group), dim=0) | ||
print(avg_par.shape) | ||
avg_pars.append(avg_par) | ||
return avg_pars | ||
|
||
def naive_ensembling(args, networks, test_loader): | ||
# simply average the weights in networks | ||
if args.width_ratio != 1: | ||
print("Unfortunately naive ensembling can't work if models are not of same shape!") | ||
return -1, None | ||
weights = [(1-args.ensemble_step), args.ensemble_step] | ||
avg_pars = get_avg_parameters(networks, weights) | ||
ensemble_network = get_model_from_name(args) | ||
# put on GPU | ||
if args.gpu_id!=-1: | ||
ensemble_network = ensemble_network.cuda(args.gpu_id) | ||
|
||
# check the test performance of the method before | ||
log_dict = {} | ||
log_dict['test_losses'] = [] | ||
# log_dict['test_counter'] = [i * len(train_loader.dataset) for i in range(args.n_epochs + 1)] | ||
routines.test(args, ensemble_network, test_loader, log_dict) | ||
|
||
# set the weights of the ensembled network | ||
for idx, (name, param) in enumerate(ensemble_network.state_dict().items()): | ||
ensemble_network.state_dict()[name].copy_(avg_pars[idx].data) | ||
|
||
# check the test performance of the method after ensembling | ||
log_dict = {} | ||
log_dict['test_losses'] = [] | ||
# log_dict['test_counter'] = [i * len(train_loader.dataset) for i in range(args.n_epochs + 1)] | ||
return routines.test(args, ensemble_network, test_loader, log_dict), ensemble_network | ||
|
||
|
||
def prediction_ensembling(args, networks, test_loader): | ||
log_dict = {} | ||
log_dict['test_losses'] = [] | ||
# test counter is not even used! | ||
# log_dict['test_counter'] = [i * len(train_loader.dataset) for i in range(args.n_epochs + 1)] | ||
|
||
if args.dataset.lower() == 'cifar10': | ||
cifar_criterion = torch.nn.CrossEntropyLoss() | ||
|
||
# set all the networks in eval mode | ||
for net in networks: | ||
net.eval() | ||
test_loss = 0 | ||
correct = 0 | ||
|
||
# with torch.no_grad(): | ||
for data, target in test_loader: | ||
if args.gpu_id!=-1: | ||
data = data.cuda(args.gpu_id) | ||
target = target.cuda(args.gpu_id) | ||
outputs = [] | ||
# average the outputs of all nets | ||
assert len(networks) == 2 | ||
if args.prediction_wts: | ||
wts = [(1 - args.ensemble_step), args.ensemble_step] | ||
else: | ||
wts = [0.5, 0.5] | ||
for idx, net in enumerate(networks): | ||
outputs.append(wts[idx]*net(data)) | ||
# print("number of outputs {} and each is of shape {}".format(len(outputs), outputs[-1].shape)) | ||
# number of outputs 2 and each is of shape torch.Size([1000, 10]) | ||
output = torch.sum(torch.stack(outputs), dim=0) # sum because multiplied by wts above | ||
|
||
# check loss of this ensembled prediction | ||
if args.dataset.lower() == 'cifar10': | ||
# mnist models return log_softmax outputs, while cifar ones return raw values! | ||
test_loss += cifar_criterion(output, target).item() | ||
elif args.dataset.lower() == 'mnist': | ||
test_loss += F.nll_loss(output, target, size_average=False).item() | ||
|
||
pred = output.data.max(1, keepdim=True)[1] | ||
correct += pred.eq(target.data.view_as(pred)).sum() | ||
|
||
test_loss /= len(test_loader.dataset) | ||
log_dict['test_losses'].append(test_loss) | ||
|
||
|
||
print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( | ||
test_loss, correct, len(test_loader.dataset), | ||
100. * correct / len(test_loader.dataset))) | ||
|
||
|
||
return (float(correct) * 100.0)/len(test_loader.dataset) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import parameters | ||
from data import get_dataloader | ||
import routines | ||
import baseline | ||
import wasserstein_ensemble | ||
import os | ||
import utils | ||
import numpy as np | ||
import sys | ||
import hyperparameters.vgg11_cifar10_baseline as vgg_hyperparams | ||
PATH_TO_CIFAR = "./cifar/" | ||
sys.path.append(PATH_TO_CIFAR) | ||
import train as cifar_train | ||
|
||
|
||
exp_path = sys.argv[1] | ||
gpu_id = int(sys.argv[2]) | ||
print("gpu_id is ", gpu_id) | ||
print("exp_path is ", exp_path) | ||
|
||
config = vgg_hyperparams.config | ||
|
||
model_types = ['model_0', 'model_1', 'geometric', 'naive_averaging'] | ||
for model in model_types: | ||
for ckpt in ['best', 'final']: | ||
if os.path.exists(os.path.join(exp_path, model)): | ||
cifar_train.get_pretrained_model(config, | ||
os.path.join(exp_path, model, ckpt + '.checkpoint'), device_id=gpu_id) |
Binary file not shown.
Binary file not shown.
Oops, something went wrong.