Skip to content

Commit

Permalink
Chunked model evaluation (#91)
Browse files Browse the repository at this point in the history
Companion PR for matsengrp/dnsm-experiments-1#46

* Chunks evaluation of `ModelBase.evaluate_sequences`. Adds a method wrapper in `netam.common` which can easily be re-used on similar methods in the future (such as the equivalent method on AbstractBinarySelectionModel, if it ever gets vectorized). This is useful for avoiding loading the whole input vector's worth of intermediate results into memory, which was causing memory errors on the GPU when evaluating neutral models on large datasets.
* fixes branch length setter
* documents `Crepe.__call__` and enforces type of argument.
  • Loading branch information
willdumm authored Dec 4, 2024
1 parent 469420e commit b59adbd
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 2 deletions.
58 changes: 58 additions & 0 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import itertools
import resource
import subprocess
from tqdm import tqdm
from functools import wraps
from itertools import islice

import numpy as np
import torch
Expand Down Expand Up @@ -380,3 +383,58 @@ def encode_sequences(sequences, encoder):
torch.stack(masks),
torch.stack(wt_base_modifiers),
)


# from https://docs.python.org/3.11/library/itertools.html#itertools-recipes
# avoiding walrus:
def chunked(iterable, n):
"Chunk data into lists of length n. The last chunk may be shorter."
it = iter(iterable)
while True:
chunk = list(islice(it, n))
if not chunk:
return
yield chunk


def chunk_method(default_chunk_size=2048, progress_bar_name=None):
"""Decorator to chunk the input to a method.
Expects that all positional arguments are iterables of the same length,
and that outputs are tuples of tensors whose first dimension
corresponds to the first dimension of the input iterables.
If method returns just one item, it must not be a tuple.
Chunking is done along the first dimension of all inputs.
Args:
default_chunk_size: The default chunk size. The decorated method can
also automatically accept a `default_chunk_size` keyword argument.
progress_bar_name: The name of the progress bar. If None, no progress bar is shown.
"""

def decorator(method):
@wraps(method)
def wrapper(self, *args, **kwargs):
if "chunk_size" in kwargs:
chunk_size = kwargs.pop("chunk_size")
else:
chunk_size = default_chunk_size
results = []
if progress_bar_name is None:
progargs = {"disable": True}
else:
progargs = {"desc": progress_bar_name}
bar = tqdm(total=len(args[0]), delay=2.0, **progargs)
for chunked_args in zip(*(chunked(arg, chunk_size) for arg in args)):
bar.update(len(chunked_args[0]))
results.append(method(self, *chunked_args, **kwargs))
if isinstance(results[0], tuple):
return tuple(torch.cat(tensors) for tensors in zip(*results))
else:
return torch.cat(results)

return wrapper

return decorator
9 changes: 8 additions & 1 deletion netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,9 @@ def export_branch_lengths(self, out_csv_path):
)

def load_branch_lengths(self, in_csv_path):
self.branch_lengths = pd.read_csv(in_csv_path)["branch_length"].values
self.branch_lengths = torch.Tensor(
pd.read_csv(in_csv_path)["branch_length"].values
)

def __repr__(self):
return f"{self.__class__.__name__}(Size: {len(self)}) on {self.branch_lengths.device}"
Expand Down Expand Up @@ -252,6 +254,11 @@ def __init__(self, encoder, model, training_hyperparameters={}):
self.training_hyperparameters = training_hyperparameters

def __call__(self, sequences):
"""Evaluate the model on a list of sequences."""
if isinstance(sequences, str):
raise ValueError(
"Expected a list of sequences for call on crepe, but got a single string instead."
)
return self.model.evaluate_sequences(sequences, encoder=self.encoder)

@property
Expand Down
4 changes: 3 additions & 1 deletion netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
generate_kmers,
aa_mask_tensor_of,
encode_sequences,
chunk_method,
)

warnings.filterwarnings(
Expand Down Expand Up @@ -64,7 +65,8 @@ def unfreeze(self):
for param in self.parameters():
param.requires_grad = True

def evaluate_sequences(self, sequences, encoder=None):
@chunk_method(progress_bar_name="Evaluating model")
def evaluate_sequences(self, sequences, encoder=None, chunk_size=2048):
if encoder is None:
raise ValueError("An encoder must be provided.")
encoded_parents, masks, wt_base_modifiers = encode_sequences(sequences, encoder)
Expand Down

0 comments on commit b59adbd

Please sign in to comment.