Skip to content

Commit

Permalink
⚡️Optimized Transformer
Browse files Browse the repository at this point in the history
1) `head_token` is re-introduced
2) users can now control whether use the `head_token` and whether use the `final_attention`
3) optimized some default settings
  • Loading branch information
carefree0910 committed Mar 23, 2021
1 parent 27f182a commit dc6abc4
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 10 deletions.
8 changes: 5 additions & 3 deletions cflearn/modules/extractors/transformer/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@ def get_default(self) -> Dict[str, Any]:
"num_heads": 4,
"num_layers": 3,
"latent_dim": 32,
"dropout": 0.0,
"norm_type": "layer_norm",
"dropout": 0.1,
"norm_type": "batch_norm",
"attention_type": "decayed",
"encoder_type": "basic",
"input_linear_config": {},
"input_linear_config": {"bias": False},
"layer_config": {"latent_dim": 128},
"encoder_config": {},
"use_head_token": True,
"use_final_attention": False,
}


Expand Down
41 changes: 34 additions & 7 deletions cflearn/modules/extractors/transformer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,13 +223,20 @@ def __init__(
input_linear_config: Dict[str, Any],
layer_config: Dict[str, Any],
encoder_config: Dict[str, Any],
use_head_token: bool,
use_final_attention: bool,
):
super().__init__(in_flat_dim, dimensions)
seq_len = dimensions.num_history
# latent projection
self.scaling = float(latent_dim) ** 0.5
self.latent_dim = latent_dim
self.input_linear = Linear(self.in_dim, latent_dim, **input_linear_config)
# head token
if not use_head_token:
self.head_token = None
else:
seq_len += 1
self.head_token = nn.Parameter(torch.randn(1, 1, latent_dim))
# position encoding
self.position_encoding = PositionalEncoding(latent_dim, seq_len, dropout)
# transformer blocks
Expand All @@ -241,27 +248,47 @@ def __init__(
layer = TransformerLayer(latent_dim, num_heads, **layer_config)
encoder_base = TransformerEncoder.get(encoder_type)
self.encoder = encoder_base(layer, num_layers, dimensions, **encoder_config)
self.final_attn_linear = nn.Linear(latent_dim, 1)
if not use_final_attention:
self.final_attn_linear = None
else:
self.final_attn_linear = nn.Linear(latent_dim, 1)

@property
def flatten_ts(self) -> bool:
return False

@property
def out_dim(self) -> int:
return self.latent_dim
if self.head_token is None:
return self.latent_dim
if self.final_attn_linear is None:
return self.latent_dim
return 2 * self.latent_dim

def _aggregate(self, net: Tensor) -> Tensor:
a_hat = self.final_attn_linear(net)
last_token = net[..., -1, :]
if self.final_attn_linear is None:
return last_token
if self.head_token is None:
no_head_token = net
else:
no_head_token = net[..., :-1, :]
a_hat = self.final_attn_linear(no_head_token)
a_prob = F.softmax(a_hat, dim=1)
return torch.sum(a_prob * net, dim=1)
a = torch.sum(a_prob * no_head_token, dim=1)
return torch.cat([a, last_token], 1)

def forward(self, net: Tensor) -> Tensor:
# input -> latent
net = self.input_linear(net)
net = torch.tanh(net)
net = net * self.scaling
# concat head token
if self.head_token is not None:
expanded_token = self.head_token.expand(net.shape[0], 1, self.latent_dim)
net = torch.cat([net, expanded_token], dim=1)
# encode latent vector with transformer
net = self.position_encoding(net)
net = self.encoder(net, None)
# aggregate
return self._aggregate(net)


Expand Down

0 comments on commit dc6abc4

Please sign in to comment.