Skip to content

Commit

Permalink
Merge pull request #48 from yuxinc17/patch-models
Browse files Browse the repository at this point in the history
Fixed off by one and other issues (affecting models, log-likelihood computation and predictions).
  • Loading branch information
iLampard authored Oct 14, 2024
2 parents d4731dd + f55ed90 commit f1583af
Show file tree
Hide file tree
Showing 15 changed files with 486 additions and 561 deletions.
6 changes: 2 additions & 4 deletions easy_tpp/config_factory/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,9 @@ def __init__(self, **kwargs):
self.time_emb_size = kwargs.get('time_emb_size', 16)
self.num_layers = kwargs.get('num_layers', 2)
self.num_heads = kwargs.get('num_heads', 2)
self.mc_num_sample_per_step = kwargs.get('mc_num_sample_per_step', 20)
self.sharing_param_layer = kwargs.get('sharing_param_layer', False)
self.loss_integral_num_sample_per_step = kwargs.get('loss_integral_num_sample_per_step', 20)
self.use_mc_samples = kwargs.get('use_mc_samples', True) # if using MC samples in computing log-likelihood
self.loss_integral_num_sample_per_step = kwargs.get('loss_integral_num_sample_per_step', 20) # mc_num_sample_per_step
self.dropout_rate = kwargs.get('dropout_rate', 0.0)
self.use_ln = kwargs.get('use_ln', False)
self.thinning = ThinningConfig.parse_from_yaml_config(kwargs.get('thinning'))
Expand All @@ -227,7 +227,6 @@ def get_yaml_config(self):
'hidden_size': self.hidden_size,
'time_emb_size': self.time_emb_size,
'num_layers': self.num_layers,
'mc_num_sample_per_step': self.mc_num_sample_per_step,
'sharing_param_layer': self.sharing_param_layer,
'loss_integral_num_sample_per_step': self.loss_integral_num_sample_per_step,
'dropout_rate': self.dropout_rate,
Expand Down Expand Up @@ -265,7 +264,6 @@ def copy(self):
hidden_size=self.hidden_size,
time_emb_size=self.time_emb_size,
num_layers=self.num_layers,
mc_num_sample_per_step=self.mc_num_sample_per_step,
sharing_param_layer=self.sharing_param_layer,
loss_integral_num_sample_per_step=self.loss_integral_num_sample_per_step,
dropout_rate=self.dropout_rate,
Expand Down
29 changes: 14 additions & 15 deletions easy_tpp/model/torch_model/torch_attnhp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from torch import nn

from easy_tpp.model.torch_model.torch_baselayer import EncoderLayer, MultiHeadAttention
from easy_tpp.model.torch_model.torch_baselayer import EncoderLayer, MultiHeadAttention, ScaledSoftplus
from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel


Expand Down Expand Up @@ -52,7 +52,7 @@ def __init__(self, model_config):
if self.use_norm:
self.norm = nn.LayerNorm(self.d_model)
self.inten_linear = nn.Linear(self.d_model * self.n_head, self.num_event_types)
self.softplus = nn.Softplus()
self.softplus = ScaledSoftplus(self.num_event_types) # learnable mark-specific beta
self.layer_event_emb = nn.Linear(self.d_model + self.d_time, self.d_model)
self.layer_intensity = nn.Sequential(self.inten_linear, self.softplus)
self.eps = torch.finfo(torch.float32).eps
Expand Down Expand Up @@ -151,7 +151,7 @@ def make_layer_mask(self, attention_mask):
a diagonal matrix, [batch_size, seq_len, seq_len]
"""
# [batch_size, seq_len, seq_len]
layer_mask = (torch.eye(attention_mask.size(1)) < 1).unsqueeze(0).expand_as(attention_mask).to(attention_mask.device)
layer_mask = (torch.eye(attention_mask.size(1), device=self.device) < 1).unsqueeze(0).expand_as(attention_mask)
return layer_mask

def make_combined_att_mask(self, attention_mask, layer_mask):
Expand Down Expand Up @@ -205,11 +205,11 @@ def loglike_loss(self, batch):
Returns:
list: loglike loss, num events.
"""
time_seqs, time_delta_seqs, type_seqs, batch_non_pad_mask, attention_mask, type_mask = batch
time_seqs, time_delta_seqs, type_seqs, batch_non_pad_mask, attention_mask = batch
# 1. compute event-loglik
# the prediction of last event has no label, so we proceed to the last but one
# att mask => diag is False, not mask.
enc_out = self.forward(time_seqs[:, :-1], type_seqs[:, :-1], attention_mask[:, 1:, :-1], time_seqs[:, 1:])
enc_out = self.forward(time_seqs[:, :-1], type_seqs[:, :-1], attention_mask[:, :-1, :-1], time_seqs[:, 1:])
# [batch_size, seq_len, num_event_types]
lambda_at_event = self.layer_intensity(enc_out)

Expand All @@ -227,17 +227,16 @@ def loglike_loss(self, batch):
time_delta_seqs[:, :-1], # not used
type_seqs[:, :-1],
sample_times,
attention_mask=attention_mask[:, 1:, :-1])
attention_mask=attention_mask[:, :-1, :-1])

event_ll, non_event_ll, num_events = self.compute_loglikelihood(lambda_at_event=lambda_at_event,
lambdas_loss_samples=lambda_t_sample,
time_delta_seq=time_delta_seqs[:, 1:],
seq_mask=batch_non_pad_mask[:, 1:],
lambda_type_mask=type_mask[:, 1:])
type_seq=type_seqs[:, 1:])

# return enc_inten to compute accuracy
# compute loss to minimize
loss = - (event_ll - non_event_ll).sum()

return loss, num_events

def compute_states_at_sample_times(self,
Expand Down Expand Up @@ -285,7 +284,7 @@ def compute_states_at_sample_times(self,
encoder_output = encoder_output.permute((1, 2, 0, 3))
return encoder_output

def compute_intensities_at_sample_times(self, time_seqs, time_delta_seqs, type_seqs, sample_times, **kwargs):
def compute_intensities_at_sample_times(self, time_seqs, time_delta_seqs, type_seqs, sample_dtimes, **kwargs):
"""Compute the intensity at sampled times.
Args:
Expand All @@ -302,17 +301,17 @@ def compute_intensities_at_sample_times(self, time_seqs, time_delta_seqs, type_s

if attention_mask is None:
batch_size, seq_len = time_seqs.size()
attention_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).unsqueeze(0)
attention_mask = attention_mask.expand(batch_size, -1, -1).to(torch.bool).to(time_seqs.device)
attention_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).unsqueeze(0).to(type_seqs.device)
attention_mask = attention_mask.expand(batch_size, -1, -1).to(torch.bool)

if sample_times.size()[1] < time_seqs.size()[1]:
if sample_dtimes.size()[1] < time_seqs.size()[1]:
# we pass sample_dtimes for last time step here
# we do a temp solution
# [batch_size, seq_len, num_samples]
sample_times = time_seqs[:, :, None] + torch.tile(sample_times, [1, time_seqs.size()[1], 1])
sample_dtimes = time_seqs[:, :, None] + torch.tile(sample_dtimes, [1, time_seqs.size()[1], 1])

# [batch_size, seq_len, num_samples, hidden_size]
encoder_output = self.compute_states_at_sample_times(time_seqs, type_seqs, attention_mask, sample_times)
encoder_output = self.compute_states_at_sample_times(time_seqs, type_seqs, attention_mask, sample_dtimes)

if compute_last_step_only:
lambdas = self.layer_intensity(encoder_output[:, -1:, :, :])
Expand Down
22 changes: 22 additions & 0 deletions easy_tpp/model/torch_model/torch_baselayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,28 @@ def attention(query, key, value, mask=None, dropout=None):
return torch.matmul(p_attn, value), p_attn


class ScaledSoftplus(nn.Module):
'''
Use different beta for mark-specific intensities
'''
def __init__(self, num_marks, threshold=20.):
super(ScaledSoftplus, self).__init__()
self.threshold = threshold
self.log_beta = nn.Parameter(torch.zeros(num_marks), requires_grad=True) # [num_marks]

def forward(self, x):
'''
:param x: [..., num_marks]
'''
beta = self.log_beta.exp()
beta_x = beta * x
return torch.where(
beta_x <= self.threshold,
torch.log1p(beta_x.clamp(max=math.log(1e5)).exp()) / beta,
x, # if above threshold, then the transform is effectively linear
)


class MultiHeadAttention(nn.Module):
def __init__(self, n_head, d_input, d_model, dropout=0.1, output_linear=False):
super(MultiHeadAttention, self).__init__()
Expand Down
82 changes: 45 additions & 37 deletions easy_tpp/model/torch_model/torch_basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
from torch import nn
from torch.nn import functional as F

from easy_tpp.model.torch_model.torch_thinning import EventSampler
from easy_tpp.utils import set_device
Expand Down Expand Up @@ -29,6 +30,7 @@ def __init__(self, model_config):
self.gen_config = model_config.thinning
self.event_sampler = None
self.device = set_device(model_config.gpu)
self.use_mc_samples = model_config.use_mc_samples

self.to(self.device)

Expand Down Expand Up @@ -81,8 +83,7 @@ def get_logits_at_last_step(logits, batch_non_pad_mask, sample_len=None):
last_logits = torch.gather(logits, dim=1, index=select_index).squeeze(1)
return last_logits

def compute_loglikelihood(self, time_delta_seq, lambda_at_event, lambdas_loss_samples, seq_mask,
lambda_type_mask):
def compute_loglikelihood(self, time_delta_seq, lambda_at_event, lambdas_loss_samples, seq_mask, type_seq):
"""Compute the loglikelihood of the event sequence based on Equation (8) of NHP paper.
Args:
Expand All @@ -92,38 +93,35 @@ def compute_loglikelihood(self, time_delta_seq, lambda_at_event, lambdas_loss_sa
lambdas_loss_samples (tensor): [batch_size, seq_len, num_sample, num_event_types],
intensity at sampling times.
seq_mask (tensor): [batch_size, seq_len], sequence mask vector to mask the padded events.
lambda_type_mask (tensor): [batch_size, seq_len, num_event_types], type mask matrix to mask the
padded event types.
type_seq (tensor): [batch_size, seq_len], sequence of mark ids, with padded events having a mark of self.pad_token_id
Returns:
tuple: event loglike, non-event loglike, intensity at event with padding events masked
"""

# Sum of lambda over every type and every event point
# [batch_size, seq_len]
event_lambdas = torch.sum(lambda_at_event * lambda_type_mask, dim=-1) + self.eps

# mask the pad event
event_lambdas = event_lambdas.masked_fill_(~seq_mask, 1.0)

# [batch_size, seq_len)
event_ll = torch.log(event_lambdas)
# First, add an epsilon to every marked intensity for stability
lambda_at_event = lambda_at_event + self.eps
lambdas_loss_samples = lambdas_loss_samples + self.eps

# Compute the big lambda integral in equation (8) of NHP paper
# 1 - take num_mc_sample rand points in each event interval
# 2 - compute its lambda value for every sample point
# 3 - take average of these sample points
# 4 - times the interval length
log_marked_event_lambdas = lambda_at_event.log()
total_sampled_lambdas = lambdas_loss_samples.sum(dim=-1)

# [batch_size, seq_len, n_loss_sample]
lambdas_total_samples = lambdas_loss_samples.sum(dim=-1)
# Compute event LL - [batch_size, seq_len]
event_ll = -F.nll_loss(
log_marked_event_lambdas.permute(0, 2, 1), # mark dimension needs to come second, not third to match nll_loss specs
target=type_seq,
ignore_index=self.pad_token_id, # Padded events have a pad_token_id as a value
reduction='none', # Does not aggregate, and replaces what would have been the log(marked intensity) with 0.
)

# interval_integral - [batch_size, seq_len]
# Compute non-event LL [batch_size, seq_len]
# interval_integral = length_interval * average of sampled lambda(t)
non_event_ll = lambdas_total_samples.mean(dim=-1) * time_delta_seq * seq_mask
if self.use_mc_samples:
non_event_ll = total_sampled_lambdas.mean(dim=-1) * time_delta_seq * seq_mask
else: # Use trapezoid rule
non_event_ll = 0.5 * (total_sampled_lambdas[..., 1:] + total_sampled_lambdas[..., :-1]).mean(dim=-1) * time_delta_seq * seq_mask

num_events = torch.masked_select(event_ll, event_ll.ne(0.0)).size()[0]

return event_ll, non_event_ll, num_events

def make_dtime_loss_samples(self, time_delta_seq):
Expand Down Expand Up @@ -160,37 +158,47 @@ def predict_one_step_at_every_event(self, batch):
Returns:
tuple: tensors of dtime and type prediction, [batch_size, seq_len].
"""
time_seq, time_delta_seq, event_seq, batch_non_pad_mask, _, type_mask = batch
time_seq, time_delta_seq, event_seq, batch_non_pad_mask, _ = batch

# remove the last event, as the prediction based on the last event has no label
# time_delta_seq should start from 1, because the first one is zero
time_seq, time_delta_seq, event_seq = time_seq[:, :-1], time_delta_seq[:, 1:], event_seq[:, :-1]
# note: the first dts is 0
# [batch_size, seq_len]
time_seq, time_delta_seq, event_seq = time_seq[:, :-1], time_delta_seq[:, :-1], event_seq[:, :-1]

# [batch_size, seq_len]
dtime_boundary = time_delta_seq + self.event_sampler.dtime_max
dtime_boundary = torch.max(time_delta_seq * self.event_sampler.dtime_max,
time_delta_seq + self.event_sampler.dtime_max)

# [batch_size, seq_len, num_sample]
accepted_dtimes, weights = self.event_sampler.draw_next_time_one_step(time_seq,
time_delta_seq,
event_seq,
dtime_boundary,
self.compute_intensities_at_sample_times)
self.compute_intensities_at_sample_times,
compute_last_step_only=False) # make it explicit

# [batch_size, seq_len]
dtimes_pred = torch.sum(accepted_dtimes * weights, dim=-1)

# [batch_size, seq_len, 1, event_num]
# We should condition on each accepted time to sample event mark, but not conditioned on the expected event time.
# 1. Use all accepted_dtimes to get intensity.
# [batch_size, seq_len, num_sample, num_marks]
intensities_at_times = self.compute_intensities_at_sample_times(time_seq,
time_delta_seq,
event_seq,
dtimes_pred[:, :, None],
max_steps=event_seq.size()[1])
accepted_dtimes)

# [batch_size, seq_len, event_num]
intensities_at_times = intensities_at_times.squeeze(dim=-2)
# 2. Normalize the intensity over last dim and then compute the weighted sum over the `num_sample` dimension.
# Each of the last dimension is a categorical distribution over all marks.
# [batch_size, seq_len, num_sample, num_marks]
intensities_normalized = intensities_at_times / intensities_at_times.sum(dim=-1, keepdim=True)

types_pred = torch.argmax(intensities_at_times, dim=-1)
# 3. Compute weighted sum of distributions and then take argmax.
# [batch_size, seq_len, num_marks]
intensities_weighted = torch.einsum('...s,...sm->...m', weights, intensities_normalized)

# [batch_size, seq_len]
types_pred = torch.argmax(intensities_weighted, dim=-1)

# [batch_size, seq_len]
dtimes_pred = torch.sum(accepted_dtimes * weights, dim=-1) # compute the expected next event time
return dtimes_pred, types_pred

def predict_multi_step_since_last_event(self, batch, forward=False):
Expand Down
Loading

0 comments on commit f1583af

Please sign in to comment.