Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clear Map for on top classification with evotuned weights #113

Open
mkd1996 opened this issue Apr 25, 2022 · 0 comments
Open

Clear Map for on top classification with evotuned weights #113

mkd1996 opened this issue Apr 25, 2022 · 0 comments

Comments

@mkd1996
Copy link

mkd1996 commented Apr 25, 2022

Hi,

Thanks for the great work! This implementation made working with embeddings so much easier.

I have a question about implementing classification/regression on evotuned weights. So, I am evotuning my model, obtaining weights, and want to implement regression/classification on my lab results. It is like I am tuning 2 times, one with similar families and one time with about 80% of my data. There is a guide in tutorials about on-top classification but I can just use the apply_func part of it for using my tuned weights not the init part. Do I need to define two init functions, one for evotuned part and one for the final added layers? and if so, is there a function in jax to combine those? I really appreciate your insight. Also, pasting the on-top model example for clarity.

from jax.experimental import stax
from jax.experimental.stax import Dense, Relu, serial

from jax_unirep.layers import AAEmbedding, mLSTM, mLSTMAvgHidden

init_fun, apply_fun = serial(
AAEmbedding(10)
mLSTM(1900),
mLSTMAvgHidden(),
# Add two layers, one dense layer that results in 512-dim activations
Dense(512), Relu(),
# And then a linear layer to produce a 1-dim activation
Dense(1)
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant