Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Neural network layers #50

Closed
rlouf opened this issue Oct 21, 2020 · 2 comments
Closed

Neural network layers #50

rlouf opened this issue Oct 21, 2020 · 2 comments

Comments

@rlouf
Copy link
Owner

rlouf commented Oct 21, 2020

DRAFT

I am moving the discussion initiated in #16 here. Please comment if you think there is an issue with the design and/or have ideas.

The idea is to subclass trax’s constructs and allow use of distributions for weights and transformation of weights. Ideally we should be able to take any model that can be expressed with ˋtrax` and make it bayesian by adding prior distributions on the weights.

Layers are distributions over functions; let us see how that would translate on a naive MNIST example:

@mcx.model
def mnist(image):
    nn <~ ml.Serial(
        dense(400, Normal(0, 1)),
        dense(400, Normal(0, 1)),
        dense(10, Normal(0, 1)),
        softmax(),
    )
    p = nn(image)
    cat <~ Categorical(p)
    return cat

For the previous example to work need to specify broadcasting rules for the layers' prior distribution:

image

With this API we can easily define hierarchical models:

@mcx.model
def mnist(image):
    sigma <~ Exponential(2)
    nn <~ ml.Serial(
        dense(400, Normal(0, sigma)),
        dense(400, Normal(0, sigma)),
        dense(10, Normal(0, sigma)),
        softmax(),
    )
    p = nn(image)
    cat <~ Categorical(p)
    return cat

Forward sampling

Let’s look now at the design of the forward sampler. We need to return forward samples of the layer's weights as well as the other random variables.

def sample_mnist(rng_key, image):
    nn = ml.Serial(
        dense(400, Normal(0, 1)),
        dense(400, Normal(0, 1)),
        dense(10, Normal(0, 1)),
        softmax(),
    ).sample(rng_key)
    p = nn(image)
    cat = Categorical(p).sample(rng_key)
    return nn, cat

where nn is a trax.layers.Serial object, which is consistent with the above assertion that Bayesian neural networks should be distributions over functions. It is possible to extract the layers' weights for further analysis by calling nn._weights. It may also be possible to JIT-compile nn.

Log-likelihood

The API is not 100% there yet:

def logpdf_mnist(nn_sample, cat, image):
    loglikelihood = 0
    loglikelihood += nn.logpdf(nn_sample)
    p = nn_sample(image)
    loglikelihood += Categorical(p).logpdf(cat)
    return loglikelihood
@rlouf
Copy link
Owner Author

rlouf commented Oct 21, 2020

As pointed out by Torsten Scholak (https://twitter.com/tscholak/status/1318897344549736450?s=20), you may want to train a regular neural network (NN) while lifting the inputs & outputs to random variables whose posterior can be sampled. Edward did a version of that: https://github.com/blei-lab/edward/blob/master/examples/deep_exponential_family.py#L181-L184

My intuition is that this could be handled at compile-time when evaluators are free to modify the graph to their liking, e.g. with HMC that applies constrained-unconstrained transformations. We could add an evaluator that tries to descend the gradient of the neural networks' weights; this evaluator would change the status of the NN from a Bayesian NN to a NN with trainable weights.

@rlouf
Copy link
Owner Author

rlouf commented Mar 8, 2021

Moved to #96

@rlouf rlouf closed this as completed Mar 8, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant