Skip to content

2:4 sparsity, fused sequence parallel, torch compile & more

Compare
Choose a tag to compare
@danthe3rd danthe3rd released this 31 Jan 08:42
· 223 commits to main since this release

Pre-built binary wheels require PyTorch 2.2.0

Added

  • Added components for model/sequence parallelism, as near-drop-in replacements for FairScale/Megatron Column&RowParallelLinear modules. They support fusing communication and computation for sequence parallelism, thus making the communication effectively free.
  • Added kernels for training models with 2:4-sparsity. We introduced a very fast kernel for converting a matrix A into 24-sparse format, which can be used during training to sparsify weights dynamically, activations etc... xFormers also provides an API that is compatible with torch-compile, see xformers.ops.sparsify24.

Improved

  • Make selective activation checkpointing be compatible with torch.compile.

Removed

  • Triton kernels now require a GPU with compute capability 8.0 at least (A100 or newer). This is due to newer versions of triton not supporting older GPUs correctly
  • Removed support for PyTorch version older than 2.1.0