Skip to content

Commit

Permalink
Parallelized DiscreteHMM.sample() (#3053)
Browse files Browse the repository at this point in the history
  • Loading branch information
ordabayevy authored Mar 22, 2022
1 parent 8ac6c46 commit 8bc4cd1
Showing 1 changed file with 84 additions and 11 deletions.
95 changes: 84 additions & 11 deletions pyro/distributions/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import torch
import torch.nn.functional as F

from pyro.ops.gamma_gaussian import (
GammaGaussian,
Expand Down Expand Up @@ -82,6 +83,82 @@ def _sequential_logmatmulexp(logits):
return logits.squeeze(-3)


def _markov_index(x, y):
"""
Join ends of two Markov paths.
"""
y = Vindex(y.unsqueeze(-2))[..., x[..., -1:, :]]
return torch.cat([x, y], -2)


def _sequential_index(samples):
"""
For a tensor ``samples`` whose time dimension is -2 and state dimension
is -1, compute Markov paths by sequential indexing.
For example, for ``samples`` with 3 states and time duration 5::
tensor([[0, 1, 1],
[1, 0, 2],
[2, 1, 0],
[0, 2, 1],
[1, 1, 0]])
computed paths are::
tensor([[0, 1, 1],
[1, 0, 0],
[1, 2, 2],
[2, 1, 1],
[0, 1, 1]])
# path for a 0th state
#
# 0 1 1
# |
# 1 0 2
# \
# 2 1 0
# |
# 0 2 1
# \
# 1 1 0
#
# paths for 1st and 2nd states
#
# 0 1 1
# |/
# 1 0 2
# /
# 2 1 0
# \
# \
# 0 2 1
# /
# 1 1 0
"""
# new Markov time dimension at -2
samples = samples.unsqueeze(-2)
batch_shape = samples.shape[:-3]
state_dim = samples.size(-1)
duration = samples.size(-3)
while samples.size(-3) > 1:
time = samples.size(-3)
even_time = time // 2 * 2
even_part = samples[..., :even_time, :, :]
x_y = even_part.reshape(batch_shape + (even_time // 2, 2, -1, state_dim))
x, y = x_y.unbind(-3)
contracted = _markov_index(x, y)
if time > even_time:
padded = F.pad(
input=samples[..., -1:, :, :],
pad=(0, 0, 0, contracted.size(-2) // 2),
)
contracted = torch.cat((contracted, padded), dim=-3)
samples = contracted
return samples.squeeze(-3)[..., :duration, :]


def _sequential_gaussian_tensordot(gaussian):
"""
Integrates a Gaussian ``x`` whose rightmost batch dimension is time, computes::
Expand Down Expand Up @@ -276,7 +353,7 @@ class DiscreteHMM(HiddenMarkovModel):
distribution.
This uses [1] to parallelize over time, achieving O(log(time)) parallel
complexity for computing :meth:`log_prob` and :meth:`filter`.
complexity for computing :meth:`log_prob`, :meth:`filter`, and :meth:`sample`.
The event_shape of this distribution includes time on the left::
Expand All @@ -292,10 +369,6 @@ class DiscreteHMM(HiddenMarkovModel):
# homogeneous + homogeneous case:
event_shape = (1,) + observation_dist.event_shape
The :meth:`sample` method is sequential (not parallized), slow, and memory
inefficient. It is intended for data generation only and is not recommended
during inference.
**References:**
[1] Simo Sarkka, Angel F. Garcia-Fernandez (2019)
Expand Down Expand Up @@ -441,13 +514,13 @@ def sample(self, sample_shape=torch.Size()):
x = Categorical(logits=init_logits).sample()

# Sample hidden states over time.
trans_shape = self.batch_shape + (self.duration, S, S)
trans_shape = (
torch.Size(sample_shape) + self.batch_shape + (self.duration, S, S)
)
trans_logits = self.transition_logits.expand(trans_shape)
xs = []
for t in range(self.duration):
x = Categorical(logits=Vindex(trans_logits)[..., t, x, :]).sample()
xs.append(x)
x = torch.stack(xs, dim=-1)
xs = Categorical(logits=trans_logits).sample()
xs = _sequential_index(xs)
x = Vindex(xs)[..., :, x]

# Sample observations conditioned on hidden states.
# Note the simple sample-then-slice approach here generalizes to all
Expand Down

0 comments on commit 8bc4cd1

Please sign in to comment.