Skip to content

Latest commit

 

History

History
88 lines (59 loc) · 2.92 KB

README.md

File metadata and controls

88 lines (59 loc) · 2.92 KB

Optix

Optix: Memory Efficient Training Framework for Large Video Generation Model

Update:

Features:

  • 4x the batchsize when training with high resolution images
  • average 1.2x the training throughput
  • Optix remains effective in DiT model training!

Results of training stable-diffusion models:

max batchsize acc ratio

Basline config:tf32, grad checkpointing

Tested on:A100 80GB; Pytorch2.1.2+cu118

Getting started

install optix

python setup.py develop

install dependency

Refer to requirements

API Usage

import optix


# optimize model(fusedops, ddp, etc), setup optimizer, and create a ema
model, vae, opt, ema = optix.compile(model, vae, learning_rate=1e-5, weight_decay=1e-5, use_ema=True)

# or do not create ema:
model, vae, opt, _ = optix.compile(model, vae, learning_rate=1e-5, weight_decay=1e-5)


# use `sliced_vae` to replace the original vae.encode codes:
# with torch.no_grad():
#     x = vae.encode(x)
#     if not args.use_video:
#         x = x.latent_dist.sample().mul_(vae.config.scaling_factor)
model_input = optix.sliced_vae(vae, model_input, use_autocast=True, nhwc=True)

Keyword arguments for optix.compile and default value:

{
    'use_ema': False,                   # create ema
    'compile_vae': True,                # [PERF] for torch>2.0, recommended to use torch.compile
    'ddp': True,                        # automatically create a ddp module over unet
    'dp_group': None,                   # ddp communication group, default is None
    'gradient_checkpointing': True,     # [PERF] grad_ckpt is ON by default; for small batchsize this can be turned off for speedup
    'xformer': True,                    # [PERF] use xformer can speedup a little bit
    'fusedln': True,                    # [PERF] use fusedln can speedup
    'compile_unet': False,              # [PERF] this function is not stable so OFF by default
    'vae_channels_last': True,          # [PERF] use channels_last format for vae
    'optim': 'adamw',                   # the optimizer type
    'learning_rate': 1e-5,              # optimizer params
    'weight_decay': 0,                  # optimizer params
    'hybrid_zero': True,                # [PERF] for multi node training, hybrid zero can be faster
}

These Keyword arguments can be directly passed to optix.compile like:

model, vae, opt, _ = optix.compile(model, vae, learning_rate=1e-5, weight_decay=1e-5,
                                   use_ema=False, compile_vae=False, optim='sgd',
                                   xformer=False)

Examples

Stable Diffusion: train_sd_unet.py

DiT:train_dit.py