You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thanks for the great work! This implementation made working with embeddings so much easier.
I have a question about implementing classification/regression on evotuned weights. So, I am evotuning my model, obtaining weights, and want to implement regression/classification on my lab results. It is like I am tuning 2 times, one with similar families and one time with about 80% of my data. There is a guide in tutorials about on-top classification but I can just use the apply_func part of it for using my tuned weights not the init part. Do I need to define two init functions, one for evotuned part and one for the final added layers? and if so, is there a function in jax to combine those? I really appreciate your insight. Also, pasting the on-top model example for clarity.
from jax.experimental import stax
from jax.experimental.stax import Dense, Relu, serial
from jax_unirep.layers import AAEmbedding, mLSTM, mLSTMAvgHidden
init_fun, apply_fun = serial(
AAEmbedding(10)
mLSTM(1900),
mLSTMAvgHidden(),
# Add two layers, one dense layer that results in 512-dim activations
Dense(512), Relu(),
# And then a linear layer to produce a 1-dim activation
Dense(1)
)
The text was updated successfully, but these errors were encountered:
Hi,
Thanks for the great work! This implementation made working with embeddings so much easier.
I have a question about implementing classification/regression on evotuned weights. So, I am evotuning my model, obtaining weights, and want to implement regression/classification on my lab results. It is like I am tuning 2 times, one with similar families and one time with about 80% of my data. There is a guide in tutorials about on-top classification but I can just use the apply_func part of it for using my tuned weights not the init part. Do I need to define two init functions, one for evotuned part and one for the final added layers? and if so, is there a function in jax to combine those? I really appreciate your insight. Also, pasting the on-top model example for clarity.
from jax.experimental import stax
from jax.experimental.stax import Dense, Relu, serial
from jax_unirep.layers import AAEmbedding, mLSTM, mLSTMAvgHidden
init_fun, apply_fun = serial(
AAEmbedding(10)
mLSTM(1900),
mLSTMAvgHidden(),
# Add two layers, one dense layer that results in 512-dim activations
Dense(512), Relu(),
# And then a linear layer to produce a 1-dim activation
Dense(1)
)
The text was updated successfully, but these errors were encountered: