This is a library of PyTorch layers for building permutation equivariant neural functional networks (NFNs). NFNs are equivariant deep learning architectures for processing weight space features, such as the weights or gradients of another neural network. We refer to the layers of an NFN as NF-Layers.
Simple installation from PyPI:
pip install nfn
If you want to edit source or run examples, clone the repository locally. Then run the following commands:
git clone https://github.com/AllanYangZhou/nfn
cd nfn
pip install -e . # installs in editable mode
We provide a very simple example of using this library to process the weights of a small CNN using an NFN at examples/basic_cnn
. The example only shows how to construct NFNs and feed data in, but does not train anything. To run:
cd examples/basic_cnn
python check_nfn_inv.py
This will check the permutation invariance of the NFN.
NF-Layers operate on WeightSpaceFeatures
. The current NF-Layers are compatible with the weight spaces of simple feedforward MLPs and 2D (image) CNNs. For weight spaces of CNN classifiers we assume there is some global pooling layer (e.g., nn.AdaptiveAvgPool2d(1)
) between the convolution and FC layers. Supporting 1D or 3D CNNs should be possible but is not currently implemented.
To construct WeightSpaceFeatures
from the weights of a Pytorch model, we provide the helper function state_dict_to_tensors
:
from nfn.common import state_dict_to_tensors
models = [...] # batch of pytorch models
state_dicts = [m.state_dict() for m in models]
wts_and_bs = [state_dict_to_tensors(sd) for sd in state_dicts]
# Collate batch. Can be done automatically by DataLoader.
wts_and_bs = default_collate(wts_and_bs)
wsfeat = WeightSpaceFeatures(*wts_and_bs)
out = nfn(wsfeat) # NFN can now ingest WeightSpaceFeatures
For now, state_dict_to_tensors
assumes that the state_dict
is an ordered dictionary with keys in order [weight1, bias1, ..., weightL, biasL]
. This is the default behavior if the state_dict
is coming from a feedforward network that is an nn.Sequential
model.
The NF-Layers are found in nfn.layers
. The main data you need to build an NFN is a network_spec
, which specifies the structure of the weight space you plan to process. If you already have a WeightSpaceFeature
object as above, you can use nfn.common.network_spec_from_wsfeat
.
from torch import nn
from nfn import layers
from nfn.common import network_spec_from_wsfeat
network_spec = network_spec_from_wsfeat(wsfeat)
nfn_channels = 32
# io_embed: encode the input and output dimensions of the weight space feature
nfn = nn.Sequential(
layers.NPLinear(network_spec, 1, nfn_channels, io_embed=True),
layers.TupleOp(nn.ReLU()),
layers.NPLinear(network_spec, nfn_channels, nfn_channels, io_embed=True),
layers.TupleOp(nn.ReLU()),
layers.HNPPool(network_spec), # pooling layer, for invariance
nn.Flatten(start_dim=-2),
nn.Linear(nfn_channels * layers.HNPPool.get_num_outs(network_spec), 1)
)