Skip to content

Commit

Permalink
Show inference results in confusion matrix and closing words
Browse files Browse the repository at this point in the history
Demonstrate how the trained model can be used to produce classification results in a new column of the GeoDataFrame, and compare those predictions with the groundtruth in a confusion matrix plot. Added some closing words about data-centric and model-centric ways of improving the results, and added the Petty et al. 2021 paper to the citation list as credit to Alek's help.
  • Loading branch information
weiji14 committed Aug 19, 2024
1 parent 30fba15 commit afeeee2
Show file tree
Hide file tree
Showing 2 changed files with 313 additions and 37 deletions.
262 changes: 233 additions & 29 deletions book/tutorials/machine-learning/point_cloud_classifier.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -354,14 +354,6 @@
"- `item_collection` - Sentinel-2 optical satellite images"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b6ffe42",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"id": "369c2a5c",
Expand Down Expand Up @@ -14169,7 +14161,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 28,
"id": "3bfafafb",
"metadata": {},
"outputs": [],
Expand All @@ -14181,12 +14173,49 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 29,
"id": "f7ff6874",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 67%|██████▋ | 2/3 [00:00<00:00, 8.63it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss: 7949.509277 [ 8140/ 9454]\n",
"loss: 3678.343994 [ 8140/ 9454]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 3/3 [00:00<00:00, 8.64it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss: 1735.765503 [ 8140/ 9454]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# Main training loop\n",
"max_epochs: int = 3\n",
Expand Down Expand Up @@ -14217,17 +14246,15 @@
" optimizer.step()\n",
" optimizer.zero_grad()\n",
"\n",
" # Report metrics\n",
" current = (i + 1) * len(x)\n",
" print(f\"loss: {loss:>7f} [{current:>5d}/{size:>5d}]\")"
" # Report metrics\n",
" current = (i + 1) * len(x)\n",
" print(f\"loss: {loss:>7f} [{current:>5d}/{size:>5d}]\")"
]
},
{
"cell_type": "markdown",
"id": "66358b37",
"metadata": {
"lines_to_next_cell": 2
},
"metadata": {},
"source": [
"Did the model learn something? A good sign to check is if the loss value is\n",
"decreasing, which means the error between the predicted and groundtruth value is\n",
Expand All @@ -14236,25 +14263,202 @@
},
{
"cell_type": "markdown",
"id": "709f7a1b",
"metadata": {
"lines_to_next_cell": 2
},
"id": "b82c7138",
"metadata": {},
"source": [
"## References\n",
"- Koo, Y., Xie, H., Kurtz, N. T., Ackley, S. F., & Wang, W. (2023).\n",
" Sea ice surface type classification of ICESat-2 ATL07 data by using data-driven\n",
" machine learning model: Ross Sea, Antarctic as an example. Remote Sensing of\n",
" Environment, 296, 113726. https://doi.org/10.1016/j.rse.2023.113726"
"\n",
"### Inference results\n",
"\n",
"Besides monitoring the loss value, it is also good to calculate a metric like\n",
"Precision, Recall or F1-score. Let's first run the model in 'inference' mode to get\n",
"predictions."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5014ed28",
"execution_count": 30,
"id": "fd38a5fa",
"metadata": {},
"outputs": [],
"source": []
"source": [
"gdf[\"predicted_surface_type\"] = None # create new column with NaN to store results\n",
"with torch.inference_mode():\n",
" for i, batch in enumerate(dataloader):\n",
" minibatch: torch.Tensor = batch[0]\n",
" x = minibatch[:, :6]\n",
" prediction = model(x=x) # one-hot encoded predictions\n",
" prediction_labels = torch.argmax(input=prediction, dim=1) # 0/1/2 labels\n",
"\n",
" start_index = i * dataloader.batch_size\n",
" stop_index = start_index + len(minibatch) - 1\n",
" gdf.loc[start_index:stop_index, \"predicted_surface_type\"] = prediction_labels"
]
},
{
"cell_type": "markdown",
"id": "fbfa6348",
"metadata": {},
"source": [
"\n",
"```{caution}\n",
"Ideally, you would want to run inference on a hold-out validation or test set, rather\n",
"than the points the model was trained on! See e.g.\n",
"[`sklearn.model_selection.train_test_split`](https://scikit-learn.org/1.5/modules/generated/sklearn.model_selection.train_test_split.html)\n",
"on how this can be done.\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "a578ef64",
"metadata": {},
"source": [
"\n",
"Now that we have the predicted results in the `predicted_surface_type` column, we can\n",
"compare it with the 'groundtruth' labels in the 'surface_type' column by visualizing\n",
"it in a confusion matrix."
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "61bd172e",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th>Predicted</th>\n",
" <th>1</th>\n",
" <th>2</th>\n",
" <th>All</th>\n",
" </tr>\n",
" <tr>\n",
" <th>Actual</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>725</td>\n",
" <td>725</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>577</td>\n",
" <td>578</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2</td>\n",
" <td>8149</td>\n",
" <td>8151</td>\n",
" </tr>\n",
" <tr>\n",
" <th>All</th>\n",
" <td>3</td>\n",
" <td>9451</td>\n",
" <td>9454</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
"Predicted 1 2 All\n",
"Actual \n",
"0 0 725 725\n",
"1 1 577 578\n",
"2 2 8149 8151\n",
"All 3 9451 9454"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.crosstab(\n",
" index=gdf.surface_type,\n",
" columns=gdf.predicted_surface_type,\n",
" rownames=[\"Actual\"],\n",
" colnames=[\"Predicted\"],\n",
" margins=True,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "737847e3",
"metadata": {},
"source": [
"\n",
"```{attention}\n",
"Oo, it looks like our model isn't producing very good results! It's practically only\n",
"predicting thick sea ice (class: 2). There could be many reasons for this, and it's up\n",
"to you to figure out a solution, either by changing the data, or adjusting the model.\n",
"\n",
"Data-centric approaches:\n",
"- Add more data! Maybe <10000 points isn't enough, try getting more!\n",
"- Check the labels! Are there wrongly labelled points? Is the Sentinel-2\n",
" dark/gray/bright bins above too simplistic? Investigate!\n",
"- Normalize the data value range. The original paper by Koo et al., 2023 applied\n",
" min-max normalization on the 6 input columns, try and apply that too!\n",
"\n",
"Model-centric appraoches:\n",
"- Manage class imbalance. There are a lot more thick sea ice points than thin sea ice\n",
" or water points, could we modify the loss function to weigh rare classes higher?\n",
"- Adjust the model hyperparameters, try adjusting the learning rate, train the model\n",
" for more epochs, etc.\n",
"- Tweak the model architecture. The original paper by Koo et al., 2023 used a\n",
" [`tanh`](https://pytorch.org/docs/2.4/generated/torch.nn.Tanh.html) activation\n",
" function in the neural network layers. Will adding that help?\n",
"\n",
"The list above isn't exhaustive, and different machine learning practicioners may have\n",
"other suggestions on what to try next. That said, you now have a Machine Learning\n",
"ready GeoParquet dataset to iterate on ideas more quickly. Good luck!\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "709f7a1b",
"metadata": {},
"source": [
"## References\n",
"- Koo, Y., Xie, H., Kurtz, N. T., Ackley, S. F., & Wang, W. (2023).\n",
" Sea ice surface type classification of ICESat-2 ATL07 data by using data-driven\n",
" machine learning model: Ross Sea, Antarctic as an example. Remote Sensing of\n",
" Environment, 296, 113726. https://doi.org/10.1016/j.rse.2023.113726\n",
"- Petty, A. A., Bagnardi, M., Kurtz, N. T., Tilling, R., Fons, S., Armitage, T.,\n",
" Horvat, C., & Kwok, R. (2021). Assessment of ICESat‐2 Sea Ice Surface\n",
" Classification with Sentinel‐2 Imagery: Implications for Freeboard and New Estimates\n",
" of Lead and Floe Geometry. Earth and Space Science, 8(3), e2020EA001491.\n",
" https://doi.org/10.1029/2020EA001491"
]
}
],
"metadata": {
Expand Down
Loading

0 comments on commit afeeee2

Please sign in to comment.