-
Notifications
You must be signed in to change notification settings - Fork 0
/
skip_transformer.py
70 lines (55 loc) · 2.39 KB
/
skip_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#! /usr/bin/python3
# -*- coding: utf-8 -*-
# @Author: Peng Xiang
import torch
from torch import nn, einsum
from models.utils import MLP_Res, grouping_operation, query_knn
class SkipTransformer(nn.Module):
def __init__(self, in_channel, dim=256, n_knn=16, pos_hidden_dim=64, attn_hidden_multiplier=4):
super(SkipTransformer, self).__init__()
self.mlp_v = MLP_Res(in_dim=in_channel*2, hidden_dim=in_channel, out_dim=in_channel)
self.n_knn = n_knn
self.conv_key = nn.Conv1d(in_channel, dim, 1)
self.conv_query = nn.Conv1d(in_channel, dim, 1)
self.conv_value = nn.Conv1d(in_channel, dim, 1)
self.pos_mlp = nn.Sequential(
nn.Conv2d(3, pos_hidden_dim, 1),
nn.BatchNorm2d(pos_hidden_dim),
nn.ReLU(),
nn.Conv2d(pos_hidden_dim, dim, 1)
)
self.attn_mlp = nn.Sequential(
nn.Conv2d(dim, dim * attn_hidden_multiplier, 1),
nn.BatchNorm2d(dim * attn_hidden_multiplier),
nn.ReLU(),
nn.Conv2d(dim * attn_hidden_multiplier, dim, 1)
)
self.conv_end = nn.Conv1d(dim, in_channel, 1)
def forward(self, pos, key, query, include_self=True):
"""
Args:
pos: (B, 3, N)
key: (B, in_channel, N)
query: (B, in_channel, N)
include_self: boolean
Returns:
Tensor: (B, in_channel, N), shape context feature
"""
value = self.mlp_v(torch.cat([key, query], 1))
identity = value
key = self.conv_key(key)
query = self.conv_query(query)
value = self.conv_value(value)
b, dim, n = value.shape
pos_flipped = pos.permute(0, 2, 1).contiguous()
idx_knn = query_knn(self.n_knn, pos_flipped, pos_flipped, include_self=include_self)
key = grouping_operation(key, idx_knn) # b, dim, n, n_knn
qk_rel = query.reshape((b, -1, n, 1)) - key
pos_rel = pos.reshape((b, -1, n, 1)) - grouping_operation(pos, idx_knn) # b, 3, n, n_knn
pos_embedding = self.pos_mlp(pos_rel)
attention = self.attn_mlp(qk_rel + pos_embedding) # b, dim, n, n_knn
attention = torch.softmax(attention, -1)
value = value.reshape((b, -1, n, 1)) + pos_embedding #
agg = einsum('b c i j, b c i j -> b c i', attention, value) # b, dim, n
y = self.conv_end(agg)
return y + identity