Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated thinning algo according to NHP's Algorithm 2 #36

Merged
merged 2 commits into from
Aug 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 40 additions & 47 deletions easy_tpp/model/torch_model/torch_thinning.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn as nn
from easy_tpp.utils import logger


class EventSampler(nn.Module):
Expand Down Expand Up @@ -34,6 +35,11 @@ def __init__(self, num_sample, num_exp, over_sample_rate, num_samples_boundary,

def compute_intensity_upper_bound(self, time_seq, time_delta_seq, event_seq, intensity_fn,
compute_last_step_only):
# logger.critical(f'time_seq: {time_seq}')
# logger.critical(f'time_delta_seq: {time_delta_seq}')
# logger.critical(f'event_seq: {event_seq}')
# logger.critical(f'intensity_fn: {intensity_fn}')
# logger.critical(f'compute_last_step_only: {compute_last_step_only}')
"""Compute the upper bound of intensity at each event timestamp.

Args:
Expand All @@ -54,10 +60,10 @@ def compute_intensity_upper_bound(self, time_seq, time_delta_seq, event_seq, int
steps=self.num_samples_boundary,
device=self.device)[None, None, :]

# [batch_size, seq_len, num_sample]
# [batch_size, seq_len, num_samples_boundary]
dtime_for_bound_sampled = time_delta_seq[:, :, None] * time_for_bound_sampled

# [batch_size, seq_len, num_sample, event_num]
# [batch_size, seq_len, num_samples_boundary, event_num]
intensities_for_bound = intensity_fn(time_seq,
time_delta_seq,
event_seq,
Expand Down Expand Up @@ -120,34 +126,46 @@ def sample_uniform_distribution(self, intensity_upper_bound):

return unif_numbers

def sample_accept(self, unif_numbers, sample_rate, total_intensities):
def sample_accept(self, unif_numbers, sample_rate, total_intensities, exp_numbers):
"""Do the sample-accept process.

For each parallel draw, find its min criterion: if that < 1.0, the 1st (i.e. smallest) sampled time
with cri < 1.0 is accepted; if none is accepted, use boundary / maxsampletime for that draw
For the accumulated exp (delta) samples drawn for each event timestamp, find (from left to right) the first
that makes the criterion < 1 and accept it as the sampled next-event time. If all exp samples are rejected
(criterion >= 1), then we set the sampled next-event time dtime_max.

Args:
unif_numbers (tensor): [batch_size, max_len, num_sample, num_exp], sampled uniform random number.
sample_rate (tensor): [batch_size, max_len], sample rate (intensity).
total_intensities (tensor): [batch_size, seq_len, num_sample, num_exp]
exp_numbers (tensor): [batch_size, seq_len, num_sample, num_exp]: sampled exp numbers (delta in Algorithm 2).

Returns:
list: two tensors,
criterion, [batch_size, max_len, num_sample, num_exp]
who_has_accepted_times, [batch_size, max_len, num_sample]
result (tensor): [batch_size, seq_len, num_sample], sampled next-event times.
"""

# [batch_size, max_len, num_sample, num_exp]
criterion = unif_numbers * sample_rate[:, :, None, None] / total_intensities


# [batch_size, max_len, num_sample, num_exp]
masked_crit_less_than_1 = torch.where(criterion<1,1,0)

# [batch_size, max_len, num_sample]
min_cri_each_draw, _ = criterion.min(dim=-1)

# find out unif_numbers * sample_rate < intensity
non_accepted_filter = (1-masked_crit_less_than_1).all(dim=3)

# [batch_size, max_len, num_sample]
who_has_accepted_times = min_cri_each_draw < 1.0

return criterion, who_has_accepted_times
first_accepted_indexer = masked_crit_less_than_1.argmax(dim=3)

# [batch_size, max_len, num_sample,1]
# indexer must be unsqueezed to 4D to match the number of dimensions of exp_numbers
result_non_accepted_unfiltered = torch.gather(exp_numbers, 3, first_accepted_indexer.unsqueeze(3))

# [batch_size, max_len, num_sample,1]
result = torch.where(non_accepted_filter.unsqueeze(3), torch.tensor(self.dtime_max), result_non_accepted_unfiltered)

# [batch_size, max_len, num_sample]
result = result.squeeze(dim=-1)

return result

def draw_next_time_one_step(self, time_seq, time_delta_seq, event_seq, dtime_boundary,
intensity_fn, compute_last_step_only=False):
Expand Down Expand Up @@ -177,7 +195,8 @@ def draw_next_time_one_step(self, time_seq, time_delta_seq, event_seq, dtime_bou
# we apply fast approximation, i.e., re-use exp sample times for computation
# [batch_size, seq_len, num_exp]
exp_numbers = self.sample_exp_distribution(intensity_upper_bound)

exp_numbers = torch.cumsum(exp_numbers, dim=-1)

# 3. compute intensity at sampled times from exp distribution
# [batch_size, seq_len, num_exp, event_num]
intensities_at_sampled_times = intensity_fn(time_seq,
Expand All @@ -193,46 +212,20 @@ def draw_next_time_one_step(self, time_seq, time_delta_seq, event_seq, dtime_bou
# add one dim of num_sample: re-use the intensity for samples for prediction
# [batch_size, seq_len, num_sample, num_exp]
total_intensities = torch.tile(total_intensities[:, :, None, :], [1, 1, self.num_sample, 1])

# [batch_size, seq_len, num_sample, num_exp]
exp_numbers = torch.tile(exp_numbers[:, :, None, :], [1, 1, self.num_sample, 1])

# 4. draw uniform distribution
# [batch_size, seq_len, num_sample, num_exp]
unif_numbers = self.sample_uniform_distribution(intensity_upper_bound)

# 5. find out accepted intensities
# criterion, [batch_size, max_len, num_sample, num_exp]
# who_has_accepted_times, [batch_size, max_len, num_sample]
criterion, who_has_accepted_times = self.sample_accept(unif_numbers, intensity_upper_bound,
total_intensities)

# 6. find out accepted dtimes
sampled_dtimes_accepted = exp_numbers.clone()

# for unaccepted, use boundary/maxsampletime for that draw
sampled_dtimes_accepted[criterion >= 1.0] = exp_numbers.max() + 1.0

accepted_times_each_draw, accepted_id_each_draw = sampled_dtimes_accepted.min(dim=-1)

# 7. fill out result
dtime_boundary_ = dtime_boundary[:, -1:] if compute_last_step_only else dtime_boundary

# [batch_size, seq_len, num_sample]
dtime_boundary_ = torch.tile(dtime_boundary_[..., None], [1, 1, self.num_sample])

# [batch_size, seq_len, num_sample]
res = torch.ones_like(dtime_boundary_) * dtime_boundary_
res = self.sample_accept(unif_numbers, intensity_upper_bound, total_intensities, exp_numbers)

# [batch_size, seq_len, num_sample]
weights = torch.ones_like(dtime_boundary_)
weights /= weights.sum(dim=-1, keepdim=True)

res[who_has_accepted_times] = accepted_times_each_draw[who_has_accepted_times]
who_not_accept = ~who_has_accepted_times

who_reach_further = exp_numbers[..., -1] > dtime_boundary_

res[who_not_accept & who_reach_further] = exp_numbers[..., -1][who_not_accept & who_reach_further]

weights = torch.ones_like(res)/res.shape[2]

# add a upper bound here in case it explodes, e.g., in ODE models
return res.clamp(max=1e5), weights
Loading