Skip to content

Commit

Permalink
update docs, fix hidden features handling
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Aug 20, 2024
1 parent 2cece78 commit 6312b75
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 18 deletions.
46 changes: 29 additions & 17 deletions pyknos/mdn/mdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def __init__(
self,
features: int,
context_features: int,
hidden_net: nn.Module,
hidden_net: Optional[nn.Module],
num_components: int,
hidden_features: Optional[int],
hidden_features: Optional[int] = None,
custom_initialization: bool = False,
embedding_net: Optional[nn.Module] = None,
):
Expand All @@ -50,6 +50,8 @@ def __init__(
"""

# Infer hidden_features from hidden_net if not provided.
if hidden_net is None:
hidden_net = nn.Identity()
try:
inferred_hidden_features: int = hidden_net(
torch.randn(1, context_features)
Expand All @@ -61,9 +63,13 @@ def __init__(
) from err

if hidden_features is not None:
msg = """'hidden_features' parameter is deprecated and will be removed in a
future version."""
warnings.warn(msg, DeprecationWarning, stacklevel=2)
warnings.warn(
"'hidden_features' parameter is deprecated and will be removed "
"in a future version. Pass a hidden_net instead and the resulting "
"hidden_features will be inferred from it.",
DeprecationWarning,
stacklevel=2,
)
assert hidden_features == inferred_hidden_features, (
f"hidden_features={hidden_features} does not match inferred value "
f"{inferred_hidden_features} from hidden_net."
Expand Down Expand Up @@ -116,13 +122,17 @@ def get_mixture_components(
"""Return logits, means, precisions and two additional useful quantities.
Args:
context: Input to the MDN, leading dimension is batch dimension.
context: Input to the MDN, leading dimension is batch
(batch_size, context_features).
Returns:
A tuple with logits (num_components), means (num_components x output_dim),
precisions (num_components, output_dim, output_dim), sum log diag of
precision factors (1), precision factors (upper triangular precision factor
A such that SIGMA^-1 = A^T A.) All batched.
A tuple with
logits (batch, num_components)
means (batch, num_components, output_dim),
precisions (batch, num_components, output_dim, output_dim),
sum log diag of precision factors (batch, 1),
precision factors (upper triangular precision factor A such that
SIGMA^-1 = A^T A.) (batch, num_components, output_dim, output_dim)
"""

h = self._hidden_net(context)
Expand Down Expand Up @@ -179,11 +189,11 @@ def log_prob(self, inputs: Tensor, context: Tensor) -> Tensor:
outputs of a neural network.
Args:
inputs: Input variable, leading dim interpreted as batch dimension.
context: Conditioning variable, leading dim interpreted as batch dimension.
inputs: Input variable, (batch, dim_input).
context: Conditioning variable, (batch, dim_context).
Returns:
Log probability of inputs given context under a MoG model.
Log probability of inputs given context under a MoG model, (batch).
"""

logits, means, precisions, sumlogdiag, _ = self.get_mixture_components(context)
Expand All @@ -205,7 +215,7 @@ def log_prob_mog(
parameters are already known.
Args:
inputs: Location at which to evaluate the MoG.
inputs: Location at which to evaluate the MoG (batch_size, dim_input).
logits: Log-weights of each component of the MoG. Shape: (batch_size,
num_components).
means: Means of each MoG, shape (batch_size, num_components, parameter_dim).
Expand All @@ -219,6 +229,7 @@ def log_prob_mog(
"""
batch_size, n_mixtures, output_dim = means.size()
inputs = inputs.view(-1, 1, output_dim)
assert inputs.size(0) == batch_size, "Batch size of inputs does not match."

# Split up evaluation into parts.
a = logits - torch.logsumexp(logits, dim=-1, keepdim=True)
Expand All @@ -243,10 +254,11 @@ def sample(self, num_samples: int, context: Tensor) -> Tensor:
Args:
num_samples: Number of samples to generate.
context: Conditioning variable, leading dimension is batch dimension.
context: Conditioning variable, (batch_size, context_dim).
Returns:
Generated samples: (num_samples, output_dim) with leading batch dimension.
Generated samples: (batch_size, num_samples, output_dim) with leading batch
dimension.
"""

# Get necessary quantities.
Expand Down Expand Up @@ -274,7 +286,7 @@ def sample_mog(
(batch_size, num_components, parameter_dim, parameter_dim).
Returns:
Tensor: Samples from the MoG.
Tensor: Samples from the MoG (batch_size, num_samples, output_dim).
"""
batch_size, _, output_dim = means.shape

Expand Down
1 change: 0 additions & 1 deletion tests/mdn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def get_mdn(
return MultivariateGaussianMDN(
features=features,
context_features=context_features,
hidden_features=hidden_features,
hidden_net=nn.Sequential(
nn.Linear(context_features, hidden_features),
nn.ReLU(),
Expand Down

0 comments on commit 6312b75

Please sign in to comment.