Skip to content

Commit

Permalink
fit passing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jgallowa07 committed May 10, 2024
1 parent feb386d commit a4bf8fa
Show file tree
Hide file tree
Showing 9 changed files with 2,012 additions and 1,302 deletions.
152 changes: 83 additions & 69 deletions multidms/biophysical.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@

import jax.numpy as jnp
from jaxopt.loss import huber_loss
from jaxopt.prox import prox_lasso

from multidms.utils import transform # TODO namespace for utils?
import pyproximal
import jax

# jax.config.update("jax_enable_x64", True)
Expand Down Expand Up @@ -76,11 +78,7 @@ def additive_model(d_params: dict, X_d: jnp.array):
jnp.array
Predicted latent phenotypes for each row in ``X_d``
"""
return (
d_params["beta_naught"]
+ d_params["alpha_d"]
+ (X_d @ (d_params["beta_m"] + d_params["s_md"]))
)
return d_params["beta0"] + X_d @ d_params["beta"]


r"""
Expand Down Expand Up @@ -281,7 +279,8 @@ def softplus_activation(d_params, act, lower_bound=-3.5, hinge_scale=0.1, **kwar
hinge_scale
* (jnp.logaddexp(0, (act - (lower_bound + d_params["gamma_d"])) / hinge_scale))
+ lower_bound
+ d_params["gamma_d"]
# TODO GAMMA
# + d_params["gamma_d"]
)


Expand Down Expand Up @@ -320,56 +319,76 @@ def _abstract_epistasis(
return t(d_params, g(d_params["theta"], additive_model(d_params, X_h)), **kwargs)


def _lasso_lock_prox(
params,
hyperparams_prox=dict(
lasso_params=None, lock_params=None, upper_bound_theta_ge_scale=None
),
scaling=1.0,
):
"""
Apply lasso and lock constraints to parameters
Parameters
----------
params : dict
Dictionary of parameters to constrain
hyperparams_prox : dict
Dictionary of hyperparameters for proximal operators
scaling : float
Scaling factor for lasso penalty
"""
# enforce monotonic epistasis and constrain ge_scale upper limit
if "ge_scale" in params["theta"]:
params["theta"]["ge_scale"] = params["theta"]["ge_scale"].clip(
0, hyperparams_prox["upper_bound_theta_ge_scale"]
)

if "p_weights_1" in params["theta"]:
params["theta"]["p_weights_1"] = params["theta"]["p_weights_1"].clip(0)
params["theta"]["p_weights_2"] = params["theta"]["p_weights_2"].clip(0)

if hyperparams_prox["lasso_params"] is not None:
for key, value in hyperparams_prox["lasso_params"].items():
params[key] = prox_lasso(params[key], value, scaling)
def proximal_objective(Dop, params, hyperparameters, scaling=1.0):
"""ADMM generalized lasso optimization."""
(
scale_coeff_lasso_shift,
admm_niter,
admm_tau,
admm_mu,
ge_scale_upper_bound,
lock_params,
bundle_idxs,
# Dop,
) = hyperparameters
# apply prox
beta_ravel = jnp.vstack(params["beta"].values()).ravel(order="F")

# see https://pyproximal.readthedocs.io/en/stable/index.html
beta_ravel, shift_ravel = pyproximal.optimization.primal.LinearizedADMM(
pyproximal.L2(b=beta_ravel),
pyproximal.L1(sigma=scaling * scale_coeff_lasso_shift),
Dop,
niter=admm_niter,
tau=admm_tau,
mu=admm_mu,
x0=beta_ravel,
show=False,
)

beta = beta_ravel.reshape(-1, len(beta_ravel) // len(params["beta"]), order="F")
shift = shift_ravel.reshape(-1, len(shift_ravel) // len(params["beta"]), order="F")

# update beta dict
for i, d in enumerate(params["beta"]):
params["beta"][d] = beta[i]

# update shifts
for i, d in enumerate(params["shift"]):
params["shift"][d] = shift[i]

# clamp beta0 for reference condition in non-scaled parameterization
# (where it's a box constraint)
params = transform(params, bundle_idxs)

# should the following two conditions be within the transform?
# I'm pretty sure it doesn't matter since the the post latent
# stuff doesn't interfere with the beta's transformation.
#
# Though I do wonder if the beta's should be transformed before
# beting passed to the predictive function? HMM?
# clamp theta scale to monotonic, and with optional upper bound
params["theta"]["ge_scale"] = params["theta"]["ge_scale"].clip(
0, ge_scale_upper_bound
)
# Any params to constrain during fit
if hyperparams_prox["lock_params"] is not None:
for key, value in hyperparams_prox["lock_params"].items():
params[key] = value
if lock_params is not None:
for (param, subparam), value in lock_params.items():
params[param][subparam] = value

# params["beta0"][params["beta0"].keys()] = 0.0
params = transform(params, bundle_idxs)

return params


def _gamma_corrected_cost_smooth(
# TODO, add back gamma correction
def smooth_objective(
f,
params,
data,
huber_scale=1,
scale_coeff_ridge_shift=0,
scale_coeff_ridge_beta=0,
scale_coeff_ridge_gamma=0,
scale_coeff_ridge_alpha_d=0,
huber_scale=1,
**kwargs,
):
"""
Expand All @@ -386,14 +405,8 @@ def _gamma_corrected_cost_smooth(
return the respective binarymap and the row associated target functional scores
huber_scale : float
Scale parameter for Huber loss function
scale_coeff_ridge_shift : float
Ridge penalty coefficient for shift parameters
scale_coeff_ridge_beta : float
Ridge penalty coefficient for beta parameters
scale_coeff_ridge_gamma : float
Ridge penalty coefficient for gamma parameters
scale_coeff_ridge_alpha_d : float
Ridge penalty coefficient for alpha parameters
Ridge penalty coefficient for shift parameters
kwargs : dict
Additional keyword arguments to pass to the biophysical model function
Expand All @@ -403,36 +416,37 @@ def _gamma_corrected_cost_smooth(
Summed loss across all conditions.
"""
X, y = data
loss = 0
huber_cost = 0
beta_ridge_penalty = 0

# Sum the huber loss across all conditions
# shift_ridge_penalty = 0
for condition, X_d in X.items():
# Subset the params for condition-specific prediction
d_params = {
"beta0": params["beta0"][condition],
"beta": params["beta"][condition],
# TODO GAMMA
# "gamma": params["gamma"][condition],
"theta": params["theta"],
"beta_m": params["beta"],
"beta_naught": params["beta_naught"],
"s_md": params[f"shift_{condition}"],
"alpha_d": params[f"alpha_{condition}"],
"gamma_d": params[f"gamma_{condition}"],
}

# compute predictions
y_d_predicted = f(d_params, X_d, **kwargs)

# compute the Huber loss between observed and predicted
# functional scores
loss += huber_loss(
y[condition] + d_params["gamma_d"], y_d_predicted, huber_scale
huber_cost += huber_loss(
# TODO GAMMA
# y[condition] + d_params["gamma"], y_d_predicted, huber_scale
y[condition],
y_d_predicted,
huber_scale,
).mean()

# compute a regularization term that penalizes non-zero
# parameters and add it to the loss function
# loss += scale_coeff_ridge_shift * (d_params["s_md"] ** 2).sum()
# loss += scale_coeff_ridge_alpha_d * (d_params["alpha_d"] ** 2).sum()
# loss += scale_coeff_ridge_gamma * (d_params["gamma_d"] ** 2).sum()
loss /= len(X)
loss += scale_coeff_ridge_beta * jnp.sum(params["beta"] ** 2)
beta_ridge_penalty += scale_coeff_ridge_beta * (d_params["beta"] ** 2).sum()

huber_cost /= len(X)

return loss
return huber_cost + beta_ridge_penalty
77 changes: 72 additions & 5 deletions multidms/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tqdm.auto import tqdm

from multidms import AAS
from multidms.utils import rereference, difference_matrix

Check failure on line 23 in multidms/data.py

View workflow job for this annotation

GitHub Actions / build-and-test (macos-latest, 3.10)

Ruff (F401)

multidms/data.py:23:41: F401 `multidms.utils.difference_matrix` imported but unused

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -54,6 +55,11 @@ def split_subs(subs_string, parser=split_sub):
return wts, sites, muts


# TODO add bundle_idxs property
# TODO add validation split
# TODO could compute the Difference matrix


class Data:
r"""
Prep and store one-hot encoding of
Expand Down Expand Up @@ -234,16 +240,22 @@ def __init__(
"condition column looks to be numeric type, converting to string",
UserWarning,
)
self._conditions = tuple(sorted(variants_df["condition"].astype(str).unique()))

if str(reference) not in self._conditions:
unique_conditions = set(variants_df["condition"].astype(str))
if str(reference) not in unique_conditions:
if not isinstance(reference, str):
raise ValueError(
"reference must be a string, note that if your "
"condition names are numeric, they are being "
"converted to string"
)
raise ValueError("reference must be in condition factor levels")

# set the reference as the first condition. We need it to be first because
# the difference matrix will be constructed with that assumption so the
# parameters are inserted in the correct order in the Model initialization
non_reference_conditions = unique_conditions - set([reference])
self._conditions = tuple([reference] + list(non_reference_conditions))

self._reference = str(reference)

self._collapse_identical_variants = collapse_identical_variants
Expand Down Expand Up @@ -464,6 +476,9 @@ def get_nis_from_site_map(site_map):
axis=1,
)

df.drop(["wts", "sites", "muts"], axis=1, inplace=True)
self._variants_df = df

# Make BinaryMap representations for each condition
allowed_subs = {s for subs in df.var_wrt_ref for s in subs.split()}
binmaps, X, y, w = {}, {}, {}, {}
Expand All @@ -481,12 +496,48 @@ def get_nis_from_site_map(site_map):
if "weight" in condition_func_score_df.columns:
w[condition] = jnp.array(condition_func_score_df["weight"].values)

df.drop(["wts", "sites", "muts"], axis=1, inplace=True)
self._variants_df = df
# set training data properties
self._training_data = {"X": X, "y": y, "w": w}
self._binarymaps = binmaps
self._mutations = tuple(ref_bmap.all_subs)

# next, we need to create a "scaled" dataset
# where the bits are flipped in the one-hot encoding
# for all non identical mutations
# see TODO for more
self._non_identical_idxs = {}
self._scaled_training_data = {"X": {}, "y": y, "w": w}
for condition in self._conditions:
self._non_identical_idxs[condition] = jnp.array(
[
idx
in ref_bmap.sub_str_to_indices(non_identical_mutations[condition])
for idx in range(len(ref_bmap.all_subs))
]
)
self._scaled_training_data["X"][condition] = rereference(
X[condition], self._non_identical_idxs[condition]
)

# make boolean jax array true at each of nis_idxs and false elsewhere

# self._non_identical_idxs = {
# condition: jnp.array(
# )
# for condition in self._conditions
# }

# self._scaled_training_data = {
# "X": {
# condition: rereference(
# self._training_data["X"][condition],
# self.non_identical_idxs[condition],
# )
# for condition in self._conditions
# },
# "y": self._training_data["y"],
# }

# initialize single mutational effects df
mut_df = pd.DataFrame({"mutation": self._mutations})

Expand All @@ -507,6 +558,7 @@ def get_nis_from_site_map(site_map):

self._mutations_df = mut_df
self._name = name if isinstance(name, str) else f"Data-{Data.counter}"

Data.counter += 1

def __repr__(self):
Expand Down Expand Up @@ -578,6 +630,16 @@ def non_identical_sites(self) -> dict:
"""
return self._non_identical_sites

# TODO should we rename "non_identical" -> "bundle" everywhere?
@property
def non_identical_idxs(self) -> dict:
"""
A dictionary keyed by condition names with values
being the indices into the binarymap representing
bundle (non_identical) mutations
"""
return self._non_identical_idxs

@property
def reference_sequence_conditions(self) -> list:
"""
Expand All @@ -591,6 +653,11 @@ def training_data(self) -> dict:
"""A dictionary with keys 'X' and 'y' for the training data."""
return self._training_data

@property
def scaled_training_data(self) -> dict:
"""A dictionary with keys 'X' and 'y' for the scaled training data."""
return self._scaled_training_data

@property
def binarymaps(self) -> dict:
"""
Expand Down
Loading

0 comments on commit a4bf8fa

Please sign in to comment.