-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* initial commit * add default O1 mode, enable other modes, add README * add carilli's review suggestions to README
- Loading branch information
Showing
2 changed files
with
315 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,41 @@ | ||
Under construction... | ||
# Mixed Precision DCGAN Training in PyTorch | ||
|
||
`main_amp.py` is based on [https://github.com/pytorch/examples/tree/master/dcgan](https://github.com/pytorch/examples/tree/master/dcgan). | ||
It implements Automatic Mixed Precision (Amp) training of the DCGAN example for different datasets. Command-line flags forwarded to `amp.initialize` are used to easily manipulate and switch between various pure and mixed precision "optimization levels" or `opt_level`s. For a detailed explanation of `opt_level`s, see the [updated API guide](https://nvidia.github.io/apex/amp.html). | ||
|
||
We introduce these changes to the PyTorch DCGAN example as described in the [Multiple models/optimizers/losses](https://nvidia.github.io/apex/advanced.html#multiple-models-optimizers-losses) section of the documentation:: | ||
``` | ||
# Added after models and optimizers construction | ||
[netD, netG], [optimizerD, optimizerG] = amp.initialize( | ||
[netD, netG], [optimizerD, optimizerG], opt_level=opt.opt_level, num_losses=3) | ||
... | ||
# loss.backward() changed to: | ||
with amp.scale_loss(errD_real, optimizerD, loss_id=0) as errD_real_scaled: | ||
errD_real_scaled.backward() | ||
... | ||
with amp.scale_loss(errD_fake, optimizerD, loss_id=1) as errD_fake_scaled: | ||
errD_fake_scaled.backward() | ||
... | ||
with amp.scale_loss(errG, optimizerG, loss_id=2) as errG_scaled: | ||
errG_scaled.backward() | ||
``` | ||
|
||
Note that we use different `loss_scalers` for each computed loss. | ||
Using a separate loss scaler per loss is [optional, not required](https://nvidia.github.io/apex/advanced.html#optionally-have-amp-use-a-different-loss-scaler-per-loss). | ||
|
||
To improve the numerical stability, we swapped `nn.Sigmoid() + nn.BCELoss()` to `nn.BCEWithLogitsLoss()`. | ||
|
||
With the new Amp API **you never need to explicitly convert your model, or the input data, to half().** | ||
|
||
"Pure FP32" training: | ||
``` | ||
$ python main_amp.py --opt-level O0 | ||
``` | ||
Recommended mixed precision training: | ||
``` | ||
$ python main_amp.py --opt-level O1 | ||
``` | ||
|
||
Have a look at the original [DCGAN example](https://github.com/pytorch/examples/tree/master/dcgan) for more information about the used arguments. | ||
|
||
To enable mixed precision training, we introduce the `--opt-level` argument. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,274 @@ | ||
from __future__ import print_function | ||
import argparse | ||
import os | ||
import random | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.parallel | ||
import torch.backends.cudnn as cudnn | ||
import torch.optim as optim | ||
import torch.utils.data | ||
import torchvision.datasets as dset | ||
import torchvision.transforms as transforms | ||
import torchvision.utils as vutils | ||
|
||
try: | ||
from apex import amp | ||
except ImportError: | ||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") | ||
|
||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--dataset', default='cifar10', help='cifar10 | lsun | mnist |imagenet | folder | lfw | fake') | ||
parser.add_argument('--dataroot', default='./', help='path to dataset') | ||
parser.add_argument('--workers', type=int, help='number of data loading workers', default=2) | ||
parser.add_argument('--batchSize', type=int, default=64, help='input batch size') | ||
parser.add_argument('--imageSize', type=int, default=64, help='the height / width of the input image to network') | ||
parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector') | ||
parser.add_argument('--ngf', type=int, default=64) | ||
parser.add_argument('--ndf', type=int, default=64) | ||
parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for') | ||
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002') | ||
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') | ||
parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use') | ||
parser.add_argument('--netG', default='', help="path to netG (to continue training)") | ||
parser.add_argument('--netD', default='', help="path to netD (to continue training)") | ||
parser.add_argument('--outf', default='.', help='folder to output images and model checkpoints') | ||
parser.add_argument('--manualSeed', type=int, help='manual seed') | ||
parser.add_argument('--classes', default='bedroom', help='comma separated list of classes for the lsun data set') | ||
parser.add_argument('--opt_level', default='O1', help='amp opt_level, default="O1"') | ||
|
||
opt = parser.parse_args() | ||
print(opt) | ||
|
||
|
||
try: | ||
os.makedirs(opt.outf) | ||
except OSError: | ||
pass | ||
|
||
if opt.manualSeed is None: | ||
opt.manualSeed = 2809 | ||
print("Random Seed: ", opt.manualSeed) | ||
random.seed(opt.manualSeed) | ||
torch.manual_seed(opt.manualSeed) | ||
|
||
cudnn.benchmark = True | ||
|
||
|
||
if opt.dataset in ['imagenet', 'folder', 'lfw']: | ||
# folder dataset | ||
dataset = dset.ImageFolder(root=opt.dataroot, | ||
transform=transforms.Compose([ | ||
transforms.Resize(opt.imageSize), | ||
transforms.CenterCrop(opt.imageSize), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | ||
])) | ||
nc=3 | ||
elif opt.dataset == 'lsun': | ||
classes = [ c + '_train' for c in opt.classes.split(',')] | ||
dataset = dset.LSUN(root=opt.dataroot, classes=classes, | ||
transform=transforms.Compose([ | ||
transforms.Resize(opt.imageSize), | ||
transforms.CenterCrop(opt.imageSize), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | ||
])) | ||
nc=3 | ||
elif opt.dataset == 'cifar10': | ||
dataset = dset.CIFAR10(root=opt.dataroot, download=True, | ||
transform=transforms.Compose([ | ||
transforms.Resize(opt.imageSize), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | ||
])) | ||
nc=3 | ||
|
||
elif opt.dataset == 'mnist': | ||
dataset = dset.MNIST(root=opt.dataroot, download=True, | ||
transform=transforms.Compose([ | ||
transforms.Resize(opt.imageSize), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5,), (0.5,)), | ||
])) | ||
nc=1 | ||
|
||
elif opt.dataset == 'fake': | ||
dataset = dset.FakeData(image_size=(3, opt.imageSize, opt.imageSize), | ||
transform=transforms.ToTensor()) | ||
nc=3 | ||
|
||
assert dataset | ||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, | ||
shuffle=True, num_workers=int(opt.workers)) | ||
|
||
device = torch.device("cuda:0") | ||
ngpu = int(opt.ngpu) | ||
nz = int(opt.nz) | ||
ngf = int(opt.ngf) | ||
ndf = int(opt.ndf) | ||
|
||
|
||
# custom weights initialization called on netG and netD | ||
def weights_init(m): | ||
classname = m.__class__.__name__ | ||
if classname.find('Conv') != -1: | ||
m.weight.data.normal_(0.0, 0.02) | ||
elif classname.find('BatchNorm') != -1: | ||
m.weight.data.normal_(1.0, 0.02) | ||
m.bias.data.fill_(0) | ||
|
||
|
||
class Generator(nn.Module): | ||
def __init__(self, ngpu): | ||
super(Generator, self).__init__() | ||
self.ngpu = ngpu | ||
self.main = nn.Sequential( | ||
# input is Z, going into a convolution | ||
nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False), | ||
nn.BatchNorm2d(ngf * 8), | ||
nn.ReLU(True), | ||
# state size. (ngf*8) x 4 x 4 | ||
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), | ||
nn.BatchNorm2d(ngf * 4), | ||
nn.ReLU(True), | ||
# state size. (ngf*4) x 8 x 8 | ||
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), | ||
nn.BatchNorm2d(ngf * 2), | ||
nn.ReLU(True), | ||
# state size. (ngf*2) x 16 x 16 | ||
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), | ||
nn.BatchNorm2d(ngf), | ||
nn.ReLU(True), | ||
# state size. (ngf) x 32 x 32 | ||
nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False), | ||
nn.Tanh() | ||
# state size. (nc) x 64 x 64 | ||
) | ||
|
||
def forward(self, input): | ||
if input.is_cuda and self.ngpu > 1: | ||
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) | ||
else: | ||
output = self.main(input) | ||
return output | ||
|
||
|
||
netG = Generator(ngpu).to(device) | ||
netG.apply(weights_init) | ||
if opt.netG != '': | ||
netG.load_state_dict(torch.load(opt.netG)) | ||
print(netG) | ||
|
||
|
||
class Discriminator(nn.Module): | ||
def __init__(self, ngpu): | ||
super(Discriminator, self).__init__() | ||
self.ngpu = ngpu | ||
self.main = nn.Sequential( | ||
# input is (nc) x 64 x 64 | ||
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), | ||
nn.LeakyReLU(0.2, inplace=True), | ||
# state size. (ndf) x 32 x 32 | ||
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), | ||
nn.BatchNorm2d(ndf * 2), | ||
nn.LeakyReLU(0.2, inplace=True), | ||
# state size. (ndf*2) x 16 x 16 | ||
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), | ||
nn.BatchNorm2d(ndf * 4), | ||
nn.LeakyReLU(0.2, inplace=True), | ||
# state size. (ndf*4) x 8 x 8 | ||
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), | ||
nn.BatchNorm2d(ndf * 8), | ||
nn.LeakyReLU(0.2, inplace=True), | ||
# state size. (ndf*8) x 4 x 4 | ||
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), | ||
) | ||
|
||
def forward(self, input): | ||
if input.is_cuda and self.ngpu > 1: | ||
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) | ||
else: | ||
output = self.main(input) | ||
|
||
return output.view(-1, 1).squeeze(1) | ||
|
||
|
||
netD = Discriminator(ngpu).to(device) | ||
netD.apply(weights_init) | ||
if opt.netD != '': | ||
netD.load_state_dict(torch.load(opt.netD)) | ||
print(netD) | ||
|
||
criterion = nn.BCEWithLogitsLoss() | ||
|
||
fixed_noise = torch.randn(opt.batchSize, nz, 1, 1, device=device) | ||
real_label = 1 | ||
fake_label = 0 | ||
|
||
# setup optimizer | ||
optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) | ||
optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) | ||
|
||
[netD, netG], [optimizerD, optimizerG] = amp.initialize( | ||
[netD, netG], [optimizerD, optimizerG], opt_level=opt.opt_level, num_losses=3) | ||
|
||
for epoch in range(opt.niter): | ||
for i, data in enumerate(dataloader, 0): | ||
############################ | ||
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) | ||
########################### | ||
# train with real | ||
netD.zero_grad() | ||
real_cpu = data[0].to(device) | ||
batch_size = real_cpu.size(0) | ||
label = torch.full((batch_size,), real_label, device=device) | ||
|
||
output = netD(real_cpu) | ||
errD_real = criterion(output, label) | ||
with amp.scale_loss(errD_real, optimizerD, loss_id=0) as errD_real_scaled: | ||
errD_real_scaled.backward() | ||
D_x = output.mean().item() | ||
|
||
# train with fake | ||
noise = torch.randn(batch_size, nz, 1, 1, device=device) | ||
fake = netG(noise) | ||
label.fill_(fake_label) | ||
output = netD(fake.detach()) | ||
errD_fake = criterion(output, label) | ||
with amp.scale_loss(errD_fake, optimizerD, loss_id=1) as errD_fake_scaled: | ||
errD_fake_scaled.backward() | ||
D_G_z1 = output.mean().item() | ||
errD = errD_real + errD_fake | ||
optimizerD.step() | ||
|
||
############################ | ||
# (2) Update G network: maximize log(D(G(z))) | ||
########################### | ||
netG.zero_grad() | ||
label.fill_(real_label) # fake labels are real for generator cost | ||
output = netD(fake) | ||
errG = criterion(output, label) | ||
with amp.scale_loss(errG, optimizerG, loss_id=2) as errG_scaled: | ||
errG_scaled.backward() | ||
D_G_z2 = output.mean().item() | ||
optimizerG.step() | ||
|
||
print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' | ||
% (epoch, opt.niter, i, len(dataloader), | ||
errD.item(), errG.item(), D_x, D_G_z1, D_G_z2)) | ||
if i % 100 == 0: | ||
vutils.save_image(real_cpu, | ||
'%s/real_samples.png' % opt.outf, | ||
normalize=True) | ||
fake = netG(fixed_noise) | ||
vutils.save_image(fake.detach(), | ||
'%s/amp_fake_samples_epoch_%03d.png' % (opt.outf, epoch), | ||
normalize=True) | ||
|
||
# do checkpointing | ||
torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.outf, epoch)) | ||
torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.outf, epoch)) | ||
|
||
|