Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when using spectral normalization in the inference networks' MLP subnet #276

Open
elseml opened this issue Dec 10, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@elseml
Copy link
Member

elseml commented Dec 10, 2024

When turning on spectral normalization in the inference network, training proceeds successfully, but subsequent sampling via approximator.sample() crashes.
Specifically, in TwoMoons_StarterNotebook.ipynb, activating spectral normalization via:

inference_network = bf.networks.FlowMatching(
    subnet="mlp", 
    subnet_kwargs={"widths": (256,)*6 , "dropout": 0.0, "residual": True, "spectral_normalization": True}
)

leads to the following ValueError when calling approximator.sample(conditions={"x": np.array([[0.0, 0.0]]).astype("float32")}, num_samples=100):

ValueError: Exception encountered when calling MLP.call().

Input 0 of layer "spectral_normalization" is incompatible with the layer: expected ndim=2, found ndim=3. Full shape received: (1, 100, 5)

During training, the input to the MLP is 2D, whereas during sampling, it is 3D. Since the error stems from the MLP subnet, it arises for multiple architectures (I tested flow matching and coupling flows).

@paul-buerkner paul-buerkner added the bug Something isn't working label Dec 11, 2024
@stefanradev93
Copy link
Contributor

It seems that the spectral normalizing wrapper does not properly wrap a dense layer's tensordot capabilities. This needs to be fixed either on the keras side or via a custom implementation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants