Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor to allow for a wider model in TIMM #3976

Merged
merged 53 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
d1bd1d5
update for releases 2.2.0rc0
yunchu Aug 20, 2024
c16f985
Fix Classification explain forward issue (#3867)
harimkang Aug 21, 2024
cba5120
Fix e2e code error (#3871)
chuneuny-emily Aug 21, 2024
b807c9d
Add documentation about configurable input size (#3870)
eunwoosh Aug 21, 2024
2835aba
Fix zero-shot e2e (#3876)
sungchul2 Aug 23, 2024
ccf2d50
Fix DeiT for multi-label classification (#3881)
harimkang Aug 23, 2024
e577b6a
Fix Semi-SL for ViT accuracy drop (#3883)
harimkang Aug 23, 2024
d1dd2b0
Update docs for 2.2 (#3884)
harimkang Aug 23, 2024
c17a923
Fix mean and scale for segmentation task (#3885)
kprokofi Aug 23, 2024
d72feeb
Update MAPI in 2.2 (#3889)
sovrasov Aug 26, 2024
00ed3a0
Improve Semi-SL for LiteHRNet (small-medium case) (#3891)
kprokofi Aug 26, 2024
2c6b4de
Improve h-cls for eff models (#3893)
sooahleex Aug 26, 2024
0dc7a29
Fix maskrcnn swin nncf acc drop (#3900)
eugene123tw Aug 27, 2024
0d6799c
Add keypoint detection recipe for single object cases (#3903)
wonjuleee Aug 28, 2024
8115b52
Improve acc drop of efficientnetv2 for h-label cls (#3907)
sooahleex Aug 29, 2024
4c8555e
Fix pretrained weight cached dir for timm (#3909)
harimkang Aug 29, 2024
52221e3
Fix keypoint detection single obj recipe (#3915)
wonjuleee Aug 30, 2024
9265c59
Fix cached dir for timm & hugging-face (#3914)
harimkang Aug 30, 2024
5170736
Fix wrong template id mapping for anomaly (#3916)
yunchu Aug 30, 2024
f611cc1
Update script to allow setting otx version using env. variable (#3913)
yunchu Aug 30, 2024
425a479
Fix Datamodule creation for OV in AutoConfigurator (#3920)
harimkang Sep 2, 2024
7f1c7da
Update tpp file for 2.2.0 (#3921)
yunchu Sep 2, 2024
51d1adf
Fix names for ignored scope [HOT-FIX, 2.2.0] (#3924)
kprokofi Sep 3, 2024
2bcf1b2
Fix classification rt_info (#3922)
sovrasov Sep 3, 2024
112b2b2
Update label info (#3925)
ashwinvaidya17 Sep 4, 2024
929132d
Fix binary classification metric task (#3928)
harimkang Sep 5, 2024
706f99b
Improve MaskRCNN SwinT NNCF (#3929)
eugene123tw Sep 5, 2024
53a7d9a
Fix get_item for Chained Tasks in Classification (#3931)
harimkang Sep 5, 2024
c3749e3
Correct Keyerror for h-label cls in label_groups for dm_label_categor…
sooahleex Sep 5, 2024
98a9cac
Remove datumaro attribute id from tiling, add subset names (#3933)
eugene123tw Sep 6, 2024
d8e6454
Fix soft predictions for Semantic Segmentation (#3934)
kprokofi Sep 6, 2024
c2705df
Update STFPM config (#3935)
ashwinvaidya17 Sep 6, 2024
c2ccfc9
Add missing pretrained weights when creating a docker image (#3938)
harimkang Sep 6, 2024
8b747f9
Change default option 'full' to 'base' in otx install (#3937)
harimkang Sep 9, 2024
d43226e
Fix auto adapt batch size in Converter (#3939)
harimkang Sep 9, 2024
1d319cd
Fix hpo converter (#3940)
eunwoosh Sep 9, 2024
aaa2765
Fix tiling XAI out of range (#3943)
eugene123tw Sep 9, 2024
ac87b49
enable model export (#3952)
ashwinvaidya17 Sep 12, 2024
8f96f27
Move templates from OTX1.X to OTX2.X (#3951)
kprokofi Sep 12, 2024
0f87c86
Add missing tile recipes and various tile recipe changes (#3942)
eugene123tw Sep 12, 2024
c7efcbc
Support ImageFromBytes (#3948)
ashwinvaidya17 Sep 12, 2024
ecef545
Change categories mapping logic (#3946)
kprokofi Sep 13, 2024
b1ec8e7
Update for 2.2.0rc1 (#3956)
yunchu Sep 13, 2024
aa31dca
Include Geti arrow dataset subset names (#3962)
eugene123tw Sep 20, 2024
93f1a55
Include full image with anno in case there's no tile in tile dataset …
eugene123tw Sep 20, 2024
45f9a24
Add type checker in converter for callable functions (optimizer, sche…
harimkang Sep 20, 2024
51fcb73
Update for 2.2.0rc2 (#3969)
yunchu Sep 20, 2024
e25448b
Refactor TIMM
harimkang Sep 24, 2024
2416aea
Remove experimental recipes
harimkang Sep 24, 2024
2e18fba
Revert timm version
harimkang Sep 24, 2024
7f573e9
Fix conflict
harimkang Sep 26, 2024
de02ab5
Fix conflict2
harimkang Sep 26, 2024
368a91c
Fix unit-test
harimkang Sep 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 12 additions & 27 deletions src/otx/algo/classification/backbones/timm.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,47 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""EfficientNetV2 model.
"""Timm Backbone Class for OTX classification.

Original papers:
- 'EfficientNetV2: Smaller Models and Faster Training,' https://arxiv.org/abs/2104.00298,
- 'Adversarial Examples Improve Image Recognition,' https://arxiv.org/abs/1911.09665.
"""
from __future__ import annotations

from typing import Literal

import timm
import torch
from torch import nn

TimmModelType = Literal[
"mobilenetv3_large_100_miil_in21k",
"mobilenetv3_large_100_miil",
"tresnet_m",
"tf_efficientnetv2_s.in21k",
"tf_efficientnetv2_s.in21ft1k",
"tf_efficientnetv2_m.in21k",
"tf_efficientnetv2_m.in21ft1k",
"tf_efficientnetv2_b0",
]


class TimmBackbone(nn.Module):
"""Timm backbone model."""
"""Timm backbone model.

Args:
model_name (str): The name of the model.
You can find available models at timm.list_models() or timm.list_pretrained().
pretrained (bool, optional): Whether to load pretrained weights. Defaults to False.
"""

def __init__(
self,
backbone: TimmModelType,
model_name: str,
pretrained: bool = False,
pooling_type: str = "avg",
**kwargs,
):
super().__init__(**kwargs)
self.backbone = backbone
self.model_name = model_name
self.pretrained: bool | dict = pretrained
self.is_mobilenet = backbone.startswith("mobilenet")
eunwoosh marked this conversation as resolved.
Show resolved Hide resolved

self.model = timm.create_model(
self.backbone,
self.model_name,
pretrained=pretrained,
num_classes=1000,
)

self.model.classifier = None # Detach classifier. Only use 'backbone' part in otx.
self.num_head_features = self.model.num_features
self.num_features = self.model.conv_head.in_channels if self.is_mobilenet else self.model.num_features
self.pooling_type = pooling_type
self.num_features = self.model.num_features

def forward(self, x: torch.Tensor, **kwargs) -> tuple[torch.Tensor]:
"""Forward."""
Expand All @@ -60,11 +50,6 @@ def forward(self, x: torch.Tensor, **kwargs) -> tuple[torch.Tensor]:

def extract_features(self, x: torch.Tensor) -> torch.Tensor:
"""Extract features."""
if self.is_mobilenet:
x = self.model.conv_stem(x)
x = self.model.bn1(x)
x = self.model.act1(x)
return self.model.blocks(x)
return self.model.forward_features(x)

def get_config_optim(self, lrs: list[float] | float) -> list[dict[str, float]]:
Expand Down
104 changes: 90 additions & 14 deletions src/otx/algo/classification/timm_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""EfficientNetV2 model implementation."""
"""TIMM wrapper model class for OTX."""

from __future__ import annotations

Expand All @@ -12,7 +12,7 @@
import torch
from torch import nn

from otx.algo.classification.backbones.timm import TimmBackbone, TimmModelType
from otx.algo.classification.backbones.timm import TimmBackbone
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
from otx.algo.classification.heads import (
HierarchicalCBAMClsHead,
Expand Down Expand Up @@ -50,12 +50,38 @@


class TimmModelForMulticlassCls(OTXMulticlassClsModel):
"""TimmModel for multi-class classification task."""
"""TimmModel for multi-class classification task.

Args:
label_info (LabelInfoTypes): The label information for the classification task.
model_name (str): The name of the model.
You can find available models at timm.list_models() or timm.list_pretrained().
input_size (tuple[int, int], optional): Model input size in the order of height and width.
Defaults to (224, 224).
pretrained (bool, optional): Whether to load pretrained weights. Defaults to True.
optimizer (OptimizerCallable, optional): The optimizer callable for training the model.
scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): The learning rate scheduler callable.
metric (MetricCallable, optional): The metric callable for evaluating the model.
Defaults to MultiClassClsMetricCallable.
torch_compile (bool, optional): Whether to compile the model using TorchScript. Defaults to False.
train_type (Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED], optional): The training type.

Example:
1. API
>>> model = TimmModelForMulticlassCls(
... model_name="tf_efficientnetv2_s.in21k",
... label_info=<Number-of-classes>,
... )
2. CLI
>>> otx train \
... --model otx.algo.classification.timm_model.TimmModelForMulticlassCls \
... --model.model_name tf_efficientnetv2_s.in21k
"""

def __init__(
self,
label_info: LabelInfoTypes,
backbone: TimmModelType,
model_name: str,
input_size: tuple[int, int] = (224, 224), # input size of default classification data recipe
pretrained: bool = True,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
Expand All @@ -64,7 +90,7 @@ def __init__(
torch_compile: bool = False,
train_type: Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED] = OTXTrainType.SUPERVISED,
) -> None:
self.backbone = backbone
self.model_name = model_name
self.pretrained = pretrained

super().__init__(
Expand Down Expand Up @@ -92,7 +118,7 @@ def _create_model(self) -> nn.Module:
return model

def _build_model(self, num_classes: int) -> nn.Module:
backbone = TimmBackbone(backbone=self.backbone, pretrained=self.pretrained)
backbone = TimmBackbone(model_name=self.model_name, pretrained=self.pretrained)
neck = GlobalAveragePooling(dim=2)
if self.train_type == OTXTrainType.SEMI_SUPERVISED:
return SemiSLClassifier(
Expand Down Expand Up @@ -142,20 +168,45 @@ def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | dict[str, t


class TimmModelForMultilabelCls(OTXMultilabelClsModel):
"""TimmModel for multi-label classification task."""
"""TimmModel for multi-label classification task.

Args:
label_info (LabelInfoTypes): The label information for the classification task.
model_name (str): The name of the model.
You can find available models at timm.list_models() or timm.list_pretrained().
input_size (tuple[int, int], optional): Model input size in the order of height and width.
Defaults to (224, 224).
pretrained (bool, optional): Whether to load pretrained weights. Defaults to True.
optimizer (OptimizerCallable, optional): The optimizer callable for training the model.
scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): The learning rate scheduler callable.
metric (MetricCallable, optional): The metric callable for evaluating the model.
Defaults to MultiLabelClsMetricCallable.
torch_compile (bool, optional): Whether to compile the model using TorchScript. Defaults to False.

Example:
1. API
>>> model = TimmModelForMultilabelCls(
... model_name="tf_efficientnetv2_s.in21k",
... label_info=<Number-of-classes>,
... )
2. CLI
>>> otx train \
... --model otx.algo.classification.timm_model.TimmModelForMultilabelCls \
... --model.model_name tf_efficientnetv2_s.in21k
"""

def __init__(
self,
label_info: LabelInfoTypes,
backbone: TimmModelType,
model_name: str,
input_size: tuple[int, int] = (224, 224), # input size of default classification data recipe
pretrained: bool = True,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiLabelClsMetricCallable,
torch_compile: bool = False,
) -> None:
self.backbone = backbone
self.model_name = model_name
self.pretrained = pretrained

super().__init__(
Expand All @@ -182,7 +233,7 @@ def _create_model(self) -> nn.Module:
return model

def _build_model(self, num_classes: int) -> nn.Module:
backbone = TimmBackbone(backbone=self.backbone, pretrained=self.pretrained)
backbone = TimmBackbone(model_name=self.model_name, pretrained=self.pretrained)
return ImageClassifier(
backbone=backbone,
neck=GlobalAveragePooling(dim=2),
Expand Down Expand Up @@ -222,22 +273,47 @@ def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | dict[str, t


class TimmModelForHLabelCls(OTXHlabelClsModel):
"""EfficientNetV2 Model for hierarchical label classification task."""
"""Timm Model for hierarchical label classification task.

Args:
label_info (HLabelInfo): The label information for the classification task.
model_name (str): The name of the model.
You can find available models at timm.list_models() or timm.list_pretrained().
input_size (tuple[int, int], optional): Model input size in the order of height and width.
Defaults to (224, 224).
pretrained (bool, optional): Whether to load pretrained weights. Defaults to True.
optimizer (OptimizerCallable, optional): The optimizer callable for training the model.
scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): The learning rate scheduler callable.
metric (MetricCallable, optional): The metric callable for evaluating the model.
Defaults to HLabelClsMetricCallable.
torch_compile (bool, optional): Whether to compile the model using TorchScript. Defaults to False.

Example:
1. API
>>> model = TimmModelForHLabelCls(
... model_name="tf_efficientnetv2_s.in21k",
... label_info=<h-label-info>,
... )
2. CLI
>>> otx train \
... --model otx.algo.classification.timm_model.TimmModelForHLabelCls \
... --model.model_name tf_efficientnetv2_s.in21k
"""

label_info: HLabelInfo

def __init__(
self,
label_info: HLabelInfo,
backbone: TimmModelType,
model_name: str,
input_size: tuple[int, int] = (224, 224), # input size of default classification data recipe
pretrained: bool = True,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = HLabelClsMetricCallable,
torch_compile: bool = False,
) -> None:
self.backbone = backbone
self.model_name = model_name
self.pretrained = pretrained

super().__init__(
Expand Down Expand Up @@ -267,7 +343,7 @@ def _create_model(self) -> nn.Module:
return model

def _build_model(self, head_config: dict) -> nn.Module:
backbone = TimmBackbone(backbone=self.backbone, pretrained=self.pretrained)
backbone = TimmBackbone(model_name=self.model_name, pretrained=self.pretrained)
copied_head_config = copy(head_config)
copied_head_config["step_size"] = (ceil(self.input_size[0] / 32), ceil(self.input_size[1] / 32))
return HLabelClassifier(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
model:
class_path: otx.algo.classification.timm_model.TimmModelForHLabelCls
init_args:
backbone: tf_efficientnetv2_s.in21k
model_name: tf_efficientnetv2_s.in21k

optimizer:
class_path: torch.optim.SGD
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ model:
class_path: otx.algo.classification.timm_model.TimmModelForMulticlassCls
init_args:
label_info: 1000
backbone: tf_efficientnetv2_s.in21k
model_name: tf_efficientnetv2_s.in21k

optimizer:
class_path: torch.optim.SGD
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ model:
class_path: otx.algo.classification.timm_model.TimmModelForMulticlassCls
init_args:
label_info: 1000
backbone: tf_efficientnetv2_s.in21k
model_name: tf_efficientnetv2_s.in21k
train_type: SEMI_SUPERVISED

optimizer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ model:
class_path: otx.algo.classification.timm_model.TimmModelForMultilabelCls
init_args:
label_info: 1000
backbone: tf_efficientnetv2_s.in21k
model_name: tf_efficientnetv2_s.in21k

optimizer:
class_path: torch.optim.SGD
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/algo/classification/backbones/test_timm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@

class TestOTXEfficientNetV2:
def test_forward(self):
model = TimmBackbone(backbone="tf_efficientnetv2_s.in21k")
model = TimmBackbone(model_name="tf_efficientnetv2_s.in21k")
assert model(torch.randn(1, 3, 244, 244))[0].shape == torch.Size([1, 1280, 8, 8])

def test_get_config_optim(self):
model = TimmBackbone(backbone="tf_efficientnetv2_s.in21k")
model = TimmBackbone(model_name="tf_efficientnetv2_s.in21k")
assert model.get_config_optim([0.01])[0]["lr"] == 0.01
assert model.get_config_optim(0.01)[0]["lr"] == 0.01

Expand All @@ -24,5 +24,5 @@ def test_check_pretrained_weight_download(self):
if target.exists():
shutil.rmtree(target)
assert not target.exists()
TimmBackbone(backbone="tf_efficientnetv2_s.in21k", pretrained=True)
TimmBackbone(model_name="tf_efficientnetv2_s.in21k", pretrained=True)
assert target.exists()
6 changes: 3 additions & 3 deletions tests/unit/algo/classification/test_timm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
def fxt_multi_class_cls_model():
return TimmModelForMulticlassCls(
label_info=10,
backbone="tf_efficientnetv2_s.in21k",
model_name="tf_efficientnetv2_s.in21k",
)


Expand Down Expand Up @@ -59,7 +59,7 @@ def test_predict_step(self, fxt_multi_class_cls_model, fxt_multiclass_cls_batch_
def fxt_multi_label_cls_model():
return TimmModelForMultilabelCls(
label_info=10,
backbone="tf_efficientnetv2_s.in21k",
model_name="tf_efficientnetv2_s.in21k",
)


Expand Down Expand Up @@ -97,7 +97,7 @@ def test_predict_step(self, fxt_multi_label_cls_model, fxt_multilabel_cls_batch_
def fxt_h_label_cls_model(fxt_hlabel_cifar):
return TimmModelForHLabelCls(
label_info=fxt_hlabel_cifar,
backbone="tf_efficientnetv2_s.in21k",
model_name="tf_efficientnetv2_s.in21k",
)


Expand Down
Loading