Skip to content

Commit

Permalink
notebook cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
jgallowa07 committed May 6, 2024
1 parent 48f4f4d commit feb386d
Show file tree
Hide file tree
Showing 5 changed files with 990 additions and 952 deletions.
28 changes: 8 additions & 20 deletions multidms/biophysical.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,6 @@ def _gamma_corrected_cost_smooth(
"""
X, y = data
loss = 0
n = 0

# Sum the huber loss across all conditions
# shift_ridge_penalty = 0
Expand All @@ -424,27 +423,16 @@ def _gamma_corrected_cost_smooth(

# compute the Huber loss between observed and predicted
# functional scores
huber_loss_d = huber_loss(
loss += huber_loss(
y[condition] + d_params["gamma_d"], y_d_predicted, huber_scale
).sum()
).mean()

# compute a regularization term that penalizes non-zero
# parameters and add it to the loss function
penalty_ridge_shift = scale_coeff_ridge_shift * (d_params["s_md"] ** 2).sum()
penalty_ridge_alpha_d = (
scale_coeff_ridge_alpha_d * (d_params["alpha_d"] ** 2).sum()
)
penalty_ridge_gamma = scale_coeff_ridge_gamma * (d_params["gamma_d"] ** 2).sum()
n_d = y[condition].shape[0]
n += n_d

loss += (
huber_loss_d
+ penalty_ridge_shift
+ penalty_ridge_alpha_d
+ penalty_ridge_gamma
) / n_d

loss += scale_coeff_ridge_beta * jnp.sum(params["beta"] ** 2) / n
# loss += scale_coeff_ridge_shift * (d_params["s_md"] ** 2).sum()
# loss += scale_coeff_ridge_alpha_d * (d_params["alpha_d"] ** 2).sum()
# loss += scale_coeff_ridge_gamma * (d_params["gamma_d"] ** 2).sum()
loss /= len(X)
loss += scale_coeff_ridge_beta * jnp.sum(params["beta"] ** 2)

return loss / len(X)
return loss
3 changes: 2 additions & 1 deletion multidms/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,8 @@ def __init__(
latent_model = multidms.biophysical.additive_model
if latent_model == multidms.biophysical.additive_model:
n_beta_shift = len(self._data.mutations)
self._params["beta"] = jax.random.normal(shape=(n_beta_shift,), key=key)
# self._params["beta"] = jax.random.normal(shape=(n_beta_shift,), key=key)
self._params["beta"] = jnp.zeros(shape=(n_beta_shift,))
for condition in data.conditions:
self._params[f"shift_{condition}"] = jnp.zeros(shape=(n_beta_shift,))
self._params[f"alpha_{condition}"] = jnp.zeros(shape=(1,))
Expand Down
Loading

0 comments on commit feb386d

Please sign in to comment.