Skip to content

Bump connectorx version #9

Bump connectorx version

Bump connectorx version #9

Workflow file for this run

name: Tests
on:
- push
- pull_request
jobs:
# Installs the conda environment and trains METL
train:
name: Test METL training
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
# Could also test on the beta M1 macOS runner
# https://docs.github.com/en/actions/using-github-hosted-runners/about-github-hosted-runners/about-github-hosted-runners#standard-github-hosted-runners-for-public-repositories
os:
- macos-latest
- ubuntu-latest
- windows-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
# Can set up package caching later conda-incubator/setup-miniconda
- name: Install conda environment
uses: conda-incubator/setup-miniconda@v2
with:
activate-environment: metl
environment-file: environment.yml
auto-activate-base: false
miniforge-variant: Mambaforge
miniforge-version: 'latest'
use-mamba: true
# Installs latest commit from main branch
- name: Install metl package from metl-pretrained repo
shell: bash --login {0}
run: pip install git+https://github.com/gitter-lab/metl-pretrained
# Log conda environment contents
- name: Log conda environment
shell: bash --login {0}
run: conda list
# Pretrain source model on GFP Rosetta dataset
- name: Pretrain source METL model
shell: bash --login {0}
run: python code/train_source_model.py @args/pretrain_avgfp_local.txt --max_epochs 5 --limit_train_batches 5 --limit_val_batches 5 --limit_test_batches 5
# Finetune target model on GFP DMS dataset
- name: Finetune target METL model
shell: bash --login {0}
run: python code/train_target_model.py @args/finetune_avgfp_local.txt --enable_progress_bar false --enable_simple_progress_messages --max_epochs 50 --unfreeze_backbone_at_epoch 25
# Load target model checkpoint and run inference on example variants
- name: Load and test target METL model
shell: bash --login {0}
run: |
python code/convert_ckpt.py output/training_logs/DgLkMZxu/checkpoints/epoch=49-step=50.ckpt
python code/tests.py --checkpoint_path output/training_logs/DgLkMZxu/checkpoints/DgLkMZxu.pt --variants E3K,G102S;T36P,S203T,K207R;V10A,D19G,F25S,E113V --dataset avgfp