Open
Description
When turning on spectral normalization in the inference network, training proceeds successfully, but subsequent sampling via approximator.sample() crashes.
Specifically, in TwoMoons_StarterNotebook.ipynb, activating spectral normalization via:
inference_network = bf.networks.FlowMatching(
subnet="mlp",
subnet_kwargs={"widths": (256,)*6 , "dropout": 0.0, "residual": True, "spectral_normalization": True}
)
leads to the following ValueError when calling approximator.sample(conditions={"x": np.array([[0.0, 0.0]]).astype("float32")}, num_samples=100)
:
ValueError: Exception encountered when calling MLP.call().
Input 0 of layer "spectral_normalization" is incompatible with the layer: expected ndim=2, found ndim=3. Full shape received: (1, 100, 5)
During training, the input to the MLP is 2D, whereas during sampling, it is 3D. Since the error stems from the MLP subnet, it arises for multiple architectures (I tested flow matching and coupling flows).
Metadata
Metadata
Assignees
Labels
No labels