Here we show how to develop a new classifier with an example as follows
Create a new file mmfewshot/classification/models/classifiers/my_classifier.py
.
from mmcls.models.builder import CLASSIFIERS
from .base import BaseFewShotClassifier
@CLASSIFIERS.register_module()
class MyClassifier(BaseFewShotClassifier):
def __init__(self, arg1, arg2):
pass
# customize input for different mode
# the input should keep consistent with the dataset
def forward(self, img, mode='train',**kwargs):
if mode == 'train':
return self.forward_train(img=img, **kwargs)
elif mode == 'query':
return self.forward_query(img=img, **kwargs)
elif mode == 'support':
return self.forward_support(img=img, **kwargs)
elif mode == 'extract_feat':
assert img is not None
return self.extract_feat(img=img)
else:
raise ValueError()
# customize forward function for training data
def forward_train(self, img, gt_label, **kwargs):
pass
# customize forward function for meta testing support data
def forward_support(self, img, gt_label, **kwargs):
pass
# customize forward function for meta testing query data
def forward_query(self, img):
pass
# prepare meta testing
def before_meta_test(self, meta_test_cfg, **kwargs):
pass
# prepare forward meta testing query images
def before_forward_support(self, **kwargs):
pass
# prepare forward meta testing support images
def before_forward_query(self, **kwargs):
pass
You can either add the following line to mmfewshot/classification/models/heads/__init__.py
from .my_classifier import MyClassifier
or alternatively add
custom_imports = dict(
imports=['mmfewshot.classification.models.classifier.my_classifier'],
allow_failed_imports=False)
to the config file to avoid modifying the original code.
model = dict(
type="MyClassifier",
...
)
Here we show how to develop a new backbone with an example as follows
Create a new file mmfewshot/classification/models/backbones/mynet.py
.
import torch.nn as nn
from mmcls.models.builder import BACKBONES
@BACKBONES.register_module()
class MyNet(nn.Module):
def __init__(self, arg1, arg2):
pass
def forward(self, x): # should return a tensor
pass
You can either add the following line to mmfewshot/classification/models/backbones/__init__.py
from .mynet import MyNet
or alternatively add
custom_imports = dict(
imports=['mmfewshot.classification.models.backbones.mynet'],
allow_failed_imports=False)
to the config file to avoid modifying the original code.
model = dict(
...
backbone=dict(
type='MyNet',
arg1=xxx,
arg2=xxx),
...
Here we show how to develop a new head with an example as follows
Create a new file mmfewshot/classification/models/heads/myhead.py
.
from mmcls.models.builder import HEADS
from .base_head import BaseFewShotHead
@HEADS.register_module()
class MyHead(BaseFewShotHead):
def __init__(self, arg1, arg2) -> None:
pass
def forward_train(self, x, gt_label, **kwargs):
pass
def forward_support(self, x, gt_label, **kwargs):
pass
def forward_query(self, x, **kwargs):
pass
def before_forward_support(self) -> None:
pass
def before_forward_query(self) -> None:
pass
You can either add the following line to mmfewshot/classification/models/heads/__init__.py
from .myhead import MyHead
or alternatively add
custom_imports = dict(
imports=['mmfewshot.classification.models.backbones.myhead'],
allow_failed_imports=False)
to the config file to avoid modifying the original code.
model = dict(
...
head=dict(
type='MyHead',
arg1=xxx,
arg2=xxx),
...
To add a new loss function, the users need implement it in mmfewshot/classification/models/losses/my_loss.py
.
The decorator weighted_loss
enable the loss to be weighted for each element.
import torch
import torch.nn as nn
from ..builder import LOSSES
from .utils import weighted_loss
@weighted_loss
def my_loss(pred, target):
assert pred.size() == target.size() and target.numel() > 0
loss = torch.abs(pred - target)
return loss
@LOSSES.register_module()
class MyLoss(nn.Module):
def __init__(self, reduction='mean', loss_weight=1.0):
super(MyLoss, self).__init__()
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss_bbox = self.loss_weight * my_loss(
pred, target, weight, reduction=reduction, avg_factor=avg_factor)
return loss_bbox
Then the users need to add it in the mmfewshot/classification/models/losses/__init__.py
.
from .my_loss import MyLoss, my_loss
Alternatively, you can add
custom_imports=dict(
imports=['mmfewshot.classification.models.losses.my_loss'])
to the config file and achieve the same goal.
To use it, modify the loss_xxx
field.
Since MyLoss is for regression, you need to modify the loss_bbox
field in the head.
loss_bbox=dict(type='MyLoss', loss_weight=1.0))