Skip to content

Latest commit

 

History

History
265 lines (191 loc) · 6.03 KB

customize_models.md

File metadata and controls

265 lines (191 loc) · 6.03 KB

Tutorial 3: Customize Models

Add a new classifier

Here we show how to develop a new classifier with an example as follows

1. Define a new classifier

Create a new file mmfewshot/classification/models/classifiers/my_classifier.py.

from mmcls.models.builder import CLASSIFIERS
from .base import BaseFewShotClassifier

@CLASSIFIERS.register_module()
class MyClassifier(BaseFewShotClassifier):

    def __init__(self, arg1, arg2):
        pass

    # customize input for different mode
    # the input should keep consistent with the dataset
    def forward(self, img, mode='train',**kwargs):
        if mode == 'train':
            return self.forward_train(img=img, **kwargs)
        elif mode == 'query':
            return self.forward_query(img=img,  **kwargs)
        elif mode == 'support':
            return self.forward_support(img=img, **kwargs)
        elif mode == 'extract_feat':
            assert img is not None
            return self.extract_feat(img=img)
        else:
            raise ValueError()

    # customize forward function for training data
    def forward_train(self, img, gt_label, **kwargs):
        pass

    # customize forward function for meta testing support data
    def forward_support(self, img, gt_label, **kwargs):
        pass

    # customize forward function for meta testing query data
    def forward_query(self, img):
        pass

    # prepare meta testing
    def before_meta_test(self, meta_test_cfg, **kwargs):
        pass

    # prepare forward meta testing query images
    def before_forward_support(self, **kwargs):
        pass

    # prepare forward meta testing support images
    def before_forward_query(self, **kwargs):
        pass

2. Import the module

You can either add the following line to mmfewshot/classification/models/heads/__init__.py

from .my_classifier import MyClassifier

or alternatively add

custom_imports = dict(
    imports=['mmfewshot.classification.models.classifier.my_classifier'],
    allow_failed_imports=False)

to the config file to avoid modifying the original code.

3. Use the classifier in your config file

model = dict(
    type="MyClassifier",
    ...
)

Add a new backbone

Here we show how to develop a new backbone with an example as follows

1. Define a new backbone

Create a new file mmfewshot/classification/models/backbones/mynet.py.

import torch.nn as nn

from mmcls.models.builder import BACKBONES

@BACKBONES.register_module()
class MyNet(nn.Module):

    def __init__(self, arg1, arg2):
        pass

    def forward(self, x):  # should return a tensor
        pass

2. Import the module

You can either add the following line to mmfewshot/classification/models/backbones/__init__.py

from .mynet import MyNet

or alternatively add

custom_imports = dict(
    imports=['mmfewshot.classification.models.backbones.mynet'],
    allow_failed_imports=False)

to the config file to avoid modifying the original code.

3. Use the backbone in your config file

model = dict(
    ...
    backbone=dict(
        type='MyNet',
        arg1=xxx,
        arg2=xxx),
    ...

Add new heads

Here we show how to develop a new head with an example as follows

1. Define a new head

Create a new file mmfewshot/classification/models/heads/myhead.py.

from mmcls.models.builder import HEADS
from .base_head import BaseFewShotHead

@HEADS.register_module()
class MyHead(BaseFewShotHead):

    def __init__(self, arg1, arg2) -> None:
        pass

    def forward_train(self, x, gt_label, **kwargs):
        pass

    def forward_support(self, x, gt_label, **kwargs):
        pass

    def forward_query(self, x, **kwargs):
        pass

    def before_forward_support(self) -> None:
        pass

    def before_forward_query(self) -> None:
        pass

2. Import the module

You can either add the following line to mmfewshot/classification/models/heads/__init__.py

from .myhead import MyHead

or alternatively add

custom_imports = dict(
    imports=['mmfewshot.classification.models.backbones.myhead'],
    allow_failed_imports=False)

to the config file to avoid modifying the original code.

3. Use the head in your config file

model = dict(
    ...
    head=dict(
        type='MyHead',
        arg1=xxx,
        arg2=xxx),
    ...

Add new loss

To add a new loss function, the users need implement it in mmfewshot/classification/models/losses/my_loss.py. The decorator weighted_loss enable the loss to be weighted for each element.

import torch
import torch.nn as nn

from ..builder import LOSSES
from .utils import weighted_loss

@weighted_loss
def my_loss(pred, target):
    assert pred.size() == target.size() and target.numel() > 0
    loss = torch.abs(pred - target)
    return loss

@LOSSES.register_module()
class MyLoss(nn.Module):

    def __init__(self, reduction='mean', loss_weight=1.0):
        super(MyLoss, self).__init__()
        self.reduction = reduction
        self.loss_weight = loss_weight

    def forward(self,
                pred,
                target,
                weight=None,
                avg_factor=None,
                reduction_override=None):
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
        loss_bbox = self.loss_weight * my_loss(
            pred, target, weight, reduction=reduction, avg_factor=avg_factor)
        return loss_bbox

Then the users need to add it in the mmfewshot/classification/models/losses/__init__.py.

from .my_loss import MyLoss, my_loss

Alternatively, you can add

custom_imports=dict(
    imports=['mmfewshot.classification.models.losses.my_loss'])

to the config file and achieve the same goal.

To use it, modify the loss_xxx field. Since MyLoss is for regression, you need to modify the loss_bbox field in the head.

loss_bbox=dict(type='MyLoss', loss_weight=1.0))