From 4eda4112f134d3d4e75d6e710005ba0cefc80dc0 Mon Sep 17 00:00:00 2001 From: Joseph Xu Date: Mon, 16 Dec 2024 17:02:24 -0800 Subject: [PATCH] Add cell to evaluate fine-tuned model to assessment notebook. PiperOrigin-RevId: 706881343 --- src/colab/skai_assessment_notebook.ipynb | 460 ++++++++++++++++++----- src/colab/skai_assessment_notebook.py | 378 ++++++++++++++++--- src/colab/sync_notebook_source.py | 3 + 3 files changed, 709 insertions(+), 132 deletions(-) diff --git a/src/colab/skai_assessment_notebook.ipynb b/src/colab/skai_assessment_notebook.ipynb index 04ae3f4f..ddb8f5cb 100644 --- a/src/colab/skai_assessment_notebook.ipynb +++ b/src/colab/skai_assessment_notebook.ipynb @@ -3,7 +3,57 @@ { "cell_type": "code", "execution_count": null, - "id": "07ff36dd", + "id": "eae43051", + "metadata": { + "cellView": "form" + }, + "outputs": [], + "source": [ + "# @title Install Libraries\n", + "# @markdown This will take approximately 1 minute to run. After completing, you\n", + "# @markdown may be prompted to restart the kernel. Select \"Restart\" and then\n", + "# @markdown proceed to run the next cell.\n", + "\"\"\"Notebook for running SKAI assessments.\"\"\"\n", + "\n", + "# pylint: disable=g-statement-before-imports\n", + "SKAI_REPO = 'https://github.com/google-research/skai.git'\n", + "SKAI_CODE_DIR = '/content/skai_src'\n", + "\n", + "\n", + "def install_requirements():\n", + " \"\"\"Installs necessary Python libraries.\"\"\"\n", + " !rm -rf {SKAI_CODE_DIR}\n", + " !git clone {SKAI_REPO} {SKAI_CODE_DIR}\n", + " !pip install {SKAI_CODE_DIR}/src/.\n", + "\n", + " requirements = [\n", + " 'apache_beam[gcp]==2.54.0',\n", + " 'fiona',\n", + " # https://github.com/apache/beam/issues/32169\n", + " 'google-cloud-storage>=2.18.2',\n", + " 'ml-collections',\n", + " 'openlocationcode',\n", + " 'rasterio',\n", + " 'rio-cogeo',\n", + " 'rtree',\n", + " 'tensorflow==2.14.0',\n", + " 'tensorflow_addons',\n", + " 'tensorflow_text',\n", + " 'xmanager',\n", + " ]\n", + "\n", + " requirements_file = '/content/requirements.txt'\n", + " with open(requirements_file, 'w') as f:\n", + " f.write('\\n'.join(requirements))\n", + " !pip install -r {requirements_file}\n", + "\n", + "install_requirements()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "04a304e2", "metadata": { "cellView": "form" }, @@ -18,7 +68,6 @@ "\n", "# @markdown You must re-run this cell every time you make a change.\n", "import os\n", - "import textwrap\n", "import ee\n", "from google.colab import auth\n", "\n", @@ -58,14 +107,15 @@ "AFTER_IMAGE_9 = '' # @param {type:\"string\"}\n", "\n", "# Constants\n", - "SKAI_REPO = 'https://github.com/google-research/skai.git'\n", "OPEN_BUILDINGS_FEATURE_COLLECTION = 'GOOGLE/Research/open-buildings/v3/polygons'\n", "OSM_OVERPASS_URL = 'https://lz4.overpass-api.de/api/interpreter'\n", "TRAIN_TFRECORD_NAME = 'labeled_examples_train.tfrecord'\n", "TEST_TFRECORD_NAME = 'labeled_examples_test.tfrecord'\n", + "HIGH_RECALL = 0.7\n", + "HIGH_PRECISION = 0.7\n", + "INFERENCE_BATCH_SIZE = 8\n", "\n", "# Derived variables\n", - "SKAI_CODE_DIR = '/content/skai_src'\n", "AOI_PATH = os.path.join(OUTPUT_DIR, 'aoi.geojson')\n", "BUILDINGS_FILE_LOG = os.path.join(OUTPUT_DIR, 'buildings_file_log.txt')\n", "EXAMPLE_GENERATION_CONFIG_PATH = os.path.join(\n", @@ -123,7 +173,7 @@ }, { "cell_type": "markdown", - "id": "c43c271a", + "id": "ac665ca1", "metadata": {}, "source": [ "#Initialization" @@ -132,49 +182,7 @@ { "cell_type": "code", "execution_count": null, - "id": "66b30da4", - "metadata": { - "cellView": "form", - "lines_to_next_cell": 1 - }, - "outputs": [], - "source": [ - "# @title Install Libraries\n", - "# @markdown This will take approximately 1 minute to run. After completing, you\n", - "# @markdown may be prompted to restart the kernel. Select \"Restart\" and then\n", - "# @markdown proceed to run the next cell.\n", - "def install_requirements():\n", - " \"\"\"Installs necessary Python libraries.\"\"\"\n", - " !rm -rf {SKAI_CODE_DIR}\n", - " !git clone {SKAI_REPO} {SKAI_CODE_DIR}\n", - " !pip install {SKAI_CODE_DIR}/src/.\n", - "\n", - " requirements = textwrap.dedent('''\n", - " apache_beam[gcp]==2.54.0\n", - " google-cloud-storage>=2.18.2 # https://github.com/apache/beam/issues/32169\n", - " ml-collections\n", - " openlocationcode\n", - " rasterio\n", - " rio-cogeo\n", - " rtree\n", - " tensorflow==2.14.0\n", - " tensorflow_addons\n", - " tensorflow_text\n", - " xmanager\n", - " ''')\n", - "\n", - " requirements_file = '/content/requirements.txt'\n", - " with open(requirements_file, 'w') as f:\n", - " f.write(requirements)\n", - " !pip install -r {requirements_file}\n", - "\n", - "install_requirements()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c6ccd086", + "id": "8db9f358", "metadata": { "cellView": "form" }, @@ -192,7 +200,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e1a73acf", + "id": "4e48e166", "metadata": { "cellView": "form" }, @@ -207,6 +215,7 @@ "import math\n", "import shutil\n", "import subprocess\n", + "import textwrap\n", "import time\n", "import warnings\n", "\n", @@ -220,10 +229,13 @@ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", + "import seaborn as sns\n", "import shapely.wkt\n", "from skai import earth_engine as skai_ee\n", "from skai import labeling\n", "from skai import open_street_map\n", + "from skai.model import inference_lib\n", + "import sklearn.metrics\n", "import tensorflow as tf\n", "import tqdm.notebook\n", "\n", @@ -403,6 +415,26 @@ " return sorted(model_dirs, reverse=True)\n", "\n", "\n", + "def get_best_checkpoint(model_dir: str):\n", + " \"\"\"Finds the checkpoint subdirectory with the highest AUPRC.\n", + "\n", + " Args:\n", + " model_dir: Model directory.\n", + "\n", + " Returns:\n", + " Checkpoint directory path.\n", + " \"\"\"\n", + " checkpoint_dirs = tf.io.gfile.glob(os.path.join(model_dir, 'epoch-*-aucpr-*'))\n", + " best_checkpoint = None\n", + " best_aucpr = 0\n", + " for checkpoint in checkpoint_dirs:\n", + " aucpr = float(checkpoint.split('-')[-1])\n", + " if aucpr > best_aucpr:\n", + " best_checkpoint = checkpoint\n", + " best_aucpr = aucpr\n", + " return best_checkpoint\n", + "\n", + "\n", "def find_labeling_image_metadata_files(labeling_images_dir: str):\n", " return tf.io.gfile.glob(os.path.join(\n", " labeling_images_dir, '*', 'image_metadata.csv'))\n", @@ -442,7 +474,7 @@ }, { "cell_type": "markdown", - "id": "46e19485", + "id": "d6c3d749", "metadata": {}, "source": [ "# Check Assessment Status\n", @@ -454,7 +486,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a3ce47df", + "id": "3c560e8c", "metadata": { "cellView": "form" }, @@ -514,7 +546,7 @@ }, { "cell_type": "markdown", - "id": "4c1b8072", + "id": "dc765bfa", "metadata": {}, "source": [ "# Example Generation" @@ -523,7 +555,7 @@ { "cell_type": "code", "execution_count": null, - "id": "197a4696", + "id": "72d3d0c9", "metadata": { "cellView": "form" }, @@ -561,7 +593,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ba0ebc4e", + "id": "e835a795", "metadata": { "cellView": "form" }, @@ -689,7 +721,7 @@ { "cell_type": "code", "execution_count": null, - "id": "384190bb", + "id": "04e89c2a", "metadata": { "cellView": "form" }, @@ -743,7 +775,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ae7c376f", + "id": "78721d9c", "metadata": { "cellView": "form" }, @@ -777,7 +809,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bb227715", + "id": "c3c9ce0a", "metadata": { "cellView": "form" }, @@ -803,7 +835,7 @@ { "cell_type": "code", "execution_count": null, - "id": "97df7b72", + "id": "555f2a5f", "metadata": { "cellView": "form" }, @@ -838,7 +870,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c613f80f", + "id": "57f8ce7e", "metadata": { "cellView": "form" }, @@ -859,7 +891,7 @@ }, { "cell_type": "markdown", - "id": "ddb8d0a9", + "id": "3d73947f", "metadata": {}, "source": [ "# Labeling" @@ -868,7 +900,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ca92a4eb", + "id": "cc44a3be", "metadata": { "cellView": "form" }, @@ -970,7 +1002,7 @@ }, { "cell_type": "markdown", - "id": "fe8d056c", + "id": "4b371fa1", "metadata": {}, "source": [ "When the labeling project is complete, download the CSV from the labeling tool\n", @@ -983,7 +1015,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ada766c8", + "id": "d706fbff", "metadata": { "cellView": "form" }, @@ -1018,7 +1050,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7d899c68", + "id": "25e51ffc", "metadata": { "cellView": "form" }, @@ -1082,7 +1114,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b256bd32", + "id": "14bcc064", "metadata": { "cellView": "form" }, @@ -1188,7 +1220,7 @@ }, { "cell_type": "markdown", - "id": "9d72f91a", + "id": "933ba4fc", "metadata": {}, "source": [ "# Fine Tuning" @@ -1197,7 +1229,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3556419c", + "id": "c5442ad5", "metadata": { "cellView": "form" }, @@ -1289,9 +1321,10 @@ { "cell_type": "code", "execution_count": null, - "id": "abb8d601", + "id": "520e5852", "metadata": { - "cellView": "form" + "cellView": "form", + "lines_to_next_cell": 1 }, "outputs": [], "source": [ @@ -1337,28 +1370,284 @@ { "cell_type": "code", "execution_count": null, - "id": "7224a5e1", + "id": "94699857", + "metadata": { + "cellView": "form" + }, + "outputs": [], + "source": [ + "# @title Evaluate Fine-Tuned Model\n", + "\n", + "def plot_precision_recall(labels: np.ndarray, scores: np.ndarray) -> None:\n", + " \"\"\"Plots distinct precision and recall curves in a single graph.\n", + "\n", + " The X-axis of the graph is the threshold value. This graph shows the\n", + " trade-off between precision and recall for a specific threshold value more\n", + " clearly than the usual PR curve.\n", + "\n", + " Args:\n", + " labels: True labels array.\n", + " scores: Model scores array.\n", + " \"\"\"\n", + " sklearn.metrics.PrecisionRecallDisplay.from_predictions(labels, scores)\n", + " plt.title('Precision and Recall vs. Threshold')\n", + " plt.grid()\n", + " plt.show()\n", + "\n", + " precision, recall, thresholds = sklearn.metrics.precision_recall_curve(\n", + " labels, scores)\n", + " x = pd.DataFrame({\n", + " 'threshold': thresholds,\n", + " 'precision': precision[:-1],\n", + " 'recall': recall[:-1],\n", + " })\n", + " sns.lineplot(data=x.set_index('threshold'))\n", + " plt.title('Precision/Recall vs. Threshold')\n", + " plt.grid()\n", + " plt.show()\n", + "\n", + "\n", + "def get_recall_at_precision(\n", + " thresholds: np.ndarray,\n", + " precisions: np.ndarray,\n", + " recalls: np.ndarray,\n", + " min_precision: float) -> tuple[float, float, float]:\n", + " \"\"\"Finds threshold that maximizes recall with a minimum precision value.\n", + "\n", + " Args:\n", + " thresholds: List of threshold values returned by\n", + " sklearn.metrics.precision_recall_curve. Length N.\n", + " precisions: List of precision values returned by\n", + " sklearn.metrics.precision_recall_curve. Length N + 1.\n", + " recalls: List of recall values returned by\n", + " sklearn.metrics.precision_recall_curve. Length N + 1.\n", + " min_precision: Minimum precision value to maintain.\n", + "\n", + " Returns:\n", + " Tuple of (threshold, precision, recall).\n", + " \"\"\"\n", + " precisions = precisions[:-1]\n", + " recalls = recalls[:-1]\n", + " eligible = (precisions > min_precision)\n", + " if not any(eligible):\n", + " # If precision never exceeds the minimum value desired, return the threshold\n", + " # where it is highest.\n", + " eligible = (precisions == np.max(precisions))\n", + " i = np.argmax(recalls[eligible])\n", + " return thresholds[eligible][i], precisions[eligible][i], recalls[eligible][i]\n", + "\n", + "\n", + "def get_precision_at_recall(\n", + " thresholds: np.ndarray,\n", + " precisions: np.ndarray,\n", + " recalls: np.ndarray,\n", + " min_recall: float) -> tuple[float, float, float]:\n", + " \"\"\"Finds threshold that maximizes precision with a minimum recall value.\n", + "\n", + " Args:\n", + " thresholds: List of threshold values returned by\n", + " sklearn.metrics.precision_recall_curve. Length N.\n", + " precisions: List of precision values returned by\n", + " sklearn.metrics.precision_recall_curve. Length N + 1.\n", + " recalls: List of recall values returned by\n", + " sklearn.metrics.precision_recall_curve. Length N + 1.\n", + " min_recall: Minimum recall value to maintain.\n", + "\n", + " Returns:\n", + " Tuple of (threshold, precision, recall).\n", + " \"\"\"\n", + " precisions = precisions[:-1]\n", + " recalls = recalls[:-1]\n", + " eligible = (recalls > min_recall)\n", + " if not any(eligible):\n", + " # If recall never exceeds the minimum value desired, return the threshold\n", + " # where it is highest.\n", + " eligible = (recalls == np.max(recalls))\n", + " i = np.argmax(precisions[eligible])\n", + " return thresholds[eligible][i], precisions[eligible][i], recalls[eligible][i]\n", + "\n", + "\n", + "def get_max_f1_threshold(\n", + " scores: np.ndarray, labels: np.ndarray\n", + ") -> tuple[float, float, float, float]:\n", + " \"\"\"Finds the threshold that maximizes F1 score.\n", + "\n", + " Args:\n", + " scores: Prediction scores assigned by the model.\n", + " labels: True labels.\n", + "\n", + " Returns:\n", + " Tuple of best threshold and F1-score, Precision, Recall at that threshold.\n", + " \"\"\"\n", + " best_f1 = 0\n", + " best_threshold = 0\n", + " best_precision = 0\n", + " best_recall = 0\n", + " for threshold in scores:\n", + " predictions = (scores >= threshold)\n", + " if (f1 := sklearn.metrics.f1_score(labels, predictions)) > best_f1:\n", + " best_f1 = f1\n", + " best_threshold = threshold\n", + " best_precision = sklearn.metrics.precision_score(labels, predictions)\n", + " best_recall = sklearn.metrics.recall_score(labels, predictions)\n", + " return best_threshold, best_f1, best_precision, best_recall\n", + "\n", + "\n", + "def plot_score_distribution(labels: np.ndarray, scores: np.ndarray) -> None:\n", + " df = {'score': scores, 'label': labels}\n", + " sns.displot(data=df, x='score', col='label')\n", + " plt.show()\n", + "\n", + "\n", + "def print_model_metrics(scores: np.ndarray, labels: np.ndarray) -> None:\n", + " \"\"\"Prints evaluation metrics.\"\"\"\n", + " precisions, recalls, thresholds = sklearn.metrics.precision_recall_curve(\n", + " labels, scores\n", + " )\n", + " auprc = sklearn.metrics.auc(recalls, precisions)\n", + " auroc = sklearn.metrics.roc_auc_score(labels, scores)\n", + " print(f'AUPRC: {auprc:.4g}')\n", + " print(f'AUROC: {auroc:.4g}')\n", + "\n", + " threshold, f1, precision, recall = get_max_f1_threshold(scores, labels)\n", + " print('\\nFor maximum F1-score')\n", + " print(f' Threshold: {threshold}')\n", + " print(f' F1-score: {f1}')\n", + " print(f' Precision: {precision}')\n", + " print(f' Recall: {recall}')\n", + "\n", + " threshold, precision, recall = get_precision_at_recall(\n", + " thresholds, precisions, recalls, HIGH_RECALL\n", + " )\n", + " print(f'\\nFor recall >= {HIGH_RECALL}')\n", + " print(f' Threshold: {threshold}')\n", + " print(f' Precision: {precision}')\n", + " print(f' Recall: {recall}')\n", + "\n", + " threshold, precision, recall = get_recall_at_precision(\n", + " thresholds, precisions, recalls, HIGH_PRECISION\n", + " )\n", + " print(f'\\nFor precision >= {HIGH_PRECISION}')\n", + " print(f' Threshold: {threshold}')\n", + " print(f' Precision: {precision}')\n", + " print(f' Recall: {recall}')\n", + "\n", + " plot_precision_recall(labels, scores)\n", + " plot_score_distribution(labels, scores)\n", + "\n", + "\n", + "def _read_examples(path: str) -> list[tf.train.Example]:\n", + " examples = []\n", + " for record in tf.data.TFRecordDataset([path]):\n", + " example = tf.train.Example()\n", + " example.ParseFromString(record.numpy())\n", + " examples.append(example)\n", + " return examples\n", + "\n", + "\n", + "def _get_label(example: tf.train.Example) -> float:\n", + " return example.features.feature['label'].float_list.value[0]\n", + "\n", + "\n", + "def _evaluate_model(model_dir: str, examples_path: str) -> None:\n", + " \"\"\"Evaluates model on examples and prints metrics.\"\"\"\n", + "\n", + " print('Reading examples ...')\n", + " examples = _read_examples(examples_path)\n", + " print('Done reading examples')\n", + " if not examples:\n", + " raise ValueError('No examples')\n", + "\n", + " print('Loading model ...')\n", + " model = inference_lib.TF2InferenceModel(\n", + " model_dir,\n", + " 224,\n", + " False,\n", + " inference_lib.ModelType.CLASSIFICATION,\n", + " )\n", + " model.prepare_model()\n", + " print('Done loading model')\n", + "\n", + " print('Running inference ...')\n", + " scores = []\n", + " labels = []\n", + " for batch_start in tqdm.notebook.tqdm(\n", + " range(0, len(examples), INFERENCE_BATCH_SIZE)\n", + " ):\n", + " batch = examples[batch_start:batch_start+INFERENCE_BATCH_SIZE]\n", + " scores.extend(model.predict_scores(batch).numpy())\n", + " labels.extend(_get_label(e) for e in batch)\n", + " scores = np.array(scores)\n", + " labels = np.array(labels)\n", + " print_model_metrics(scores, labels)\n", + "\n", + "\n", + "def evaluate_model_on_test_examples():\n", + " \"\"\"Lets user evaluate a model on chosen trained model and test examples.\n", + " \"\"\"\n", + " labeled_example_dirs = find_labeled_examples_dirs()\n", + " examples_select = widgets.Dropdown(\n", + " options=labeled_example_dirs,\n", + " description='Choose a labeled examples dir:',\n", + " layout={'width': 'initial'},\n", + " )\n", + " examples_select.style.description_width = 'initial'\n", + "\n", + " model_dirs = find_model_dirs()\n", + " if not model_dirs:\n", + " print('No trained model directories found. Please train a model first.')\n", + " return\n", + "\n", + " model_select = widgets.Dropdown(\n", + " options=model_dirs,\n", + " description='Choose a model:',\n", + " layout={'width': 'initial'},\n", + " )\n", + " model_select.style.description_width = 'initial'\n", + " run_button = widgets.Button(description='Run')\n", + "\n", + " def run_button_clicked(_):\n", + " run_button.disabled = True\n", + " test_path = os.path.join(examples_select.value, TEST_TFRECORD_NAME)\n", + " model_dir = os.path.join(model_select.value, 'model')\n", + " checkpoint = get_best_checkpoint(model_dir)\n", + " if not checkpoint:\n", + " print('Model directory does not contain a valid checkpoint directory.')\n", + " return\n", + " _evaluate_model(checkpoint, test_path)\n", + "\n", + " run_button.on_click(run_button_clicked)\n", + "\n", + " display(model_select)\n", + " display(examples_select)\n", + " display(run_button)\n", + "\n", + "\n", + "evaluate_model_on_test_examples()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "905a140b", "metadata": { "cellView": "form" }, "outputs": [], "source": [ "# @title Run inference\n", - "def get_best_checkpoint(model_dir: str):\n", - " checkpoint_dirs = tf.io.gfile.glob(os.path.join(model_dir, 'epoch-*-aucpr-*'))\n", - " best_checkpoint = None\n", - " best_aucpr = 0\n", - " for checkpoint in checkpoint_dirs:\n", - " aucpr = float(checkpoint.split('-')[-1])\n", - " if aucpr > best_aucpr:\n", - " best_checkpoint = checkpoint\n", - " best_aucpr = aucpr\n", - " return best_checkpoint\n", + "# @markdown These should be changed to the thresholds chosen in the eval cell.\n", + "DEFAULT_THRESHOLD = 0.5 # @param {\"type\":\"number\"}\n", + "HIGH_PRECISION_THRESHOLD = 0.6 # @param {\"type\":\"number\"}\n", + "HIGH_RECALL_THRESHOLD = 0.4 # @param {\"type\":\"number\"}\n", "\n", "\n", "def run_inference(\n", " examples_pattern: str,\n", " model_dir: str,\n", + " default_threshold: float,\n", + " high_precision_threshold: float,\n", + " high_recall_threshold: float,\n", " output_dir: str,\n", " output_path: str,\n", " cloud_project: str,\n", @@ -1398,9 +1687,9 @@ " --cloud_region='{cloud_region}' \\\n", " --dataflow_temp_dir='{temp_dir}' \\\n", " --worker_service_account='{service_account}' \\\n", - " --threshold=0.5 \\\n", - " --high_precision_threshold=0.75 \\\n", - " --high_recall_threshold=0.4 \\\n", + " --threshold={default_threshold} \\\n", + " --high_precision_threshold={high_precision_threshold} \\\n", + " --high_recall_threshold={high_recall_threshold} \\\n", " --max_dataflow_workers=4 {accelerator_flags}\n", " ''')\n", "\n", @@ -1435,6 +1724,9 @@ " run_inference(\n", " UNLABELED_TFRECORD_PATTERN,\n", " checkpoint,\n", + " DEFAULT_THRESHOLD,\n", + " HIGH_PRECISION_THRESHOLD,\n", + " HIGH_RECALL_THRESHOLD,\n", " OUTPUT_DIR,\n", " INFERENCE_CSV,\n", " GCP_PROJECT,\n", @@ -1454,7 +1746,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6c28b8fd", + "id": "6f1c7b44", "metadata": { "cellView": "form" }, diff --git a/src/colab/skai_assessment_notebook.py b/src/colab/skai_assessment_notebook.py index 30a49ac0..9be61984 100644 --- a/src/colab/skai_assessment_notebook.py +++ b/src/colab/skai_assessment_notebook.py @@ -12,6 +12,47 @@ # name: python3 # --- +# %% cellView="form" +# @title Install Libraries +# @markdown This will take approximately 1 minute to run. After completing, you +# @markdown may be prompted to restart the kernel. Select "Restart" and then +# @markdown proceed to run the next cell. +"""Notebook for running SKAI assessments.""" + +# pylint: disable=g-statement-before-imports +SKAI_REPO = 'https://github.com/google-research/skai.git' +SKAI_CODE_DIR = '/content/skai_src' + + +def install_requirements(): + """Installs necessary Python libraries.""" + # !rm -rf {SKAI_CODE_DIR} + # !git clone {SKAI_REPO} {SKAI_CODE_DIR} + # !pip install {SKAI_CODE_DIR}/src/. + + requirements = [ + 'apache_beam[gcp]==2.54.0', + 'fiona', + # https://github.com/apache/beam/issues/32169 + 'google-cloud-storage>=2.18.2', + 'ml-collections', + 'openlocationcode', + 'rasterio', + 'rio-cogeo', + 'rtree', + 'tensorflow==2.14.0', + 'tensorflow_addons', + 'tensorflow_text', + 'xmanager', + ] + + requirements_file = '/content/requirements.txt' + with open(requirements_file, 'w') as f: + f.write('\n'.join(requirements)) + # !pip install -r {requirements_file} + +install_requirements() + # %% cellView="form" # @title Configure Assessment Parameters @@ -22,7 +63,6 @@ # @markdown You must re-run this cell every time you make a change. import os -import textwrap import ee from google.colab import auth @@ -62,14 +102,15 @@ AFTER_IMAGE_9 = '' # @param {type:"string"} # Constants -SKAI_REPO = 'https://github.com/google-research/skai.git' OPEN_BUILDINGS_FEATURE_COLLECTION = 'GOOGLE/Research/open-buildings/v3/polygons' OSM_OVERPASS_URL = 'https://lz4.overpass-api.de/api/interpreter' TRAIN_TFRECORD_NAME = 'labeled_examples_train.tfrecord' TEST_TFRECORD_NAME = 'labeled_examples_test.tfrecord' +HIGH_RECALL = 0.7 +HIGH_PRECISION = 0.7 +INFERENCE_BATCH_SIZE = 8 # Derived variables -SKAI_CODE_DIR = '/content/skai_src' AOI_PATH = os.path.join(OUTPUT_DIR, 'aoi.geojson') BUILDINGS_FILE_LOG = os.path.join(OUTPUT_DIR, 'buildings_file_log.txt') EXAMPLE_GENERATION_CONFIG_PATH = os.path.join( @@ -128,38 +169,6 @@ def process_image_entries(entries: list[str]) -> list[str]: # %% [markdown] # #Initialization -# %% cellView="form" -# @title Install Libraries -# @markdown This will take approximately 1 minute to run. After completing, you -# @markdown may be prompted to restart the kernel. Select "Restart" and then -# @markdown proceed to run the next cell. -def install_requirements(): - """Installs necessary Python libraries.""" - # !rm -rf {SKAI_CODE_DIR} - # !git clone {SKAI_REPO} {SKAI_CODE_DIR} - # !pip install {SKAI_CODE_DIR}/src/. - - requirements = textwrap.dedent(''' - apache_beam[gcp]==2.54.0 - google-cloud-storage>=2.18.2 # https://github.com/apache/beam/issues/32169 - ml-collections - openlocationcode - rasterio - rio-cogeo - rtree - tensorflow==2.14.0 - tensorflow_addons - tensorflow_text - xmanager - ''') - - requirements_file = '/content/requirements.txt' - with open(requirements_file, 'w') as f: - f.write(requirements) - # !pip install -r {requirements_file} - -install_requirements() - # %% cellView="form" # @title Authenticate with Google Cloud def authenticate(): @@ -179,6 +188,7 @@ def authenticate(): import math import shutil import subprocess +import textwrap import time import warnings @@ -192,10 +202,13 @@ def authenticate(): import matplotlib.pyplot as plt import numpy as np import pandas as pd +import seaborn as sns import shapely.wkt from skai import earth_engine as skai_ee from skai import labeling from skai import open_street_map +from skai.model import inference_lib +import sklearn.metrics import tensorflow as tf import tqdm.notebook @@ -375,6 +388,26 @@ def find_model_dirs(): return sorted(model_dirs, reverse=True) +def get_best_checkpoint(model_dir: str): + """Finds the checkpoint subdirectory with the highest AUPRC. + + Args: + model_dir: Model directory. + + Returns: + Checkpoint directory path. + """ + checkpoint_dirs = tf.io.gfile.glob(os.path.join(model_dir, 'epoch-*-aucpr-*')) + best_checkpoint = None + best_aucpr = 0 + for checkpoint in checkpoint_dirs: + aucpr = float(checkpoint.split('-')[-1]) + if aucpr > best_aucpr: + best_checkpoint = checkpoint + best_aucpr = aucpr + return best_checkpoint + + def find_labeling_image_metadata_files(labeling_images_dir: str): return tf.io.gfile.glob(os.path.join( labeling_images_dir, '*', 'image_metadata.csv')) @@ -1163,24 +1196,270 @@ def run_tensorboard(_): start_tensorboard() +# %% cellView="form" +# @title Evaluate Fine-Tuned Model + +def plot_precision_recall(labels: np.ndarray, scores: np.ndarray) -> None: + """Plots distinct precision and recall curves in a single graph. + + The X-axis of the graph is the threshold value. This graph shows the + trade-off between precision and recall for a specific threshold value more + clearly than the usual PR curve. + + Args: + labels: True labels array. + scores: Model scores array. + """ + sklearn.metrics.PrecisionRecallDisplay.from_predictions(labels, scores) + plt.title('Precision and Recall vs. Threshold') + plt.grid() + plt.show() + + precision, recall, thresholds = sklearn.metrics.precision_recall_curve( + labels, scores) + x = pd.DataFrame({ + 'threshold': thresholds, + 'precision': precision[:-1], + 'recall': recall[:-1], + }) + sns.lineplot(data=x.set_index('threshold')) + plt.title('Precision/Recall vs. Threshold') + plt.grid() + plt.show() + + +def get_recall_at_precision( + thresholds: np.ndarray, + precisions: np.ndarray, + recalls: np.ndarray, + min_precision: float) -> tuple[float, float, float]: + """Finds threshold that maximizes recall with a minimum precision value. + + Args: + thresholds: List of threshold values returned by + sklearn.metrics.precision_recall_curve. Length N. + precisions: List of precision values returned by + sklearn.metrics.precision_recall_curve. Length N + 1. + recalls: List of recall values returned by + sklearn.metrics.precision_recall_curve. Length N + 1. + min_precision: Minimum precision value to maintain. + + Returns: + Tuple of (threshold, precision, recall). + """ + precisions = precisions[:-1] + recalls = recalls[:-1] + eligible = (precisions > min_precision) + if not any(eligible): + # If precision never exceeds the minimum value desired, return the threshold + # where it is highest. + eligible = (precisions == np.max(precisions)) + i = np.argmax(recalls[eligible]) + return thresholds[eligible][i], precisions[eligible][i], recalls[eligible][i] + + +def get_precision_at_recall( + thresholds: np.ndarray, + precisions: np.ndarray, + recalls: np.ndarray, + min_recall: float) -> tuple[float, float, float]: + """Finds threshold that maximizes precision with a minimum recall value. + + Args: + thresholds: List of threshold values returned by + sklearn.metrics.precision_recall_curve. Length N. + precisions: List of precision values returned by + sklearn.metrics.precision_recall_curve. Length N + 1. + recalls: List of recall values returned by + sklearn.metrics.precision_recall_curve. Length N + 1. + min_recall: Minimum recall value to maintain. + + Returns: + Tuple of (threshold, precision, recall). + """ + precisions = precisions[:-1] + recalls = recalls[:-1] + eligible = (recalls > min_recall) + if not any(eligible): + # If recall never exceeds the minimum value desired, return the threshold + # where it is highest. + eligible = (recalls == np.max(recalls)) + i = np.argmax(precisions[eligible]) + return thresholds[eligible][i], precisions[eligible][i], recalls[eligible][i] + + +def get_max_f1_threshold( + scores: np.ndarray, labels: np.ndarray +) -> tuple[float, float, float, float]: + """Finds the threshold that maximizes F1 score. + + Args: + scores: Prediction scores assigned by the model. + labels: True labels. + + Returns: + Tuple of best threshold and F1-score, Precision, Recall at that threshold. + """ + best_f1 = 0 + best_threshold = 0 + best_precision = 0 + best_recall = 0 + for threshold in scores: + predictions = (scores >= threshold) + if (f1 := sklearn.metrics.f1_score(labels, predictions)) > best_f1: + best_f1 = f1 + best_threshold = threshold + best_precision = sklearn.metrics.precision_score(labels, predictions) + best_recall = sklearn.metrics.recall_score(labels, predictions) + return best_threshold, best_f1, best_precision, best_recall + + +def plot_score_distribution(labels: np.ndarray, scores: np.ndarray) -> None: + df = {'score': scores, 'label': labels} + sns.displot(data=df, x='score', col='label') + plt.show() + + +def print_model_metrics(scores: np.ndarray, labels: np.ndarray) -> None: + """Prints evaluation metrics.""" + precisions, recalls, thresholds = sklearn.metrics.precision_recall_curve( + labels, scores + ) + auprc = sklearn.metrics.auc(recalls, precisions) + auroc = sklearn.metrics.roc_auc_score(labels, scores) + print(f'AUPRC: {auprc:.4g}') + print(f'AUROC: {auroc:.4g}') + + threshold, f1, precision, recall = get_max_f1_threshold(scores, labels) + print('\nFor maximum F1-score') + print(f' Threshold: {threshold}') + print(f' F1-score: {f1}') + print(f' Precision: {precision}') + print(f' Recall: {recall}') + + threshold, precision, recall = get_precision_at_recall( + thresholds, precisions, recalls, HIGH_RECALL + ) + print(f'\nFor recall >= {HIGH_RECALL}') + print(f' Threshold: {threshold}') + print(f' Precision: {precision}') + print(f' Recall: {recall}') + + threshold, precision, recall = get_recall_at_precision( + thresholds, precisions, recalls, HIGH_PRECISION + ) + print(f'\nFor precision >= {HIGH_PRECISION}') + print(f' Threshold: {threshold}') + print(f' Precision: {precision}') + print(f' Recall: {recall}') + + plot_precision_recall(labels, scores) + plot_score_distribution(labels, scores) + + +def _read_examples(path: str) -> list[tf.train.Example]: + examples = [] + for record in tf.data.TFRecordDataset([path]): + example = tf.train.Example() + example.ParseFromString(record.numpy()) + examples.append(example) + return examples + + +def _get_label(example: tf.train.Example) -> float: + return example.features.feature['label'].float_list.value[0] + + +def _evaluate_model(model_dir: str, examples_path: str) -> None: + """Evaluates model on examples and prints metrics.""" + + print('Reading examples ...') + examples = _read_examples(examples_path) + print('Done reading examples') + if not examples: + raise ValueError('No examples') + + print('Loading model ...') + model = inference_lib.TF2InferenceModel( + model_dir, + 224, + False, + inference_lib.ModelType.CLASSIFICATION, + ) + model.prepare_model() + print('Done loading model') + + print('Running inference ...') + scores = [] + labels = [] + for batch_start in tqdm.notebook.tqdm( + range(0, len(examples), INFERENCE_BATCH_SIZE) + ): + batch = examples[batch_start:batch_start+INFERENCE_BATCH_SIZE] + scores.extend(model.predict_scores(batch).numpy()) + labels.extend(_get_label(e) for e in batch) + scores = np.array(scores) + labels = np.array(labels) + print_model_metrics(scores, labels) + + +def evaluate_model_on_test_examples(): + """Lets user evaluate a model on chosen trained model and test examples. + """ + labeled_example_dirs = find_labeled_examples_dirs() + examples_select = widgets.Dropdown( + options=labeled_example_dirs, + description='Choose a labeled examples dir:', + layout={'width': 'initial'}, + ) + examples_select.style.description_width = 'initial' + + model_dirs = find_model_dirs() + if not model_dirs: + print('No trained model directories found. Please train a model first.') + return + + model_select = widgets.Dropdown( + options=model_dirs, + description='Choose a model:', + layout={'width': 'initial'}, + ) + model_select.style.description_width = 'initial' + run_button = widgets.Button(description='Run') + + def run_button_clicked(_): + run_button.disabled = True + test_path = os.path.join(examples_select.value, TEST_TFRECORD_NAME) + model_dir = os.path.join(model_select.value, 'model') + checkpoint = get_best_checkpoint(model_dir) + if not checkpoint: + print('Model directory does not contain a valid checkpoint directory.') + return + _evaluate_model(checkpoint, test_path) + + run_button.on_click(run_button_clicked) + + display(model_select) + display(examples_select) + display(run_button) + + +evaluate_model_on_test_examples() # %% cellView="form" # @title Run inference -def get_best_checkpoint(model_dir: str): - checkpoint_dirs = tf.io.gfile.glob(os.path.join(model_dir, 'epoch-*-aucpr-*')) - best_checkpoint = None - best_aucpr = 0 - for checkpoint in checkpoint_dirs: - aucpr = float(checkpoint.split('-')[-1]) - if aucpr > best_aucpr: - best_checkpoint = checkpoint - best_aucpr = aucpr - return best_checkpoint +# @markdown These should be changed to the thresholds chosen in the eval cell. +DEFAULT_THRESHOLD = 0.5 # @param {"type":"number"} +HIGH_PRECISION_THRESHOLD = 0.6 # @param {"type":"number"} +HIGH_RECALL_THRESHOLD = 0.4 # @param {"type":"number"} def run_inference( examples_pattern: str, model_dir: str, + default_threshold: float, + high_precision_threshold: float, + high_recall_threshold: float, output_dir: str, output_path: str, cloud_project: str, @@ -1220,9 +1499,9 @@ def run_inference( --cloud_region='{cloud_region}' \ --dataflow_temp_dir='{temp_dir}' \ --worker_service_account='{service_account}' \ - --threshold=0.5 \ - --high_precision_threshold=0.75 \ - --high_recall_threshold=0.4 \ + --threshold={default_threshold} \ + --high_precision_threshold={high_precision_threshold} \ + --high_recall_threshold={high_recall_threshold} \ --max_dataflow_workers=4 {accelerator_flags} ''') @@ -1257,6 +1536,9 @@ def start_clicked(_): run_inference( UNLABELED_TFRECORD_PATTERN, checkpoint, + DEFAULT_THRESHOLD, + HIGH_PRECISION_THRESHOLD, + HIGH_RECALL_THRESHOLD, OUTPUT_DIR, INFERENCE_CSV, GCP_PROJECT, diff --git a/src/colab/sync_notebook_source.py b/src/colab/sync_notebook_source.py index 497924c8..15b6bfde 100644 --- a/src/colab/sync_notebook_source.py +++ b/src/colab/sync_notebook_source.py @@ -60,6 +60,9 @@ 'AFTER_IMAGE_8': '', 'AFTER_IMAGE_9': '', 'DAMAGE_SCORE_THRESHOLD': 0.5, + 'DEFAULT_THRESHOLD': 0.5, + 'HIGH_PRECISION_THRESHOLD': 0.6, + 'HIGH_RECALL_THRESHOLD': 0.4, }