You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am moving the discussion initiated in #16 here. Please comment if you think there is an issue with the design and/or have ideas.
The idea is to subclass trax’s constructs and allow use of distributions for weights and transformation of weights. Ideally we should be able to take any model that can be expressed with ˋtrax` and make it bayesian by adding prior distributions on the weights.
Layers are distributions over functions; let us see how that would translate on a naive MNIST example:
where nn is a trax.layers.Serial object, which is consistent with the above assertion that Bayesian neural networks should be distributions over functions. It is possible to extract the layers' weights for further analysis by calling nn._weights. It may also be possible to JIT-compile nn.
My intuition is that this could be handled at compile-time when evaluators are free to modify the graph to their liking, e.g. with HMC that applies constrained-unconstrained transformations. We could add an evaluator that tries to descend the gradient of the neural networks' weights; this evaluator would change the status of the NN from a Bayesian NN to a NN with trainable weights.
DRAFT
I am moving the discussion initiated in #16 here. Please comment if you think there is an issue with the design and/or have ideas.
The idea is to subclass trax’s constructs and allow use of distributions for weights and transformation of weights. Ideally we should be able to take any model that can be expressed with ˋtrax` and make it bayesian by adding prior distributions on the weights.
Layers are distributions over functions; let us see how that would translate on a naive MNIST example:
For the previous example to work need to specify broadcasting rules for the layers' prior distribution:
With this API we can easily define hierarchical models:
Forward sampling
Let’s look now at the design of the forward sampler. We need to return forward samples of the layer's weights as well as the other random variables.
where nn is a
trax.layers.Serial
object, which is consistent with the above assertion that Bayesian neural networks should be distributions over functions. It is possible to extract the layers' weights for further analysis by callingnn._weights
. It may also be possible to JIT-compilenn
.Log-likelihood
The API is not 100% there yet:
The text was updated successfully, but these errors were encountered: