-
Notifications
You must be signed in to change notification settings - Fork 95
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7c86191
commit 0518b39
Showing
68 changed files
with
6,335 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# Datasets | ||
ResNet*/data/ | ||
Transformer*/data/* | ||
dataset/* | ||
|
||
# jupyter checkpoints | ||
**/.ipynb_checkpoints | ||
|
||
# Compiled python modules. | ||
*.pyc | ||
|
||
# Byte-compiled | ||
_pycache__/ | ||
.cache/ | ||
|
||
# Python egg metadata, regenerated from source files by setuptools. | ||
*.egg-info | ||
.eggs/ | ||
|
||
# PyPI distribution artifacts. | ||
build/ | ||
dist/ | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# PyCharm/vscode | ||
.idea | ||
.vscode | ||
|
||
# Vim | ||
.*.swp | ||
|
||
# Other | ||
*.DS_Store |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,25 +1,13 @@ | ||
# TODO: The maintainer of this repo has not yet edited this file | ||
|
||
**REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? | ||
|
||
- **No CSS support:** Fill out this template with information about how to file issues and get help. | ||
- **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). | ||
- **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide. | ||
|
||
*Then remove this first heading from this SUPPORT.MD file before publishing your repo.* | ||
|
||
# Support | ||
|
||
## How to file issues and get help | ||
|
||
This project uses GitHub Issues to track bugs and feature requests. Please search the existing | ||
issues before filing new issues to avoid duplicates. For new issues, file your bug or | ||
feature request as a new Issue. | ||
|
||
For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE | ||
FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER | ||
CHANNEL. WHERE WILL YOU HELP PEOPLE?**. | ||
|
||
## Microsoft Support Policy | ||
|
||
Support for this **PROJECT or PRODUCT** is limited to the resources listed above. | ||
# Support | ||
|
||
## How to file issues and get help | ||
|
||
This project uses GitHub Issues to track bugs and feature requests. Please search the existing | ||
issues before filing new issues to avoid duplicates. For new issues, file your bug or | ||
feature request as a new Issue. | ||
|
||
For help and questions about using this project, please use Github Discussions in this repo. | ||
|
||
## Microsoft Support Policy | ||
|
||
Support for this project is limited to the resources listed above. |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
dataset/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# μP MLP | ||
This folder contains the source code for our experiment on MLP, which also serves as an example usage of `mup`. | ||
The script trains a series of MLPs with increasing hidden sizes from 64 to 8192. | ||
|
||
## Save Model Base Shapes | ||
To train a μP model, one needs to first specify the base shapes. To save base shapes info of the narrowest model, run, | ||
``` | ||
python main.py --save_base_shapes width64.bsh | ||
``` | ||
|
||
## Verify Implementation with Coordinate Check | ||
Before we scale up and start training, it is recommended to check the size of activation coordinates as model width increases. We have integrated such a test in this example using the helper functions in `mup`; you can simply run: | ||
|
||
```bash | ||
python main.py --load_base_shapes width64.bsh --coord_check | ||
``` | ||
You should find the generated plots under `./coord_checks`, which show stable coordinate sizes under μP, e.g., | ||
|
||
![](coord_checks/μp_mlp_sgd_coord.png) | ||
|
||
and growing sizes under SP, e.g., | ||
|
||
![](coord_checks/sp_mlp_sgd_coord.png) | ||
|
||
|
||
## Start Training | ||
Having verified our implementation of μP, we can scale up our model and train using the same hyperparameters used for the small model and expect that the wider model performs better on the training data and that the optimal hyperparameters transfer. | ||
``` | ||
python main.py --load_base_shapes width64.bsh | ||
``` | ||
|
||
Note that if you do not specify `--load_base_shapes`, the script will default to training a SP model. |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,265 @@ | ||
import time | ||
import pandas as pd | ||
import numpy as np | ||
import torch.nn.functional as F | ||
from torchvision import datasets, transforms | ||
import torch | ||
from torch import nn | ||
import torch.optim as optim | ||
import argparse | ||
import math | ||
|
||
from mup.coord_check import get_coord_data, plot_coord_data | ||
from mup import MuSGD, get_shapes, set_base_shapes, make_base_shapes, MuReadout | ||
|
||
def coord_check(mup, lr, train_loader, nsteps, nseeds, args, plotdir='', legend=False): | ||
|
||
def gen(w, standparam=False): | ||
def f(): | ||
model = MLP(width=w, nonlin=torch.tanh, output_mult=args.output_mult, input_mult=args.input_mult).to(device) | ||
if standparam: | ||
set_base_shapes(model, None) | ||
else: | ||
assert args.load_base_shapes, 'load_base_shapes needs to be nonempty' | ||
set_base_shapes(model, args.load_base_shapes) | ||
return model | ||
return f | ||
|
||
widths = 2**np.arange(7, 14) | ||
models = {w: gen(w, standparam=not mup) for w in widths} | ||
|
||
df = get_coord_data(models, train_loader, mup=mup, lr=lr, optimizer='sgd', flatten_input=True, nseeds=nseeds, nsteps=nsteps, lossfn='nll') | ||
|
||
prm = 'μP' if mup else 'SP' | ||
return plot_coord_data(df, legend=legend, | ||
save_to=os.path.join(plotdir, f'{prm.lower()}_mlp_sgd_coord.png'), | ||
suptitle=f'{prm} MLP SGD lr={lr} nseeds={nseeds}', | ||
face_color='xkcd:light grey' if not mup else None) | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description=''' | ||
PyTorch MLP on CIFAR-10, with μP. | ||
This is the script we use in the MLP experiment in our paper. | ||
To train a μP model, one needs to first specify the base shapes. To save base shapes info, run, for example, | ||
python main.py --save_base_shapes width64.bsh | ||
To train using MuSGD, run | ||
python main.py --load_base_shapes width64.bsh | ||
To perform coord check, run | ||
python main.py --load_base_shapes width64.bsh --coord_check | ||
If you don't specify a base shape file, then you are using standard parametrization | ||
python main.py | ||
We provide below some optimal hyperparameters for different activation/loss function combos: | ||
if nonlin == torch.relu and criterion == F.cross_entropy: | ||
args.input_mult = 0.00390625 | ||
args.output_mult = 32 | ||
elif nonlin == torch.tanh and criterion == F.cross_entropy: | ||
args.input_mult = 0.125 | ||
args.output_mult = 32 | ||
elif nonlin == torch.relu and criterion == MSE_label: | ||
args.input_mult = 0.03125 | ||
args.output_mult = 32 | ||
elif nonlin == torch.tanh and criterion == MSE_label: | ||
args.input_mult = 8 | ||
args.output_mult = 0.125 | ||
''', formatter_class=argparse.RawTextHelpFormatter) | ||
parser.add_argument('--save_base_shapes', type=str, default='', | ||
help='file location to save base shapes at') | ||
parser.add_argument('--load_base_shapes', type=str, default='', | ||
help='file location to load base shapes from') | ||
parser.add_argument('--seed', type=int, default=1) | ||
parser.add_argument('--batch_size', type=int, default=64) | ||
parser.add_argument('--epochs', type=int, default=20) | ||
parser.add_argument('--momentum', type=float, default=0.9) | ||
parser.add_argument('--lr', type=float, default=0.1) | ||
parser.add_argument('--output_mult', type=float, default=1.0) | ||
parser.add_argument('--input_mult', type=float, default=1.0) | ||
parser.add_argument('--init_std', type=float, default=1.0) | ||
parser.add_argument('--no_shuffle', action='store_true') | ||
parser.add_argument('--log_interval', type=int, default=300) | ||
parser.add_argument('--log_dir', type=str, default='.') | ||
parser.add_argument('--data_dir', type=str, default='/tmp') | ||
parser.add_argument('--coord_check', action='store_true', | ||
help='test μ parametrization is correctly implemented by collecting statistics on coordinate distributions for a few steps of training.') | ||
parser.add_argument('--coord_check_nsteps', type=int, default=3, | ||
help='Do coord check with this many steps.') | ||
parser.add_argument('--coord_check_nseeds', type=int, default=5, | ||
help='number of seeds for testing correctness of μ parametrization') | ||
|
||
args = parser.parse_args() | ||
|
||
torch.manual_seed(args.seed) | ||
|
||
device = torch.device("cuda") | ||
|
||
kwargs = {'num_workers': 1, 'pin_memory': True} | ||
|
||
transform = transforms.Compose( | ||
[transforms.ToTensor(), | ||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) | ||
|
||
trainset = datasets.CIFAR10(root=args.data_dir, train=True, | ||
download=True, transform=transform) | ||
train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, | ||
shuffle=not args.no_shuffle, num_workers=2) | ||
|
||
testset = datasets.CIFAR10(root=args.data_dir, train=False, | ||
download=True, transform=transform) | ||
test_loader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, | ||
shuffle=False, num_workers=2) | ||
|
||
classes = ('plane', 'car', 'bird', 'cat', | ||
'deer', 'dog', 'frog', 'horse', 'ship', 'truck') | ||
|
||
|
||
class MLP(nn.Module): | ||
def __init__(self, width=128, num_classes=10, nonlin=F.relu, output_mult=1.0, input_mult=1.0): | ||
super(MLP, self).__init__() | ||
self.nonlin = nonlin | ||
self.input_mult = input_mult | ||
self.output_mult = output_mult | ||
self.fc_1 = nn.Linear(3072, width, bias=False) | ||
self.fc_2 = nn.Linear(width, width, bias=False) | ||
self.fc_3 = MuReadout(width, num_classes, bias=False, output_mult=args.output_mult) | ||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
nn.init.kaiming_normal_(self.fc_1.weight, a=1, mode='fan_in') | ||
self.fc_1.weight.data /= self.input_mult**0.5 | ||
self.fc_1.weight.data *= args.init_std | ||
nn.init.kaiming_normal_(self.fc_2.weight, a=1, mode='fan_in') | ||
self.fc_2.weight.data *= args.init_std | ||
nn.init.zeros_(self.fc_3.weight) | ||
|
||
def forward(self, x): | ||
out = self.nonlin(self.fc_1(x) * self.input_mult**0.5) | ||
out = self.nonlin(self.fc_2(out)) | ||
return self.fc_3(out) | ||
|
||
|
||
def train(args, model, device, train_loader, optimizer, epoch, | ||
scheduler=None, criterion=F.cross_entropy): | ||
model.train() | ||
train_loss = 0 | ||
correct = 0 | ||
start_time = time.time() | ||
for batch_idx, (data, target) in enumerate(train_loader): | ||
data, target = data.to(device), target.to(device) | ||
optimizer.zero_grad() | ||
output = model(data.view(data.size(0), -1)) | ||
|
||
loss = criterion(output, target) | ||
loss.backward() | ||
train_loss += loss.item() * data.shape[0] # sum up batch loss | ||
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability | ||
correct += pred.eq(target.view_as(pred)).sum().item() | ||
optimizer.step() | ||
if batch_idx % args.log_interval == 0: | ||
elapsed = time.time() - start_time | ||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} | ms/batch {:5.2f}'.format( | ||
epoch, batch_idx * len(data), len(train_loader.dataset), | ||
100. * batch_idx / len(train_loader), loss.item(), | ||
elapsed * 1000 / args.log_interval)) | ||
start_time = time.time() | ||
if scheduler is not None: | ||
scheduler.step() | ||
train_loss /= len(train_loader.dataset) | ||
train_acc = correct / len(train_loader.dataset) | ||
print('\nTrain set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( | ||
train_loss, correct, len(train_loader.dataset), | ||
100. * correct / len(train_loader.dataset))) | ||
return train_loss, train_acc | ||
|
||
def test(args, model, device, test_loader, | ||
evalmode=True, criterion=F.cross_entropy): | ||
if evalmode: | ||
model.eval() | ||
test_loss = 0 | ||
correct = 0 | ||
with torch.no_grad(): | ||
for data, target in test_loader: | ||
data, target = data.to(device), target.to(device) | ||
output = model(data.view(data.size(0), -1)) | ||
test_loss += criterion(output, target, reduction='sum').item() # sum up batch loss | ||
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability | ||
correct += pred.eq(target.view_as(pred)).sum().item() | ||
|
||
test_loss /= len(test_loader.dataset) | ||
|
||
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( | ||
test_loss, correct, len(test_loader.dataset), | ||
100. * correct / len(test_loader.dataset))) | ||
return test_loss, correct / len(test_loader.dataset) | ||
|
||
|
||
def MSE_label(output, target): | ||
y_onehot = output.new_zeros(output.size(0), 10) | ||
y_onehot.scatter_(1, target.unsqueeze(-1), 1) | ||
y_onehot -= 1/10 | ||
return F.mse_loss(output, y_onehot) | ||
|
||
if args.coord_check: | ||
print('testing parametrization') | ||
import os | ||
os.makedirs('coord_checks', exist_ok=True) | ||
plotdir = 'coord_checks' | ||
coord_check(mup=True, lr=args.lr, train_loader=train_loader, nsteps=args.coord_check_nsteps, nseeds=args.coord_check_nseeds, args=args, plotdir=plotdir, legend=False) | ||
coord_check(mup=False, lr=args.lr, train_loader=train_loader, nsteps=args.coord_check_nsteps, nseeds=args.coord_check_nseeds, args=args, plotdir=plotdir, legend=False) | ||
import sys; sys.exit() | ||
|
||
logs = [] | ||
for nonlin in [torch.relu, torch.tanh]: | ||
for criterion in [F.cross_entropy, MSE_label]: | ||
|
||
for width in [64, 128, 256, 512, 1024, 2048, 4096, 8192]: | ||
# print(f'{nonlin.__name__}_{criterion.__name__}_{str(width)}') | ||
mynet = MLP(width=width, nonlin=nonlin, output_mult=args.output_mult, input_mult=args.input_mult).to(device) | ||
if args.save_base_shapes: | ||
print(f'saving base shapes at {args.save_base_shapes}') | ||
base_shapes = get_shapes(mynet) | ||
delta_shapes = get_shapes( | ||
# just need to change whatever dimension(s) we are scaling | ||
MLP(width=width+1, nonlin=nonlin, output_mult=args.output_mult, input_mult=args.input_mult) | ||
) | ||
make_base_shapes(base_shapes, delta_shapes, savefile=args.save_base_shapes) | ||
print('done and exit') | ||
import sys; sys.exit() | ||
if args.load_base_shapes: | ||
print(f'loading base shapes from {args.load_base_shapes}') | ||
set_base_shapes(mynet, args.load_base_shapes) | ||
print('done') | ||
else: | ||
print(f'using own shapes') | ||
set_base_shapes(mynet, None) | ||
print('done') | ||
optimizer = MuSGD(mynet.parameters(), lr=args.lr, momentum=args.momentum) | ||
for epoch in range(1, args.epochs+1): | ||
train_loss, train_acc, = train(args, mynet, device, train_loader, optimizer, epoch, criterion=criterion) | ||
test_loss, test_acc = test(args, mynet, device, test_loader) | ||
logs.append(dict( | ||
epoch=epoch, | ||
train_loss=train_loss, | ||
train_acc=train_acc, | ||
test_loss=test_loss, | ||
test_acc=test_acc, | ||
width=width, | ||
nonlin=nonlin.__name__, | ||
criterion='xent' if criterion.__name__=='cross_entropy' else 'mse', | ||
)) | ||
if math.isnan(train_loss): | ||
break | ||
|
||
with open(os.path.join(os.path.expanduser(args.log_dir), 'logs.tsv'), 'w') as f: | ||
logdf = pd.DataFrame(logs) | ||
print(os.path.join(os.path.expanduser(args.log_dir), 'logs.tsv')) | ||
f.write(logdf.to_csv(sep='\t', float_format='%.4f')) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# This is a base shape file encoded in yaml | ||
# - `null` indicates a dimension is "finite", i.e. a non-"width" dimension | ||
# - a number indicates the base dimension of an "infinite" dimension, i.e. some notion of "width" | ||
fc_1.weight: | ||
- 64 | ||
- null | ||
fc_2.weight: | ||
- 64 | ||
- 64 | ||
fc_3.weight: | ||
- null | ||
- 64 |
Oops, something went wrong.