Skip to content

Commit

Permalink
Update tutorials with base_weights
Browse files Browse the repository at this point in the history
  • Loading branch information
Johansmm committed Oct 19, 2021
1 parent 5c4355a commit f38c6b3
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 33 deletions.
91 changes: 62 additions & 29 deletions docsource/source/tutorials/finetuning_deformable_detr.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
"cells": [
{
"cell_type": "markdown",
"id": "8390a00b",
"metadata": {},
"source": [
"# Finetuning Deformanble DETR\n",
"\n",
Expand All @@ -27,11 +29,12 @@
"[Deformable DetrR50 architecture]: https://arxiv.org/abs/2010.04159\n",
"[Mask Wearing Dataset]: https://public.roboflow.com/object-detection/mask-wearing\n",
"[Deformable DETR Finetune]: ../alonet/deformable_models.rst#deformable-detr-r50-finetune"
],
"metadata": {}
]
},
{
"cell_type": "markdown",
"id": "9fab10df",
"metadata": {},
"source": [
"## 1. Train Deformable DETR50 Finetune\n",
"\n",
Expand Down Expand Up @@ -68,22 +71,25 @@
"[Deformable DETR R50 Finetune]: ../alonet/deformable_models.rst#deformable-detr-r50-finetune\n",
"[Deformable DETR R50 Finetune with refinement]: ../alonet/deformable_models.rst#deformable-detr-r50-finetune-with-refinement\n",
"[Models]: ../alonet/deformable_models.rst"
],
"metadata": {}
]
},
{
"cell_type": "markdown",
"id": "df3a4d07",
"metadata": {},
"source": [
"For training purposes, it is usual in [Aloception] to define a model on [Pytorch lightning module]. With a finetune model, the architecture definition changes, but the training process remains static:\n",
"\n",
"[Pytorch lightning module]: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html\n",
"[Aloception]: ../index.rst"
],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e453f54e",
"metadata": {},
"outputs": [],
"source": [
"from argparse import Namespace, ArgumentParser\n",
"import alonet\n",
Expand All @@ -109,7 +115,7 @@
"# Architecture definition\n",
"deformabe_finetune = DeformableDetrR50RefinementFinetune(\n",
" num_classes = num_classes, \n",
" weights = \"deformable-detr-r50-refinement\",\n",
" base_weights = \"deformable-detr-r50-refinement\", # Load by default\n",
" activation_fn = \"softmax\"\n",
")\n",
"lit_deformable = LitDeformableDetr(model = deformabe_finetune)\n",
Expand All @@ -123,12 +129,12 @@
" project = \"deformable_detr\",\n",
" expe_name = \"people_mask\"\n",
")"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "markdown",
"id": "54ae4cc3",
"metadata": {},
"source": [
"<div class=\"alert alert-info\">\n",
"\n",
Expand Down Expand Up @@ -162,21 +168,24 @@
"[How to setup your data]: ./data_setup.rst\n",
"[Train a Deformanble model]: training_deformable_detr.ipynb\n",
"[Train a DetrR50 finetune model]: finetuning_detr.ipynb"
],
"metadata": {}
]
},
{
"cell_type": "markdown",
"id": "981b6f35",
"metadata": {},
"source": [
"## 2. Make inferences\n",
"\n",
"In order to make some inferences on the dataset using the trained model, we need to load the weights. For that, we can use one function in [Alonet](../alonet/alonet.rst) for this purpose. Also, we need to keep in mind **the project and run id that we used in training process**:"
],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d9cd843f",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from alonet.common import load_training \n",
Expand All @@ -196,20 +205,43 @@
" args = args, \n",
" model = detr_finetune)\n",
"lit_deformable.model.to(device)"
],
]
},
{
"cell_type": "markdown",
"id": "044fc41a",
"metadata": {},
"source": [
"Or in another way"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "66593b0b",
"metadata": {},
"outputs": [],
"metadata": {}
"source": [
"detr_finetune = DeformableDetrR50RefinementFinetune(\n",
" num_classes = num_classes, \n",
" weights = \"/path/trained/weights.ckpt\" or \"/path/trained/weights.pth\"\n",
")"
]
},
{
"cell_type": "markdown",
"id": "1a58dfd4",
"metadata": {},
"source": [
"This enables to use the valid dataset and show some results:"
],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d4125230",
"metadata": {},
"outputs": [],
"source": [
"frames = next(iter(coco_loader.val_dataloader()))\n",
"frames = frames[0].batch_list(frames).to(device)\n",
Expand All @@ -220,12 +252,12 @@
" gt_boxes.get_view(frames[0], title=\"Ground truth boxes\"),\n",
" pred_boxes.get_view(frames[0], title=\"Predicted boxes\"),\n",
"]).render()"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "markdown",
"id": "674eb55f",
"metadata": {},
"source": [
"<div class=\"alert alert-info\">\n",
"\n",
Expand All @@ -236,21 +268,24 @@
"</div>\n",
"\n",
"[DETR/Deformable DETR models to tensorRT]: tensort_inference.rst"
],
"metadata": {}
]
},
{
"cell_type": "markdown",
"id": "13e998c8",
"metadata": {},
"source": [
"## 3. Optional: Make prediction in camera\n",
"\n",
"If there is access to a local camera, the following code would allow you to take snapshots with the camera and make predictions at the same time:"
],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "83eb5696",
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import cv2\n",
Expand Down Expand Up @@ -283,9 +318,7 @@
" cv2.destroyAllWindows()\n",
"else:\n",
" print(\"[ERROR] Impossible to open camera\")"
],
"outputs": [],
"metadata": {}
]
}
],
"metadata": {
Expand All @@ -309,4 +342,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
58 changes: 54 additions & 4 deletions docsource/source/tutorials/finetuning_detr.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
"id": "fcd240d4",
"metadata": {},
"source": [
"Its statement is the same as [Detr R50 Finetune], with difference that now `num_classes` **attribute is mandatory**:\n",
"Its statement is the same as [Detr R50 Finetune], with the difference that now `num_classes` **attribute is mandatory**:\n",
"\n",
"[Detr R50 Finetune]: ../alonet/detr_models.rst#module-alonet.detr.detr_r50_finetune"
]
Expand All @@ -68,6 +68,35 @@
"detr_finetune = DetrR50Finetune(num_classes = 2)"
]
},
{
"cell_type": "markdown",
"id": "0ce4fbee",
"metadata": {},
"source": [
"By default, all *finetune* models load the pretrained weights of its base models trained on [COCO detection 2017 dataset]. For [Detr R50 Finetune], this weights corresponds to [Detr R50] weights. \n",
"\n",
"If required, the weights of the base model can be changed by specifying the file path through the `base_weights` parameter:\n",
"\n",
"[Detr R50]: ../alonet/detr_models.rst#module-alonet.detr.detr_r50\n",
"[Detr R50 Finetune]: ../alonet/detr_models.rst#module-alonet.detr.detr_r50_finetune\n",
"[COCO detection 2017 dataset]: https://cocodataset.org/#detection-2017"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d0a75dd7",
"metadata": {},
"outputs": [],
"source": [
"from alonet.detr import DetrR50Finetune\n",
"\n",
"detr_finetune = DetrR50Finetune(\n",
" num_classes = 2,\n",
" base_weights = \"/path/base/weights.pth\" or \"/path/base/weights.ckpt\" # By default \"detr-r50\"\n",
")"
]
},
{
"cell_type": "markdown",
"id": "07f5927f",
Expand Down Expand Up @@ -124,10 +153,10 @@
"\n",
"# Define COCO dataset as pl.LightningDataModule for only animals\n",
"pets = ['cat', 'dog']\n",
"coco_loader = CocoDetection2Detr(classes = pets)\n",
"coco_loader = CocoDetection2Detr(classes = pets) # Load detr-r50 pretrained weights\n",
"\n",
"# Define architecture as pl.LightningModule, using PRETRAINED WEIGHTS\n",
"lit_detr = LitDetr(model = DetrR50Finetune(len(pets), weights = 'detr-r50'))\n",
"# Define architecture as pl.LightningModule\n",
"lit_detr = LitDetr(model = DetrR50Finetune(len(pets))\n",
"\n",
"# Start train loop\n",
"args.max_epochs = 5 # Due to finetune, we just need 5 epochs to train this model\n",
Expand Down Expand Up @@ -225,6 +254,27 @@
"lit_detr.model.to(device)"
]
},
{
"cell_type": "markdown",
"id": "baa2066d",
"metadata": {},
"source": [
"Also, we can load the weights from the absolute path:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "83c594aa",
"metadata": {},
"outputs": [],
"source": [
"detr_finetune = DetrR50Finetune(\n",
" num_classes = len(pets), \n",
" weights = \"/path/trained/weights.ckpt\" or \"/path/trained/weights.pth\"\n",
")"
]
},
{
"cell_type": "markdown",
"id": "b3fb76ac",
Expand Down

0 comments on commit f38c6b3

Please sign in to comment.