This repository contains the official implementation of the Geometric Algebra Transformer by Johann Brehmer, Pim de Haan, Sönke Behrends, and Taco Cohen, published at NeurIPS 2023.
Problems involving geometric data arise in physics, chemistry, robotics, computer vision, and many other fields. Such data can take numerous forms, such as points, direction vectors, translations, or rotations, but to date there is no single architecture that can be applied to such a wide variety of geometric types while respecting their symmetries. In this paper we introduce the Geometric Algebra Transformer (GATr), a general-purpose architecture for geometric data. GATr represents inputs, outputs, and hidden states in the projective geometric (or Clifford) algebra, which offers an efficient 16-dimensional vector-space representation of common geometric objects as well as operators acting on them. GATr is equivariant with respect to E(3), the symmetry group of 3D Euclidean space. As a Transformer, GATr is versatile, efficient, and scalable. We demonstrate GATr in problems from n-body modeling to wall-shear-stress estimation on large arterial meshes to robotic motion planning. GATr consistently outperforms both non-geometric and equivariant baselines in terms of error, data efficiency, and scalability.
Clone the repository.
git clone https://github.com/Qualcomm-AI-research/geometric-algebra-transformer
Build the Docker image from inside the checked-out repository. On Linux, this can be done using
cd geometric-algebra-transformer
docker build -f docker/Dockerfile --tag gatr:latest .
The commands for Windows are similar.
Once the image has built successfully, we can run a container that is based on it.
On Linux, the command that does this and also
- mounts the current working directory into the container
- changes to the current working directory
- exposes all of the host's GPUs in the container
is
docker run --rm -it -v $PWD:$PWD -w $PWD --gpus=all gatr:latest /bin/bash
In this repository, we include two experiments: one on a synthetic n-body modeling problem and one on predicting the wall shear stress on artery walls.
Running experiments will require the specification of directories for datasets, checkpoints, and logs. We will define a common base directory (feel free to adapt):
BASEDIR=/tmp/gatr-experiments
First, we need to generate training and evaluation datasets:
python scripts/generate_nbody_dataset.py base_dir="${BASEDIR}" seed=42
Let's train a GATr model, using just 1000 (or 1%) training trajectories and training for 5000 steps:
python scripts/nbody_experiment.py base_dir="${BASEDIR}" seed=42 model=gatr_nbody data.subsample=0.01 training.steps=5000 run_name=gatr
At the end of the training, evaluation metrics are printed to stdout
and written to $BASEDIR/experiments/nbody/gatr/metrics
. Training and evaluation are also
tracked with MLflow to an SQLite database at $BASEDIR/tracking/mlflow.db
.
MLFlow is configured in the mlflow
section of config/nbody.yaml
.
We can repeat the same for different baselines, for instance:
for METHOD in mlp transformer gcagnn se3transformer segnn
do
python scripts/nbody_experiment.py base_dir="${BASEDIR}" seed=42 model=${METHOD}_nbody data.subsample=0.01 training.steps=5000 run_name=$METHOD
done
On the eval
set (without domain shift), you should get results like these:
Method | MSE
----------------+--------
GATr | 0.0005
----------------+--------
MLP | 0.0776
Transformer | 0.0050
GCA-GNN | 0.0526
SE3-Transformer | 0.0025
SEGNN | 0.0006
Note that the code is not fully deterministic, so you will not get these exact numbers. (After averaging over multiple seeds, such fluctuations should be negligible).
You can also vary the number of training samples seen (data.subsample
), the number of optimizer
steps (training.steps
), the random seed (seed
), and many other settings. Have a look at
the n-body config and the GATr model config.
If you train for long enough and average results from multiple random seeds, you should get numbers that are comparable to Fig. 2 in our paper. However, do not expect exact reproducibility: the paper experiments were performed with an earlier version of GATr than the one published in this repository. In addition, there is stochasticity both in the dataset generation and in the initialization and training of the networks.
The arterial wall-shear-stress experiment uses a dataset created by Julian Suk et al.
First, you need to download that dataset following
the instructions in their repo.
We only need the single (not bifurcating) dataset. Store this dataset at some location
$ARTERYDIR/raw/database.hdf5
and set the environment variable $ARTERYDIR
correspondingly
(as we will use it in the commands below).
To train a GATr model on a non-canonicalized version of this dataset for 20k steps, run:
python scripts/artery_experiment.py base_dir="${BASEDIR}" data.data_dir="${ARTERYDIR}" seed=42 training.lr_decay=0.1 training.steps=50000 model=gatr_arteries run_name=gatr
To train a Transformer baseline, run:
python scripts/artery_experiment.py base_dir="${BASEDIR}" data.data_dir="${ARTERYDIR}" seed=42 training.lr_decay=0.1 training.steps=50000 model=transformer_arteries run_name=transformer_arteries
Eventually you should find results like these:
Method | Approximation error
-------------+---------------------
GATr | 0.089
Transformer | 0.105
If you run for long enough, you should get results comparable to Fig. 3 and Tbl. 2 in our paper. Similar to the nbody experiment, results are also tracked with MLFlow. Again, there is no guarantee for exact reproducibility due to code changes and stochasticity.
To use GATr on your own problem, you will at least need two components from this repository: GATr networks, which act on multivector data, and interface functions that embed various geometric objects into this multivector representations.
Here is an example code snippet that illustrates the recipe:
from gatr import GATr, SelfAttentionConfig, MLPConfig
from gatr.interface import embed_point, extract_scalar
import torch
class ExampleWrapper(torch.nn.Module):
"""Example wrapper around a GATr model.
Expects input data that consists of a point cloud: one 3D point for each item in the data.
Returns outputs that consists of one scalar number for the whole dataset.
Parameters
----------
blocks : int
Number of transformer blocks
hidden_mv_channels : int
Number of hidden multivector channels
hidden_s_channels : int
Number of hidden scalar channels
"""
def __init__(self, blocks=20, hidden_mv_channels=16, hidden_s_channels=32):
super().__init__()
self.gatr = GATr(
in_mv_channels=1,
out_mv_channels=1,
hidden_mv_channels=hidden_mv_channels,
in_s_channels=None,
out_s_channels=None,
hidden_s_channels=hidden_s_channels,
num_blocks=blocks,
attention=SelfAttentionConfig(), # Use default parameters for attention
mlp=MLPConfig(), # Use default parameters for MLP
)
def forward(self, inputs):
"""Forward pass.
Parameters
----------
inputs : torch.Tensor with shape (*batch_dimensions, num_points, 3)
Point cloud input data
Returns
-------
outputs : torch.Tensor with shape (*batch_dimensions, 1)
Model prediction: a single scalar for the whole point cloud.
"""
# Embed point cloud in PGA
embedded_inputs = embed_point(inputs).unsqueeze(-2) # (..., num_points, 1, 16)
# Pass data through GATr
embedded_outputs, _ = self.gatr(embedded_inputs, scalars=None) # (..., num_points, 1, 16)
# Extract scalar and aggregate outputs from point cloud
nodewise_outputs = extract_scalar(embedded_outputs) # (..., num_points, 1, 1)
outputs = torch.mean(nodewise_outputs, dim=(-3, -2)) # (..., 1)
return outputs
In the following, we will go into more detail on the conventions used in this code base and the structure of the repository.
Representations: GATr operates with two kind of representations: geometric algebra multivectors
and auxiliary scalar representations. Both are simply represented as torch.Tensor
instances.
The multivectors are based on the projective geometric algebra G(3, 0, 1). They are tensors of the
shape (..., 16)
, for instance (batchsize, items, channels, 16)
. The sixteen multivector
components are sorted as in the
clifford
library, as follows:
[x_scalars, x_0, x_1, x_2, x_3, x_01, x_02, x_03, x_12, x_13, x_23, x_012, x_013, x_023, x_123, x_0123]
.
Scalar representations have free shapes, but should match the multivector representations they accompany in batchsize and number of items. The number of channels may be different.
Interface to the real world: To map the multivector representations to physical objects, we
use the plane-based convention presented in
Roelfs and De Keninck, "Graded symmetry groups: Plane and simple".
3D points are thus represented as trivectors, planes as vectors, and so on. We provide these
these interface functions in the gatr.interface
submodule.
Functions: We distinguish between primitives (functions) and layers (often stateful
torch.nn.Module
instances). Almost all primitives and layers are Pin(3, 0, 1)-equivariant,
see docstrings for exceptions.
geometric-algebra-transformer
|
└───config: configuration YAML files for the experiments
| └───model: model configurations
| | arteries.yaml: default configuration for the artery experiment
| | hydra.yaml: hydra configuration (imported by the experiment configs)
| | nbody.yaml: default configuration for the n-body experiment
|
└───docker: Docker environment
| └───ext_packages: patches for external dependencies
| | Dockerfile: Docker container specification
| | requirements.txt: external dependencies installed in Dockerfile
|
└───gatr: core library
| └───baselines: baseline layers and architectures
| | | gcan.py: GCA-MLP and GCA-GNN baselines
| | | mlp.py: MLP baseline
| | | segnn.py: SEGNN baseline
| | | transformer.py: Transformer baseline
| |
| └───experiments: experiment managers
| | └───arteries: experiment manager, wrappers, etc for the artery experiment
| | └───nbody: experiment manager, wrappers, etc for the n-body experiment
| | | base_experiment.py: base class for experiment managers
| | | base_wrapper.py: base class for experiment-specific GATr wrappers
| |
| └───interface: embedding of geometric quantities into projective geometric algebra
| | | object.py: 3D objects consisting of positions and orientations
| | | plane.py: planes
| | | point.py: points
| | | pseudoscalar.py: pseudoscalars
| | | reflection.py: planar reflections
| | | rotation.py: rotations
| | | scalar.py: scalars
| | | translation.py: translations
| |
| └───layers: network layers
| | └───attention: self-attention layer, its components, and the corresponding configuration
| | └───mlp: geometric MLP, its components, and the corresponding configuration
| | | dropout.py: multivector dropout
| | | gatr_block.py: GATr transformer block, the main layer that GATr networks consist of
| | | layer_norm.py: geometric LayerNorm
| | | linear.py: equivariant linear layer between multivectors
| |
| └───nets: complete network architectures
| | | axial_gatr.py: axial-attention version of GATr for two token dimensions
| | | gatr.py: GATr architecture for a single token dimension
| |
| └───primitives: core functional building blocks of GATr
| | └───data: pre-computed basis stored in `.pt` files
| | | attention.py: geometric attention mechanism
| | | bilinear.py: bilinear equivariant functions like the geometric product
| | | dropout.py: geometric dropout
| | | dual.py: dual-related functions like the equivariant join
| | | invariants.py: invariant functions of multivectors like the norm
| | | linear.py: equivariant linear maps between multivectors
| | | nonlinearities.py: gated nonlinearities
| | | normalization.py: geometric normalization functions
| |
| └───utils: utility functions
| | clifford.py: non-differentiable GA functions based on the clifford library
| | dual_basis.py: helper functions to construct the equivariant join efficiently
| | einsum.py: optimized einsum function
| | logger.py: logger
| | misc.py: various utility functions
| | mlflow.py: logging with MLflow
| | plotting.py: plotting defaults
| | quaternions.py: quaternion operations
| | tensors.py: various tensor operations
|
└───scripts: entrypoints to run experiments
| | artery_experiment.py: artery experiment script
| | generate_nbody_dataset.py: data generation script for the n-body experiment
| | nbody_experiment.py: n-body experiment script
│
└───tests: unit tests (e.g. for self-consistency and Pin equivariance)
| └───gatr
| | └───baselines: unit tests for gatr.baselines
| | └───interface: unit tests for gatr.interface
| | └───layers: unit tests for gatr.layers
| | └───nets: unit tests for gatr.nets
| | └───primitives: unit tests for gatr.primitives
| | └───utils: unit tests for gatr.utils
| |
| └───helpers: utility functions for unit tests
| | constants.py: test settings (like tolerances)
| | equivariance.py: helper functions to test Pin equivariance
| | geometric_algebra.py: helper functions to test GA functionality
│
└───tests_regression: regression tests (training a GATr to predict the distance between points)
│
| .coveragerc: coverage configuration
| .dockerignore: Docker configuration
| .gitignore: git configuration
│ .pylintrc: pylint configuration
│ LICENSE: license under which this code may be used
│ pyproject.toml: project settings
│ README.md: this README file
│ setup.py: installation script
GATr v1.3 now supports compilation using torch.compile
. However, be aware that compilation is
experimental and may yield errors or warnings based on your hardware, model, data, and torch
version, and we will not be able to offer support for compilation errors.
If you choose to use compilation, start with
torch.compile(model, dynamic=True, fullgraph=False)
. Compilation fails with our caching mechanism
for torch.einsum
(failure observed for torch==2.2.1
), so the caching mechanism should be
disabled by calling gatr.utils.einsum.enable_cached_einsum(False)
.
The GATr paper and this repository use the projective geometric algebra. In the work Euclidean, Projective, Conformal: Choosing a Geometric Algebra for Equivariant Transformers, we explored using the Euclidean and conformal algebras instead. We found a nuanced trade-off between the different algebras for their use in GATr. At the moment, the code for the different algebras has not been released.
If you find our code useful, please cite:
@inproceedings{brehmer2023geometric,
title = {Geometric Algebra Transformer},
author = {Brehmer, Johann and de Haan, Pim and Behrends, S{\"o}nke and Cohen, Taco},
booktitle = {Advances in Neural Information Processing Systems},
year = {2023},
volume = {37},
eprint = {2305.18415},
url = {https://arxiv.org/abs/2305.18415},
}
@inproceedings{dehaan2023euclidean,
title = {Euclidean, Projective, Conformal: Choosing a Geometric Algebra for Equivariant Transformers},
author = {Pim de Haan and Taco Cohen and Johann Brehmer},
booktitle = {Proceedings of the 27th International Conference on Artificial Intelligence and Statistics},
year = {2024},
volume = {27},
eprint = {2311.04744},
url = {https://arxiv.org/abs/2311.04744},
}