Skip to content

devzhk/cgds-package

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

26 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CGDs

Overview

CGDs is a package implementing optimization algorithms including three variants of CGD in Pytorch with Hessian vector product and conjugate gradient.
CGDs is for competitive optimization problem such as generative adversarial networks (GANs) as follows: $$ \min_{\mathbf{x}}f(\mathbf{x}, \mathbf{y}) \min_{\mathbf{y}} g(\mathbf{x}, \mathbf{y}) $$

Installation

pip3 install CGDs

You can also directly download the CGDs directory and copy it to your project.

Package description

The CGDs package implements the following optimization algorithms with Pytorch:

How to use

Quickstart with notebook: Examples of using ACGD.

Similar to Pytorch package torch.optim, using optimizers in CGDs has two main steps: construction and update steps.

Construction

To construct an optimizer, you have to give it two iterables containing the parameters (all should be Variables). Then you need to specify the device, learning rates.

Example:

from src import CGDs
import torch
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
optimizer = CGDs.ACGD(max_param=model_G.parameters(), min_params=model_D.parameters(), 
                      lr_max=1e-3, lr_min=1e-3, device=device)
optimizer = CGDs.BCGD(max_params=[var1, var2], min_params=[var3, var4, var5], 
                      lr_max=0.01, lr_min=0.01, device=device)   

Update step

Both two optimizers have step() method, which updates the parameters according to their update rules. The function can be called once the computation graph is created. You have to pass in the loss but do not have to compute gradients before step() , which is different from torch.optim.

Example:

for data in dataset:
    optimizer.zero_grad()
    real_pred = model_D(data)
    latent = torch.randn((batch_size, latent_dim), device=device)
    fake_pred = D(G(latent))
    loss = loss_fn(real_output, fake_output)
    optimizer.step(loss=loss)

For general competitive optimization, two losses should be defined and passed to optimizer.step

loss_x = loss_f(x, y)
loss_y = loss_g(x, y)
optimizer.step(loss_x, loss_y)

Use with Pytorch DistributedDataParallel

For example,

G = DDP(G, device_ids=[rank], broadcast_buffers=False)
D = DDP(D, device_ids=[rank], broadcast_buffers=False)
g_reducer = G.reducer
d_reducer = D.reducer

optimizer = ACGD(max_params=G.parameters(), min_params=D.parameters(), 
                 max_reducer=g_reducer, min_reducer=d_reducer, 
                 lr_max=1e-3, lr_min=1e-3, 
                 tol=1e-4, atol=1e-8)
for data in dataloader:
    real_pred = D(data)
    latent = torch.randn((batchsize, latent_dim))
    fake_img = G(latent)
    fake_pred = D(fake_img)
    # trigger is used to trigger the comm
    trigger = real_pred[0, 0] + fake_img[0, 0, 0, 0]
    loss = loss_fn(real_pred, fake_pred)
    optimizer.step(loss, trigger=trigger.mean())

Citation

Please cite it if you find this code useful.

@misc{cgds-package,
  author = {Hongkai Zheng},
  title = {CGDs},
  year = {2020},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/devzhk/cgds-package}},
}