Skip to content

Commit

Permalink
Merge pull request #11 from haddocking/ntxxt-patch-1
Browse files Browse the repository at this point in the history
Ntxxt patch 1
  • Loading branch information
ntxxt authored May 24, 2024
2 parents a3511c4 + ff45355 commit 2e4c090
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,13 @@ We provide a command-line interface for DeepRank-GNN-esm that can easily be used

```bash
$ deeprank-gnn-esm-predict -h
usage: deeprank-gnn-esm-predict [-h] pdb_file chain_id_1 chain_id_2
usage: deeprank-gnn-esm-predict [-h] pdb_file chain_id_1 chain_id_2 model_name

positional arguments:
pdb_file Path to the PDB file.
chain_id_1 First chain ID.
chain_id_2 Second chain ID.
model_name pre_trained model weight

optional arguments:
-h, --help show this help message and exit
Expand All @@ -62,7 +63,8 @@ $ wget https://files.rcsb.org/view/1B6C.pdb -q

# make sure the environment is activated
$ conda activate deeprank-gnn-esm-gpu-env
(deeprank-gnn-esm-gpu-env) $ deeprank-gnn-esm-predict 1B6C.pdb A B
(deeprank-gnn-esm-gpu-env) $ export MODEL=../paper_pretrained_models/scoring_of_docking_models/gnn_esm/treg_yfnat_b64_e20_lr0.001_foldall_esm.pth.tar
(deeprank-gnn-esm-gpu-env) $ deeprank-gnn-esm-predict 1B6C.pdb A B $MODEL
2023-06-28 06:08:21,889 predict:64 INFO - Setting up workspace - /home/DeepRank-GNN-esm/1B6C-gnn_esm_pred_A_B
2023-06-28 06:08:21,945 predict:72 INFO - Renumbering PDB file.
2023-06-28 06:08:22,294 predict:104 INFO - Reading sequence of PDB 1B6C.pdb
Expand Down
20 changes: 9 additions & 11 deletions src/deeprank_gnn/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,8 @@
log.addHandler(ch)

# Constants
# TODO: Make these configurable
ESM_MODEL = "esm2_t33_650M_UR50D"
GNN_ESM_MODEL = "paper_pretrained_models/scoring_of_docking_models/gnn_esm/treg_yfnat_b64_e20_lr0.001_foldall_esm.pth.tar"
#GNN_ESM_MODEL = "paper_pretrained_models/scoring_of_docking_models/gnn_esm/treg_yfnat_b64_e20_lr0.001_foldall_esm.pth.tar"

TOKS_PER_BATCH = 4096
REPR_LAYERS = [0, 32, 33]
Expand All @@ -49,11 +48,7 @@
BATCH_SIZE = 64
DEVICE_NAME = "cuda" if torch.cuda.is_available() else "cpu" # configurable

"""
added two parameters in NeuralNet: num_workers and batch_size
default batch_size is 32, default num_workers is 1
for both, the higher the faster but depend on gpu capacity, should be configurable too
"""

###########################################################


Expand Down Expand Up @@ -212,7 +207,7 @@ def get_embedding(fasta_file: Path, output_dir: Path) -> None:


def create_graph(pdb_path: Path, workspace_path: Path) -> str:
"""Generate a graph ...?"""
"""Generate a graph """
log.info(f"Generating graph, using {NPROC} processors")

outfile = str(workspace_path / "graph.hdf5")
Expand All @@ -232,7 +227,7 @@ def create_graph(pdb_path: Path, workspace_path: Path) -> str:
return outfile


def predict(input: str, workspace_path: Path) -> str:
def predict(input: str, workspace_path: Path, model_path: str) -> str:
"""Predict the fnat of a protein complex."""
log.info("Predicting fnat of protein complex.")
gnn = GINet
Expand All @@ -256,7 +251,7 @@ def predict(input: str, workspace_path: Path) -> str:
num_workers=NPROC,
batch_size=BATCH_SIZE,
target=target,
pretrained_model=GNN_ESM_MODEL,
pretrained_model=model_path,
threshold=threshold,
)
model.test(hdf5=output)
Expand Down Expand Up @@ -307,11 +302,14 @@ def main():
parser.add_argument("pdb_file", help="Path to the PDB file.")
parser.add_argument("chain_id_1", help="First chain ID.")
parser.add_argument("chain_id_2", help="Second chain ID.")
parser.add_argument("model_path", help="Path to the pretrained model.")
args = parser.parse_args()

pdb_file = args.pdb_file
chain_id_1 = args.chain_id_1
chain_id_2 = args.chain_id_2
model_path = args.model_path


identificator = Path(pdb_file).stem + f"-gnn_esm_pred_{chain_id_1}_{chain_id_2}"
workspace_path = setup_workspace(identificator)
Expand All @@ -337,7 +335,7 @@ def main():
graph = create_graph(pdb_path=pdb_file.parent, workspace_path=workspace_path)

## Predict fnat
csv_output = predict(input=graph, workspace_path=workspace_path)
csv_output = predict(input=graph, workspace_path=workspace_path, model_path=model_path)

## Present the results
parse_output(csv_output=csv_output, workspace_path=workspace_path, chain_ids=[chain_id_1, chain_id_2])
Expand Down

0 comments on commit 2e4c090

Please sign in to comment.