Skip to content

Commit

Permalink
feat: adding REST API server for prediction (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
holtgrewe authored Sep 7, 2023
1 parent 48d504c commit 8bb7516
Show file tree
Hide file tree
Showing 12 changed files with 182 additions and 36 deletions.
5 changes: 2 additions & 3 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
- name: Install Python
uses: actions/setup-python@v2
with:
python-version: "3.8"
python-version: "3.9"

- name: Install dependencies
run: |
Expand Down Expand Up @@ -56,7 +56,7 @@ jobs:
- name: Install Python
uses: actions/setup-python@v2
with:
python-version: "3.8"
python-version: "3.9"

- name: Install dependencies
run: |
Expand All @@ -74,7 +74,6 @@ jobs:
strategy:
matrix:
python-version:
- '3.8'
- '3.9'
- '3.10'
- '3.11'
Expand Down
2 changes: 1 addition & 1 deletion CONTRIBUTING.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,4 @@ Before you submit a pull request, check that it meets these guidelines:
1. The pull request should include tests.
2. If the pull request adds functionality, the docs should be updated.
Put your new functionality into a function with a docstring.
3. The pull request should work for Python 3.8 to 3.11.
3. The pull request should work for Python 3.9 to 3.11.
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,7 @@ flake8:
mypy: export MYPYPATH=stubs
mypy:
mypy cada_prio tests

.PHONY: serve
serve:
uvicorn cada_prio.rest_server:app --host 0.0.0.0 --port 8080 --reload
77 changes: 51 additions & 26 deletions cada_prio/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,25 @@

import json
import os
import pickle
import typing

import cattrs
from gensim.models import Word2Vec
from logzero import logger
import networkx as nx
import numpy as np

from cada_prio import train_model


def load_hgnc_info(path_hgnc_json: str) -> typing.List[train_model.GeneIds]:
result = []
with open(path_hgnc_json, "rt") as f:
for line in f:
result.append(cattrs.structure(json.loads(line), train_model.GeneIds))
return result


def run(
path_model: str,
orig_hpo_terms: typing.List[str],
genes: typing.Optional[typing.List[str]] = None,
) -> int:
# Load and prepare data
def load_hgnc_info(path_model):
logger.info("Loading HGNC info...")
logger.info("- parsing")
hgnc_infos = load_hgnc_info(os.path.join(path_model, "hgnc_info.jsonl"))
hgnc_infos = []
path_hgnc_jsonl = os.path.join(path_model, "hgnc_info.jsonl")
with open(path_hgnc_jsonl, "rt") as f:
for line in f:
hgnc_infos.append(cattrs.structure(json.loads(line), train_model.GeneIds))
logger.info("- create mapping")
all_to_hgnc = {}
for record in hgnc_infos:
Expand All @@ -39,21 +30,28 @@ def run(
if record.ensembl_gene_id:
all_to_hgnc[record.ensembl_gene_id] = record
hgnc_info_by_id = {record.hgnc_id: record for record in hgnc_infos}
hgnc_ids = []
for gene in genes or []:
if gene in all_to_hgnc:
hgnc_ids.append(all_to_hgnc[gene].hgnc_id)
else:
logger.warning("could not resolve HGNC ID for gene %s", gene)
logger.info("... done loading HGNC info")
return all_to_hgnc, hgnc_info_by_id


def load_graph_model(path_model):
logger.info("Loading graph...")
graph = nx.read_gpickle(os.path.join(path_model, "graph.gpickle"))
with open(os.path.join(path_model, "graph.gpickle"), "rb") as inputf:
graph = pickle.load(inputf)
logger.info("... done loading graph")
logger.info("Loading model...")
model = Word2Vec.load(os.path.join(path_model, "model"))
logger.info("... done loading model")
return graph, model


class NoValidHpoTerms(ValueError):
pass


def run_prediction(
orig_hpo_terms, orig_genes, all_to_hgnc, graph, model
) -> typing.Tuple[typing.List[str], typing.Dict[str, float]]:
# Lookup HPO term embeddings.
hpo_terms = {}
for hpo_term in orig_hpo_terms:
Expand All @@ -63,7 +61,16 @@ def run(
hpo_terms[hpo_term] = model.wv[hpo_term]
if not hpo_terms:
logger.error("no valid HPO terms in model")
return 1
raise NoValidHpoTerms("no valid HPO terms in query")

# Map gene IDs to HGNC IDs
genes = []
for orig_gene in orig_genes or []:
print(all_to_hgnc.get(orig_gene))
if orig_gene in all_to_hgnc:
genes.append(all_to_hgnc[orig_gene].hgnc_id)
else:
genes.append(orig_gene)

# Generate a score for each gene in the knowledge graph
logger.info("Generating scores...")
Expand All @@ -81,11 +88,29 @@ def run(
this_gene_scores.append(score)
gene_scores[hgnc_id] = sum(this_gene_scores) / len(hpo_terms)

sorted_scores = dict(sorted(gene_scores.items(), key=lambda x: x[1], reverse=True))
return list(hpo_terms.keys()), sorted_scores


def run(
path_model: str,
orig_hpo_terms: typing.List[str],
orig_genes: typing.Optional[typing.List[str]] = None,
) -> int:
# Load and prepare data
all_to_hgnc, hgnc_info_by_id = load_hgnc_info(path_model)
graph, model = load_graph_model(path_model)
try:
hpo_terms, sorted_scores = run_prediction(
orig_hpo_terms, orig_genes, all_to_hgnc, graph, model
)
except NoValidHpoTerms:
return 1

# Write out results to stdout, largest score first
sorted_scores = sorted(gene_scores.items(), key=lambda x: x[1], reverse=True)
print("# query (len=%d): %s" % (len(hpo_terms), ",".join(hpo_terms)))
print("\t".join(["rank", "score", "gene_symbol", "ncbi_gene_id", "hgnc_id"]))
for rank, (hgnc_id, score) in enumerate(sorted_scores, start=1):
for rank, (hgnc_id, score) in enumerate(sorted_scores.items(), start=1):
hgnc_info = hgnc_info_by_id[hgnc_id]
print(
"\t".join(
Expand Down
82 changes: 82 additions & 0 deletions cada_prio/rest_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""REST API for CADA using FastAPI."""
from contextlib import asynccontextmanager
import os
import typing

from dotenv import load_dotenv
from fastapi import FastAPI, Query
from pydantic import BaseModel
from starlette.responses import Response

import cada_prio
from cada_prio import predict

# Load environment
env = os.environ
load_dotenv()

#: Debug mode
DEBUG = env.get("CADA_DEBUG", "false").lower() in ("true", "1")
#: Path to data / model
PATH_DATA = env.get("CADA_PATH_DATA", "/data/cada")

#: The CADA models, to be loaded on startup.
GLOBAL_STATIC = {}


@asynccontextmanager
async def lifespan(app: FastAPI):
# Load the models
all_to_hgnc, hgnc_info_by_id = predict.load_hgnc_info(PATH_DATA)
GLOBAL_STATIC["all_to_hgnc"] = all_to_hgnc
GLOBAL_STATIC["hgnc_info_by_id"] = hgnc_info_by_id
graph, model = predict.load_graph_model(PATH_DATA)
GLOBAL_STATIC["graph"] = graph
GLOBAL_STATIC["model"] = model

yield

GLOBAL_STATIC.clear()


app = FastAPI(lifespan=lifespan)


# Register endpoint for returning CADA version.
@app.get("/version")
async def version():
return Response(content=cada_prio.__version__)


class PredictionResult(BaseModel):
rank: int
score: float
gene_symbol: str
ncbi_gene_id: str
hgnc_id: str


# Register endpoint for the prediction
@app.get("/predict")
async def handle_predict(
hpo_terms: typing.Annotated[typing.List[str], Query()],
genes: typing.Annotated[typing.Optional[typing.List[str]], Query()] = [],
):
_, sorted_scores = predict.run_prediction(
hpo_terms,
genes,
GLOBAL_STATIC["all_to_hgnc"],
GLOBAL_STATIC["graph"],
GLOBAL_STATIC["model"],
)
hgnc_info_by_id = GLOBAL_STATIC["hgnc_info_by_id"]
return [
PredictionResult(
rank=rank,
score=score,
gene_symbol=hgnc_info_by_id[hgnc_id].symbol,
ncbi_gene_id=hgnc_info_by_id[hgnc_id].ncbi_gene_id,
hgnc_id=hgnc_info_by_id[hgnc_id].hgnc_id,
)
for rank, (hgnc_id, score) in enumerate(sorted_scores.items(), start=1)
]
4 changes: 3 additions & 1 deletion cada_prio/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import itertools
import json
import os
import pickle
import typing
import warnings

Expand Down Expand Up @@ -265,7 +266,8 @@ def write_graph_and_model(

path_graph = os.path.join(path_out, "graph.gpickle")
logger.info("Saving graph to %s...", path_graph)
nx.write_gpickle(training_graph, path_graph)
with open(path_graph, "wb") as outputf:
pickle.dump(training_graph, outputf)
logger.info("... done saving graph")

logger.info("Saving embedding to %s...", path_out)
Expand Down
3 changes: 3 additions & 0 deletions requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ tqdm >=4.0
pronto >=2.5, <3.0
networkx
node2vec >=0.4.6, <0.5
uvicorn >=0.23.2
fastapi >=0.103, <0.104
python-dotenv >=1.0, <2.0
2 changes: 2 additions & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ flake8 >=5.0.4, <7.0
pytest
pytest-cov
pytest-snapshot
pytest-asyncio
httpx >=0.24, <0.25

mypy ==1.5.1
types-python-dateutil >=2.8.19.3
Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def parse_requirements(path):
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
Expand Down
4 changes: 2 additions & 2 deletions tests/test_predict.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from cada_prio import predict


def test_predict_smoke_test(tmpdir):
predict.run("tests/data/model_smoke", "HP:0008551")
def test_predict_smoke_test():
predict.run("tests/data/model_smoke", ["HP:0008551"])
30 changes: 30 additions & 0 deletions tests/test_rest_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from unittest import mock

import pytest
from starlette.testclient import TestClient

from cada_prio import rest_server


@mock.patch("cada_prio.rest_server.PATH_DATA", "tests/data/model_smoke")
@pytest.mark.asyncio
async def test_version():
with TestClient(rest_server.app) as client:
response = client.get("/version")
assert response.status_code == 200


@mock.patch("cada_prio.rest_server.PATH_DATA", "tests/data/model_smoke")
@pytest.mark.asyncio
async def test_predict_with_gene():
with TestClient(rest_server.app) as client:
response = client.get("/predict/?hpo_terms=HP:0008551&hpo=HP:0000007&gene=MKS1")
assert response.status_code == 200


@mock.patch("cada_prio.rest_server.PATH_DATA", "tests/data/model_smoke")
@pytest.mark.asyncio
async def test_predict_without_gene():
with TestClient(rest_server.app) as client:
response = client.get("/predict/?hpo_terms=HP:0008551&hpo=HP:0000007")
assert response.status_code == 200
4 changes: 2 additions & 2 deletions utils/docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ ENV PATH="/.venv/bin:$PATH"

COPY cada_prio cada_prio
COPY requirements requirements
COPY setup.py requirements.txt README.md CHANGELOG.md .
COPY setup.py requirements.txt README.md CHANGELOG.md ./
RUN pip install .

RUN useradd --create-home cada_prio
WORKDIR /home/cada_prio
USER cada_prio

CMD ["uvicorn", "cada_prio.server:app", "--host", "0.0.0.0", "--port", "8080"]
CMD ["uvicorn", "cada_prio.rest_server:app", "--host", "0.0.0.0", "--port", "8080"]
EXPOSE 8080

0 comments on commit 8bb7516

Please sign in to comment.