Skip to content

Commit

Permalink
Release of thrifty models; pretrained module; demo notebook; weights_…
Browse files Browse the repository at this point in the history
…only=True (#81)

* Using `weights_only=True` for model loading
* Load models using `pretrained`module
* Demo notebook
* Cleaning out old trained SHM model
* Update dependencies
  • Loading branch information
matsen authored Nov 11, 2024
1 parent 39f1cde commit 63a10bb
Show file tree
Hide file tree
Showing 14 changed files with 351 additions and 22 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,4 @@ _ignore/
.DS_Store
_logs/
_checkpoints
_pretrained
6 changes: 6 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,10 @@ lint:
docs:
make -C docs html

notebooks:
mkdir -p notebooks/_ignore
for nb in notebooks/*.ipynb; do \
jupyter nbconvert --to notebook --execute "$$nb" --output notebooks/_ignore/"$$(basename $$nb)"; \
done

.PHONY: install test notebooks format lint docs
41 changes: 39 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,42 @@
# netam

Neural NETworks for antibody Affinity Maturation
Neural NETworks for antibody Affinity Maturation.

## Installation
## pip installation

TODO

This will allow you to use the models.

However, if you wish to interact with the models on a more detailed level, you will want to do a developer installation (see below).


## Pretrained models

Pretrained models will be downloaded on demand, so you will not need to install them separately.

The models are named according to the following convention:

ModeltypeSpeciesVXX-YY

where:

* `Modeltype` is the type of model, such as `Thrifty` for the "thrifty" SHM model
* `Species` is the species, such as `Hum` for human
* `XX` is the version of the model
* `YY` is any model-specific information, such as the number of parameters

If you need to clear out the cache of pretrained models, you can use the command-line call:

netam clear_model_cache


## Usage

See the examples in the `notebooks` directory.


## Developer installation

From a clone of this repository, install using:

Expand All @@ -13,6 +47,9 @@ From a clone of this repository, install using:
Note that you should be fine with an earlier version of Python.
We target Python 3.9, but 3.11 is faster.


## Experiments

If you are running one of the experiment repos, such as:

* [thrifty-experiments-1](https://github.com/matsengrp/thrifty-experiments-1/)
Expand Down
Binary file removed data/cnn_joi_sml-shmoof_small.pth
Binary file not shown.
16 changes: 0 additions & 16 deletions data/cnn_joi_sml-shmoof_small.yml

This file was deleted.

12 changes: 12 additions & 0 deletions netam/cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
import os
import shutil

import fire
import pandas as pd

from netam.pretrained import PRETRAINED_DIR


def clear_model_cache():
"""This function clears the cache of pre-trained models."""
if os.path.exists(PRETRAINED_DIR):
print(f"Removing {PRETRAINED_DIR}")
shutil.rmtree(PRETRAINED_DIR)


def concatenate_csvs(
input_csvs_str: str,
Expand Down
87 changes: 87 additions & 0 deletions netam/pretrained.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""This module provides a simple interface for downloading pre-trained models.
It was inspired by the `load_model` module of [AbLang2](https://github.com/oxpig/AbLang2).
"""

import os
import zipfile
import pkg_resources

import requests

from netam.framework import load_crepe

PRETRAINED_DIR = pkg_resources.resource_filename(__name__, "_pretrained")

PACKAGE_LOCATIONS_AND_CONTENTS = (
# Order of entries:
# * Local file name
# * Remote URL
# * Directory in which the models appear after extraction
# * List of models in the package
[
"thrifty-1.0.zip",
"https://github.com/matsengrp/thrifty-models/archive/refs/heads/release/1.0.zip",
"thrifty-models-release-1.0/models",
[
"ThriftyHumV1.0-20",
"ThriftyHumV1.0-45",
"ThriftyHumV1.0-59",
],
],
)

LOCAL_TO_REMOTE = {}
MODEL_TO_LOCAL = {}
LOCAL_TO_DIR = {}

for local_file, remote, models_dir, models in PACKAGE_LOCATIONS_AND_CONTENTS:
LOCAL_TO_REMOTE[local_file] = remote

for model in models:
MODEL_TO_LOCAL[model] = local_file


def local_path_for_model(model_name: str):
"""Return the local path for a model, downloading it if necessary."""

if model_name not in MODEL_TO_LOCAL:
raise ValueError(f"Model {model_name} not found in pre-trained models.")

os.makedirs(PRETRAINED_DIR, exist_ok=True)

local_package = MODEL_TO_LOCAL[model_name]
local_package_path = os.path.join(PRETRAINED_DIR, local_package)

if not os.path.exists(local_package_path):
url = LOCAL_TO_REMOTE[local_package]
print(f"Fetching models: downloading {url} to {local_package_path}")
response = requests.get(url)
response.raise_for_status()
with open(local_package_path, "wb") as f:
f.write(response.content)
if local_package.endswith(".zip"):
with zipfile.ZipFile(local_package_path, "r") as zip_ref:
zip_ref.extractall(PRETRAINED_DIR)
else:
raise ValueError(f"Unknown file type for {local_package}")

local_crepe_path = os.path.join(PRETRAINED_DIR, models_dir, model_name)

if not os.path.exists(local_crepe_path + ".yml"):
raise ValueError(f"Model {model_name} not found in pre-trained models.")
if not os.path.exists(local_crepe_path + ".pth"):
raise ValueError(f"Model {model_name} missing model weights.")

return local_crepe_path


def load(model_name: str):
"""Load a pre-trained model.
If the model is not already downloaded, it will be downloaded from the appropriate
URL and stashed in the PRETRAINED_DIR.
"""

local_crepe_path = local_path_for_model(model_name)
return load_crepe(local_crepe_path)
5 changes: 5 additions & 0 deletions notebooks/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Notebooks

This is a minimal set of notebooks to demonstrate `netam` in action.
Most relevant notebooks are stored in separate repositories.
See the main README for links to those repositories.
193 changes: 193 additions & 0 deletions notebooks/thrifty_demo.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
black
docformatter
fire
nbconvert
pytest
snakemake
tensorboardX
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"optuna",
"pandas",
"pyyaml",
"requests",
"tensorboardX",
"torch",
"tqdm",
Expand Down
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
load_pcp_df,
add_shm_model_outputs_to_pcp_df,
)
from netam.pretrained import local_path_for_model


@pytest.fixture(scope="module")
Expand All @@ -12,6 +13,6 @@ def pcp_df():
)
df = add_shm_model_outputs_to_pcp_df(
df,
"data/cnn_joi_sml-shmoof_small",
local_path_for_model("ThriftyHumV1.0-45"),
)
return df
4 changes: 2 additions & 2 deletions tests/test_molevol.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import netam.molevol as molevol
from netam import framework
from netam import pretrained

from netam.sequences import (
nt_idx_tensor_of_str,
Expand Down Expand Up @@ -143,8 +144,7 @@ def iterative_aaprob_of_mut_and_sub(parent_codon, mut_probs, csps):


def test_aaprob_of_mut_and_sub():
crepe_path = "data/cnn_joi_sml-shmoof_small"
crepe = framework.load_crepe(crepe_path)
crepe = pretrained.load("ThriftyHumV1.0-45")
[rates], [subs] = crepe([parent_nt_seq])
mut_probs = 1.0 - torch.exp(-rates.squeeze().clone().detach())
parent_codon = parent_nt_seq[0:3]
Expand Down
3 changes: 2 additions & 1 deletion tests/test_multihit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
codon_probs_of_parent_scaled_nt_rates_and_csps,
reshape_for_codons,
)
from netam import pretrained
from netam.sequences import nt_idx_tensor_of_str
import pytest
import pandas as pd
Expand Down Expand Up @@ -63,7 +64,7 @@
@pytest.fixture
def mini_multihit_train_val_datasets():
df = pd.read_csv("data/wyatt-10x-1p5m_pcp_2023-11-30_NI.first100.csv.gz")
crepe = framework.load_crepe("data/cnn_joi_sml-shmoof_small")
crepe = pretrained.load("ThriftyHumV1.0-45")
df = multihit.prepare_pcp_df(df, crepe, 500)
return multihit.train_test_datasets_of_pcp_df(df)

Expand Down

0 comments on commit 63a10bb

Please sign in to comment.