Skip to content

Commit

Permalink
Merge pull request #27 from zefang-liu/fix/thp
Browse files Browse the repository at this point in the history
Fix the temporal encoding for THP (PyTorch only)
  • Loading branch information
iLampard authored May 27, 2024
2 parents 5d46df1 + 57ff7c6 commit 53b7b7f
Showing 1 changed file with 7 additions and 15 deletions.
22 changes: 7 additions & 15 deletions easy_tpp/model/torch_model/torch_baselayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,18 +98,9 @@ class TimePositionalEncoding(nn.Module):

def __init__(self, d_model, max_len=5000, device='cpu'):
super().__init__()

pe = torch.zeros(max_len, d_model, device=device).float()
position = torch.arange(0, max_len, device=device).float().unsqueeze(1)
div_term = (torch.arange(0, d_model, 2, device=device).float() * -(math.log(10000.0) / d_model)).exp()

pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)

# [1, max_len, d_model]
pe = pe.unsqueeze(0)

self.register_buffer('pe', pe)
i = torch.arange(0, d_model, 1, device=device)
div_term = (2 * (i // 2).float() * -(math.log(10000.0) / d_model)).exp()
self.register_buffer('div_term', div_term)

def forward(self, x):
"""Compute time positional encoding defined in Equation (2) in THP model.
Expand All @@ -121,9 +112,10 @@ def forward(self, x):
temporal encoding vector, [batch_size, seq_len, model_dim]
"""
length = x.size(1)

return self.pe[:, :length]
result = x.unsqueeze(-1) * self.div_term
result[:, :, 0::2] = torch.sin(result[:, :, 0::2])
result[:, :, 1::2] = torch.cos(result[:, :, 1::2])
return result


class TimeShiftedPositionalEncoding(nn.Module):
Expand Down

0 comments on commit 53b7b7f

Please sign in to comment.