Fix load model test #8
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name: Tests | |
on: | |
- push | |
- pull_request | |
jobs: | |
# Installs the conda environment and trains METL | |
train: | |
name: Test METL training | |
runs-on: ${{ matrix.os }} | |
strategy: | |
fail-fast: false | |
matrix: | |
# Could also test on the beta M1 macOS runner | |
# https://docs.github.com/en/actions/using-github-hosted-runners/about-github-hosted-runners/about-github-hosted-runners#standard-github-hosted-runners-for-public-repositories | |
os: | |
- macos-latest | |
- ubuntu-latest | |
- windows-latest | |
steps: | |
- name: Checkout repository | |
uses: actions/checkout@v4 | |
# Can set up package caching later conda-incubator/setup-miniconda | |
- name: Install conda environment | |
uses: conda-incubator/setup-miniconda@v2 | |
with: | |
activate-environment: metl | |
environment-file: environment.yml | |
auto-activate-base: false | |
miniforge-variant: Mambaforge | |
miniforge-version: 'latest' | |
use-mamba: true | |
# Installs latest commit from main branch | |
- name: Install metl package from metl-pretrained repo | |
shell: bash --login {0} | |
run: pip install git+https://github.com/gitter-lab/metl-pretrained | |
# Log conda environment contents | |
- name: Log conda environment | |
shell: bash --login {0} | |
run: conda list | |
# Pretrain source model on GFP Rosetta dataset | |
- name: Pretrain source METL model | |
shell: bash --login {0} | |
run: python code/train_source_model.py @args/pretrain_avgfp_local.txt --max_epochs 5 --limit_train_batches 5 --limit_val_batches 5 --limit_test_batches 5 | |
# Finetune target model on GFP DMS dataset | |
- name: Finetune target METL model | |
shell: bash --login {0} | |
run: python code/train_target_model.py @args/finetune_avgfp_local.txt --enable_progress_bar false --enable_simple_progress_messages --max_epochs 50 --unfreeze_backbone_at_epoch 25 | |
# Load target model checkpoint and run inference on example variants | |
- name: Load and test target METL model | |
shell: bash --login {0} | |
run: | | |
python code/convert_ckpt.py output/training_logs/DgLkMZxu/checkpoints/epoch=49-step=50.ckpt | |
python code/tests.py --checkpoint_path output/training_logs/DgLkMZxu/checkpoints/DgLkMZxu.pt --variants E3K,G102S;T36P,S203T,K207R;V10A,D19G,F25S,E113V --dataset avgfp |