Skip to content

AllanYangZhou/universal_neural_functional

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 

Repository files navigation

Universal Neural Functionals

This is the code for constructing UNFs, from the paper Universal Neural Functionals. UNFs are architectures that can process the weights of other neural networks, while maintaining equivariance or invariance to the weight space permutation symmetries. In contrast to NFNs, UNFs can ingest weights from any architecture.

Equivalently, we can think of UNFs as equivariant architectures for processing any collection of tensors, where the action involves a shared set of permutations permuting the axes of the tensors in a given way.

The codebase requires JAX for core functionality and Flax for the example (though other Jax NN libraries are likely compatible as well). See usage in example.py.

High level usage

The perm_spec is what tells our library the permutation symmetries it should be equivariant to. For example, suppose you have a collection of weight tensors corresponding to a simple MLP:

params = {
    "params": {
        "Dense_0": {
            "kernel": Array[784, 512],
            "bias": Array[512]
        },
        "Dense_1": {
            "kernel": Array[512, 10],
            "bias": Array[10]
        }
    }
}

We can describe the permutation symmetry of this network as follows (assume the input and output neurons are also permutable).

  • The weight tensors can be permuted by $\sigma=(\sigma_0, \sigma_1, \sigma_2) \in S_{784} \times S_{512} \times S_{10}$.
  • $\sigma_0$ permutes the first dimension of params["params"]["Dense_0"]["kernel"].
  • $\sigma_1$ permutes the second dimension of params["params"]["Dense_0"]["kernel"], the vector params["params"]["Dense_0"]["bias"], and the first dimension of params["params"]["Dense_1"]["kernel"].
  • $\sigma_2$ permutes the second dimension of params["params"]["Dense_1"]["kernel"] and the vector params["params"]["Dense_1"]["bias"].

Then we number each permutation by integers: $(\sigma_0, \sigma_1, \sigma_2) \mapsto (0, 1, 2)$ and define the permutation specification:

perm_spec = {
    "params": {
        "Dense_0": {
            "kernel": (0, 1),
            "bias": (1,)
        },
        "Dense_1": {
            "kernel": (1, 2),
            "bias": (2,)
        }
    }
}

Notice that nothing requires the input to be a collection of weight tensors. This library processes any collection of tensors if you give it a description of the permutation symmetries.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages