Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ntxxt patch 1 #11

Merged
merged 2 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,13 @@ We provide a command-line interface for DeepRank-GNN-esm that can easily be used

```bash
$ deeprank-gnn-esm-predict -h
usage: deeprank-gnn-esm-predict [-h] pdb_file chain_id_1 chain_id_2
usage: deeprank-gnn-esm-predict [-h] pdb_file chain_id_1 chain_id_2 model_name

positional arguments:
pdb_file Path to the PDB file.
chain_id_1 First chain ID.
chain_id_2 Second chain ID.
model_name pre_trained model weight

optional arguments:
-h, --help show this help message and exit
Expand All @@ -62,7 +63,8 @@ $ wget https://files.rcsb.org/view/1B6C.pdb -q

# make sure the environment is activated
$ conda activate deeprank-gnn-esm-gpu-env
(deeprank-gnn-esm-gpu-env) $ deeprank-gnn-esm-predict 1B6C.pdb A B
(deeprank-gnn-esm-gpu-env) $ export MODEL=../paper_pretrained_models/scoring_of_docking_models/gnn_esm/treg_yfnat_b64_e20_lr0.001_foldall_esm.pth.tar
(deeprank-gnn-esm-gpu-env) $ deeprank-gnn-esm-predict 1B6C.pdb A B $MODEL
2023-06-28 06:08:21,889 predict:64 INFO - Setting up workspace - /home/DeepRank-GNN-esm/1B6C-gnn_esm_pred_A_B
2023-06-28 06:08:21,945 predict:72 INFO - Renumbering PDB file.
2023-06-28 06:08:22,294 predict:104 INFO - Reading sequence of PDB 1B6C.pdb
Expand Down
20 changes: 9 additions & 11 deletions src/deeprank_gnn/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,8 @@
log.addHandler(ch)

# Constants
# TODO: Make these configurable
ESM_MODEL = "esm2_t33_650M_UR50D"
GNN_ESM_MODEL = "paper_pretrained_models/scoring_of_docking_models/gnn_esm/treg_yfnat_b64_e20_lr0.001_foldall_esm.pth.tar"
#GNN_ESM_MODEL = "paper_pretrained_models/scoring_of_docking_models/gnn_esm/treg_yfnat_b64_e20_lr0.001_foldall_esm.pth.tar"

TOKS_PER_BATCH = 4096
REPR_LAYERS = [0, 32, 33]
Expand All @@ -49,11 +48,7 @@
BATCH_SIZE = 64
DEVICE_NAME = "cuda" if torch.cuda.is_available() else "cpu" # configurable

"""
added two parameters in NeuralNet: num_workers and batch_size
default batch_size is 32, default num_workers is 1
for both, the higher the faster but depend on gpu capacity, should be configurable too
"""

###########################################################


Expand Down Expand Up @@ -212,7 +207,7 @@ def get_embedding(fasta_file: Path, output_dir: Path) -> None:


def create_graph(pdb_path: Path, workspace_path: Path) -> str:
"""Generate a graph ...?"""
"""Generate a graph """
log.info(f"Generating graph, using {NPROC} processors")

outfile = str(workspace_path / "graph.hdf5")
Expand All @@ -232,7 +227,7 @@ def create_graph(pdb_path: Path, workspace_path: Path) -> str:
return outfile


def predict(input: str, workspace_path: Path) -> str:
def predict(input: str, workspace_path: Path, model_path: str) -> str:
"""Predict the fnat of a protein complex."""
log.info("Predicting fnat of protein complex.")
gnn = GINet
Expand All @@ -256,7 +251,7 @@ def predict(input: str, workspace_path: Path) -> str:
num_workers=NPROC,
batch_size=BATCH_SIZE,
target=target,
pretrained_model=GNN_ESM_MODEL,
pretrained_model=model_path,
threshold=threshold,
)
model.test(hdf5=output)
Expand Down Expand Up @@ -307,11 +302,14 @@ def main():
parser.add_argument("pdb_file", help="Path to the PDB file.")
parser.add_argument("chain_id_1", help="First chain ID.")
parser.add_argument("chain_id_2", help="Second chain ID.")
parser.add_argument("model_path", help="Path to the pretrained model.")
args = parser.parse_args()

pdb_file = args.pdb_file
chain_id_1 = args.chain_id_1
chain_id_2 = args.chain_id_2
model_path = args.model_path


identificator = Path(pdb_file).stem + f"-gnn_esm_pred_{chain_id_1}_{chain_id_2}"
workspace_path = setup_workspace(identificator)
Expand All @@ -337,7 +335,7 @@ def main():
graph = create_graph(pdb_path=pdb_file.parent, workspace_path=workspace_path)

## Predict fnat
csv_output = predict(input=graph, workspace_path=workspace_path)
csv_output = predict(input=graph, workspace_path=workspace_path, model_path=model_path)

## Present the results
parse_output(csv_output=csv_output, workspace_path=workspace_path, chain_ids=[chain_id_1, chain_id_2])
Expand Down
Loading