Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

minor edits #92

Merged
merged 8 commits into from
Oct 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
311 changes: 311 additions & 0 deletions colabdesign/mpnn/legacy/example.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,311 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "view-in-github"
},
"source": [
"<a href=\"https://colab.research.google.com/github/sokrypton/ColabDesign/blob/v1.1.0/colabdesign/mpnn/legacy/example.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7nofzcgaP96j",
"metadata": {
"id": "7nofzcgaP96j"
},
"outputs": [],
"source": [
"#@title install\n",
"%%bash\n",
"pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.0\n",
"# for debugging\n",
"ln -s /usr/local/lib/python3.7/dist-packages/colabdesign colabdesign"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "dcea2f7e",
"metadata": {
"id": "dcea2f7e"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"py3Dmol not installed\n"
]
}
],
"source": [
"import numpy as np\n",
"import os, sys\n",
"import joblib\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import re\n",
"import copy\n",
"import random\n",
"import haiku as hk\n",
"from tqdm import tqdm\n",
"from matplotlib import pyplot as plt\n",
"\n",
"from colabdesign.mpnn.legacy.wrapper import MPNN_wrapper"
]
},
{
"cell_type": "markdown",
"id": "931140c7",
"metadata": {
"id": "931140c7"
},
"source": [
"# Initialize model"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "23ef08a0",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "23ef08a0",
"outputId": "d5e1a8bb-bac9-4b0a-d99a-b94eb10eeef0"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of edges: 48\n",
"Training noise level: 0.02A\n"
]
}
],
"source": [
"params_path = '/content/colabdesign/mpnn/jax_weights'\n",
"model = MPNN_wrapper(params_path=params_path)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "b3896466",
"metadata": {},
"outputs": [],
"source": [
"# set the provided pdb\n",
"pdb_path='1P3J.pdb'\n",
"input = model.prep_input(pdb_path=pdb_path,\n",
" target_chain='A')"
]
},
{
"cell_type": "markdown",
"id": "1350bfbd",
"metadata": {
"id": "1350bfbd"
},
"source": [
"## Get the outputs from MPNN"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "228701fa",
"metadata": {
"id": "228701fa"
},
"outputs": [],
"source": [
"L = len(input['dataset_valid'][0]['seq'])\n",
"seed = random.randint(0,2147483647)\n",
"order = jax.random.normal(jax.random.PRNGKey(seed), (L,))\n",
"logits, log_probs = model.score(input, order=order)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "024d0ad9",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "024d0ad9",
"outputId": "4420ce1c-6f62-495c-8bec-30410838d5b7"
},
"outputs": [
{
"data": {
"text/plain": [
"((1, 212, 21), (1, 212, 21))"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"logits.shape, log_probs.shape"
]
},
{
"cell_type": "markdown",
"id": "3acf7c73",
"metadata": {},
"source": [
"## Generate sequences"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "T_fMvaY2bANv",
"metadata": {
"id": "T_fMvaY2bANv"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1/1 [00:23<00:00, 23.37s/it]\n"
]
}
],
"source": [
"# generate sequences\n",
"seqs = model.sampling(input, 1, 1)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "39616c44",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['MNIVLLGLPGSGKSTIGELICKDLGVPLISIDDIYVKAIKEKTPYGKEAEKYILKGKLVPNELTNGIIEKELSKEECKNGFVLDGYPRTVEEAEALEKILEKRGRPIDLVIYLECEEEVLRERLLTRLVCSKCFRSYNLVYRPPKTPGVCDECGAKLVVPKWDKPEVVEVRLKEYKERVEPLLEYFKEKGKLVKVDANKNEEEVYEDVKKLL']"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"seqs"
]
},
{
"cell_type": "markdown",
"id": "791d40a0",
"metadata": {},
"source": [
"## Genrate sequences for homomer"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "668ec046",
"metadata": {},
"outputs": [],
"source": [
"pdb_path='1O91.pdb'\n",
"input = model.prep_input(pdb_path=pdb_path,\n",
" target_chain='A B C',\n",
" ishomomer=True)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "34518ad9",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1/1 [00:17<00:00, 17.29s/it]\n"
]
}
],
"source": [
"seqs = model.sampling(input, 1, 1)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "58db1d7a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['EKEAFTALLTTPYPPVGEPIKFDKLLYNGQNVYDPATGIFTCKTPGVYFFSWNLNVYGKDLHVQLYKNDEAIQSSYMEYIEGKLSLTSGSAVLKLEKGDKVYLECPTEEANGLYAGEDVHSSFSGFLLYET/EKEAFTALLTTPYPPVGEPIKFDKLLYNGQNVYDPATGIFTCKTPGVYFFSWNLNVYGKDLHVQLYKNDEAIQSSYMEYIEGKLSLTSGSAVLKLEKGDKVYLECPTEEANGLYAGEDVHSSFSGFLLYET/EKEAFTALLTTPYPPVGEPIKFDKLLYNGQNVYDPATGIFTCKTPGVYFFSWNLNVYGKDLHVQLYKNDEAIQSSYMEYIEGKLSLTSGSAVLKLEKGDKVYLECPTEEANGLYAGEDVHSSFSGFLLYET']"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"seqs"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3f12f535",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"colab": {
"include_colab_link": true,
"provenance": []
},
"kernelspec": {
"display_name": "Python 3.8.13 ('jax038')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
},
"vscode": {
"interpreter": {
"hash": "d39b7156cdbdfdeaeb5cea9c1b6bf180b493eeb3b22ef2423895aed018ecbde9"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}
12 changes: 6 additions & 6 deletions colabdesign/mpnn/legacy/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import joblib

from colabdesign.shared.prng import SafeKey
from colabdesign.mpnn.utils import gather_edges, gather_nodes, cat_neighbors_nodes, scatter, get_ar_mask
from colabdesign.mpnn.sample import mpnn_sample
from .utils import gather_edges, gather_nodes, cat_neighbors_nodes, scatter, get_ar_mask
from .sample import mpnn_sample

Gelu = functools.partial(jax.nn.gelu, approximate=False)

Expand Down Expand Up @@ -166,14 +166,14 @@ def _forward_score(inputs):
def _forward_sample(inputs):
model = ProteinMPNN(**self.config)
return model.sample(**inputs)
self.sample = hk.transform(_forward_sample).apply
self.init_sample = hk.transform(_forward_sample).init
self.sample = jax.jit(hk.transform(_forward_sample).apply)
self.init_sample = jax.jit(hk.transform(_forward_sample).init)

def _forward_tsample(inputs):
model = ProteinMPNN(**self.config)
return model.tied_sample(**inputs)
self.tied_sample = hk.transform(_forward_tsample).apply
self.init_tsample = hk.transform(_forward_tsample).init
self.tied_sample = jax.jit(hk.transform(_forward_tsample).apply)
self.init_tsample = jax.jit(hk.transform(_forward_tsample).init)

def load_params(self, path):
self.params = joblib.load(path)
Expand Down
2 changes: 1 addition & 1 deletion colabdesign/mpnn/legacy/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import itertools

from colabdesign.mpnn.utils import gather_nodes, cat_neighbors_nodes, scatter, get_ar_mask
from .utils import gather_nodes, cat_neighbors_nodes, scatter, get_ar_mask

class mpnn_sample:
def sample(self, key, X, randn, S_true,
Expand Down
Loading