-
Notifications
You must be signed in to change notification settings - Fork 1
/
base.py
100 lines (80 loc) · 3.26 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from abc import ABCMeta, abstractmethod
import torch
import torch.nn as nn
from ...core import top_k_accuracy
from ..builder import build_loss
class AvgConsensus(nn.Module):
"""Average consensus module.
Args:
dim (int): Decide which dim consensus function to apply.
Default: 1.
"""
def __init__(self, dim=1):
super().__init__()
self.dim = dim
def forward(self, input):
"""Defines the computation performed at every call."""
return input.mean(dim=self.dim, keepdim=True)
class BaseHead(nn.Module, metaclass=ABCMeta):
"""Base class for head.
All Head should subclass it.
All subclass should overwrite:
- Methods:``init_weights``, initializing weights in some modules.
- Methods:``forward``, supporting to forward both for training and testing.
Args:
num_classes (int): Number of classes to be classified.
in_channels (int): Number of channels in input feature.
loss_cls (dict): Config for building loss.
Default: dict(type='CrossEntropyLoss').
multi_class (bool): Determines whether it is a multi-class
recognition task. Default: False.
label_smooth_eps (float): Epsilon used in label smooth.
Reference: arxiv.org/abs/1906.02629. Default: 0.
loss_factor (float): Factor scalar multiplied on the loss.
Default: 1.0.
"""
def __init__(self,
num_classes,
in_channels,
loss_cls=dict(type='CrossEntropyLoss', loss_factor=1.0),
multi_class=False,
label_smooth_eps=0.0):
super().__init__()
self.num_classes = num_classes
self.in_channels = in_channels
self.loss_cls = build_loss(loss_cls)
self.multi_class = multi_class
self.label_smooth_eps = label_smooth_eps
@abstractmethod
def init_weights(self):
"""Initiate the parameters either from existing checkpoint or from
scratch."""
pass
@abstractmethod
def forward(self, x):
"""Defines the computation performed at every call."""
pass
def loss(self, cls_score, labels):
"""Calculate the loss given output ``cls_score`` and target ``labels``.
Args:
cls_score (torch.Tensor): The output of the model.
labels (torch.Tensor): The target output of the model.
Returns:
dict: A dict containing field 'loss_cls'(mandatory)
and 'top1_acc', 'top5_acc'(optional).
"""
losses = dict()
if labels.shape == torch.Size([]):
labels = labels.unsqueeze(0)
if not self.multi_class:
top_k_acc = top_k_accuracy(cls_score.detach().cpu().numpy(),
labels.detach().cpu().numpy(), (1, 5))
losses['top1_acc'] = torch.tensor(
top_k_acc[0], device=cls_score.device)
losses['top5_acc'] = torch.tensor(
top_k_acc[1], device=cls_score.device)
elif self.label_smooth_eps != 0:
labels = ((1 - self.label_smooth_eps) * labels +
self.label_smooth_eps / self.num_classes)
losses['loss_cls'] = self.loss_cls(cls_score, labels)
return losses