Skip to content

Commit

Permalink
a few fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
melanibe committed Feb 16, 2024
1 parent 9fa8a5e commit 5668c05
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 820 deletions.
4 changes: 2 additions & 2 deletions configs/data/chexpert.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
defaults:
- base.yaml
- _self_
_target_: data_handling.padchest.CheXpertDataModule
dataset: padchest
_target_: data_handling.xray.CheXpertDataModule
dataset: chexpert
batch_size: 24
num_workers: 12
augmentations:
Expand Down
2 changes: 1 addition & 1 deletion configs/data/padchest.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
defaults:
- base.yaml
- _self_
_target_: data_handling.padchest.PadChestDataModule
_target_: data_handling.xray.PadChestDataModule
dataset: padchest
batch_size: 16
num_workers: 14
Expand Down
70 changes: 26 additions & 44 deletions evaluation/chexpert_pneumo.ipynb
Original file line number Diff line number Diff line change
@@ -1,40 +1,27 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# CheXpert evaluation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"import seaborn as sns\n",
"\n",
"sys.path.append(\"/vol/biomedic3/mb121/causal-contrastive\")\n",
"\n",
"from sklearn.metrics import roc_auc_score, balanced_accuracy_score\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from classification.classification_module import ClassificationModule\n",
"\n",
"from hydra import compose, initialize\n",
"import pandas as pd\n",
"from data_handling.padchest import CheXpertDataModule\n",
"\n",
"import matplotlib\n",
"from evaluation.helper_functions import (\n",
" extract_train_label_prop,\n",
" extract_run_type,\n",
" extract_pretraining_type,\n",
" extract_finetuning_type,\n",
" run_inference,\n",
")\n",
"\n",
"sns.set_theme(context=\"paper\", style=\"whitegrid\", font_scale=1.5)\n",
"matplotlib.rcParams[\"font.family\"] = \"serif\"\n",
"import os\n",
"\n",
"import pandas as pd\n",
"from hydra import compose, initialize\n",
"from sklearn.metrics import roc_auc_score\n",
"from data_handling.xray import CheXpertDataModule\n",
"from classification.classification_module import ClassificationModule\n",
"from evaluation.helper_functions import run_inference\n",
"os.chdir(\"/vol/biomedic3/mb121/causal-contrastive/evaluation\")"
]
},
Expand All @@ -44,7 +31,15 @@
"metadata": {},
"outputs": [],
"source": [
"model_dict_normal = {\n",
"# Mapping from human readable run name to Weights&Biases run_id. \n",
"\n",
"# Human readable name should be in format:\n",
"# for finetuning:\n",
"# {simclr/simclrcf/simclrcfaug}-{train_prop}-{seed}\n",
"# for linear probing\n",
"# {simclr/simclrcf/simclrcfaug}head-{train_prop}-{seed}\n",
"\n",
"model_dict_normal: dict[str, str] = {\n",
" \"simclr-1.0-33\": \"84lv0t6h\",\n",
" \"simclr-1.0-22\": \"yp8kgxkn\",\n",
" \"simclr-1.0-11\": \"wmbptk3z\",\n",
Expand Down Expand Up @@ -105,6 +100,9 @@
" \"simclrcfhead-0.25-11\": \"h92wu2up\",\n",
" \"simclrhead-0.1-33\": \"x7afc3ic\",\n",
" \"simclrhead-1.0-33\": \"gj75gyod\",\n",
" \"supervised-0.25-11\": 'af029hmt',\n",
" 'supervised-0.25-22': 'r5rzknzo',\n",
" 'supervised-0.25-33': 'yms2a9pj'\n",
"}"
]
},
Expand Down Expand Up @@ -150,11 +148,6 @@
" res = {}\n",
" res[\"N_test\"] = [inference_results[\"targets\"].shape[0]]\n",
" res[\"Scanner\"] = [\"CheXpert\"]\n",
" res[\"Bal_Acc\"] = [\n",
" balanced_accuracy_score(\n",
" inference_results[\"targets\"], np.argmax(inference_results[\"confs\"], 1)\n",
" )\n",
" ]\n",
" res[\"run_name\"] = run_name\n",
" res[\"ROC\"] = [\n",
" roc_auc_score(\n",
Expand All @@ -173,18 +166,7 @@
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df[\"ctrain_label_prop\"] = df.run_name.apply(extract_train_label_prop)\n",
"df[\"group\"] = df.run_name.apply(extract_run_type)\n",
"df[\"Pretraining\"] = df.group.apply(lambda x: extract_pretraining_type(x))\n",
"df[\"Classifier\"] = df.group.apply(lambda x: extract_finetuning_type(x))\n",
"tmp = (\n",
" df.groupby([\"Pretraining\", \"Classifier\", \"ctrain_label_prop\"])\n",
" .run_name.unique()\n",
" .apply(lambda x: len(x))\n",
")\n",
"tmp.loc[tmp < 3]"
]
"source": []
}
],
"metadata": {
Expand Down
Loading

0 comments on commit 5668c05

Please sign in to comment.