Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
modelfusion committed Feb 19, 2020
0 parents commit 8144c3b
Show file tree
Hide file tree
Showing 22 changed files with 4,116 additions and 0 deletions.
47 changes: 47 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
### Requirements

Install the Python Optimal Transport Library

```
pip install POT
```

Other than that, we also need PyTorch v1 or higher and NumPy. (Also, Python 3.6 +)

Before running, unzip the respective pretrained model zip file as well as the cifar zip file.

### Sample commands of one-shot model fusion

#### For MNIST + MLPNet

```
python main.py --gpu-id 1 --model-name mlpnet --n-epochs 10 --save-result-file sample.csv \
--sweep-name exp_sample --exact --correction --ground-metric euclidean --weight-stats \
--activation-histograms --activation-mode raw --geom-ensemble-type acts --sweep-id 21 \
--act-num-samples 200 --ground-metric-normalize none --activation-seed 21 \
--prelu-acts --recheck-acc --load-models ./mnist_models --ckpt-type final \
--past-correction --not-squared --dist-normalize --print-distances --to-download
```

#### For CIFAR10 + VGG11
```
python main.py --gpu-id 1 --model-name vgg11_nobias --n-epochs 300 --save-result-file sample.csv \
--sweep-name exp_sample --correction --ground-metric euclidean --weight-stats \
--geom-ensemble-type wts --ground-metric-normalize none --sweep-id 90 --load-models ./cifar_models/ \
--ckpt-type best --dataset Cifar10 --ground-metric-eff --recheck-cifar --activation-seed 21 \
--prelu-acts --past-correction --not-squared --normalize-wts --exact
```

#### For CIFAR10 + ResNet18

```
python main.py --gpu-id 1 --model-name resnet18_nobias_nobn --n-epochs 300 --save-result-file sample.csv \
--sweep-name exp_sample --exact --correction --ground-metric euclidean --weight-stats \
--activation-histograms --activation-mode raw --geom-ensemble-type acts --sweep-id 21 \
--act-num-samples 200 --ground-metric-normalize none --activation-seed 21 --prelu-acts --recheck-acc \
--load-models ./resnet_models/ --ckpt-type best --past-correction --not-squared --dataset Cifar10 \
--handle-skips
```

The code and pretrained models correspond to the ICML 2020 submission:
`Model Fusion via Optimal Transport`. If you use any of the code or pretrained models for your research, please consider citing the paper.
101 changes: 101 additions & 0 deletions baseline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import torch
import torch.nn.functional as F
from model import get_model_from_name
import routines

def get_avg_parameters(networks, weights=None):
avg_pars = []
for par_group in zip(*[net.parameters() for net in networks]):
print([par.shape for par in par_group])
if weights is not None:
weighted_par_group = [par * weights[i] for i, par in enumerate(par_group)]
avg_par = torch.sum(torch.stack(weighted_par_group), dim=0)
else:
# print("shape of stacked params is ", torch.stack(par_group).shape) # (2, 400, 784)
avg_par = torch.mean(torch.stack(par_group), dim=0)
print(avg_par.shape)
avg_pars.append(avg_par)
return avg_pars

def naive_ensembling(args, networks, test_loader):
# simply average the weights in networks
if args.width_ratio != 1:
print("Unfortunately naive ensembling can't work if models are not of same shape!")
return -1, None
weights = [(1-args.ensemble_step), args.ensemble_step]
avg_pars = get_avg_parameters(networks, weights)
ensemble_network = get_model_from_name(args)
# put on GPU
if args.gpu_id!=-1:
ensemble_network = ensemble_network.cuda(args.gpu_id)

# check the test performance of the method before
log_dict = {}
log_dict['test_losses'] = []
# log_dict['test_counter'] = [i * len(train_loader.dataset) for i in range(args.n_epochs + 1)]
routines.test(args, ensemble_network, test_loader, log_dict)

# set the weights of the ensembled network
for idx, (name, param) in enumerate(ensemble_network.state_dict().items()):
ensemble_network.state_dict()[name].copy_(avg_pars[idx].data)

# check the test performance of the method after ensembling
log_dict = {}
log_dict['test_losses'] = []
# log_dict['test_counter'] = [i * len(train_loader.dataset) for i in range(args.n_epochs + 1)]
return routines.test(args, ensemble_network, test_loader, log_dict), ensemble_network


def prediction_ensembling(args, networks, test_loader):
log_dict = {}
log_dict['test_losses'] = []
# test counter is not even used!
# log_dict['test_counter'] = [i * len(train_loader.dataset) for i in range(args.n_epochs + 1)]

if args.dataset.lower() == 'cifar10':
cifar_criterion = torch.nn.CrossEntropyLoss()

# set all the networks in eval mode
for net in networks:
net.eval()
test_loss = 0
correct = 0

# with torch.no_grad():
for data, target in test_loader:
if args.gpu_id!=-1:
data = data.cuda(args.gpu_id)
target = target.cuda(args.gpu_id)
outputs = []
# average the outputs of all nets
assert len(networks) == 2
if args.prediction_wts:
wts = [(1 - args.ensemble_step), args.ensemble_step]
else:
wts = [0.5, 0.5]
for idx, net in enumerate(networks):
outputs.append(wts[idx]*net(data))
# print("number of outputs {} and each is of shape {}".format(len(outputs), outputs[-1].shape))
# number of outputs 2 and each is of shape torch.Size([1000, 10])
output = torch.sum(torch.stack(outputs), dim=0) # sum because multiplied by wts above

# check loss of this ensembled prediction
if args.dataset.lower() == 'cifar10':
# mnist models return log_softmax outputs, while cifar ones return raw values!
test_loss += cifar_criterion(output, target).item()
elif args.dataset.lower() == 'mnist':
test_loss += F.nll_loss(output, target, size_average=False).item()

pred = output.data.max(1, keepdim=True)[1]
correct += pred.eq(target.data.view_as(pred)).sum()

test_loss /= len(test_loader.dataset)
log_dict['test_losses'].append(test_loss)


print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))


return (float(correct) * 100.0)/len(test_loader.dataset)
28 changes: 28 additions & 0 deletions check_accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import parameters
from data import get_dataloader
import routines
import baseline
import wasserstein_ensemble
import os
import utils
import numpy as np
import sys
import hyperparameters.vgg11_cifar10_baseline as vgg_hyperparams
PATH_TO_CIFAR = "./cifar/"
sys.path.append(PATH_TO_CIFAR)
import train as cifar_train


exp_path = sys.argv[1]
gpu_id = int(sys.argv[2])
print("gpu_id is ", gpu_id)
print("exp_path is ", exp_path)

config = vgg_hyperparams.config

model_types = ['model_0', 'model_1', 'geometric', 'naive_averaging']
for model in model_types:
for ckpt in ['best', 'final']:
if os.path.exists(os.path.join(exp_path, model)):
cifar_train.get_pretrained_model(config,
os.path.join(exp_path, model, ckpt + '.checkpoint'), device_id=gpu_id)
Binary file added cifar.zip
Binary file not shown.
Binary file added cifar_models.zip
Binary file not shown.
Loading

0 comments on commit 8144c3b

Please sign in to comment.