Skip to content

Commit

Permalink
Merge pull request #38 from ajboyd2/patch-1
Browse files Browse the repository at this point in the history
Corrected SAHP.
  • Loading branch information
iLampard authored Sep 10, 2024
2 parents 940b5a6 + c5d1a2a commit a2e254d
Showing 1 changed file with 19 additions and 24 deletions.
43 changes: 19 additions & 24 deletions easy_tpp/model/torch_model/torch_sahp.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,35 +51,36 @@ def __init__(self, model_config):
if self.use_norm:
self.norm = nn.LayerNorm(self.d_model)

# Equation (12): mu
self.mu = torch.empty([self.d_model, self.num_event_types], device=self.device)
# Equation (13): eta
self.eta = torch.empty([self.d_model, self.num_event_types], device=self.device)
# Equation (14): gamma
self.gamma = torch.empty([self.d_model, self.num_event_types], device=self.device)

nn.init.xavier_normal_(self.mu)
nn.init.xavier_normal_(self.eta)
nn.init.xavier_normal_(self.gamma)

def state_decay(self, encode_state, mu, eta, gamma, duration_t):
# Equation (12): mu = GELU(h*W_mu)
self.mu = nn.Sequential(
nn.Linear(self.d_model, self.num_event_types, bias=False),
nn.GELU(),
)
# Equation (13): eta = GELU(h*W_eta)
self.eta = nn.Sequential(
nn.Linear(self.d_model, self.num_event_types, bias=False),
nn.GELU(),
)
# Equation (14): gamma = Softplus(h*W_gamma)
self.gamma = nn.Sequential(
nn.Linear(self.d_model, self.num_event_types, bias=False),
nn.Softplus(),
)

def state_decay(self, encode_state, duration_t):
"""Equation (15), which computes the pre-intensity states
Args:
encode_state (tensor): [batch_size, seq_len, hidden_size].
mu (tensor): [batch_size, seq_len, hidden_size].
eta (tensor): [batch_size, seq_len, hidden_size].
gamma (tensor): [batch_size, seq_len, hidden_size].
duration_t (tensor): [batch_size, seq_len, num_sample].
Returns:
tensor: hidden states at event times.
"""
mu, eta, gamma = self.mu(encode_state), self.eta(encode_state), self.gamma(encode_state)

# [batch_size, hidden_dim]
states = torch.matmul(encode_state, mu) + (
torch.matmul(encode_state, eta) - torch.matmul(encode_state, mu)) * torch.exp(
-torch.matmul(encode_state, gamma) * torch.clip(duration_t, max=10)) # a temp fix to avoid exploding the exp term
states = mu + (eta - mu) * torch.exp(-gamma * duration_t)
return states

def forward(self, time_seqs, time_delta_seqs, event_seqs, attention_mask):
Expand Down Expand Up @@ -122,9 +123,6 @@ def loglike_loss(self, batch):
enc_out = self.forward(time_seqs[:, :-1], time_delta_seqs[:, 1:], type_seqs[:, :-1], attention_mask[:, 1:, :-1])

cell_t = self.state_decay(encode_state=enc_out,
mu=self.mu[None, ...],
eta=self.eta[None, ...],
gamma=self.gamma[None, ...],
duration_t=time_delta_seqs[:, 1:, None])

# [batch_size, seq_len, num_event_types]
Expand Down Expand Up @@ -166,9 +164,6 @@ def compute_states_at_sample_times(self,
"""

cell_states = self.state_decay(encode_state[:, :, None, :],
self.mu[None, None, ...],
self.eta[None, None, ...],
self.gamma[None, None, ...],
sample_dtimes[:, :, :, None])

return cell_states
Expand Down

0 comments on commit a2e254d

Please sign in to comment.