Skip to content

Error when using spectral normalization in the inference networks' MLP subnet #276

Open
@elseml

Description

@elseml

When turning on spectral normalization in the inference network, training proceeds successfully, but subsequent sampling via approximator.sample() crashes.
Specifically, in TwoMoons_StarterNotebook.ipynb, activating spectral normalization via:

inference_network = bf.networks.FlowMatching(
    subnet="mlp", 
    subnet_kwargs={"widths": (256,)*6 , "dropout": 0.0, "residual": True, "spectral_normalization": True}
)

leads to the following ValueError when calling approximator.sample(conditions={"x": np.array([[0.0, 0.0]]).astype("float32")}, num_samples=100):

ValueError: Exception encountered when calling MLP.call().

Input 0 of layer "spectral_normalization" is incompatible with the layer: expected ndim=2, found ndim=3. Full shape received: (1, 100, 5)

During training, the input to the MLP is 2D, whereas during sampling, it is 3D. Since the error stems from the MLP subnet, it arises for multiple architectures (I tested flow matching and coupling flows).

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions