-
Notifications
You must be signed in to change notification settings - Fork 0
/
SDMG_Model.py
32 lines (26 loc) · 984 Bytes
/
SDMG_Model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from torch import nn
from SDMG_Head import SDMGRHead
from SDMG_Neck import SdmgNeck
from SDMG_Backbone import UNet
class SDMG_R(nn.Module):
def __init__(self):
super(SDMG_R, self).__init__()
self.backbone = UNet(base_channels=16)
self.neck = SdmgNeck()
self.head = SDMGRHead(num_chars=92, num_classes=26)
def _prepare(self, pic, relations, texts, gt_bboxes):
batch_pic = pic
rela = []
txt = []
bboxes = []
for batch_idx, _tag in enumerate(texts):
rela.append(relations[batch_idx, :, :, :])
txt.append(texts[batch_idx, :, :])
bboxes.append(gt_bboxes[batch_idx, :, :])
return batch_pic, rela, txt, bboxes
def forward(self, pic, relations, texts, gt_bboxes):
img, rela, txt, bbox = self._prepare(pic, relations, texts, gt_bboxes)
x = self.backbone(img)
x = self.neck(x, bbox)
x = self.head(rela, txt, x)
return x