Skip to content

Commit

Permalink
syncing with epam update; better tensorboard; refactoring (#26)
Browse files Browse the repository at this point in the history
* Adjusting to new `epam` branch length optimization return type
* Making tensorboard output more informative with memory usage and wall clock time
* additional sampling options
* refactoring so that we have predictions and loss for DNSMs
* saving crepes during training
  • Loading branch information
matsen authored May 31, 2024
1 parent 40b23cd commit 2e820d5
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 51 deletions.
66 changes: 32 additions & 34 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def branch_lengths(self, new_branch_lengths):
f"Expected {len(self._branch_lengths)} branch lengths, "
f"got {len(new_branch_lengths)}"
)
assert np.all(np.isfinite(new_branch_lengths) & (new_branch_lengths > 0))
assert torch.all(torch.isfinite(new_branch_lengths) & (new_branch_lengths > 0))
self._branch_lengths = new_branch_lengths
self.update_neutral_aa_mut_probs()

Expand All @@ -122,7 +122,7 @@ def load_branch_lengths(self, in_csv_path):
self.branch_lengths = pd.read_csv(in_csv_path)["branch_length"].values

def update_neutral_aa_mut_probs(self):
print("consolidating shmple rates into substitution probabilities...")
print("consolidating neutral rates into substitution probabilities...")

neutral_aa_mut_prob_l = []

Expand Down Expand Up @@ -250,41 +250,32 @@ def load_branch_lengths(self, in_csv_prefix):
in_csv_prefix + ".val_branch_lengths.csv"
)

def loss_of_batch(self, batch):
def predictions_of_batch(self, batch):
"""
Make predictions for a batch of data.
Note that we use the mask for prediction as part of the input for the
transformer, though we don't mask the predictions themselves.
"""
aa_parents_idxs = batch["aa_parents_idxs"].to(self.device)
aa_subs_indicator = batch["subs_indicator"].to(self.device)
mask = batch["mask"].to(self.device)
log_neutral_aa_mut_probs = batch["log_neutral_aa_mut_probs"].to(self.device)

if not torch.isfinite(log_neutral_aa_mut_probs[mask]).all():
raise ValueError(
f"log_neutral_aa_mut_probs has non-finite values at relevant positions: {log_neutral_aa_mut_probs[mask]}"
)

log_selection_factors = self.model(aa_parents_idxs, mask)
return self.complete_loss_fn(
log_neutral_aa_mut_probs,
log_selection_factors,
aa_subs_indicator,
mask,
)

def complete_loss_fn(
self,
log_neutral_aa_mut_probs,
log_selection_factors,
aa_subs_indicator,
mask,
):
# Take the product of the neutral mutation probabilities and the selection factors.
predictions = torch.exp(log_neutral_aa_mut_probs + log_selection_factors)

predictions = predictions.masked_select(mask)
aa_subs_indicator = aa_subs_indicator.masked_select(mask)

assert torch.isfinite(predictions).all()
predictions = clamp_probability(predictions)
return predictions

def loss_of_batch(self, batch):
aa_subs_indicator = batch["subs_indicator"].to(self.device)
mask = batch["mask"].to(self.device)
aa_subs_indicator = aa_subs_indicator.masked_select(mask)
predictions = self.predictions_of_batch(batch).masked_select(mask)
return self.bce_loss(predictions, aa_subs_indicator)

def _find_optimal_branch_length(
Expand All @@ -307,6 +298,7 @@ def _find_optimal_branch_length(

def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
optimal_lengths = []
failed_count = 0

for parent, child, rates, subs_probs, starting_length in tqdm(
zip(
Expand All @@ -319,18 +311,24 @@ def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
total=len(dataset.nt_parents),
desc="Finding optimal branch lengths",
):
optimal_lengths.append(
self._find_optimal_branch_length(
parent,
child,
rates[: len(parent)],
subs_probs[: len(parent), :],
starting_length,
**optimization_kwargs,
)
branch_length, failed_to_converge = self._find_optimal_branch_length(
parent,
child,
rates[: len(parent)],
subs_probs[: len(parent), :],
starting_length,
**optimization_kwargs,
)

optimal_lengths.append(branch_length)
failed_count += failed_to_converge

if failed_count > 0:
print(
f"Branch length optimization failed to converge for {failed_count} of {len(dataset)} sequences."
)

return np.array(optimal_lengths)
return torch.tensor(optimal_lengths)

def to_crepe(self):
training_hyperparameters = {
Expand Down
85 changes: 68 additions & 17 deletions netam/framework.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
from abc import ABC, abstractmethod
import os
from time import time

import pandas as pd
import numpy as np
Expand Down Expand Up @@ -344,11 +345,19 @@ def trimmed_shm_model_outputs_of_crepe(crepe, parents):
return trimmed_rates, trimmed_csps


def load_and_add_shm_model_outputs_to_pcp_df(pcp_df_path_gz, crepe_prefix, device=None):
def load_and_add_shm_model_outputs_to_pcp_df(
pcp_df_path_gz, crepe_prefix, sample_count=None, chosen_v_families=None
):
pcp_df = pd.read_csv(pcp_df_path_gz, compression="gzip", index_col=0).reset_index(
drop=True
)
crepe = load_crepe(crepe_prefix, device)
pcp_df["v_family"] = pcp_df["v_gene"].str.split("-").str[0]
if chosen_v_families is not None:
chosen_v_families = set(chosen_v_families)
pcp_df = pcp_df[pcp_df["v_family"].isin(chosen_v_families)]
if sample_count is not None:
pcp_df = pcp_df.sample(sample_count)
crepe = load_crepe(crepe_prefix)
rates, csps = trimmed_shm_model_outputs_of_crepe(crepe, pcp_df["parent"])
pcp_df["rates"] = rates
pcp_df["subs_probs"] = csps
Expand Down Expand Up @@ -428,7 +437,21 @@ def multi_train(self, epochs, max_tries=3):
return train_history

def write_loss(self, loss_name, loss, step):
self.writer.add_scalar(loss_name, loss, step)
self.writer.add_scalar(loss_name, loss, step, walltime=time())

def write_cuda_memory_info(self):
megabyte_scaling_factor = 1 / 1024**2
if self.device.type == "cuda":
self.writer.add_scalar(
"CUDA memory allocated",
torch.cuda.memory_allocated() * megabyte_scaling_factor,
self.global_epoch,
)
self.writer.add_scalar(
"CUDA max memory allocated",
torch.cuda.max_memory_allocated() * megabyte_scaling_factor,
self.global_epoch,
)

def process_data_loader(self, data_loader, train_mode=False, loss_reduction=None):
"""
Expand Down Expand Up @@ -512,7 +535,12 @@ def process_data_loader(self, data_loader, train_mode=False, loss_reduction=None
self.write_loss("Validation loss", average_loss, self.global_epoch)
return loss_reduction(average_loss)

def train(self, epochs):
def train(self, epochs, out_prefix=None):
"""
Train the model for the given number of epochs.
If out_prefix is provided, then a crepe will be saved to that location.
"""
assert self.train_loader is not None, "No training data provided."

train_losses = []
Expand All @@ -537,6 +565,10 @@ def record_losses(train_loss, val_loss):
if current_lr < self.min_learning_rate:
break

if self.device.type == "cuda":
# Clear cache for accurate memory usage tracking.
torch.cuda.empty_cache()

train_loss = self.process_data_loader(
self.train_loader, train_mode=True
).item()
Expand All @@ -561,11 +593,15 @@ def record_losses(train_loss, val_loss):
lr=current_lr,
refresh=True,
)
self.write_cuda_memory_info()
self.writer.flush()

if best_model_state is not None:
self.model.load_state_dict(best_model_state)

if out_prefix is not None:
self.save_crepe(out_prefix)

return pd.DataFrame({"train_loss": train_losses, "val_loss": val_losses})

def evaluate(self):
Expand Down Expand Up @@ -639,9 +675,13 @@ def standardize_and_use_yun_approx_branch_lengths(self):
dataset.branch_lengths = torch.tensor(lengths)

def mark_branch_lengths_optimized(self, cycle):
self.writer.add_scalar("branch length optimization", cycle, self.global_epoch)
self.writer.add_scalar(
"branch length optimization", cycle, self.global_epoch, walltime=time()
)

def joint_train(self, epochs=100, cycle_count=2, training_method="full"):
def joint_train(
self, epochs=100, cycle_count=2, training_method="full", out_prefix=None
):
"""
Do joint optimization of model and branch lengths.
Expand Down Expand Up @@ -858,6 +898,7 @@ def log_pcp_probability(log_branch_length):

def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
optimal_lengths = []
failed_count = 0

self.model.eval()
self.model.freeze()
Expand All @@ -879,15 +920,21 @@ def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
total=len(dataset.encoded_parents),
desc="Finding optimal branch lengths",
):
optimal_lengths.append(
self._find_optimal_branch_length(
encoded_parent,
mask,
mutation_indicator,
wt_base_modifier,
starting_branch_length,
**optimization_kwargs,
)
branch_length, failed_to_converge = self._find_optimal_branch_length(
encoded_parent,
mask,
mutation_indicator,
wt_base_modifier,
starting_branch_length,
**optimization_kwargs,
)

optimal_lengths.append(branch_length)
failed_count += failed_to_converge

if failed_count > 0:
print(
f"Branch length optimization failed to converge for {failed_count} of {len(dataset)} sequences."
)

self.model.unfreeze()
Expand All @@ -896,8 +943,12 @@ def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):

def write_loss(self, loss_name, loss, step):
rate_loss, csp_loss = loss.unbind()
self.writer.add_scalar("Rate " + loss_name, rate_loss.item(), step)
self.writer.add_scalar("CSP " + loss_name, csp_loss.item(), step)
self.writer.add_scalar(
"Rate " + loss_name, rate_loss.item(), step, walltime=time()
)
self.writer.add_scalar(
"CSP " + loss_name, csp_loss.item(), step, walltime=time()
)


def burrito_class_of_model(model):
Expand Down

0 comments on commit 2e820d5

Please sign in to comment.