diff --git a/easy_tpp/model/torch_model/torch_sahp.py b/easy_tpp/model/torch_model/torch_sahp.py index 637e2e0..ec02024 100644 --- a/easy_tpp/model/torch_model/torch_sahp.py +++ b/easy_tpp/model/torch_model/torch_sahp.py @@ -51,35 +51,36 @@ def __init__(self, model_config): if self.use_norm: self.norm = nn.LayerNorm(self.d_model) - # Equation (12): mu - self.mu = torch.empty([self.d_model, self.num_event_types], device=self.device) - # Equation (13): eta - self.eta = torch.empty([self.d_model, self.num_event_types], device=self.device) - # Equation (14): gamma - self.gamma = torch.empty([self.d_model, self.num_event_types], device=self.device) - - nn.init.xavier_normal_(self.mu) - nn.init.xavier_normal_(self.eta) - nn.init.xavier_normal_(self.gamma) - - def state_decay(self, encode_state, mu, eta, gamma, duration_t): + # Equation (12): mu = GELU(h*W_mu) + self.mu = nn.Sequential( + nn.Linear(self.d_model, self.num_event_types, bias=False), + nn.GELU(), + ) + # Equation (13): eta = GELU(h*W_eta) + self.eta = nn.Sequential( + nn.Linear(self.d_model, self.num_event_types, bias=False), + nn.GELU(), + ) + # Equation (14): gamma = Softplus(h*W_gamma) + self.gamma = nn.Sequential( + nn.Linear(self.d_model, self.num_event_types, bias=False), + nn.Softplus(), + ) + + def state_decay(self, encode_state, duration_t): """Equation (15), which computes the pre-intensity states Args: encode_state (tensor): [batch_size, seq_len, hidden_size]. - mu (tensor): [batch_size, seq_len, hidden_size]. - eta (tensor): [batch_size, seq_len, hidden_size]. - gamma (tensor): [batch_size, seq_len, hidden_size]. duration_t (tensor): [batch_size, seq_len, num_sample]. Returns: tensor: hidden states at event times. """ + mu, eta, gamma = self.mu(encode_state), self.eta(encode_state), self.gamma(encode_state) # [batch_size, hidden_dim] - states = torch.matmul(encode_state, mu) + ( - torch.matmul(encode_state, eta) - torch.matmul(encode_state, mu)) * torch.exp( - -torch.matmul(encode_state, gamma) * torch.clip(duration_t, max=10)) # a temp fix to avoid exploding the exp term + states = mu + (eta - mu) * torch.exp(-gamma * duration_t) return states def forward(self, time_seqs, time_delta_seqs, event_seqs, attention_mask): @@ -122,9 +123,6 @@ def loglike_loss(self, batch): enc_out = self.forward(time_seqs[:, :-1], time_delta_seqs[:, 1:], type_seqs[:, :-1], attention_mask[:, 1:, :-1]) cell_t = self.state_decay(encode_state=enc_out, - mu=self.mu[None, ...], - eta=self.eta[None, ...], - gamma=self.gamma[None, ...], duration_t=time_delta_seqs[:, 1:, None]) # [batch_size, seq_len, num_event_types] @@ -166,9 +164,6 @@ def compute_states_at_sample_times(self, """ cell_states = self.state_decay(encode_state[:, :, None, :], - self.mu[None, None, ...], - self.eta[None, None, ...], - self.gamma[None, None, ...], sample_dtimes[:, :, :, None]) return cell_states