diff --git a/pyknos/mdn/mdn.py b/pyknos/mdn/mdn.py index ad15bad..bdf8294 100644 --- a/pyknos/mdn/mdn.py +++ b/pyknos/mdn/mdn.py @@ -30,9 +30,9 @@ def __init__( self, features: int, context_features: int, - hidden_net: nn.Module, + hidden_net: Optional[nn.Module], num_components: int, - hidden_features: Optional[int], + hidden_features: Optional[int] = None, custom_initialization: bool = False, embedding_net: Optional[nn.Module] = None, ): @@ -50,6 +50,8 @@ def __init__( """ # Infer hidden_features from hidden_net if not provided. + if hidden_net is None: + hidden_net = nn.Identity() try: inferred_hidden_features: int = hidden_net( torch.randn(1, context_features) @@ -61,9 +63,13 @@ def __init__( ) from err if hidden_features is not None: - msg = """'hidden_features' parameter is deprecated and will be removed in a - future version.""" - warnings.warn(msg, DeprecationWarning, stacklevel=2) + warnings.warn( + "'hidden_features' parameter is deprecated and will be removed " + "in a future version. Pass a hidden_net instead and the resulting " + "hidden_features will be inferred from it.", + DeprecationWarning, + stacklevel=2, + ) assert hidden_features == inferred_hidden_features, ( f"hidden_features={hidden_features} does not match inferred value " f"{inferred_hidden_features} from hidden_net." @@ -116,13 +122,17 @@ def get_mixture_components( """Return logits, means, precisions and two additional useful quantities. Args: - context: Input to the MDN, leading dimension is batch dimension. + context: Input to the MDN, leading dimension is batch + (batch_size, context_features). Returns: - A tuple with logits (num_components), means (num_components x output_dim), - precisions (num_components, output_dim, output_dim), sum log diag of - precision factors (1), precision factors (upper triangular precision factor - A such that SIGMA^-1 = A^T A.) All batched. + A tuple with + logits (batch, num_components) + means (batch, num_components, output_dim), + precisions (batch, num_components, output_dim, output_dim), + sum log diag of precision factors (batch, 1), + precision factors (upper triangular precision factor A such that + SIGMA^-1 = A^T A.) (batch, num_components, output_dim, output_dim) """ h = self._hidden_net(context) @@ -179,11 +189,11 @@ def log_prob(self, inputs: Tensor, context: Tensor) -> Tensor: outputs of a neural network. Args: - inputs: Input variable, leading dim interpreted as batch dimension. - context: Conditioning variable, leading dim interpreted as batch dimension. + inputs: Input variable, (batch, dim_input). + context: Conditioning variable, (batch, dim_context). Returns: - Log probability of inputs given context under a MoG model. + Log probability of inputs given context under a MoG model, (batch). """ logits, means, precisions, sumlogdiag, _ = self.get_mixture_components(context) @@ -205,7 +215,7 @@ def log_prob_mog( parameters are already known. Args: - inputs: Location at which to evaluate the MoG. + inputs: Location at which to evaluate the MoG (batch_size, dim_input). logits: Log-weights of each component of the MoG. Shape: (batch_size, num_components). means: Means of each MoG, shape (batch_size, num_components, parameter_dim). @@ -219,6 +229,7 @@ def log_prob_mog( """ batch_size, n_mixtures, output_dim = means.size() inputs = inputs.view(-1, 1, output_dim) + assert inputs.size(0) == batch_size, "Batch size of inputs does not match." # Split up evaluation into parts. a = logits - torch.logsumexp(logits, dim=-1, keepdim=True) @@ -243,10 +254,11 @@ def sample(self, num_samples: int, context: Tensor) -> Tensor: Args: num_samples: Number of samples to generate. - context: Conditioning variable, leading dimension is batch dimension. + context: Conditioning variable, (batch_size, context_dim). Returns: - Generated samples: (num_samples, output_dim) with leading batch dimension. + Generated samples: (batch_size, num_samples, output_dim) with leading batch + dimension. """ # Get necessary quantities. @@ -274,7 +286,7 @@ def sample_mog( (batch_size, num_components, parameter_dim, parameter_dim). Returns: - Tensor: Samples from the MoG. + Tensor: Samples from the MoG (batch_size, num_samples, output_dim). """ batch_size, _, output_dim = means.shape diff --git a/tests/mdn_test.py b/tests/mdn_test.py index 09d827c..b50b425 100644 --- a/tests/mdn_test.py +++ b/tests/mdn_test.py @@ -27,7 +27,6 @@ def get_mdn( return MultivariateGaussianMDN( features=features, context_features=context_features, - hidden_features=hidden_features, hidden_net=nn.Sequential( nn.Linear(context_features, hidden_features), nn.ReLU(),