diff --git a/docsource/source/alonet/deformable_models.rst b/docsource/source/alonet/deformable_models.rst
index 9990d8fb..2b796a93 100644
--- a/docsource/source/alonet/deformable_models.rst
+++ b/docsource/source/alonet/deformable_models.rst
@@ -18,20 +18,28 @@ To instantiate a Deformable DETR R50 (resnet50 backbone) with iterative box refi
from alonet.deformable_detr import DeformableDetrR50Refinement
model = DeformableDetrR50Refinement(num_classes=NUM_CLASS)
-If you want to finetune from the model pretrained on COCO dataset:
+To load the pretrained weights on COCO dataset:
+
+ .. code-block:: python
+
+ model = DeformableDetrR50(num_classes=NUM_CLASS, weights="deformable-detr-r50")
+ # with iterative box refinement
+ model = DeformableDetrR50Refinement(num_classes=NUM_CLASS, weights="deformable-detr-r50-refinement")
+
+If you want to finetune from the model pretrained on COCO dataset (by default):
.. code-block:: python
from alonet.deformable_detr import DeformableDetrR50Finetune
# NUM_CLASS is the number of classes in your finetune
- model = DeformableDetrR50Finetune(num_classes=NUM_CLASS, weights="deformable-detr-r50")
+ model = DeformableDetrR50Finetune(num_classes=NUM_CLASS)
.. code-block:: python
# with iterative box refinement
from alonet.deformable_detr import DeformableDetrR50RefinementFinetune
# NUM_CLASS is the number of classes in your finetune
- model = DeformableDetrR50RefinementFinetune(num_classes=NUM_CLASS, weights="deformable-detr-r50-refinement")
+ model = DeformableDetrR50RefinementFinetune(num_classes=NUM_CLASS)
To run inference:
diff --git a/docsource/source/alonet/detr_models.rst b/docsource/source/alonet/detr_models.rst
index a6b508aa..4498965e 100644
--- a/docsource/source/alonet/detr_models.rst
+++ b/docsource/source/alonet/detr_models.rst
@@ -11,13 +11,25 @@ To instantiate a DETR R50 (resnet50 backbone):
from alonet.detr import DetrR50
model = DetrR50()
-If you want to finetune from the model pretrained on COCO dataset:
+To load pretrained weights on COCO dataset:
+
+ .. code-block:: python
+
+ model = DetrR50(num_classes=NUM_CLASS, weights='detr-r50')
+
+Or from trained-models:
+
+ .. code-block:: python
+
+ model = DetrR50(num_classes=NUM_CLASS, weights='path/to/weights.pth' or 'path/to/weights.ckpt')
+
+If you want to finetune from the model pretrained on COCO dataset (by default):
.. code-block:: python
from alonet.detr import DetrR50Finetune
- # NUM_CLASS is the number of classes in your finetune
- model = DetrR50Finetune(num_classes=NUM_CLASS, weights="detr-r50")
+ # NUM_CLASS is the desired number of classes in the new model
+ model = DetrR50Finetune(num_classes=NUM_CLASS)
To run inference:
diff --git a/docsource/source/tutorials/finetuning_deformable_detr.ipynb b/docsource/source/tutorials/finetuning_deformable_detr.ipynb
index d1846abc..9d697048 100644
--- a/docsource/source/tutorials/finetuning_deformable_detr.ipynb
+++ b/docsource/source/tutorials/finetuning_deformable_detr.ipynb
@@ -2,6 +2,8 @@
"cells": [
{
"cell_type": "markdown",
+ "id": "8390a00b",
+ "metadata": {},
"source": [
"# Finetuning Deformanble DETR\n",
"\n",
@@ -27,11 +29,12 @@
"[Deformable DetrR50 architecture]: https://arxiv.org/abs/2010.04159\n",
"[Mask Wearing Dataset]: https://public.roboflow.com/object-detection/mask-wearing\n",
"[Deformable DETR Finetune]: ../alonet/deformable_models.rst#deformable-detr-r50-finetune"
- ],
- "metadata": {}
+ ]
},
{
"cell_type": "markdown",
+ "id": "9fab10df",
+ "metadata": {},
"source": [
"## 1. Train Deformable DETR50 Finetune\n",
"\n",
@@ -68,22 +71,25 @@
"[Deformable DETR R50 Finetune]: ../alonet/deformable_models.rst#deformable-detr-r50-finetune\n",
"[Deformable DETR R50 Finetune with refinement]: ../alonet/deformable_models.rst#deformable-detr-r50-finetune-with-refinement\n",
"[Models]: ../alonet/deformable_models.rst"
- ],
- "metadata": {}
+ ]
},
{
"cell_type": "markdown",
+ "id": "df3a4d07",
+ "metadata": {},
"source": [
"For training purposes, it is usual in [Aloception] to define a model on [Pytorch lightning module]. With a finetune model, the architecture definition changes, but the training process remains static:\n",
"\n",
"[Pytorch lightning module]: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html\n",
"[Aloception]: ../index.rst"
- ],
- "metadata": {}
+ ]
},
{
"cell_type": "code",
"execution_count": null,
+ "id": "e453f54e",
+ "metadata": {},
+ "outputs": [],
"source": [
"from argparse import Namespace, ArgumentParser\n",
"import alonet\n",
@@ -109,7 +115,7 @@
"# Architecture definition\n",
"deformabe_finetune = DeformableDetrR50RefinementFinetune(\n",
" num_classes = num_classes, \n",
- " weights = \"deformable-detr-r50-refinement\",\n",
+ " base_weights = \"deformable-detr-r50-refinement\", # Load by default\n",
" activation_fn = \"softmax\"\n",
")\n",
"lit_deformable = LitDeformableDetr(model = deformabe_finetune)\n",
@@ -123,12 +129,12 @@
" project = \"deformable_detr\",\n",
" expe_name = \"people_mask\"\n",
")"
- ],
- "outputs": [],
- "metadata": {}
+ ]
},
{
"cell_type": "markdown",
+ "id": "54ae4cc3",
+ "metadata": {},
"source": [
"
\n",
"\n",
@@ -162,21 +168,24 @@
"[How to setup your data]: ./data_setup.rst\n",
"[Train a Deformanble model]: training_deformable_detr.ipynb\n",
"[Train a DetrR50 finetune model]: finetuning_detr.ipynb"
- ],
- "metadata": {}
+ ]
},
{
"cell_type": "markdown",
+ "id": "981b6f35",
+ "metadata": {},
"source": [
"## 2. Make inferences\n",
"\n",
"In order to make some inferences on the dataset using the trained model, we need to load the weights. For that, we can use one function in [Alonet](../alonet/alonet.rst) for this purpose. Also, we need to keep in mind **the project and run id that we used in training process**:"
- ],
- "metadata": {}
+ ]
},
{
"cell_type": "code",
"execution_count": null,
+ "id": "d9cd843f",
+ "metadata": {},
+ "outputs": [],
"source": [
"import torch\n",
"from alonet.common import load_training \n",
@@ -196,20 +205,43 @@
" args = args, \n",
" model = detr_finetune)\n",
"lit_deformable.model.to(device)"
- ],
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "044fc41a",
+ "metadata": {},
+ "source": [
+ "Or in another way"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "66593b0b",
+ "metadata": {},
"outputs": [],
- "metadata": {}
+ "source": [
+ "detr_finetune = DeformableDetrR50RefinementFinetune(\n",
+ " num_classes = num_classes, \n",
+ " weights = \"/path/trained/weights.ckpt\" or \"/path/trained/weights.pth\"\n",
+ ")"
+ ]
},
{
"cell_type": "markdown",
+ "id": "1a58dfd4",
+ "metadata": {},
"source": [
"This enables to use the valid dataset and show some results:"
- ],
- "metadata": {}
+ ]
},
{
"cell_type": "code",
"execution_count": null,
+ "id": "d4125230",
+ "metadata": {},
+ "outputs": [],
"source": [
"frames = next(iter(coco_loader.val_dataloader()))\n",
"frames = frames[0].batch_list(frames).to(device)\n",
@@ -220,12 +252,12 @@
" gt_boxes.get_view(frames[0], title=\"Ground truth boxes\"),\n",
" pred_boxes.get_view(frames[0], title=\"Predicted boxes\"),\n",
"]).render()"
- ],
- "outputs": [],
- "metadata": {}
+ ]
},
{
"cell_type": "markdown",
+ "id": "674eb55f",
+ "metadata": {},
"source": [
"
\n",
"\n",
@@ -236,21 +268,24 @@
"
\n",
"\n",
"[DETR/Deformable DETR models to tensorRT]: tensort_inference.rst"
- ],
- "metadata": {}
+ ]
},
{
"cell_type": "markdown",
+ "id": "13e998c8",
+ "metadata": {},
"source": [
"## 3. Optional: Make prediction in camera\n",
"\n",
"If there is access to a local camera, the following code would allow you to take snapshots with the camera and make predictions at the same time:"
- ],
- "metadata": {}
+ ]
},
{
"cell_type": "code",
"execution_count": null,
+ "id": "83eb5696",
+ "metadata": {},
+ "outputs": [],
"source": [
"%matplotlib inline\n",
"import cv2\n",
@@ -283,9 +318,7 @@
" cv2.destroyAllWindows()\n",
"else:\n",
" print(\"[ERROR] Impossible to open camera\")"
- ],
- "outputs": [],
- "metadata": {}
+ ]
}
],
"metadata": {
@@ -309,4 +342,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
-}
\ No newline at end of file
+}
diff --git a/docsource/source/tutorials/finetuning_detr.ipynb b/docsource/source/tutorials/finetuning_detr.ipynb
index d3290edf..6005f13f 100644
--- a/docsource/source/tutorials/finetuning_detr.ipynb
+++ b/docsource/source/tutorials/finetuning_detr.ipynb
@@ -51,7 +51,7 @@
"id": "fcd240d4",
"metadata": {},
"source": [
- "Its statement is the same as [Detr R50 Finetune], with difference that now `num_classes` **attribute is mandatory**:\n",
+ "Its statement is the same as [Detr R50 Finetune], with the difference that now `num_classes` **attribute is mandatory**:\n",
"\n",
"[Detr R50 Finetune]: ../alonet/detr_models.rst#module-alonet.detr.detr_r50_finetune"
]
@@ -68,6 +68,35 @@
"detr_finetune = DetrR50Finetune(num_classes = 2)"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "0ce4fbee",
+ "metadata": {},
+ "source": [
+ "By default, all *finetune* models load the pretrained weights of its base models trained on [COCO detection 2017 dataset]. For [Detr R50 Finetune], this weights corresponds to [Detr R50] weights. \n",
+ "\n",
+ "If required, the weights of the base model can be changed by specifying the file path through the `base_weights` parameter:\n",
+ "\n",
+ "[Detr R50]: ../alonet/detr_models.rst#module-alonet.detr.detr_r50\n",
+ "[Detr R50 Finetune]: ../alonet/detr_models.rst#module-alonet.detr.detr_r50_finetune\n",
+ "[COCO detection 2017 dataset]: https://cocodataset.org/#detection-2017"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "d0a75dd7",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from alonet.detr import DetrR50Finetune\n",
+ "\n",
+ "detr_finetune = DetrR50Finetune(\n",
+ " num_classes = 2,\n",
+ " base_weights = \"/path/base/weights.pth\" or \"/path/base/weights.ckpt\" # By default \"detr-r50\"\n",
+ ")"
+ ]
+ },
{
"cell_type": "markdown",
"id": "07f5927f",
@@ -124,10 +153,10 @@
"\n",
"# Define COCO dataset as pl.LightningDataModule for only animals\n",
"pets = ['cat', 'dog']\n",
- "coco_loader = CocoDetection2Detr(classes = pets)\n",
+ "coco_loader = CocoDetection2Detr(classes = pets) # Load detr-r50 pretrained weights\n",
"\n",
- "# Define architecture as pl.LightningModule, using PRETRAINED WEIGHTS\n",
- "lit_detr = LitDetr(model = DetrR50Finetune(len(pets), weights = 'detr-r50'))\n",
+ "# Define architecture as pl.LightningModule\n",
+ "lit_detr = LitDetr(model = DetrR50Finetune(len(pets))\n",
"\n",
"# Start train loop\n",
"args.max_epochs = 5 # Due to finetune, we just need 5 epochs to train this model\n",
@@ -225,6 +254,27 @@
"lit_detr.model.to(device)"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "baa2066d",
+ "metadata": {},
+ "source": [
+ "Also, we can load the weights from the absolute path:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "83c594aa",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "detr_finetune = DetrR50Finetune(\n",
+ " num_classes = len(pets), \n",
+ " weights = \"/path/trained/weights.ckpt\" or \"/path/trained/weights.pth\"\n",
+ ")"
+ ]
+ },
{
"cell_type": "markdown",
"id": "b3fb76ac",