Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more strict typing #1253

Merged
merged 10 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/bioregistry/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

"""Command line interface for the bioregistry."""

from .cli import main
from .cli import main # type:ignore

if __name__ == "__main__":
main()
24 changes: 17 additions & 7 deletions src/bioregistry/analysis/bioregistry_diff.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Given two dates, analyzes and visualizes changes in the Bioregistry."""

import json
from __future__ import annotations

import datetime
import logging
from typing import Any

import click
import matplotlib.pyplot as plt
Expand All @@ -21,7 +24,12 @@
FILE_PATH = "src/bioregistry/data/bioregistry.json"


def get_commit_before_date(date, owner=REPO_OWNER, name=REPO_NAME, branch=BRANCH):
def get_commit_before_date(
date: datetime.date,
owner: str = REPO_OWNER,
name: str = REPO_NAME,
branch: str = BRANCH,
) -> str | None:
"""Return the last commit before a given date.

:param date: The date to get the commit before.
Expand All @@ -46,7 +54,9 @@ def get_commit_before_date(date, owner=REPO_OWNER, name=REPO_NAME, branch=BRANCH
return None


def get_file_at_commit(file_path, commit_sha, owner=REPO_OWNER, name=REPO_NAME):
def get_file_at_commit(
file_path: str, commit_sha: str, owner: str = REPO_OWNER, name: str = REPO_NAME
) -> dict[str, Any]:
"""Return the content of a given file at a specific commit.

:param file_path: The file path in the repository.
Expand All @@ -61,9 +71,9 @@ def get_file_at_commit(file_path, commit_sha, owner=REPO_OWNER, name=REPO_NAME):
response.raise_for_status()
file_info = response.json()
download_url = file_info["download_url"]
file_content_response = requests.get(download_url)
file_content_response.raise_for_status()
return json.loads(file_content_response.text)
res = requests.get(download_url)
res.raise_for_status()
return res.json()


def compare_bioregistry(old_data, new_data):
Expand Down Expand Up @@ -246,7 +256,7 @@ def compare_dates(date1, date2):
:param date1: The starting date in the format YYYY-MM-DD.
:param date2: The ending date in the format YYYY-MM-DD.
"""
added, deleted, updated, update_details, old_data, new_data, all_mapping_keys = get_data(
added, deleted, updated, update_details, _old_data, _new_data, all_mapping_keys = get_data(
date1, date2
)
if added is not None and updated is not None:
Expand Down
139 changes: 76 additions & 63 deletions src/bioregistry/analysis/paper_ranking.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,48 @@
"""Train a TF-IDF classifier and use it to score the relevance of new PubMed papers to the Bioregistry."""

from __future__ import annotations

import datetime
import json
from collections import defaultdict
from pathlib import Path

import click
import indra.literature.pubmed_client as pubmed_client
import numpy as np
import pandas as pd
from numpy.typing import NDArray
from sklearn.base import ClassifierMixin
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import matthews_corrcoef, roc_auc_score
from sklearn.model_selection import cross_val_predict, train_test_split
from sklearn.svm import SVC, LinearSVC
from sklearn.tree import DecisionTreeClassifier
from tabulate import tabulate

DIRECTORY = Path("exports/analyses/paper_ranking")
HERE = Path(__file__).parent.resolve()
ROOT = HERE.parent.parent.parent.resolve()

BIOREGISTRY_PATH = ROOT.joinpath("src", "bioregistry", "data", "bioregistry.json")

DIRECTORY = ROOT.joinpath("exports", "analyses", "paper_ranking")
DIRECTORY.mkdir(exist_ok=True, parents=True)

URL = "https://docs.google.com/spreadsheets/d/e/2PACX-1vRPtP-tcXSx8zvhCuX6fqz_\
QvHowyAoDahnkixARk9rFTe0gfBN9GfdG6qTNQHHVL0i33XGSp_nV9XM/pub?output=csv"


def load_bioregistry_json(file_path):
def load_bioregistry_json(path: Path | None = None) -> pd.DataFrame:
"""Load bioregistry data from a JSON file, extracting publication details and fetching abstracts if missing.

:param file_path: Path to the bioregistry JSON file.
:type file_path: str
:param path: Path to the bioregistry JSON file.
:return: DataFrame containing publication details.
:rtype: pd.DataFrame
"""
if path is None:
path = BIOREGISTRY_PATH
try:
with open(file_path, "r") as f:
data = json.load(f)
data = json.loads(path.read_text())
except json.JSONDecodeError as e:
click.echo(f"JSONDecodeError: {e.msg}")
click.echo(f"Error at line {e.lineno}, column {e.colno}")
Expand All @@ -58,67 +68,62 @@ def load_bioregistry_json(file_path):
if pub["pubmed"] in fetched_metadata:
pub["abstract"] = fetched_metadata[pub["pubmed"]].get("abstract", "")

click.echo(f"Got {len(publications)} publications from the bioregistry")
click.echo(f"Got {len(publications):;} publications from the bioregistry")

return pd.DataFrame(publications)


def fetch_pubmed_papers():
def fetch_pubmed_papers() -> pd.DataFrame:
"""Fetch PubMed papers from the last 30 days using specific search terms.

:return: DataFrame containing PubMed paper details.
:rtype: pd.DataFrame
"""
click.echo("Starting fetch_pubmed_papers")

search_terms = ["database", "ontology", "resource", "vocabulary", "nomenclature"]
paper_to_terms = {}
paper_to_terms: defaultdict[str, list[str]] = defaultdict(list)

for term in search_terms:
pmids = pubmed_client.get_ids(term, use_text_word=True, reldate=30)
for pmid in pmids:
if pmid in paper_to_terms:
paper_to_terms[pmid].append(term)
else:
paper_to_terms[pmid] = [term]
pubmed_ids = pubmed_client.get_ids(term, use_text_word=True, reldate=30)
for pubmed_id in pubmed_ids:
paper_to_terms[pubmed_id].append(term)

all_pmids = list(paper_to_terms.keys())
click.echo(f"{len(all_pmids)} PMIDs found")
click.echo(f"{len(all_pmids):;} articles found")
if not all_pmids:
click.echo(f"No PMIDs found for the last 30 days with the search terms: {search_terms}")
click.echo(f"No articles found for the last 30 days with the search terms: {search_terms}")
return pd.DataFrame()

papers = {}
for chunk in [all_pmids[i : i + 200] for i in range(0, len(all_pmids), 200)]:
papers.update(pubmed_client.get_metadata_for_ids(chunk, get_abstracts=True))

records = []
for pmid, paper in papers.items():
for pubmed_id, paper in papers.items():
title = paper.get("title")
abstract = paper.get("abstract", "")

if title and abstract:
records.append(
{
"pubmed": pmid,
"pubmed": pubmed_id,
"title": title,
"abstract": abstract,
"year": paper.get("publication_date", {}).get("year"),
"search_terms": paper_to_terms.get(pmid),
"search_terms": paper_to_terms.get(pubmed_id),
}
)

click.echo(f"{len(records)} records fetched from PubMed")
click.echo(f"{len(records):,} records fetched from PubMed")
return pd.DataFrame(records)


def load_curation_data():
def load_curation_data() -> pd.DataFrame:
"""Download and load curation data from a Google Sheets URL.

:return: DataFrame containing curated publication details.
:rtype: pd.DataFrame
"""
click.echo("Downloading curation")
click.echo("Downloading curation sheet")
df = pd.read_csv(URL)
df["label"] = df["relevant"].map(_map_labels)
df = df[["pubmed", "title", "abstract", "label"]]
Expand All @@ -136,13 +141,11 @@ def load_curation_data():
return df


def _map_labels(s: str):
def _map_labels(s: str) -> int | None:
"""Map labels to binary values.

:param s: Label value.
:type s: str
:return: Mapped binary label value.
:rtype: int
"""
if s in {"1", "1.0", 1}:
return 1
Expand All @@ -151,15 +154,15 @@ def _map_labels(s: str):
return None


def train_classifiers(x_train, y_train):
Classifiers = list[tuple[str, ClassifierMixin]]


def train_classifiers(x_train: NDArray[np.float64], y_train: NDArray[np.str_]) -> Classifiers:
"""Train multiple classifiers on the training data.

:param x_train: Training features.
:type x_train: array-like
:param y_train: Training labels.
:type y_train: array-like
:return: List of trained classifiers.
:rtype: list
"""
classifiers = [
("rf", RandomForestClassifier()),
Expand All @@ -173,17 +176,15 @@ def train_classifiers(x_train, y_train):
return classifiers


def generate_meta_features(classifiers, x_train, y_train):
def generate_meta_features(
classifiers: Classifiers, x_train: NDArray[np.float64], y_train: NDArray[np.str_]
) -> pd.DataFrame:
"""Generate meta-features for training a meta-classifier using cross-validation predictions.

:param classifiers: List of trained classifiers.
:type classifiers: list
:param x_train: Training features.
:type x_train: array-like
:param y_train: Training labels.
:type y_train: array-like
:return: DataFrame containing meta-features.
:rtype: pd.DataFrame
"""
meta_features = pd.DataFrame()
for name, clf in classifiers:
Expand All @@ -197,42 +198,42 @@ def generate_meta_features(classifiers, x_train, y_train):
return meta_features


def evaluate_meta_classifier(meta_clf, x_test_meta, y_test):
def evaluate_meta_classifier(
meta_clf: ClassifierMixin, x_test_meta: NDArray[np.float64], y_test: NDArray[np.str_]
) -> tuple[float, float]:
"""Evaluate meta-classifier using MCC and AUC-ROC scores.

:param meta_clf: Trained meta-classifier.
:type meta_clf: classifier
:param x_test_meta: Test meta-features.
:type x_test_meta: array-like
:param y_test: Test labels.
:type y_test: array-like
:return: MCC and AUC-ROC scores.
:rtype: tuple
"""
y_pred = meta_clf.predict(x_test_meta)
mcc = matthews_corrcoef(y_test, y_pred)
roc_auc = roc_auc_score(y_test, meta_clf.predict_proba(x_test_meta)[:, 1])
return mcc, roc_auc


def truncate_text(text, max_length):
def truncate_text(text: str, max_length: int) -> str:
"""Truncate text to a specified maximum length."""
# FIXME replace with builtin textwrap function
return text if len(text) <= max_length else text[:max_length] + "..."


def predict_and_save(df, vectorizer, classifiers, meta_clf, filename):
def predict_and_save(
df: pd.DataFrame,
vectorizer: TfidfVectorizer,
classifiers: Classifiers,
meta_clf: ClassifierMixin,
filename: str | Path,
) -> None:
"""Predict and save scores for new data using trained classifiers and meta-classifier.

:param df: DataFrame containing new data.
:type df: pd.DataFrame
:param vectorizer: Trained TF-IDF vectorizer.
:type vectorizer: TfidfVectorizer
:param classifiers: List of trained classifiers.
:type classifiers: list
:param meta_clf: Trained meta-classifier.
:type meta_clf: classifier
:param filename: Filename to save the predictions.
:type filename: str
"""
x_meta = pd.DataFrame()
x_transformed = vectorizer.transform(df.title + " " + df.abstract)
Expand All @@ -249,23 +250,35 @@ def predict_and_save(df, vectorizer, classifiers, meta_clf, filename):
click.echo(f"Wrote predicted scores to {DIRECTORY.joinpath(filename)}")


def _first_of_month() -> str:
today = datetime.date.today()
return datetime.date(today.year, today.month, 1).isoformat()


@click.command()
@click.option(
"--bioregistry-file",
default="src/bioregistry/data/bioregistry.json",
type=Path,
help="Path to the bioregistry.json file",
)
@click.option("--start-date", required=True, help="Start date of the period")
@click.option("--end-date", required=True, help="End date of the period")
def main(bioregistry_file, start_date, end_date):
@click.option(
"--start-date",
required=True,
help="Start date of the period",
default=_first_of_month,
)
@click.option(
"--end-date",
required=True,
help="End date of the period",
default=lambda x: datetime.date.today().isoformat(),
)
def main(bioregistry_file: Path, start_date: str, end_date: str) -> None:
"""Load data, train classifiers, evaluate models, and predict new data.

:param bioregistry_file: Path to the bioregistry JSON file.
:type bioregistry_file: str
:param start_date: The start date of the period for which papers are being ranked.
:type start_date: str
:param end_date: The end date of the period for which papers are being ranked.
:type end_date: str
"""
publication_df = load_bioregistry_json(bioregistry_file)
curation_df = load_curation_data()
Expand Down Expand Up @@ -295,7 +308,7 @@ def main(bioregistry_file, start_date, end_date):
try:
mcc = matthews_corrcoef(y_test, y_pred)
except ValueError as e:
click.secho(f"{clf} failed to calculate MCC: {e}", fg="yellow")
click.secho(f"{clf} failed to calculate MCC: {e:.2f}", fg="yellow")
mcc = None
try:
if hasattr(clf, "predict_proba"):
Expand All @@ -310,7 +323,7 @@ def main(bioregistry_file, start_date, end_date):
scores.append((name, mcc or float("nan"), roc_auc or float("nan")))

evaluation_df = pd.DataFrame(scores, columns=["classifier", "mcc", "auc_roc"]).round(3)
click.echo(tabulate(evaluation_df, showindex=False, headers=evaluation_df.columns))
click.echo(evaluation_df.to_markdown(index=False))

meta_features = generate_meta_features(classifiers, x_train, y_train)
meta_clf = LogisticRegression()
Expand All @@ -323,8 +336,8 @@ def main(bioregistry_file, start_date, end_date):
else:
x_test_meta[name] = clf.decision_function(x_test)

mcc, roc_auc = evaluate_meta_classifier(meta_clf, x_test_meta, y_test)
click.echo(f"Meta-Classifier MCC: {mcc}, AUC-ROC: {roc_auc}")
mcc, roc_auc = evaluate_meta_classifier(meta_clf, x_test_meta.to_numpy(), y_test)
click.echo(f"Meta-Classifier MCC: {mcc:.2f}, AUC-ROC: {roc_auc:.2f}")
new_row = {"classifier": "meta_classifier", "mcc": mcc, "auc_roc": roc_auc}
evaluation_df = pd.concat([evaluation_df, pd.DataFrame([new_row])], ignore_index=True)

Expand All @@ -349,7 +362,7 @@ def main(bioregistry_file, start_date, end_date):
.sort_values("rf_importance", ascending=False, key=abs)
.round(4)
)
click.echo(tabulate(importances_df.head(15), showindex=False, headers=importances_df.columns))
click.echo(importances_df.head(15).to_markdown(index=False))

importance_path = DIRECTORY.joinpath("importances.tsv")
click.echo(f"Writing feature (word) importances to {importance_path}")
Expand Down
Loading
Loading