-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Release of thrifty models; pretrained module; demo notebook; weights_…
…only=True (#81) * Using `weights_only=True` for model loading * Load models using `pretrained`module * Demo notebook * Cleaning out old trained SHM model * Update dependencies
- Loading branch information
Showing
14 changed files
with
351 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -157,3 +157,4 @@ _ignore/ | |
.DS_Store | ||
_logs/ | ||
_checkpoints | ||
_pretrained |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
"""This module provides a simple interface for downloading pre-trained models. | ||
It was inspired by the `load_model` module of [AbLang2](https://github.com/oxpig/AbLang2). | ||
""" | ||
|
||
import os | ||
import zipfile | ||
import pkg_resources | ||
|
||
import requests | ||
|
||
from netam.framework import load_crepe | ||
|
||
PRETRAINED_DIR = pkg_resources.resource_filename(__name__, "_pretrained") | ||
|
||
PACKAGE_LOCATIONS_AND_CONTENTS = ( | ||
# Order of entries: | ||
# * Local file name | ||
# * Remote URL | ||
# * Directory in which the models appear after extraction | ||
# * List of models in the package | ||
[ | ||
"thrifty-1.0.zip", | ||
"https://github.com/matsengrp/thrifty-models/archive/refs/heads/release/1.0.zip", | ||
"thrifty-models-release-1.0/models", | ||
[ | ||
"ThriftyHumV1.0-20", | ||
"ThriftyHumV1.0-45", | ||
"ThriftyHumV1.0-59", | ||
], | ||
], | ||
) | ||
|
||
LOCAL_TO_REMOTE = {} | ||
MODEL_TO_LOCAL = {} | ||
LOCAL_TO_DIR = {} | ||
|
||
for local_file, remote, models_dir, models in PACKAGE_LOCATIONS_AND_CONTENTS: | ||
LOCAL_TO_REMOTE[local_file] = remote | ||
|
||
for model in models: | ||
MODEL_TO_LOCAL[model] = local_file | ||
|
||
|
||
def local_path_for_model(model_name: str): | ||
"""Return the local path for a model, downloading it if necessary.""" | ||
|
||
if model_name not in MODEL_TO_LOCAL: | ||
raise ValueError(f"Model {model_name} not found in pre-trained models.") | ||
|
||
os.makedirs(PRETRAINED_DIR, exist_ok=True) | ||
|
||
local_package = MODEL_TO_LOCAL[model_name] | ||
local_package_path = os.path.join(PRETRAINED_DIR, local_package) | ||
|
||
if not os.path.exists(local_package_path): | ||
url = LOCAL_TO_REMOTE[local_package] | ||
print(f"Fetching models: downloading {url} to {local_package_path}") | ||
response = requests.get(url) | ||
response.raise_for_status() | ||
with open(local_package_path, "wb") as f: | ||
f.write(response.content) | ||
if local_package.endswith(".zip"): | ||
with zipfile.ZipFile(local_package_path, "r") as zip_ref: | ||
zip_ref.extractall(PRETRAINED_DIR) | ||
else: | ||
raise ValueError(f"Unknown file type for {local_package}") | ||
|
||
local_crepe_path = os.path.join(PRETRAINED_DIR, models_dir, model_name) | ||
|
||
if not os.path.exists(local_crepe_path + ".yml"): | ||
raise ValueError(f"Model {model_name} not found in pre-trained models.") | ||
if not os.path.exists(local_crepe_path + ".pth"): | ||
raise ValueError(f"Model {model_name} missing model weights.") | ||
|
||
return local_crepe_path | ||
|
||
|
||
def load(model_name: str): | ||
"""Load a pre-trained model. | ||
If the model is not already downloaded, it will be downloaded from the appropriate | ||
URL and stashed in the PRETRAINED_DIR. | ||
""" | ||
|
||
local_crepe_path = local_path_for_model(model_name) | ||
return load_crepe(local_crepe_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Notebooks | ||
|
||
This is a minimal set of notebooks to demonstrate `netam` in action. | ||
Most relevant notebooks are stored in separate repositories. | ||
See the main README for links to those repositories. |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
black | ||
docformatter | ||
fire | ||
nbconvert | ||
pytest | ||
snakemake | ||
tensorboardX | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
"optuna", | ||
"pandas", | ||
"pyyaml", | ||
"requests", | ||
"tensorboardX", | ||
"torch", | ||
"tqdm", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters